From ed78fdaf9b08551599bc79736729f7cf3fc6eb06 Mon Sep 17 00:00:00 2001 From: Matthew Maurer Date: Mon, 25 Oct 2021 13:04:44 -0700 Subject: DoH: Modularize main event loop * Connection now provides HTTP/3. * Network has the logic for resolving DNS and maintaining a Connection. * Dispatcher routes requests to the appropriate Network or creates one if needed. * IO and maintenance is performed via tasks rather than manually pushing the futures in the main event loop. Bug: 202081046 Test: resolv_integration_test Test: resolv_stress_test + I682678b84b35c575a3eb88c2c1c67aefd195616c Change-Id: I4296d0c7a7852951f41418b18686794d8df781bd --- Android.bp | 1 + doh/config.rs | 13 +- doh/connection/driver.rs | 369 ++++++++++++++++ doh/connection/mod.rs | 150 +++++++ doh/dispatcher/driver.rs | 127 ++++++ doh/dispatcher/mod.rs | 110 +++++ doh/doh.rs | 1041 +--------------------------------------------- doh/ffi.rs | 18 +- doh/network/driver.rs | 190 +++++++++ doh/network/mod.rs | 112 +++++ 10 files changed, 1078 insertions(+), 1053 deletions(-) create mode 100644 doh/connection/driver.rs create mode 100644 doh/connection/mod.rs create mode 100644 doh/dispatcher/driver.rs create mode 100644 doh/dispatcher/mod.rs create mode 100644 doh/network/driver.rs create mode 100644 doh/network/mod.rs diff --git a/Android.bp b/Android.bp index ce9a3dcb..823af852 100644 --- a/Android.bp +++ b/Android.bp @@ -334,6 +334,7 @@ doh_rust_deps = [ "liblibc", "liblog_rust", "libring", + "libthiserror", "libtokio", "liburl", ] diff --git a/doh/config.rs b/doh/config.rs index 04d07c50..91284052 100644 --- a/doh/config.rs +++ b/doh/config.rs @@ -27,7 +27,8 @@ use quiche::{h3, Result}; use std::collections::HashMap; use std::ops::DerefMut; -use std::sync::{Arc, Mutex, RwLock, Weak}; +use std::sync::{Arc, RwLock, Weak}; +use tokio::sync::Mutex; type WeakConfig = Weak>; @@ -80,8 +81,8 @@ impl Config { /// Take the underlying config, usable as `&mut quiche::Config` for use /// with `quiche::connect`. - pub fn take(&mut self) -> impl DerefMut + '_ { - self.0.lock().unwrap() + pub async fn take(&mut self) -> impl DerefMut + '_ { + self.0.lock().await } } @@ -229,11 +230,11 @@ fn lifetimes() { assert_eq!(cache.state.read().unwrap().path_to_config.len(), 2); } -#[test] -fn quiche_connect() { +#[tokio::test] +async fn quiche_connect() { use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; let mut config = Config::from_cert_path(None).unwrap(); let socket_addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 42)); let conn_id = quiche::ConnectionId::from_ref(&[]); - quiche::connect(None, &conn_id, socket_addr, &mut config.take()).unwrap(); + quiche::connect(None, &conn_id, socket_addr, config.take().await.deref_mut()).unwrap(); } diff --git a/doh/connection/driver.rs b/doh/connection/driver.rs new file mode 100644 index 00000000..4fd1e266 --- /dev/null +++ b/doh/connection/driver.rs @@ -0,0 +1,369 @@ +/* +* Copyright (C) 2021 The Android Open Source Project +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +//! Defines a backing task to keep a HTTP/3 connection running + +use crate::boot_time; +use crate::boot_time::BootTime; +use log::warn; +use quiche::h3; +use std::collections::HashMap; +use std::default::Default; +use std::future; +use std::io; +use std::pin::Pin; +use thiserror::Error; +use tokio::net::UdpSocket; +use tokio::select; +use tokio::sync::{mpsc, oneshot}; + +#[derive(Error, Debug)] +pub enum Error { + #[error("network IO error: {0}")] + Network(#[from] io::Error), + #[error("QUIC error: {0}")] + Quic(#[from] quiche::Error), + #[error("HTTP/3 error: {0}")] + H3(#[from] h3::Error), + #[error("Response delivery error: {0}")] + StreamSend(#[from] mpsc::error::SendError), + #[error("Connection closed")] + Closed, +} + +pub type Result = std::result::Result; + +#[derive(Debug)] +/// HTTP/3 Request to be sent on the connection +pub struct Request { + /// Request headers + pub headers: Vec, + /// Expiry time for the request, relative to `CLOCK_BOOTTIME` + pub expiry: Option, + /// Channel to send the response to + pub response_tx: oneshot::Sender, +} + +#[derive(Debug)] +/// HTTP/3 Response +pub struct Stream { + /// Response headers + pub headers: Vec, + /// Response body + pub data: Vec, + /// Error code if stream was reset + pub error: Option, +} + +impl Stream { + fn new(headers: Vec) -> Self { + Self { headers, data: Vec::new(), error: None } + } +} + +const MAX_UDP_PACKET_SIZE: usize = 65536; + +struct Driver { + request_rx: mpsc::Receiver, + quiche_conn: Pin>, + socket: UdpSocket, + // This buffer is large, boxing it will keep it + // off the stack and prevent it being copied during + // moves of the driver. + buffer: Box<[u8; MAX_UDP_PACKET_SIZE]>, +} + +struct H3Driver { + driver: Driver, + // h3_conn sometimes can't "fit" a request in its available windows. + // This value holds a peeked request in that case, waiting for + // transmission to become possible. + buffered_request: Option, + // We can't check if a receiver is dead without potentially receiving a message, and if we poll + // on a dead receiver in a select! it will immediately return None. As a result, we need this + // to gate whether or not to include .recv() in our select! + closing: bool, + h3_conn: h3::Connection, + requests: HashMap, + streams: HashMap, +} + +async fn optional_timeout(timeout: Option) { + match timeout { + Some(timeout) => boot_time::sleep(timeout).await, + None => future::pending().await, + } +} + +/// Creates a future which when polled will handle events related to a HTTP/3 connection. +/// The returned error code will explain why the connection terminated. +pub async fn drive( + request_rx: mpsc::Receiver, + quiche_conn: Pin>, + socket: UdpSocket, +) -> Result<()> { + Driver::new(request_rx, quiche_conn, socket).drive().await +} + +impl Driver { + fn new( + request_rx: mpsc::Receiver, + quiche_conn: Pin>, + socket: UdpSocket, + ) -> Self { + Self { request_rx, quiche_conn, socket, buffer: Box::new([0; MAX_UDP_PACKET_SIZE]) } + } + + async fn drive(mut self) -> Result<()> { + // Prime connection + self.flush_tx().await?; + loop { + self = self.drive_once().await? + } + } + + fn handle_closed(&self) -> Result<()> { + if self.quiche_conn.is_closed() { + Err(Error::Closed) + } else { + Ok(()) + } + } + + async fn drive_once(mut self) -> Result { + let timer = optional_timeout(self.quiche_conn.timeout()); + select! { + // If a quiche timer would fire, call their callback + _ = timer => self.quiche_conn.on_timeout(), + // If we got packets from our peer, pass them to quiche + Ok((size, from)) = self.socket.recv_from(self.buffer.as_mut()) => { + self.quiche_conn.recv(&mut self.buffer[..size], quiche::RecvInfo { from })?; + } + }; + // Any of the actions in the select could require us to send packets to the peer + self.flush_tx().await?; + + // If the QUIC connection is live, but the HTTP/3 is not, try to bring it up + if self.quiche_conn.is_established() { + let h3_config = h3::Config::new()?; + let h3_conn = h3::Connection::with_transport(&mut self.quiche_conn, &h3_config)?; + return H3Driver::new(self, h3_conn).drive().await; + } + + // If the connection has closed, tear down + self.handle_closed()?; + + Ok(self) + } + + async fn flush_tx(&mut self) -> Result<()> { + let send_buf = self.buffer.as_mut(); + loop { + match self.quiche_conn.send(send_buf) { + Err(quiche::Error::Done) => return Ok(()), + Err(e) => return Err(e.into()), + Ok((valid_len, send_info)) => { + self.socket.send_to(&send_buf[..valid_len], send_info.to).await?; + } + } + } + } +} + +impl H3Driver { + fn new(driver: Driver, h3_conn: h3::Connection) -> Self { + Self { + driver, + h3_conn, + closing: false, + requests: HashMap::new(), + streams: HashMap::new(), + buffered_request: None, + } + } + + async fn drive(mut self) -> Result { + loop { + self.drive_once().await?; + } + } + + async fn drive_once(&mut self) -> Result<()> { + // We can't call self.driver.drive_once at the same time as + // self.driver.request_rx.recv() due to ownership + let timer = optional_timeout(self.driver.quiche_conn.timeout()); + // If we've buffered a request (due to the connection being full) + // try to resend that first + if let Some(request) = self.buffered_request.take() { + self.handle_request(request)?; + } + select! { + // Only attempt to enqueue new requests if we have no buffered request and aren't + // closing + msg = self.driver.request_rx.recv(), if !self.closing && self.buffered_request.is_none() => match msg { + Some(request) => self.handle_request(request)?, + None => self.shutdown(true, b"DONE").await?, + }, + // If a quiche timer would fire, call their callback + _ = timer => self.driver.quiche_conn.on_timeout(), + // If we got packets from our peer, pass them to quiche + Ok((size, from)) = self.driver.socket.recv_from(self.driver.buffer.as_mut()) => { + self.driver.quiche_conn.recv(&mut self.driver.buffer[..size], quiche::RecvInfo { from })?; + } + }; + + // Any of the actions in the select could require us to send packets to the peer + self.driver.flush_tx().await?; + + // Process any incoming HTTP/3 events + self.flush_h3().await?; + + // If the connection has closed, tear down + self.driver.handle_closed() + } + + fn handle_request(&mut self, request: Request) -> Result<()> { + // If the request has already timed out, don't issue it to the server. + if let Some(expiry) = request.expiry { + if BootTime::now() > expiry { + return Ok(()); + } + } + let stream_id = + // If h3_conn says the stream is blocked, this error is recoverable just by trying + // again once the stream has made progress. Buffer the request for a later retry. + match self.h3_conn.send_request(&mut self.driver.quiche_conn, &request.headers, true) { + Err(h3::Error::StreamBlocked) | Err(h3::Error::TransportError(quiche::Error::StreamLimit)) => { + // We only call handle_request on a value that has just come out of + // buffered_request, or when buffered_request is empty. This assert just + // validates that we don't break that assumption later, as it could result in + // requests being dropped on the floor under high load. + assert!(self.buffered_request.is_none()); + self.buffered_request = Some(request); + return Ok(()) + } + result => result?, + }; + self.requests.insert(stream_id, request); + Ok(()) + } + + async fn recv_body(&mut self, stream_id: u64) -> Result<()> { + const STREAM_READ_CHUNK: usize = 4096; + if let Some(stream) = self.streams.get_mut(&stream_id) { + loop { + let base_len = stream.data.len(); + stream.data.resize(base_len + STREAM_READ_CHUNK, 0); + match self.h3_conn.recv_body( + &mut self.driver.quiche_conn, + stream_id, + &mut stream.data[base_len..], + ) { + Err(h3::Error::Done) => { + stream.data.truncate(base_len); + return Ok(()); + } + Err(e) => { + stream.data.truncate(base_len); + return Err(e.into()); + } + Ok(recvd) => stream.data.truncate(base_len + recvd), + } + } + } else { + warn!("Received body for untracked stream ID {}", stream_id); + } + Ok(()) + } + + fn discard_datagram(&mut self, _flow_id: u64) -> Result<()> { + loop { + match self.h3_conn.recv_dgram(&mut self.driver.quiche_conn, self.driver.buffer.as_mut()) + { + Err(h3::Error::Done) => return Ok(()), + Err(e) => return Err(e.into()), + _ => (), + } + } + } + + async fn flush_h3(&mut self) -> Result<()> { + loop { + match self.h3_conn.poll(&mut self.driver.quiche_conn) { + Err(h3::Error::Done) => return Ok(()), + Err(e) => return Err(e.into()), + Ok((stream_id, event)) => self.process_h3_event(stream_id, event).await?, + } + } + } + + async fn process_h3_event(&mut self, stream_id: u64, event: h3::Event) -> Result<()> { + if !self.requests.contains_key(&stream_id) { + warn!("Received event {:?} for stream_id {} without a request.", event, stream_id); + } + match event { + h3::Event::Headers { list, has_body } => { + let stream = Stream::new(list); + if self.streams.insert(stream_id, stream).is_some() { + warn!("Re-using stream ID {} before it was completed.", stream_id) + } + if !has_body { + self.respond(stream_id); + } + } + h3::Event::Data => { + self.recv_body(stream_id).await?; + } + h3::Event::Finished => self.respond(stream_id), + // This clause is for quiche 0.10.x, we're still on 0.9.x + //h3::Event::Reset(e) => { + // self.streams.get_mut(&stream_id).map(|stream| stream.error = Some(e)); + // self.respond(stream_id); + //} + h3::Event::Datagram => { + warn!("Unexpected Datagram received"); + // We don't care if something went wrong with the datagram, we didn't + // want it anyways. + let _ = self.discard_datagram(stream_id); + } + h3::Event::GoAway => self.shutdown(false, b"SERVER GOAWAY").await?, + } + Ok(()) + } + + async fn shutdown(&mut self, send_goaway: bool, msg: &[u8]) -> Result<()> { + self.driver.request_rx.close(); + while self.driver.request_rx.recv().await.is_some() {} + self.closing = true; + if send_goaway { + self.h3_conn.send_goaway(&mut self.driver.quiche_conn, 0)?; + } + if self.driver.quiche_conn.close(true, 0, msg).is_err() { + warn!("Trying to close already closed QUIC connection"); + } + Ok(()) + } + + fn respond(&mut self, stream_id: u64) { + match (self.streams.remove(&stream_id), self.requests.remove(&stream_id)) { + (Some(stream), Some(request)) => { + // We don't care about the error, because it means the requestor has left. + let _ = request.response_tx.send(stream); + } + (None, _) => warn!("Tried to deliver untracked stream {}", stream_id), + (_, None) => warn!("Tried to deliver stream {} to untracked requestor", stream_id), + } + } +} diff --git a/doh/connection/mod.rs b/doh/connection/mod.rs new file mode 100644 index 00000000..bc5b75c5 --- /dev/null +++ b/doh/connection/mod.rs @@ -0,0 +1,150 @@ +/* +* Copyright (C) 2021 The Android Open Source Project +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +//! Module providing an async abstraction around a quiche HTTP/3 connection + +use crate::boot_time::BootTime; +use crate::network::SocketTagger; +use log::error; +use quiche::h3; +use std::future::Future; +use std::io; +use std::net::SocketAddr; +use thiserror::Error; +use tokio::net::UdpSocket; +use tokio::sync::mpsc; +use tokio::sync::oneshot; +use tokio::task; + +mod driver; + +pub use driver::Stream; +use driver::{drive, Request}; + +/// Quiche HTTP/3 connection +pub struct Connection { + request_tx: mpsc::Sender, +} + +fn new_scid() -> [u8; quiche::MAX_CONN_ID_LEN] { + use ring::rand::{SecureRandom, SystemRandom}; + let mut scid = [0; quiche::MAX_CONN_ID_LEN]; + SystemRandom::new().fill(&mut scid).unwrap(); + scid +} + +fn mark_socket(socket: &std::net::UdpSocket, socket_mark: u32) -> io::Result<()> { + use std::os::unix::io::AsRawFd; + let fd = socket.as_raw_fd(); + // libc::setsockopt is a wrapper function calling into bionic setsockopt. + // The only pointer being passed in is &socket_mark, which is valid by virtue of being a + // reference, and the foreign function doesn't take ownership or a reference to that memory + // after completion. + if unsafe { + libc::setsockopt( + fd, + libc::SOL_SOCKET, + libc::SO_MARK, + &socket_mark as *const _ as *const libc::c_void, + std::mem::size_of::() as libc::socklen_t, + ) + } == 0 + { + Ok(()) + } else { + Err(io::Error::last_os_error()) + } +} + +async fn build_socket( + peer_addr: SocketAddr, + socket_mark: u32, + tag_socket: &SocketTagger, +) -> io::Result { + let bind_addr = match peer_addr { + SocketAddr::V4(_) => "0.0.0.0:0", + SocketAddr::V6(_) => "[::]:0", + }; + + let socket = UdpSocket::bind(bind_addr).await?; + let std_socket = socket.into_std()?; + mark_socket(&std_socket, socket_mark) + .unwrap_or_else(|e| error!("Unable to mark socket : {:?}", e)); + tag_socket(&std_socket).await; + let socket = UdpSocket::from_std(std_socket)?; + socket.connect(peer_addr).await?; + Ok(socket) +} + +/// Error type for HTTP/3 connection +#[derive(Debug, Error)] +pub enum Error { + /// QUIC protocol error + #[error("QUIC error: {0}")] + Quic(#[from] quiche::Error), + /// HTTP/3 protocol error + #[error("HTTP/3 error: {0}")] + H3(#[from] h3::Error), + /// Unable to send the request to the driver. This likely means the + /// backing task has died. + #[error("Unable to send request")] + SendRequest(#[from] mpsc::error::SendError), + /// IO failed. This is most likely to occur while trying to set up the + /// UDP socket for use by the connection. + #[error("IO error: {0}")] + Io(#[from] io::Error), + /// The request is no longer being serviced. This could mean that the + /// request was dropped for an unspecified reason, or that the connection + /// was closed prematurely and it can no longer be serviced. + #[error("Driver dropped request")] + RecvResponse(#[from] oneshot::error::RecvError), +} + +/// Common result type for working with a HTTP/3 connection +pub type Result = std::result::Result; + +impl Connection { + const MAX_PENDING_REQUESTS: usize = 10; + /// Create a new connection with a background task handling IO. + pub async fn new( + server_name: Option<&str>, + to: SocketAddr, + socket_mark: u32, + tag_socket: &SocketTagger, + config: &mut quiche::Config, + ) -> Result { + let (request_tx, request_rx) = mpsc::channel(Self::MAX_PENDING_REQUESTS); + let scid = new_scid(); + let quiche_conn = + quiche::connect(server_name, &quiche::ConnectionId::from_ref(&scid), to, config)?; + let socket = build_socket(to, socket_mark, tag_socket).await?; + let driver = drive(request_rx, quiche_conn, socket); + task::spawn(driver); + Ok(Self { request_tx }) + } + + /// Send a query, produce a future which will provide a response. + /// The future is separately returned rather than awaited to allow it to be waited on without + /// keeping the `Connection` itself borrowed. + pub async fn query( + &self, + headers: Vec, + expiry: Option, + ) -> Result>> { + let (response_tx, response_rx) = oneshot::channel(); + self.request_tx.send(Request { headers, response_tx, expiry }).await?; + Ok(async move { response_rx.await.ok() }) + } +} diff --git a/doh/dispatcher/driver.rs b/doh/dispatcher/driver.rs new file mode 100644 index 00000000..75f456b6 --- /dev/null +++ b/doh/dispatcher/driver.rs @@ -0,0 +1,127 @@ +/* + * Copyright (C) 2021 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +//! Provides a backing task to implement a Dispatcher + +use crate::boot_time::{BootTime, Duration}; +use anyhow::{bail, Result}; +use log::{debug, trace, warn}; +use std::collections::HashMap; +use tokio::sync::{mpsc, oneshot}; + +use super::{Command, QueryError, Response}; +use crate::network::{Network, ServerInfo, SocketTagger, ValidationReporter}; +use crate::{config, network}; + +pub struct Driver { + command_rx: mpsc::Receiver, + networks: HashMap, + validation: ValidationReporter, + tagger: SocketTagger, + config_cache: config::Cache, +} + +fn debug_err(r: Result<()>) { + if let Err(e) = r { + debug!("Dispatcher loop got {:?}", e); + } +} + +impl Driver { + pub fn new( + command_rx: mpsc::Receiver, + validation: ValidationReporter, + tagger: SocketTagger, + ) -> Self { + Self { + command_rx, + networks: HashMap::new(), + validation, + tagger, + config_cache: config::Cache::new(), + } + } + + pub async fn drive(mut self) -> Result<()> { + loop { + self.drive_once().await? + } + } + + async fn drive_once(&mut self) -> Result<()> { + if let Some(command) = self.command_rx.recv().await { + trace!("dispatch command: {:?}", command); + match command { + Command::Probe { info, timeout } => debug_err(self.probe(info, timeout).await), + Command::Query { net_id, base64_query, expired_time, resp } => { + debug_err(self.query(net_id, base64_query, expired_time, resp).await) + } + Command::Clear { net_id } => { + self.networks.remove(&net_id); + self.config_cache.garbage_collect(); + } + Command::Exit => { + bail!("Death due to Exit") + } + } + Ok(()) + } else { + bail!("Death due to command_tx dying") + } + } + + async fn query( + &mut self, + net_id: u32, + query: String, + expiry: BootTime, + response: oneshot::Sender, + ) -> Result<()> { + if let Some(network) = self.networks.get_mut(&net_id) { + network.query(network::Query { query, response, expiry }).await?; + } else { + warn!("Tried to send a query to non-existent network net_id={}", net_id); + response.send(Response::Error { error: QueryError::Unexpected }).unwrap_or_else(|_| { + warn!("Unable to send reply for non-existent network net_id={}", net_id); + }) + } + Ok(()) + } + + async fn probe(&mut self, info: ServerInfo, timeout: Duration) -> Result<()> { + use std::collections::hash_map::Entry; + if !self.networks.get(&info.net_id).map_or(true, |net| net.get_info() == &info) { + // If we have a network registered to the provided net_id, but the server info doesn't + // match, our API has been used incorrectly. Attempt to recover by deleting the old + // network and recreating it according to the probe request. + warn!("Probing net_id={} with mismatched server info", info.net_id); + self.networks.remove(&info.net_id); + } + // Can't use or_insert_with because creating a network may fail + let net = match self.networks.entry(info.net_id) { + Entry::Occupied(network) => network.into_mut(), + Entry::Vacant(vacant) => { + let config = self.config_cache.from_cert_path(&info.cert_path)?; + vacant.insert( + Network::new(info, config, self.validation.clone(), self.tagger.clone()) + .await?, + ) + } + }; + net.probe(timeout).await?; + Ok(()) + } +} diff --git a/doh/dispatcher/mod.rs b/doh/dispatcher/mod.rs new file mode 100644 index 00000000..66e2f3d5 --- /dev/null +++ b/doh/dispatcher/mod.rs @@ -0,0 +1,110 @@ +/* + * Copyright (C) 2021 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +use crate::boot_time::{BootTime, Duration}; +use anyhow::Result; +use log::error; +use tokio::runtime::{Builder, Runtime}; +use tokio::sync::{mpsc, oneshot}; +use tokio::task; + +pub use crate::network::{ServerInfo, SocketTagger, ValidationReporter}; + +const MAX_BUFFERED_CMD_COUNT: usize = 400; + +mod driver; +use driver::Driver; + +#[derive(Eq, PartialEq, Debug)] +/// Error response to a query +pub enum QueryError { + /// Network failed probing + BrokenServer, + /// HTTP/3 connection died + ConnectionError, + /// Network not probed yet + ServerNotReady, + /// Server reset HTTP/3 stream + Reset(u64), + /// Tried to query non-existent network + Unexpected, +} + +#[derive(Eq, PartialEq, Debug)] +pub enum Response { + Error { error: QueryError }, + Success { answer: Vec }, +} + +#[derive(Debug)] +pub enum Command { + Probe { + info: ServerInfo, + timeout: Duration, + }, + Query { + net_id: u32, + base64_query: String, + expired_time: BootTime, + resp: oneshot::Sender, + }, + Clear { + net_id: u32, + }, + Exit, +} + +/// Context for a running DoH engine. +pub struct Dispatcher { + /// Used to submit cmds to the I/O task. + cmd_sender: mpsc::Sender, + join_handle: task::JoinHandle>, + runtime: Runtime, +} + +impl Dispatcher { + const DOH_THREADS: usize = 2; + + pub fn new(validation: ValidationReporter, tagger: SocketTagger) -> Result { + let (cmd_sender, cmd_receiver) = mpsc::channel::(MAX_BUFFERED_CMD_COUNT); + let runtime = Builder::new_multi_thread() + .worker_threads(Self::DOH_THREADS) + .enable_all() + .thread_name("doh-handler") + .build()?; + let join_handle = runtime.spawn(async { + let result = Driver::new(cmd_receiver, validation, tagger).drive().await; + match result { + Err(ref e) => error!("Dispatcher driver exited due to {:?}", e), + Ok(()) => (), + } + result + }); + Ok(Dispatcher { cmd_sender, join_handle, runtime }) + } + + pub fn send_cmd(&self, cmd: Command) -> Result<()> { + self.cmd_sender.blocking_send(cmd)?; + Ok(()) + } + + pub fn exit_handler(&mut self) { + if self.cmd_sender.blocking_send(Command::Exit).is_err() { + return; + } + let _ = self.runtime.block_on(&mut self.join_handle); + } +} diff --git a/doh/doh.rs b/doh/doh.rs index 0e55d422..5ceb5e81 100644 --- a/doh/doh.rs +++ b/doh/doh.rs @@ -16,1045 +16,10 @@ //! DoH backend for the Android DnsResolver module. -use anyhow::{anyhow, bail, Context, Result}; -use futures::future::{join_all, BoxFuture}; -use futures::stream::FuturesUnordered; -use futures::StreamExt; -use log::{debug, error, info, trace, warn}; -use quiche::h3; -use ring::rand::SecureRandom; -use std::collections::HashMap; -use std::net::SocketAddr; -use std::os::unix::io::{AsRawFd, RawFd}; -use std::pin::Pin; -use std::sync::Arc; -use tokio::net::UdpSocket; -use tokio::runtime::{Builder, Runtime}; -use tokio::sync::{mpsc, oneshot}; -use tokio::task; -use url::Url; - pub mod boot_time; mod config; +mod connection; +mod dispatcher; mod encoding; mod ffi; - -use boot_time::{timeout, BootTime, Duration}; -use config::Config; - -const MAX_BUFFERED_CMD_SIZE: usize = 400; -const DOH_PORT: u16 = 443; - -type ValidationReporter = Box BoxFuture<()> + Send + Sync>; -type SocketTagger = Arc BoxFuture<()> + Send + Sync>; - -type SCID = [u8; quiche::MAX_CONN_ID_LEN]; -type Base64Query = String; -type CmdSender = mpsc::Sender; -type CmdReceiver = mpsc::Receiver; -type QueryResponder = oneshot::Sender; -type DnsRequest = Vec; - -#[derive(Eq, PartialEq, Debug)] -enum QueryError { - BrokenServer, - ConnectionError, - ServerNotReady, - Unexpected, -} - -#[derive(Eq, PartialEq, Debug, Clone)] -struct ServerInfo { - net_id: u32, - url: Url, - peer_addr: SocketAddr, - domain: Option, - sk_mark: u32, - cert_path: Option, -} - -#[derive(Eq, PartialEq, Debug)] -enum Response { - Error { error: QueryError }, - Success { answer: Vec }, -} - -#[derive(Debug)] -enum DohCommand { - Probe { info: ServerInfo, timeout: Duration }, - Query { net_id: u32, base64_query: Base64Query, expired_time: BootTime, resp: QueryResponder }, - Clear { net_id: u32 }, - Exit, -} - -#[allow(clippy::large_enum_variant)] -enum ConnectionState { - Idle, - Connecting { - quic_conn: Option>>, - udp_sk: Option, - expired_time: Option, - }, - Connected { - quic_conn: Pin>, - udp_sk: UdpSocket, - h3_conn: Option, - query_map: HashMap, QueryResponder)>, - expired_time: Option, - }, - /// Indicate that the Connection can't be used due to - /// network or unexpected reasons. - Error, -} - -impl ConnectionState { - fn is_connected(&self) -> bool { - matches!(*self, Self::Connected { .. }) - } - fn is_error(&self) -> bool { - matches!(*self, Self::Error) - } -} - -enum H3Result { - Data { data: Vec }, - Finished, - Ignore, -} - -/// Context for a running DoH engine. -pub struct DohDispatcher { - /// Used to submit cmds to the I/O task. - cmd_sender: CmdSender, - join_handle: task::JoinHandle>, - runtime: Runtime, -} - -// DoH dispatcher -impl DohDispatcher { - fn new(validation: ValidationReporter, tag_socket: SocketTagger) -> Result { - let (cmd_sender, cmd_receiver) = mpsc::channel::(MAX_BUFFERED_CMD_SIZE); - let runtime = Builder::new_multi_thread() - .worker_threads(2) - .enable_all() - .thread_name("doh-handler") - .build() - .expect("Failed to create tokio runtime"); - let join_handle = runtime.spawn(doh_handler(cmd_receiver, validation, tag_socket)); - Ok(DohDispatcher { cmd_sender, join_handle, runtime }) - } - - fn send_cmd(&self, cmd: DohCommand) -> Result<()> { - self.cmd_sender.blocking_send(cmd)?; - Ok(()) - } - - fn exit_handler(&mut self) { - if self.cmd_sender.blocking_send(DohCommand::Exit).is_err() { - return; - } - let _ = self.runtime.block_on(&mut self.join_handle); - } -} - -struct DohConnection { - info: ServerInfo, - config: Config, - scid: SCID, - state: ConnectionState, - pending_queries: Vec<(DnsRequest, QueryResponder, BootTime)>, - cached_session: Option>, - tag_socket: SocketTagger, -} - -impl DohConnection { - fn new(info: &ServerInfo, config: Config, tag_socket: SocketTagger) -> Result { - let mut scid = [0; quiche::MAX_CONN_ID_LEN]; - ring::rand::SystemRandom::new().fill(&mut scid).context("failed to generate scid")?; - Ok(DohConnection { - info: info.clone(), - config, - scid, - state: ConnectionState::Idle, - pending_queries: Vec::new(), - cached_session: None, - tag_socket, - }) - } - - async fn state_to_connecting(&mut self) -> Result<()> { - if self.state.is_error() { - self.state_to_idle(); - } - self.state = match self.state { - ConnectionState::Idle => { - let udp_sk_std = make_doh_udp_socket(self.info.peer_addr, self.info.sk_mark)?; - (self.tag_socket)(&udp_sk_std).await; - let udp_sk = UdpSocket::from_std(udp_sk_std)?; - let connid = quiche::ConnectionId::from_ref(&self.scid); - debug!("init the connection for Network {}", self.info.net_id); - let mut quic_conn = quiche::connect( - self.info.domain.as_deref(), - &connid, - self.info.peer_addr, - &mut self.config.take(), - )?; - if let Some(session) = &self.cached_session { - if quic_conn.set_session(session).is_err() { - warn!("can't restore session for network {}", self.info.net_id); - } - } - ConnectionState::Connecting { - quic_conn: Some(quic_conn), - udp_sk: Some(udp_sk), - expired_time: None, - } - } - ConnectionState::Error => panic!("state_to_idle did not transition"), - ConnectionState::Connecting { .. } => return Ok(()), - ConnectionState::Connected { .. } => { - panic!("Invalid state transition to Connecting state!") - } - }; - Ok(()) - } - - fn state_to_connected(&mut self) -> Result<()> { - self.state = match &mut self.state { - // Only Connecting -> Connected is valid. - ConnectionState::Connecting { quic_conn, udp_sk, .. } => { - if let (Some(mut quic_conn), Some(udp_sk)) = (quic_conn.take(), udp_sk.take()) { - let h3_config = h3::Config::new()?; - let h3_conn = - quiche::h3::Connection::with_transport(&mut quic_conn, &h3_config)?; - ConnectionState::Connected { - quic_conn, - udp_sk, - h3_conn: Some(h3_conn), - query_map: HashMap::new(), - expired_time: None, - } - } else { - bail!("state transition fail!"); - } - } - // The rest should fail. - _ => panic!("Invalid state transition to Connected state!"), - }; - Ok(()) - } - - fn state_to_idle(&mut self) { - self.state = match self.state { - // Only either Connected or Error -> Idle is valid. - // TODO: Error -> Idle is the re-probing case, add the relevant statistic. - ConnectionState::Connected { .. } | ConnectionState::Error => ConnectionState::Idle, - // The rest should fail. - _ => panic!("Invalid state transition to Idle state!"), - } - } - - fn state_to_error(&mut self) { - self.pending_queries.clear(); - self.state = ConnectionState::Error - } - - fn is_reprobe_required(&self) -> bool { - matches!(self.state, ConnectionState::Error) - } - - fn has_not_handled_queries(&self) -> bool { - match &self.state { - ConnectionState::Connecting { .. } | ConnectionState::Idle => { - !self.pending_queries.is_empty() - } - ConnectionState::Connected { query_map, .. } => { - !query_map.is_empty() || !self.pending_queries.is_empty() - } - _ => false, - } - } - - fn handle_if_connection_expired(&mut self) { - let expired_time = match &mut self.state { - ConnectionState::Connecting { expired_time, .. } => expired_time, - ConnectionState::Connected { expired_time, .. } => expired_time, - // ignore - _ => return, - }; - - if let Some(expired_time) = expired_time { - if let Some(elapsed) = BootTime::now().checked_duration_since(*expired_time) { - warn!( - "Change the state to Idle due to connection timeout, {:?}, {}", - elapsed, self.info.net_id - ); - self.state_to_idle(); - } - } - } - - async fn probe(&mut self, t: Duration) -> Result<()> { - match timeout(t, async { - self.try_connect().await?; - info!("probe start for {}", self.info.net_id); - if let ConnectionState::Connected { quic_conn, udp_sk, h3_conn, expired_time, .. } = - &mut self.state - { - let h3_conn = h3_conn.as_mut().ok_or_else(|| anyhow!("h3 conn isn't available"))?; - let req = match encoding::probe_query() { - Ok(q) => match encoding::dns_request(&q, &self.info.url) { - Ok(req) => req, - Err(e) => bail!(e), - }, - Err(e) => bail!(e), - }; - // Send the probe query. - let req_id = h3_conn.send_request(quic_conn, &req, true /*fin*/)?; - loop { - flush_tx(quic_conn, udp_sk).await?; - recv_rx(quic_conn, udp_sk, expired_time).await?; - loop { - match recv_h3(quic_conn, h3_conn) { - Ok((stream_id, H3Result::Finished)) => { - if stream_id == req_id { - return Ok(()); - } - } - // TODO: Verify the answer - Ok((_stream_id, H3Result::Data { .. })) => {} - Ok((_stream_id, H3Result::Ignore)) => {} - Err(_) => break, - } - } - } - } else { - bail!("state error while performing probe()"); - } - }) - .await - { - Ok(v) => match v { - Ok(_) => Ok(()), - Err(e) => { - self.state_to_error(); - bail!(e); - } - }, - Err(e) => { - self.state_to_error(); - bail!(e); - } - } - } - - async fn try_connect(&mut self) -> Result<()> { - if matches!(self.state, ConnectionState::Connected { .. }) { - return Ok(()); - } - self.state_to_connecting().await?; - debug!("connecting to Network {}", self.info.net_id); - - let (quic_conn, udp_sk, expired_time) = match &mut self.state { - ConnectionState::Connecting { quic_conn, udp_sk, expired_time, .. } => { - if let (Some(quic_conn), Some(udp_sk)) = (quic_conn.as_mut(), udp_sk.as_mut()) { - (quic_conn, udp_sk, expired_time) - } else { - bail!("unexpected error while performing connect()"); - } - } - _ => bail!("state error while performing try_connect()"), - }; - - while !quic_conn.is_established() { - flush_tx(quic_conn, udp_sk).await?; - recv_rx(quic_conn, udp_sk, expired_time).await?; - } - self.cached_session = quic_conn.session(); - self.state_to_connected()?; - info!("connected to Network {}", self.info.net_id); - Ok(()) - } - - async fn try_send_doh_query( - &mut self, - req: DnsRequest, - resp: QueryResponder, - expired_time: BootTime, - ) -> Result<()> { - self.handle_if_connection_expired(); - match &mut self.state { - ConnectionState::Connected { quic_conn, udp_sk, h3_conn, query_map, .. } => { - let h3_conn = h3_conn.as_mut().ok_or_else(|| anyhow!("h3 conn isn't available"))?; - send_dns_query( - quic_conn, - udp_sk, - h3_conn, - query_map, - &mut self.pending_queries, - resp, - expired_time, - req, - ) - .await? - } - ConnectionState::Connecting { .. } | ConnectionState::Idle => { - self.pending_queries.push((req, resp, expired_time)) - } - ConnectionState::Error => { - error!( - "state is error while performing try_send_doh_query(), network: {}", - self.info.net_id - ); - let _ = resp.send(Response::Error { error: QueryError::BrokenServer }); - } - } - Ok(()) - } - - async fn process_queries(&mut self) -> Result<()> { - debug!("process_queries entry, Network {}", self.info.net_id); - self.try_connect().await?; - if let ConnectionState::Connected { quic_conn, udp_sk, h3_conn, query_map, expired_time } = - &mut self.state - { - let h3_conn = h3_conn.as_mut().ok_or_else(|| anyhow!("h3 conn isn't available"))?; - loop { - while !self.pending_queries.is_empty() { - if let Some((req, resp, exp_time)) = self.pending_queries.pop() { - // Ignore the expired queries. - if BootTime::now().checked_duration_since(exp_time).is_some() { - warn!("Drop the obsolete query for network {}", self.info.net_id); - continue; - } - send_dns_query( - quic_conn, - udp_sk, - h3_conn, - query_map, - &mut self.pending_queries, - resp, - exp_time, - req, - ) - .await?; - } - } - flush_tx(quic_conn, udp_sk).await?; - recv_rx(quic_conn, udp_sk, expired_time).await?; - loop { - match recv_h3(quic_conn, h3_conn) { - Ok((stream_id, H3Result::Data { mut data })) => { - if let Some((answer, _)) = query_map.get_mut(&stream_id) { - answer.append(&mut data); - } else { - // Should not happen - warn!("No associated receiver found while receiving Data, Network {}, stream id: {}", self.info.net_id, stream_id); - } - } - Ok((stream_id, H3Result::Finished)) => { - if let Some((answer, resp)) = query_map.remove(&stream_id) { - debug!( - "sending answer back to resolv, Network {}, stream id: {}", - self.info.net_id, stream_id - ); - resp.send(Response::Success { answer }).unwrap_or_else(|e| { - trace!( - "the receiver dropped {:?}, stream id: {}", - e, - stream_id - ); - }); - } else { - // Should not happen - warn!("No associated receiver found while receiving Finished, Network {}, stream id: {}", self.info.net_id, stream_id); - } - } - Ok((_stream_id, H3Result::Ignore)) => {} - Err(_) => break, - } - } - if quic_conn.is_closed() || !quic_conn.is_established() { - self.state_to_idle(); - bail!("connection become idle"); - } - } - } else { - self.state_to_error(); - bail!("state error while performing process_queries(), network: {}", self.info.net_id); - } - } -} - -fn recv_h3( - quic_conn: &mut Pin>, - h3_conn: &mut h3::Connection, -) -> Result<(u64, H3Result)> { - match h3_conn.poll(quic_conn) { - // Process HTTP/3 events. - Ok((stream_id, quiche::h3::Event::Data)) => { - debug!("quiche::h3::Event::Data"); - let mut buf = vec![0; config::MAX_DATAGRAM_SIZE]; - match h3_conn.recv_body(quic_conn, stream_id, &mut buf) { - Ok(read) => { - trace!( - "got {} bytes of response data on stream {}: {:x?}", - read, - stream_id, - &buf[..read] - ); - buf.truncate(read); - Ok((stream_id, H3Result::Data { data: buf })) - } - Err(e) => { - warn!("recv_h3::recv_body {:?}", e); - bail!(e); - } - } - } - Ok((stream_id, quiche::h3::Event::Headers { list, has_body })) => { - trace!( - "got response headers {:?} on stream id {} has_body {}", - list, - stream_id, - has_body - ); - Ok((stream_id, H3Result::Ignore)) - } - Ok((stream_id, quiche::h3::Event::Finished)) => { - debug!("quiche::h3::Event::Finished on stream id {}", stream_id); - Ok((stream_id, H3Result::Finished)) - } - Ok((stream_id, quiche::h3::Event::Datagram)) => { - debug!("quiche::h3::Event::Datagram on stream id {}", stream_id); - Ok((stream_id, H3Result::Ignore)) - } - // TODO: Check if it's necessary to handle GoAway event. - Ok((stream_id, quiche::h3::Event::GoAway)) => { - debug!("quiche::h3::Event::GoAway on stream id {}", stream_id); - Ok((stream_id, H3Result::Ignore)) - } - Err(e) => { - debug!("recv_h3 {:?}", e); - bail!(e); - } - } -} - -#[allow(clippy::too_many_arguments)] -async fn send_dns_query( - quic_conn: &mut Pin>, - udp_sk: &mut UdpSocket, - h3_conn: &mut h3::Connection, - query_map: &mut HashMap, QueryResponder)>, - pending_queries: &mut Vec<(DnsRequest, QueryResponder, BootTime)>, - resp: QueryResponder, - expired_time: BootTime, - req: DnsRequest, -) -> Result<()> { - if !quic_conn.is_established() { - bail!("quic connection is not ready"); - } - match h3_conn.send_request(quic_conn, &req, true /*fin*/) { - Ok(stream_id) => { - query_map.insert(stream_id, (Vec::new(), resp)); - flush_tx(quic_conn, udp_sk).await?; - debug!("send dns query successfully stream id: {}", stream_id); - Ok(()) - } - Err(quiche::h3::Error::StreamBlocked) => { - warn!("try to send query but error on StreamBlocked"); - pending_queries.push((req, resp, expired_time)); - Ok(()) - } - Err(e) => { - resp.send(Response::Error { error: QueryError::ConnectionError }).ok(); - bail!(e); - } - } -} - -async fn recv_rx( - quic_conn: &mut Pin>, - udp_sk: &mut UdpSocket, - expired_time: &mut Option, -) -> Result<()> { - // TODO: Evaluate if we could make the buffer smaller. - let mut buf = [0; 65535]; - let quic_idle_timeout_ms = Duration::from_millis(config::QUICHE_IDLE_TIMEOUT_MS); - let ts = quic_conn.timeout().unwrap_or(quic_idle_timeout_ms); - - if let Some(next_expired) = BootTime::now().checked_add(quic_idle_timeout_ms) { - expired_time.replace(next_expired); - } else { - expired_time.take(); - } - debug!("recv_rx entry next timeout {:?} {:?}", ts, expired_time); - match timeout(ts, udp_sk.recv_from(&mut buf)).await { - Ok(v) => match v { - Ok((size, from)) => { - let recv_info = quiche::RecvInfo { from }; - let processed = match quic_conn.recv(&mut buf[..size], recv_info) { - Ok(l) => l, - Err(e) => { - debug!("recv_rx error {:?}", e); - bail!("quic recv failed: {:?}", e); - } - }; - debug!("processed {} bytes", processed); - Ok(()) - } - Err(e) => bail!("socket recv failed: {:?}", e), - }, - Err(_) => { - warn!("timeout did not receive value within {:?}", ts); - quic_conn.on_timeout(); - Ok(()) - } - } -} - -async fn flush_tx( - quic_conn: &mut Pin>, - udp_sk: &mut UdpSocket, -) -> Result<()> { - let mut out = [0; config::MAX_DATAGRAM_SIZE]; - loop { - let (write, _) = match quic_conn.send(&mut out) { - Ok(v) => v, - Err(quiche::Error::Done) => { - debug!("done writing"); - break; - } - Err(e) => { - quic_conn.close(false, 0x1, b"fail").ok(); - bail!(e); - } - }; - udp_sk.send(&out[..write]).await?; - debug!("written {}", write); - } - Ok(()) -} - -async fn handle_probe_result( - result: (ServerInfo, Result), - doh_conn_map: &mut HashMap)>, - validation: &ValidationReporter, -) { - let (info, doh_conn) = match result { - (info, Ok(doh_conn)) => { - info!("probing_task success on net_id: {}", info.net_id); - (info, doh_conn) - } - (info, Err((e, doh_conn))) => { - error!("probe failed on network {}, {:?}", e, info.net_id); - (info, doh_conn) - // TODO: Retry probe? - } - }; - // If the network is removed or the server is replaced before probing, - // ignore the probe result. - match doh_conn_map.get(&info.net_id) { - Some((server_info, _)) => { - if *server_info != info { - warn!( - "The previous configuration for network {} was replaced before probe finished", - info.net_id - ); - return; - } - } - _ => { - warn!("network {} was removed before probe finished", info.net_id); - return; - } - } - validation(&info, doh_conn.state.is_connected()).await; - doh_conn_map.insert(info.net_id, (info, Some(doh_conn))); -} - -async fn probe_task( - info: ServerInfo, - mut doh: DohConnection, - t: Duration, -) -> (ServerInfo, Result) { - match doh.probe(t).await { - Ok(_) => (info, Ok(doh)), - Err(e) => (info, Err((anyhow!(e), doh))), - } -} - -fn make_connection_if_needed( - info: &ServerInfo, - doh_conn_map: &mut HashMap)>, - config_cache: &config::Cache, - tag_socket: SocketTagger, -) -> Result> { - // Check if connection exists. - match doh_conn_map.get(&info.net_id) { - // The connection exists but has failed. Re-probe. - Some((server_info, Some(doh))) if *server_info == *info && doh.is_reprobe_required() => { - let (_, doh) = doh_conn_map - .insert(info.net_id, (info.clone(), None)) - .ok_or_else(|| anyhow!("unexpected error, missing connection"))?; - return Ok(doh); - } - // The connection exists or the connection is under probing, ignore. - Some((server_info, _)) if *server_info == *info => return Ok(None), - // TODO: change the inner connection instead of removing? - _ => doh_conn_map.remove(&info.net_id), - }; - let config = config_cache.from_cert_path(&info.cert_path)?; - let doh = DohConnection::new(info, config, tag_socket)?; - doh_conn_map.insert(info.net_id, (info.clone(), None)); - Ok(Some(doh)) -} - -async fn handle_query_cmd( - net_id: u32, - base64_query: Base64Query, - expired_time: BootTime, - resp: QueryResponder, - doh_conn_map: &mut HashMap)>, -) { - if let Some((info, quic_conn)) = doh_conn_map.get_mut(&net_id) { - match (&info.domain, quic_conn) { - // Connection is not ready, strict mode - (Some(_), None) => { - let _ = resp.send(Response::Error { error: QueryError::ServerNotReady }); - } - // Connection is not ready, Opportunistic mode - (None, None) => { - let _ = resp.send(Response::Error { error: QueryError::ServerNotReady }); - } - // Connection is ready - (_, Some(quic_conn)) => { - if let Ok(req) = encoding::dns_request(&base64_query, &info.url) { - let _ = quic_conn.try_send_doh_query(req, resp, expired_time).await; - } else { - let _ = resp.send(Response::Error { error: QueryError::Unexpected }); - } - } - } - } else { - error!("No connection is associated with the given net id {}", net_id); - let _ = resp.send(Response::Error { error: QueryError::ServerNotReady }); - } -} -fn need_process_queries(doh_conn_map: &HashMap)>) -> bool { - if doh_conn_map.is_empty() { - return false; - } - for (_, doh_conn) in doh_conn_map.values() { - if let Some(doh_conn) = doh_conn { - if doh_conn.has_not_handled_queries() { - return true; - } - } - } - false -} - -async fn doh_handler( - mut cmd_rx: CmdReceiver, - validation: ValidationReporter, - tag_socket: SocketTagger, -) -> Result<()> { - info!("doh_dispatcher entry"); - let config_cache = config::Cache::new(); - - // Currently, only support 1 server per network. - let mut doh_conn_map: HashMap)> = HashMap::new(); - let mut probe_futures = FuturesUnordered::new(); - loop { - tokio::select! { - _ = async { - let mut futures = vec![]; - for (_, doh_conn) in doh_conn_map.values_mut() { - if let Some(doh_conn) = doh_conn { - futures.push(doh_conn.process_queries()); - } - } - join_all(futures).await - }, if need_process_queries(&doh_conn_map) => {}, - Some(result) = probe_futures.next() => { - handle_probe_result(result, &mut doh_conn_map, &validation).await; - info!("probe_futures remaining size: {}", probe_futures.len()); - }, - Some(cmd) = cmd_rx.recv() => { - trace!("recv {:?}", cmd); - match cmd { - DohCommand::Probe { info, timeout: t } => { - match make_connection_if_needed(&info, &mut doh_conn_map, &config_cache, tag_socket.clone()) { - Ok(Some(doh)) => { - // Create a new async task associated to the DoH connection. - probe_futures.push(probe_task(info, doh, t)); - debug!("probe_futures size: {}", probe_futures.len()); - } - Ok(None) => { - // No further probe is needed. - warn!("connection for network {} already exists", info.net_id); - // TODO: Report the status again? - } - Err(e) => { - error!("create connection for network {} error {:?}", info.net_id, e); - validation(&info, false).await - } - } - }, - DohCommand::Query { net_id, base64_query, expired_time, resp } => { - handle_query_cmd(net_id, base64_query, expired_time, resp, &mut doh_conn_map).await; - }, - DohCommand::Clear { net_id } => { - doh_conn_map.remove(&net_id); - info!("Doh Clear server for netid: {}", net_id); - config_cache.garbage_collect(); - }, - DohCommand::Exit => return Ok(()), - } - } - } - } -} - -fn make_doh_udp_socket(peer_addr: SocketAddr, mark: u32) -> Result { - let bind_addr = match peer_addr { - std::net::SocketAddr::V4(_) => "0.0.0.0:0", - std::net::SocketAddr::V6(_) => "[::]:0", - }; - let udp_sk = std::net::UdpSocket::bind(bind_addr)?; - udp_sk.set_nonblocking(true)?; - if mark_socket(udp_sk.as_raw_fd(), mark).is_err() { - warn!("Mark socket failed, is it a test?"); - } - udp_sk.connect(peer_addr)?; - - trace!("connecting to {:} from {:}", peer_addr, udp_sk.local_addr()?); - Ok(udp_sk) -} - -fn mark_socket(fd: RawFd, mark: u32) -> Result<()> { - // libc::setsockopt is a wrapper function calling into bionic setsockopt. - // Both fd and mark are valid, which makes the function call mostly safe. - if unsafe { - libc::setsockopt( - fd, - libc::SOL_SOCKET, - libc::SO_MARK, - &mark as *const _ as *const libc::c_void, - std::mem::size_of::() as libc::socklen_t, - ) - } == 0 - { - Ok(()) - } else { - Err(anyhow::Error::new(std::io::Error::last_os_error())) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use futures::FutureExt; - use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; - - const TEST_NET_ID: u32 = 50; - const TEST_MARK: u32 = 0xD0033; - const LOOPBACK_ADDR: &str = "127.0.0.1:443"; - const LOCALHOST_URL: &str = "https://mylocal.com/dns-query"; - - // TODO: Make some tests for DohConnection. - - fn make_testing_variables( - ) -> (ServerInfo, HashMap)>, config::Cache, Runtime) - { - let test_map: HashMap)> = HashMap::new(); - let info = ServerInfo { - net_id: TEST_NET_ID, - url: Url::parse(LOCALHOST_URL).unwrap(), - peer_addr: LOOPBACK_ADDR.parse().unwrap(), - domain: None, - sk_mark: 0, - cert_path: None, - }; - let config_cache = config::Cache::new(); - let rt = Builder::new_current_thread() - .thread_name("test-runtime") - .enable_all() - .build() - .expect("Failed to create testing tokio runtime"); - (info, test_map, config_cache, rt) - } - - fn build_socket_tagger() -> SocketTagger { - Arc::new(|_| async {}.boxed()) - } - - #[test] - fn make_connection_if_needed() { - let (info, mut test_map, config_cache, rt) = make_testing_variables(); - rt.block_on(async { - // Expect to make a new connection. - let mut doh = super::make_connection_if_needed( - &info, - &mut test_map, - &config_cache, - build_socket_tagger(), - ) - .unwrap() - .unwrap(); - assert_eq!(doh.info.net_id, info.net_id); - assert!(matches!(doh.state, ConnectionState::Idle)); - doh.state = ConnectionState::Error; - test_map.insert(info.net_id, (info.clone(), Some(doh))); - // Expect that we will get a connection with fail status that we added to the map before. - let mut doh = super::make_connection_if_needed( - &info, - &mut test_map, - &config_cache, - build_socket_tagger(), - ) - .unwrap() - .unwrap(); - assert_eq!(doh.info.net_id, info.net_id); - assert!(matches!(doh.state, ConnectionState::Error)); - doh.state = make_dummy_connected_state(); - test_map.insert(info.net_id, (info.clone(), Some(doh))); - // Expect that we will get None because the map contains a connection with ready status. - assert!(super::make_connection_if_needed( - &info, - &mut test_map, - &config_cache, - build_socket_tagger() - ) - .unwrap() - .is_none()); - }); - } - - #[test] - fn handle_query_cmd() { - let (info, mut test_map, config_cache, rt) = make_testing_variables(); - let t = Duration::from_millis(100); - rt.block_on(async { - // Test no available server cases. - let (resp_tx, resp_rx) = oneshot::channel(); - let query = encoding::probe_query().unwrap(); - super::handle_query_cmd( - info.net_id, - query.clone(), - BootTime::now().checked_add(t).unwrap(), - resp_tx, - &mut test_map, - ) - .await; - assert_eq!( - timeout(t, resp_rx).await.unwrap().unwrap(), - Response::Error { error: QueryError::ServerNotReady } - ); - - let (resp_tx, resp_rx) = oneshot::channel(); - test_map.insert(info.net_id, (info.clone(), None)); - super::handle_query_cmd( - info.net_id, - query.clone(), - BootTime::now().checked_add(t).unwrap(), - resp_tx, - &mut test_map, - ) - .await; - assert_eq!( - timeout(t, resp_rx).await.unwrap().unwrap(), - Response::Error { error: QueryError::ServerNotReady } - ); - - // Test the connection broken case. - test_map.clear(); - let (resp_tx, resp_rx) = oneshot::channel(); - let mut doh = super::make_connection_if_needed( - &info, - &mut test_map, - &config_cache, - build_socket_tagger(), - ) - .unwrap() - .unwrap(); - doh.state = ConnectionState::Error; - test_map.insert(info.net_id, (info.clone(), Some(doh))); - super::handle_query_cmd( - info.net_id, - query.clone(), - BootTime::now().checked_add(t).unwrap(), - resp_tx, - &mut test_map, - ) - .await; - assert_eq!( - timeout(t, resp_rx).await.unwrap().unwrap(), - Response::Error { error: QueryError::BrokenServer } - ); - }); - } - - fn make_testing_connection_variables() -> (Pin>, UdpSocket) { - let sk = super::make_doh_udp_socket(LOOPBACK_ADDR.parse().unwrap(), TEST_MARK).unwrap(); - let udp_sk = UdpSocket::from_std(sk).unwrap(); - let mut scid = [0; quiche::MAX_CONN_ID_LEN]; - ring::rand::SystemRandom::new().fill(&mut scid).context("failed to generate scid").unwrap(); - let connid = quiche::ConnectionId::from_ref(&scid); - let mut config = Config::from_cert_path(None).unwrap(); - let quic_conn = - quiche::connect(None, &connid, LOOPBACK_ADDR.parse().unwrap(), &mut config.take()) - .unwrap(); - (quic_conn, udp_sk) - } - - fn make_dummy_connected_state() -> super::ConnectionState { - let (quic_conn, udp_sk) = make_testing_connection_variables(); - ConnectionState::Connected { - quic_conn, - udp_sk, - h3_conn: None, - query_map: HashMap::new(), - expired_time: None, - } - } - - #[test] - fn make_doh_udp_socket() { - // Make a socket connecting to loopback with a test mark. - let sk = super::make_doh_udp_socket(LOOPBACK_ADDR.parse().unwrap(), TEST_MARK).unwrap(); - // Check if the socket is connected to loopback. - assert_eq!( - sk.peer_addr().unwrap(), - SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), DOH_PORT)) - ); - - // Check if the socket mark is correct. - let fd: RawFd = sk.as_raw_fd(); - - let mut mark: u32 = 50; - let mut size = std::mem::size_of::() as libc::socklen_t; - unsafe { - // Safety: It's fine since the fd belongs to this test. - assert_eq!( - libc::getsockopt( - fd, - libc::SOL_SOCKET, - libc::SO_MARK, - &mut mark as *mut _ as *mut libc::c_void, - &mut size as *mut libc::socklen_t, - ), - 0 - ); - } - assert_eq!(mark, TEST_MARK); - - // Check if the socket is non-blocking. - unsafe { - // Safety: It's fine since the fd belongs to this test. - assert_eq!(libc::fcntl(fd, libc::F_GETFL, 0) & libc::O_NONBLOCK, libc::O_NONBLOCK); - } - } -} +mod network; diff --git a/doh/ffi.rs b/doh/ffi.rs index ce2f5a64..2df5be68 100644 --- a/doh/ffi.rs +++ b/doh/ffi.rs @@ -17,6 +17,8 @@ //! C API for the DoH backend for the Android DnsResolver module. use crate::boot_time::{timeout, BootTime, Duration}; +use crate::dispatcher::{Command, Dispatcher, Response, ServerInfo}; +use crate::network::{SocketTagger, ValidationReporter}; use futures::FutureExt; use libc::{c_char, int32_t, size_t, ssize_t, uint32_t, uint64_t}; use log::{error, warn}; @@ -25,22 +27,19 @@ use std::net::{IpAddr, SocketAddr}; use std::ops::DerefMut; use std::os::unix::io::RawFd; use std::str::FromStr; -use std::sync::Mutex; +use std::sync::{Arc, Mutex}; use std::{ptr, slice}; use tokio::runtime::Runtime; use tokio::sync::oneshot; use tokio::task; use url::Url; -use super::DohDispatcher as Dispatcher; -use super::{DohCommand, Response, ServerInfo, SocketTagger, ValidationReporter, DOH_PORT}; - pub type ValidationCallback = extern "C" fn(net_id: uint32_t, success: bool, ip_addr: *const c_char, host: *const c_char); pub type TagSocketCallback = extern "C" fn(sock: RawFd); fn wrap_validation_callback(validation_fn: ValidationCallback) -> ValidationReporter { - Box::new(move |info: &ServerInfo, success: bool| { + Arc::new(move |info: &ServerInfo, success: bool| { async move { let (ip_addr, domain) = match ( CString::new(info.peer_addr.ip().to_string()), @@ -65,7 +64,6 @@ fn wrap_validation_callback(validation_fn: ValidationCallback) -> ValidationRepo fn wrap_tag_socket_callback(tag_socket_fn: TagSocketCallback) -> SocketTagger { use std::os::unix::io::AsRawFd; - use std::sync::Arc; Arc::new(move |udp_socket: &std::net::UdpSocket| { let fd = udp_socket.as_raw_fd(); async move { @@ -107,6 +105,8 @@ pub const DOH_LOG_LEVEL_DEBUG: u32 = 3; /// The trace log level. pub const DOH_LOG_LEVEL_TRACE: u32 = 4; +const DOH_PORT: u16 = 443; + fn level_from_u32(level: u32) -> Option { use log::Level::*; match level { @@ -216,7 +216,7 @@ pub unsafe extern "C" fn doh_net_new( return -libc::EINVAL; } }; - let cmd = DohCommand::Probe { + let cmd = Command::Probe { info: ServerInfo { net_id, url, @@ -257,7 +257,7 @@ pub unsafe extern "C" fn doh_query( let (resp_tx, resp_rx) = oneshot::channel(); let t = Duration::from_millis(timeout_ms); if let Some(expired_time) = BootTime::now().checked_add(t) { - let cmd = DohCommand::Query { + let cmd = Command::Query { net_id, base64_query: base64::encode_config(q, base64::URL_SAFE_NO_PAD), expired_time, @@ -309,7 +309,7 @@ pub unsafe extern "C" fn doh_query( /// and not yet deleted by `doh_dispatcher_delete()`. #[no_mangle] pub extern "C" fn doh_net_delete(doh: &DohDispatcher, net_id: uint32_t) { - if let Err(e) = doh.lock().send_cmd(DohCommand::Clear { net_id }) { + if let Err(e) = doh.lock().send_cmd(Command::Clear { net_id }) { error!("Failed to send the query: {:?}", e); } } diff --git a/doh/network/driver.rs b/doh/network/driver.rs new file mode 100644 index 00000000..6c17f35e --- /dev/null +++ b/doh/network/driver.rs @@ -0,0 +1,190 @@ +/* + * Copyright (C) 2021 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +//! Provides a backing task to implement a network + +use crate::boot_time::{timeout, BootTime, Duration}; +use crate::config::Config; +use crate::connection::Connection; +use crate::dispatcher::{QueryError, Response}; +use crate::encoding; +use anyhow::{anyhow, Result}; +use std::sync::Arc; +use tokio::sync::{mpsc, watch}; +use tokio::task; + +use super::{Query, ServerInfo, SocketTagger, ValidationReporter}; + +pub struct Driver { + info: ServerInfo, + config: Config, + connection: Connection, + command_rx: mpsc::Receiver, + status_tx: watch::Sender, + validation: ValidationReporter, + tag_socket: SocketTagger, +} + +#[derive(Debug)] +/// Requests the network can handle +pub enum Command { + /// Send a DNS query to the network + Query(Query), + /// Run a probe to check the health of the network. Argument is timeout. + Probe(Duration), +} + +#[derive(Clone, Debug)] +/// Current Network Status +/// +/// (Unprobed or Failed) can go to (Live or Failed) via Probe. +/// Currently, there is no way to go from Live to Failed - probing a live network will short-circuit to returning valid, and query failures do not declare the network failed. +pub enum Status { + /// Network has not been probed, it may or may not work + Unprobed, + /// Network is believed to be working + Live, + /// Network is broken, reason as argument + Failed(Arc), +} + +impl Status { + pub fn is_live(&self) -> bool { + matches!(self, Self::Live) + } + pub fn is_failed(&self) -> bool { + matches!(self, Self::Failed(_)) + } +} + +async fn build_connection( + info: &ServerInfo, + tag_socket: &SocketTagger, + config: &mut Config, +) -> Result { + use std::ops::DerefMut; + Ok(Connection::new( + info.domain.as_deref(), + info.peer_addr, + info.sk_mark, + tag_socket, + config.take().await.deref_mut(), + ) + .await?) +} + +impl Driver { + const MAX_BUFFERED_COMMANDS: usize = 10; + + pub async fn new( + info: ServerInfo, + mut config: Config, + validation: ValidationReporter, + tag_socket: SocketTagger, + ) -> Result<(Self, mpsc::Sender, watch::Receiver)> { + let (command_tx, command_rx) = mpsc::channel(Self::MAX_BUFFERED_COMMANDS); + let (status_tx, status_rx) = watch::channel(Status::Unprobed); + let connection = build_connection(&info, &tag_socket, &mut config).await?; + Ok(( + Self { info, config, connection, status_tx, command_rx, validation, tag_socket }, + command_tx, + status_rx, + )) + } + + pub async fn drive(mut self) -> Result<()> { + while let Some(cmd) = self.command_rx.recv().await { + if let Err(e) = match cmd { + Command::Probe(duration) => self.probe(duration).await, + Command::Query(query) => self.send_query(query).await, + } { + self.status_tx.send(Status::Failed(Arc::new(e)))? + }; + } + Ok(()) + } + + async fn probe(&mut self, probe_timeout: Duration) -> Result<()> { + if self.status_tx.borrow().is_failed() { + // If our network is currently failed, it may be due to issues with the connection. + // Re-establish before re-probing + self.connection = + build_connection(&self.info, &self.tag_socket, &mut self.config).await?; + self.status_tx.send(Status::Unprobed)?; + } + if self.status_tx.borrow().is_live() { + // If we're already validated, short circuit + (self.validation)(&self.info, true).await; + return Ok(()); + } + self.force_probe(probe_timeout).await + } + + async fn force_probe(&mut self, probe_timeout: Duration) -> Result<()> { + let probe = encoding::probe_query()?; + let dns_request = encoding::dns_request(&probe, &self.info.url)?; + let expiry = BootTime::now().checked_add(probe_timeout); + let request = async { + match self.connection.query(dns_request, expiry).await { + Err(e) => self.status_tx.send(Status::Failed(Arc::new(anyhow!(e)))), + Ok(rsp) => { + if let Some(_stream) = rsp.await { + // TODO verify stream contents + self.status_tx.send(Status::Live) + } else { + self.status_tx.send(Status::Failed(Arc::new(anyhow!("Empty response")))) + } + } + } + }; + match timeout(probe_timeout, request).await { + // Timed out + Err(time) => self.status_tx.send(Status::Failed(Arc::new(anyhow!( + "Probe timed out after {:?} (timeout={:?})", + time, + probe_timeout + )))), + // Query completed + Ok(r) => r, + }?; + let valid = self.status_tx.borrow().is_live(); + (self.validation)(&self.info, valid).await; + Ok(()) + } + + async fn send_query(&mut self, query: Query) -> Result<()> { + let request = encoding::dns_request(&query.query, &self.info.url)?; + let stream_fut = self.connection.query(request, Some(query.expiry)).await?; + task::spawn(async move { + let stream = match stream_fut.await { + Some(stream) => stream, + None => { + // We don't care if the response is gone + let _ = + query.response.send(Response::Error { error: QueryError::ConnectionError }); + return; + } + }; + // We don't care if the response is gone. + let _ = if let Some(err) = stream.error { + query.response.send(Response::Error { error: QueryError::Reset(err) }) + } else { + query.response.send(Response::Success { answer: stream.data }) + }; + }); + Ok(()) + } +} diff --git a/doh/network/mod.rs b/doh/network/mod.rs new file mode 100644 index 00000000..5d26688c --- /dev/null +++ b/doh/network/mod.rs @@ -0,0 +1,112 @@ +/* + * Copyright (C) 2021 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +//! Provides the ability to query DNS for a specific network configuration + +use crate::boot_time::{BootTime, Duration}; +use crate::config::Config; +use crate::dispatcher::{QueryError, Response}; +use anyhow::Result; +use futures::future::BoxFuture; +use log::warn; +use std::net::SocketAddr; +use std::sync::Arc; +use tokio::sync::{mpsc, oneshot, watch}; +use tokio::task; +use url::Url; + +mod driver; + +use driver::{Command, Driver}; + +pub use driver::Status; + +/// Closure to signal validation status to outside world +pub type ValidationReporter = Arc BoxFuture<()> + Send + Sync>; +/// Closure to tag socket during connection construction +pub type SocketTagger = Arc BoxFuture<()> + Send + Sync>; + +#[derive(Eq, PartialEq, Debug, Clone)] +pub struct ServerInfo { + pub net_id: u32, + pub url: Url, + pub peer_addr: SocketAddr, + pub domain: Option, + pub sk_mark: u32, + pub cert_path: Option, +} + +#[derive(Debug)] +/// DNS resolution query +pub struct Query { + /// Raw DNS query, base64 encoded + pub query: String, + /// Place to send the answer + pub response: oneshot::Sender, + /// When this request is considered stale (will be ignored if not serviced by that point) + pub expiry: BootTime, +} + +/// Handle to a particular network's DNS resolution +pub struct Network { + info: ServerInfo, + status_rx: watch::Receiver, + command_tx: mpsc::Sender, +} + +impl Network { + pub async fn new( + info: ServerInfo, + config: Config, + validation: ValidationReporter, + tagger: SocketTagger, + ) -> Result { + let (driver, command_tx, status_rx) = + Driver::new(info.clone(), config, validation, tagger).await?; + task::spawn(driver.drive()); + Ok(Network { info, command_tx, status_rx }) + } + + pub async fn probe(&mut self, timeout: Duration) -> Result<()> { + self.command_tx.send(Command::Probe(timeout)).await?; + Ok(()) + } + + pub async fn query(&mut self, query: Query) -> Result<()> { + // The clone is used to prevent status_rx from being held across an await + let status: Status = self.status_rx.borrow().clone(); + match status { + Status::Failed(_) => query + .response + .send(Response::Error { error: QueryError::BrokenServer }) + .unwrap_or_else(|_| { + warn!("Query result listener went away before receiving a response") + }), + Status::Unprobed => query + .response + .send(Response::Error { error: QueryError::ServerNotReady }) + .unwrap_or_else(|_| { + warn!("Query result listener went away before receiving a response") + }), + Status::Live => self.command_tx.send(Command::Query(query)).await?, + } + Ok(()) + } + + pub fn get_info(&self) -> &ServerInfo { + &self.info + } +} -- cgit v1.2.3