aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatthew Maurer <mmaurer@google.com>2021-10-25 13:04:44 -0700
committerMike Yu <yumike@google.com>2021-11-09 21:00:59 +0800
commited78fdaf9b08551599bc79736729f7cf3fc6eb06 (patch)
tree8f17c8e4ab91793d5520a619394000f565fb6039
parentf42426ec380f66ccdf28e463c5386c3b611a6b94 (diff)
downloadDnsResolver-ed78fdaf9b08551599bc79736729f7cf3fc6eb06.tar.gz
DoH: Modularize main event loop
* Connection now provides HTTP/3. * Network has the logic for resolving DNS and maintaining a Connection. * Dispatcher routes requests to the appropriate Network or creates one if needed. * IO and maintenance is performed via tasks rather than manually pushing the futures in the main event loop. Bug: 202081046 Test: resolv_integration_test Test: resolv_stress_test + I682678b84b35c575a3eb88c2c1c67aefd195616c Change-Id: I4296d0c7a7852951f41418b18686794d8df781bd
-rw-r--r--Android.bp1
-rw-r--r--doh/config.rs13
-rw-r--r--doh/connection/driver.rs369
-rw-r--r--doh/connection/mod.rs150
-rw-r--r--doh/dispatcher/driver.rs127
-rw-r--r--doh/dispatcher/mod.rs110
-rw-r--r--doh/doh.rs1041
-rw-r--r--doh/ffi.rs18
-rw-r--r--doh/network/driver.rs190
-rw-r--r--doh/network/mod.rs112
10 files changed, 1078 insertions, 1053 deletions
diff --git a/Android.bp b/Android.bp
index ce9a3dcb..823af852 100644
--- a/Android.bp
+++ b/Android.bp
@@ -334,6 +334,7 @@ doh_rust_deps = [
"liblibc",
"liblog_rust",
"libring",
+ "libthiserror",
"libtokio",
"liburl",
]
diff --git a/doh/config.rs b/doh/config.rs
index 04d07c50..91284052 100644
--- a/doh/config.rs
+++ b/doh/config.rs
@@ -27,7 +27,8 @@
use quiche::{h3, Result};
use std::collections::HashMap;
use std::ops::DerefMut;
-use std::sync::{Arc, Mutex, RwLock, Weak};
+use std::sync::{Arc, RwLock, Weak};
+use tokio::sync::Mutex;
type WeakConfig = Weak<Mutex<quiche::Config>>;
@@ -80,8 +81,8 @@ impl Config {
/// Take the underlying config, usable as `&mut quiche::Config` for use
/// with `quiche::connect`.
- pub fn take(&mut self) -> impl DerefMut<Target = quiche::Config> + '_ {
- self.0.lock().unwrap()
+ pub async fn take(&mut self) -> impl DerefMut<Target = quiche::Config> + '_ {
+ self.0.lock().await
}
}
@@ -229,11 +230,11 @@ fn lifetimes() {
assert_eq!(cache.state.read().unwrap().path_to_config.len(), 2);
}
-#[test]
-fn quiche_connect() {
+#[tokio::test]
+async fn quiche_connect() {
use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
let mut config = Config::from_cert_path(None).unwrap();
let socket_addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 42));
let conn_id = quiche::ConnectionId::from_ref(&[]);
- quiche::connect(None, &conn_id, socket_addr, &mut config.take()).unwrap();
+ quiche::connect(None, &conn_id, socket_addr, config.take().await.deref_mut()).unwrap();
}
diff --git a/doh/connection/driver.rs b/doh/connection/driver.rs
new file mode 100644
index 00000000..4fd1e266
--- /dev/null
+++ b/doh/connection/driver.rs
@@ -0,0 +1,369 @@
+/*
+* Copyright (C) 2021 The Android Open Source Project
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+//! Defines a backing task to keep a HTTP/3 connection running
+
+use crate::boot_time;
+use crate::boot_time::BootTime;
+use log::warn;
+use quiche::h3;
+use std::collections::HashMap;
+use std::default::Default;
+use std::future;
+use std::io;
+use std::pin::Pin;
+use thiserror::Error;
+use tokio::net::UdpSocket;
+use tokio::select;
+use tokio::sync::{mpsc, oneshot};
+
+#[derive(Error, Debug)]
+pub enum Error {
+ #[error("network IO error: {0}")]
+ Network(#[from] io::Error),
+ #[error("QUIC error: {0}")]
+ Quic(#[from] quiche::Error),
+ #[error("HTTP/3 error: {0}")]
+ H3(#[from] h3::Error),
+ #[error("Response delivery error: {0}")]
+ StreamSend(#[from] mpsc::error::SendError<Stream>),
+ #[error("Connection closed")]
+ Closed,
+}
+
+pub type Result<T> = std::result::Result<T, Error>;
+
+#[derive(Debug)]
+/// HTTP/3 Request to be sent on the connection
+pub struct Request {
+ /// Request headers
+ pub headers: Vec<h3::Header>,
+ /// Expiry time for the request, relative to `CLOCK_BOOTTIME`
+ pub expiry: Option<BootTime>,
+ /// Channel to send the response to
+ pub response_tx: oneshot::Sender<Stream>,
+}
+
+#[derive(Debug)]
+/// HTTP/3 Response
+pub struct Stream {
+ /// Response headers
+ pub headers: Vec<h3::Header>,
+ /// Response body
+ pub data: Vec<u8>,
+ /// Error code if stream was reset
+ pub error: Option<u64>,
+}
+
+impl Stream {
+ fn new(headers: Vec<h3::Header>) -> Self {
+ Self { headers, data: Vec::new(), error: None }
+ }
+}
+
+const MAX_UDP_PACKET_SIZE: usize = 65536;
+
+struct Driver {
+ request_rx: mpsc::Receiver<Request>,
+ quiche_conn: Pin<Box<quiche::Connection>>,
+ socket: UdpSocket,
+ // This buffer is large, boxing it will keep it
+ // off the stack and prevent it being copied during
+ // moves of the driver.
+ buffer: Box<[u8; MAX_UDP_PACKET_SIZE]>,
+}
+
+struct H3Driver {
+ driver: Driver,
+ // h3_conn sometimes can't "fit" a request in its available windows.
+ // This value holds a peeked request in that case, waiting for
+ // transmission to become possible.
+ buffered_request: Option<Request>,
+ // We can't check if a receiver is dead without potentially receiving a message, and if we poll
+ // on a dead receiver in a select! it will immediately return None. As a result, we need this
+ // to gate whether or not to include .recv() in our select!
+ closing: bool,
+ h3_conn: h3::Connection,
+ requests: HashMap<u64, Request>,
+ streams: HashMap<u64, Stream>,
+}
+
+async fn optional_timeout(timeout: Option<boot_time::Duration>) {
+ match timeout {
+ Some(timeout) => boot_time::sleep(timeout).await,
+ None => future::pending().await,
+ }
+}
+
+/// Creates a future which when polled will handle events related to a HTTP/3 connection.
+/// The returned error code will explain why the connection terminated.
+pub async fn drive(
+ request_rx: mpsc::Receiver<Request>,
+ quiche_conn: Pin<Box<quiche::Connection>>,
+ socket: UdpSocket,
+) -> Result<()> {
+ Driver::new(request_rx, quiche_conn, socket).drive().await
+}
+
+impl Driver {
+ fn new(
+ request_rx: mpsc::Receiver<Request>,
+ quiche_conn: Pin<Box<quiche::Connection>>,
+ socket: UdpSocket,
+ ) -> Self {
+ Self { request_rx, quiche_conn, socket, buffer: Box::new([0; MAX_UDP_PACKET_SIZE]) }
+ }
+
+ async fn drive(mut self) -> Result<()> {
+ // Prime connection
+ self.flush_tx().await?;
+ loop {
+ self = self.drive_once().await?
+ }
+ }
+
+ fn handle_closed(&self) -> Result<()> {
+ if self.quiche_conn.is_closed() {
+ Err(Error::Closed)
+ } else {
+ Ok(())
+ }
+ }
+
+ async fn drive_once(mut self) -> Result<Self> {
+ let timer = optional_timeout(self.quiche_conn.timeout());
+ select! {
+ // If a quiche timer would fire, call their callback
+ _ = timer => self.quiche_conn.on_timeout(),
+ // If we got packets from our peer, pass them to quiche
+ Ok((size, from)) = self.socket.recv_from(self.buffer.as_mut()) => {
+ self.quiche_conn.recv(&mut self.buffer[..size], quiche::RecvInfo { from })?;
+ }
+ };
+ // Any of the actions in the select could require us to send packets to the peer
+ self.flush_tx().await?;
+
+ // If the QUIC connection is live, but the HTTP/3 is not, try to bring it up
+ if self.quiche_conn.is_established() {
+ let h3_config = h3::Config::new()?;
+ let h3_conn = h3::Connection::with_transport(&mut self.quiche_conn, &h3_config)?;
+ return H3Driver::new(self, h3_conn).drive().await;
+ }
+
+ // If the connection has closed, tear down
+ self.handle_closed()?;
+
+ Ok(self)
+ }
+
+ async fn flush_tx(&mut self) -> Result<()> {
+ let send_buf = self.buffer.as_mut();
+ loop {
+ match self.quiche_conn.send(send_buf) {
+ Err(quiche::Error::Done) => return Ok(()),
+ Err(e) => return Err(e.into()),
+ Ok((valid_len, send_info)) => {
+ self.socket.send_to(&send_buf[..valid_len], send_info.to).await?;
+ }
+ }
+ }
+ }
+}
+
+impl H3Driver {
+ fn new(driver: Driver, h3_conn: h3::Connection) -> Self {
+ Self {
+ driver,
+ h3_conn,
+ closing: false,
+ requests: HashMap::new(),
+ streams: HashMap::new(),
+ buffered_request: None,
+ }
+ }
+
+ async fn drive(mut self) -> Result<Driver> {
+ loop {
+ self.drive_once().await?;
+ }
+ }
+
+ async fn drive_once(&mut self) -> Result<()> {
+ // We can't call self.driver.drive_once at the same time as
+ // self.driver.request_rx.recv() due to ownership
+ let timer = optional_timeout(self.driver.quiche_conn.timeout());
+ // If we've buffered a request (due to the connection being full)
+ // try to resend that first
+ if let Some(request) = self.buffered_request.take() {
+ self.handle_request(request)?;
+ }
+ select! {
+ // Only attempt to enqueue new requests if we have no buffered request and aren't
+ // closing
+ msg = self.driver.request_rx.recv(), if !self.closing && self.buffered_request.is_none() => match msg {
+ Some(request) => self.handle_request(request)?,
+ None => self.shutdown(true, b"DONE").await?,
+ },
+ // If a quiche timer would fire, call their callback
+ _ = timer => self.driver.quiche_conn.on_timeout(),
+ // If we got packets from our peer, pass them to quiche
+ Ok((size, from)) = self.driver.socket.recv_from(self.driver.buffer.as_mut()) => {
+ self.driver.quiche_conn.recv(&mut self.driver.buffer[..size], quiche::RecvInfo { from })?;
+ }
+ };
+
+ // Any of the actions in the select could require us to send packets to the peer
+ self.driver.flush_tx().await?;
+
+ // Process any incoming HTTP/3 events
+ self.flush_h3().await?;
+
+ // If the connection has closed, tear down
+ self.driver.handle_closed()
+ }
+
+ fn handle_request(&mut self, request: Request) -> Result<()> {
+ // If the request has already timed out, don't issue it to the server.
+ if let Some(expiry) = request.expiry {
+ if BootTime::now() > expiry {
+ return Ok(());
+ }
+ }
+ let stream_id =
+ // If h3_conn says the stream is blocked, this error is recoverable just by trying
+ // again once the stream has made progress. Buffer the request for a later retry.
+ match self.h3_conn.send_request(&mut self.driver.quiche_conn, &request.headers, true) {
+ Err(h3::Error::StreamBlocked) | Err(h3::Error::TransportError(quiche::Error::StreamLimit)) => {
+ // We only call handle_request on a value that has just come out of
+ // buffered_request, or when buffered_request is empty. This assert just
+ // validates that we don't break that assumption later, as it could result in
+ // requests being dropped on the floor under high load.
+ assert!(self.buffered_request.is_none());
+ self.buffered_request = Some(request);
+ return Ok(())
+ }
+ result => result?,
+ };
+ self.requests.insert(stream_id, request);
+ Ok(())
+ }
+
+ async fn recv_body(&mut self, stream_id: u64) -> Result<()> {
+ const STREAM_READ_CHUNK: usize = 4096;
+ if let Some(stream) = self.streams.get_mut(&stream_id) {
+ loop {
+ let base_len = stream.data.len();
+ stream.data.resize(base_len + STREAM_READ_CHUNK, 0);
+ match self.h3_conn.recv_body(
+ &mut self.driver.quiche_conn,
+ stream_id,
+ &mut stream.data[base_len..],
+ ) {
+ Err(h3::Error::Done) => {
+ stream.data.truncate(base_len);
+ return Ok(());
+ }
+ Err(e) => {
+ stream.data.truncate(base_len);
+ return Err(e.into());
+ }
+ Ok(recvd) => stream.data.truncate(base_len + recvd),
+ }
+ }
+ } else {
+ warn!("Received body for untracked stream ID {}", stream_id);
+ }
+ Ok(())
+ }
+
+ fn discard_datagram(&mut self, _flow_id: u64) -> Result<()> {
+ loop {
+ match self.h3_conn.recv_dgram(&mut self.driver.quiche_conn, self.driver.buffer.as_mut())
+ {
+ Err(h3::Error::Done) => return Ok(()),
+ Err(e) => return Err(e.into()),
+ _ => (),
+ }
+ }
+ }
+
+ async fn flush_h3(&mut self) -> Result<()> {
+ loop {
+ match self.h3_conn.poll(&mut self.driver.quiche_conn) {
+ Err(h3::Error::Done) => return Ok(()),
+ Err(e) => return Err(e.into()),
+ Ok((stream_id, event)) => self.process_h3_event(stream_id, event).await?,
+ }
+ }
+ }
+
+ async fn process_h3_event(&mut self, stream_id: u64, event: h3::Event) -> Result<()> {
+ if !self.requests.contains_key(&stream_id) {
+ warn!("Received event {:?} for stream_id {} without a request.", event, stream_id);
+ }
+ match event {
+ h3::Event::Headers { list, has_body } => {
+ let stream = Stream::new(list);
+ if self.streams.insert(stream_id, stream).is_some() {
+ warn!("Re-using stream ID {} before it was completed.", stream_id)
+ }
+ if !has_body {
+ self.respond(stream_id);
+ }
+ }
+ h3::Event::Data => {
+ self.recv_body(stream_id).await?;
+ }
+ h3::Event::Finished => self.respond(stream_id),
+ // This clause is for quiche 0.10.x, we're still on 0.9.x
+ //h3::Event::Reset(e) => {
+ // self.streams.get_mut(&stream_id).map(|stream| stream.error = Some(e));
+ // self.respond(stream_id);
+ //}
+ h3::Event::Datagram => {
+ warn!("Unexpected Datagram received");
+ // We don't care if something went wrong with the datagram, we didn't
+ // want it anyways.
+ let _ = self.discard_datagram(stream_id);
+ }
+ h3::Event::GoAway => self.shutdown(false, b"SERVER GOAWAY").await?,
+ }
+ Ok(())
+ }
+
+ async fn shutdown(&mut self, send_goaway: bool, msg: &[u8]) -> Result<()> {
+ self.driver.request_rx.close();
+ while self.driver.request_rx.recv().await.is_some() {}
+ self.closing = true;
+ if send_goaway {
+ self.h3_conn.send_goaway(&mut self.driver.quiche_conn, 0)?;
+ }
+ if self.driver.quiche_conn.close(true, 0, msg).is_err() {
+ warn!("Trying to close already closed QUIC connection");
+ }
+ Ok(())
+ }
+
+ fn respond(&mut self, stream_id: u64) {
+ match (self.streams.remove(&stream_id), self.requests.remove(&stream_id)) {
+ (Some(stream), Some(request)) => {
+ // We don't care about the error, because it means the requestor has left.
+ let _ = request.response_tx.send(stream);
+ }
+ (None, _) => warn!("Tried to deliver untracked stream {}", stream_id),
+ (_, None) => warn!("Tried to deliver stream {} to untracked requestor", stream_id),
+ }
+ }
+}
diff --git a/doh/connection/mod.rs b/doh/connection/mod.rs
new file mode 100644
index 00000000..bc5b75c5
--- /dev/null
+++ b/doh/connection/mod.rs
@@ -0,0 +1,150 @@
+/*
+* Copyright (C) 2021 The Android Open Source Project
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+//! Module providing an async abstraction around a quiche HTTP/3 connection
+
+use crate::boot_time::BootTime;
+use crate::network::SocketTagger;
+use log::error;
+use quiche::h3;
+use std::future::Future;
+use std::io;
+use std::net::SocketAddr;
+use thiserror::Error;
+use tokio::net::UdpSocket;
+use tokio::sync::mpsc;
+use tokio::sync::oneshot;
+use tokio::task;
+
+mod driver;
+
+pub use driver::Stream;
+use driver::{drive, Request};
+
+/// Quiche HTTP/3 connection
+pub struct Connection {
+ request_tx: mpsc::Sender<Request>,
+}
+
+fn new_scid() -> [u8; quiche::MAX_CONN_ID_LEN] {
+ use ring::rand::{SecureRandom, SystemRandom};
+ let mut scid = [0; quiche::MAX_CONN_ID_LEN];
+ SystemRandom::new().fill(&mut scid).unwrap();
+ scid
+}
+
+fn mark_socket(socket: &std::net::UdpSocket, socket_mark: u32) -> io::Result<()> {
+ use std::os::unix::io::AsRawFd;
+ let fd = socket.as_raw_fd();
+ // libc::setsockopt is a wrapper function calling into bionic setsockopt.
+ // The only pointer being passed in is &socket_mark, which is valid by virtue of being a
+ // reference, and the foreign function doesn't take ownership or a reference to that memory
+ // after completion.
+ if unsafe {
+ libc::setsockopt(
+ fd,
+ libc::SOL_SOCKET,
+ libc::SO_MARK,
+ &socket_mark as *const _ as *const libc::c_void,
+ std::mem::size_of::<u32>() as libc::socklen_t,
+ )
+ } == 0
+ {
+ Ok(())
+ } else {
+ Err(io::Error::last_os_error())
+ }
+}
+
+async fn build_socket(
+ peer_addr: SocketAddr,
+ socket_mark: u32,
+ tag_socket: &SocketTagger,
+) -> io::Result<UdpSocket> {
+ let bind_addr = match peer_addr {
+ SocketAddr::V4(_) => "0.0.0.0:0",
+ SocketAddr::V6(_) => "[::]:0",
+ };
+
+ let socket = UdpSocket::bind(bind_addr).await?;
+ let std_socket = socket.into_std()?;
+ mark_socket(&std_socket, socket_mark)
+ .unwrap_or_else(|e| error!("Unable to mark socket : {:?}", e));
+ tag_socket(&std_socket).await;
+ let socket = UdpSocket::from_std(std_socket)?;
+ socket.connect(peer_addr).await?;
+ Ok(socket)
+}
+
+/// Error type for HTTP/3 connection
+#[derive(Debug, Error)]
+pub enum Error {
+ /// QUIC protocol error
+ #[error("QUIC error: {0}")]
+ Quic(#[from] quiche::Error),
+ /// HTTP/3 protocol error
+ #[error("HTTP/3 error: {0}")]
+ H3(#[from] h3::Error),
+ /// Unable to send the request to the driver. This likely means the
+ /// backing task has died.
+ #[error("Unable to send request")]
+ SendRequest(#[from] mpsc::error::SendError<Request>),
+ /// IO failed. This is most likely to occur while trying to set up the
+ /// UDP socket for use by the connection.
+ #[error("IO error: {0}")]
+ Io(#[from] io::Error),
+ /// The request is no longer being serviced. This could mean that the
+ /// request was dropped for an unspecified reason, or that the connection
+ /// was closed prematurely and it can no longer be serviced.
+ #[error("Driver dropped request")]
+ RecvResponse(#[from] oneshot::error::RecvError),
+}
+
+/// Common result type for working with a HTTP/3 connection
+pub type Result<T> = std::result::Result<T, Error>;
+
+impl Connection {
+ const MAX_PENDING_REQUESTS: usize = 10;
+ /// Create a new connection with a background task handling IO.
+ pub async fn new(
+ server_name: Option<&str>,
+ to: SocketAddr,
+ socket_mark: u32,
+ tag_socket: &SocketTagger,
+ config: &mut quiche::Config,
+ ) -> Result<Self> {
+ let (request_tx, request_rx) = mpsc::channel(Self::MAX_PENDING_REQUESTS);
+ let scid = new_scid();
+ let quiche_conn =
+ quiche::connect(server_name, &quiche::ConnectionId::from_ref(&scid), to, config)?;
+ let socket = build_socket(to, socket_mark, tag_socket).await?;
+ let driver = drive(request_rx, quiche_conn, socket);
+ task::spawn(driver);
+ Ok(Self { request_tx })
+ }
+
+ /// Send a query, produce a future which will provide a response.
+ /// The future is separately returned rather than awaited to allow it to be waited on without
+ /// keeping the `Connection` itself borrowed.
+ pub async fn query(
+ &self,
+ headers: Vec<h3::Header>,
+ expiry: Option<BootTime>,
+ ) -> Result<impl Future<Output = Option<Stream>>> {
+ let (response_tx, response_rx) = oneshot::channel();
+ self.request_tx.send(Request { headers, response_tx, expiry }).await?;
+ Ok(async move { response_rx.await.ok() })
+ }
+}
diff --git a/doh/dispatcher/driver.rs b/doh/dispatcher/driver.rs
new file mode 100644
index 00000000..75f456b6
--- /dev/null
+++ b/doh/dispatcher/driver.rs
@@ -0,0 +1,127 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+//! Provides a backing task to implement a Dispatcher
+
+use crate::boot_time::{BootTime, Duration};
+use anyhow::{bail, Result};
+use log::{debug, trace, warn};
+use std::collections::HashMap;
+use tokio::sync::{mpsc, oneshot};
+
+use super::{Command, QueryError, Response};
+use crate::network::{Network, ServerInfo, SocketTagger, ValidationReporter};
+use crate::{config, network};
+
+pub struct Driver {
+ command_rx: mpsc::Receiver<Command>,
+ networks: HashMap<u32, Network>,
+ validation: ValidationReporter,
+ tagger: SocketTagger,
+ config_cache: config::Cache,
+}
+
+fn debug_err(r: Result<()>) {
+ if let Err(e) = r {
+ debug!("Dispatcher loop got {:?}", e);
+ }
+}
+
+impl Driver {
+ pub fn new(
+ command_rx: mpsc::Receiver<Command>,
+ validation: ValidationReporter,
+ tagger: SocketTagger,
+ ) -> Self {
+ Self {
+ command_rx,
+ networks: HashMap::new(),
+ validation,
+ tagger,
+ config_cache: config::Cache::new(),
+ }
+ }
+
+ pub async fn drive(mut self) -> Result<()> {
+ loop {
+ self.drive_once().await?
+ }
+ }
+
+ async fn drive_once(&mut self) -> Result<()> {
+ if let Some(command) = self.command_rx.recv().await {
+ trace!("dispatch command: {:?}", command);
+ match command {
+ Command::Probe { info, timeout } => debug_err(self.probe(info, timeout).await),
+ Command::Query { net_id, base64_query, expired_time, resp } => {
+ debug_err(self.query(net_id, base64_query, expired_time, resp).await)
+ }
+ Command::Clear { net_id } => {
+ self.networks.remove(&net_id);
+ self.config_cache.garbage_collect();
+ }
+ Command::Exit => {
+ bail!("Death due to Exit")
+ }
+ }
+ Ok(())
+ } else {
+ bail!("Death due to command_tx dying")
+ }
+ }
+
+ async fn query(
+ &mut self,
+ net_id: u32,
+ query: String,
+ expiry: BootTime,
+ response: oneshot::Sender<Response>,
+ ) -> Result<()> {
+ if let Some(network) = self.networks.get_mut(&net_id) {
+ network.query(network::Query { query, response, expiry }).await?;
+ } else {
+ warn!("Tried to send a query to non-existent network net_id={}", net_id);
+ response.send(Response::Error { error: QueryError::Unexpected }).unwrap_or_else(|_| {
+ warn!("Unable to send reply for non-existent network net_id={}", net_id);
+ })
+ }
+ Ok(())
+ }
+
+ async fn probe(&mut self, info: ServerInfo, timeout: Duration) -> Result<()> {
+ use std::collections::hash_map::Entry;
+ if !self.networks.get(&info.net_id).map_or(true, |net| net.get_info() == &info) {
+ // If we have a network registered to the provided net_id, but the server info doesn't
+ // match, our API has been used incorrectly. Attempt to recover by deleting the old
+ // network and recreating it according to the probe request.
+ warn!("Probing net_id={} with mismatched server info", info.net_id);
+ self.networks.remove(&info.net_id);
+ }
+ // Can't use or_insert_with because creating a network may fail
+ let net = match self.networks.entry(info.net_id) {
+ Entry::Occupied(network) => network.into_mut(),
+ Entry::Vacant(vacant) => {
+ let config = self.config_cache.from_cert_path(&info.cert_path)?;
+ vacant.insert(
+ Network::new(info, config, self.validation.clone(), self.tagger.clone())
+ .await?,
+ )
+ }
+ };
+ net.probe(timeout).await?;
+ Ok(())
+ }
+}
diff --git a/doh/dispatcher/mod.rs b/doh/dispatcher/mod.rs
new file mode 100644
index 00000000..66e2f3d5
--- /dev/null
+++ b/doh/dispatcher/mod.rs
@@ -0,0 +1,110 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+use crate::boot_time::{BootTime, Duration};
+use anyhow::Result;
+use log::error;
+use tokio::runtime::{Builder, Runtime};
+use tokio::sync::{mpsc, oneshot};
+use tokio::task;
+
+pub use crate::network::{ServerInfo, SocketTagger, ValidationReporter};
+
+const MAX_BUFFERED_CMD_COUNT: usize = 400;
+
+mod driver;
+use driver::Driver;
+
+#[derive(Eq, PartialEq, Debug)]
+/// Error response to a query
+pub enum QueryError {
+ /// Network failed probing
+ BrokenServer,
+ /// HTTP/3 connection died
+ ConnectionError,
+ /// Network not probed yet
+ ServerNotReady,
+ /// Server reset HTTP/3 stream
+ Reset(u64),
+ /// Tried to query non-existent network
+ Unexpected,
+}
+
+#[derive(Eq, PartialEq, Debug)]
+pub enum Response {
+ Error { error: QueryError },
+ Success { answer: Vec<u8> },
+}
+
+#[derive(Debug)]
+pub enum Command {
+ Probe {
+ info: ServerInfo,
+ timeout: Duration,
+ },
+ Query {
+ net_id: u32,
+ base64_query: String,
+ expired_time: BootTime,
+ resp: oneshot::Sender<Response>,
+ },
+ Clear {
+ net_id: u32,
+ },
+ Exit,
+}
+
+/// Context for a running DoH engine.
+pub struct Dispatcher {
+ /// Used to submit cmds to the I/O task.
+ cmd_sender: mpsc::Sender<Command>,
+ join_handle: task::JoinHandle<Result<()>>,
+ runtime: Runtime,
+}
+
+impl Dispatcher {
+ const DOH_THREADS: usize = 2;
+
+ pub fn new(validation: ValidationReporter, tagger: SocketTagger) -> Result<Dispatcher> {
+ let (cmd_sender, cmd_receiver) = mpsc::channel::<Command>(MAX_BUFFERED_CMD_COUNT);
+ let runtime = Builder::new_multi_thread()
+ .worker_threads(Self::DOH_THREADS)
+ .enable_all()
+ .thread_name("doh-handler")
+ .build()?;
+ let join_handle = runtime.spawn(async {
+ let result = Driver::new(cmd_receiver, validation, tagger).drive().await;
+ match result {
+ Err(ref e) => error!("Dispatcher driver exited due to {:?}", e),
+ Ok(()) => (),
+ }
+ result
+ });
+ Ok(Dispatcher { cmd_sender, join_handle, runtime })
+ }
+
+ pub fn send_cmd(&self, cmd: Command) -> Result<()> {
+ self.cmd_sender.blocking_send(cmd)?;
+ Ok(())
+ }
+
+ pub fn exit_handler(&mut self) {
+ if self.cmd_sender.blocking_send(Command::Exit).is_err() {
+ return;
+ }
+ let _ = self.runtime.block_on(&mut self.join_handle);
+ }
+}
diff --git a/doh/doh.rs b/doh/doh.rs
index 0e55d422..5ceb5e81 100644
--- a/doh/doh.rs
+++ b/doh/doh.rs
@@ -16,1045 +16,10 @@
//! DoH backend for the Android DnsResolver module.
-use anyhow::{anyhow, bail, Context, Result};
-use futures::future::{join_all, BoxFuture};
-use futures::stream::FuturesUnordered;
-use futures::StreamExt;
-use log::{debug, error, info, trace, warn};
-use quiche::h3;
-use ring::rand::SecureRandom;
-use std::collections::HashMap;
-use std::net::SocketAddr;
-use std::os::unix::io::{AsRawFd, RawFd};
-use std::pin::Pin;
-use std::sync::Arc;
-use tokio::net::UdpSocket;
-use tokio::runtime::{Builder, Runtime};
-use tokio::sync::{mpsc, oneshot};
-use tokio::task;
-use url::Url;
-
pub mod boot_time;
mod config;
+mod connection;
+mod dispatcher;
mod encoding;
mod ffi;
-
-use boot_time::{timeout, BootTime, Duration};
-use config::Config;
-
-const MAX_BUFFERED_CMD_SIZE: usize = 400;
-const DOH_PORT: u16 = 443;
-
-type ValidationReporter = Box<dyn Fn(&ServerInfo, bool) -> BoxFuture<()> + Send + Sync>;
-type SocketTagger = Arc<dyn Fn(&std::net::UdpSocket) -> BoxFuture<()> + Send + Sync>;
-
-type SCID = [u8; quiche::MAX_CONN_ID_LEN];
-type Base64Query = String;
-type CmdSender = mpsc::Sender<DohCommand>;
-type CmdReceiver = mpsc::Receiver<DohCommand>;
-type QueryResponder = oneshot::Sender<Response>;
-type DnsRequest = Vec<quiche::h3::Header>;
-
-#[derive(Eq, PartialEq, Debug)]
-enum QueryError {
- BrokenServer,
- ConnectionError,
- ServerNotReady,
- Unexpected,
-}
-
-#[derive(Eq, PartialEq, Debug, Clone)]
-struct ServerInfo {
- net_id: u32,
- url: Url,
- peer_addr: SocketAddr,
- domain: Option<String>,
- sk_mark: u32,
- cert_path: Option<String>,
-}
-
-#[derive(Eq, PartialEq, Debug)]
-enum Response {
- Error { error: QueryError },
- Success { answer: Vec<u8> },
-}
-
-#[derive(Debug)]
-enum DohCommand {
- Probe { info: ServerInfo, timeout: Duration },
- Query { net_id: u32, base64_query: Base64Query, expired_time: BootTime, resp: QueryResponder },
- Clear { net_id: u32 },
- Exit,
-}
-
-#[allow(clippy::large_enum_variant)]
-enum ConnectionState {
- Idle,
- Connecting {
- quic_conn: Option<Pin<Box<quiche::Connection>>>,
- udp_sk: Option<UdpSocket>,
- expired_time: Option<BootTime>,
- },
- Connected {
- quic_conn: Pin<Box<quiche::Connection>>,
- udp_sk: UdpSocket,
- h3_conn: Option<h3::Connection>,
- query_map: HashMap<u64, (Vec<u8>, QueryResponder)>,
- expired_time: Option<BootTime>,
- },
- /// Indicate that the Connection can't be used due to
- /// network or unexpected reasons.
- Error,
-}
-
-impl ConnectionState {
- fn is_connected(&self) -> bool {
- matches!(*self, Self::Connected { .. })
- }
- fn is_error(&self) -> bool {
- matches!(*self, Self::Error)
- }
-}
-
-enum H3Result {
- Data { data: Vec<u8> },
- Finished,
- Ignore,
-}
-
-/// Context for a running DoH engine.
-pub struct DohDispatcher {
- /// Used to submit cmds to the I/O task.
- cmd_sender: CmdSender,
- join_handle: task::JoinHandle<Result<()>>,
- runtime: Runtime,
-}
-
-// DoH dispatcher
-impl DohDispatcher {
- fn new(validation: ValidationReporter, tag_socket: SocketTagger) -> Result<DohDispatcher> {
- let (cmd_sender, cmd_receiver) = mpsc::channel::<DohCommand>(MAX_BUFFERED_CMD_SIZE);
- let runtime = Builder::new_multi_thread()
- .worker_threads(2)
- .enable_all()
- .thread_name("doh-handler")
- .build()
- .expect("Failed to create tokio runtime");
- let join_handle = runtime.spawn(doh_handler(cmd_receiver, validation, tag_socket));
- Ok(DohDispatcher { cmd_sender, join_handle, runtime })
- }
-
- fn send_cmd(&self, cmd: DohCommand) -> Result<()> {
- self.cmd_sender.blocking_send(cmd)?;
- Ok(())
- }
-
- fn exit_handler(&mut self) {
- if self.cmd_sender.blocking_send(DohCommand::Exit).is_err() {
- return;
- }
- let _ = self.runtime.block_on(&mut self.join_handle);
- }
-}
-
-struct DohConnection {
- info: ServerInfo,
- config: Config,
- scid: SCID,
- state: ConnectionState,
- pending_queries: Vec<(DnsRequest, QueryResponder, BootTime)>,
- cached_session: Option<Vec<u8>>,
- tag_socket: SocketTagger,
-}
-
-impl DohConnection {
- fn new(info: &ServerInfo, config: Config, tag_socket: SocketTagger) -> Result<DohConnection> {
- let mut scid = [0; quiche::MAX_CONN_ID_LEN];
- ring::rand::SystemRandom::new().fill(&mut scid).context("failed to generate scid")?;
- Ok(DohConnection {
- info: info.clone(),
- config,
- scid,
- state: ConnectionState::Idle,
- pending_queries: Vec::new(),
- cached_session: None,
- tag_socket,
- })
- }
-
- async fn state_to_connecting(&mut self) -> Result<()> {
- if self.state.is_error() {
- self.state_to_idle();
- }
- self.state = match self.state {
- ConnectionState::Idle => {
- let udp_sk_std = make_doh_udp_socket(self.info.peer_addr, self.info.sk_mark)?;
- (self.tag_socket)(&udp_sk_std).await;
- let udp_sk = UdpSocket::from_std(udp_sk_std)?;
- let connid = quiche::ConnectionId::from_ref(&self.scid);
- debug!("init the connection for Network {}", self.info.net_id);
- let mut quic_conn = quiche::connect(
- self.info.domain.as_deref(),
- &connid,
- self.info.peer_addr,
- &mut self.config.take(),
- )?;
- if let Some(session) = &self.cached_session {
- if quic_conn.set_session(session).is_err() {
- warn!("can't restore session for network {}", self.info.net_id);
- }
- }
- ConnectionState::Connecting {
- quic_conn: Some(quic_conn),
- udp_sk: Some(udp_sk),
- expired_time: None,
- }
- }
- ConnectionState::Error => panic!("state_to_idle did not transition"),
- ConnectionState::Connecting { .. } => return Ok(()),
- ConnectionState::Connected { .. } => {
- panic!("Invalid state transition to Connecting state!")
- }
- };
- Ok(())
- }
-
- fn state_to_connected(&mut self) -> Result<()> {
- self.state = match &mut self.state {
- // Only Connecting -> Connected is valid.
- ConnectionState::Connecting { quic_conn, udp_sk, .. } => {
- if let (Some(mut quic_conn), Some(udp_sk)) = (quic_conn.take(), udp_sk.take()) {
- let h3_config = h3::Config::new()?;
- let h3_conn =
- quiche::h3::Connection::with_transport(&mut quic_conn, &h3_config)?;
- ConnectionState::Connected {
- quic_conn,
- udp_sk,
- h3_conn: Some(h3_conn),
- query_map: HashMap::new(),
- expired_time: None,
- }
- } else {
- bail!("state transition fail!");
- }
- }
- // The rest should fail.
- _ => panic!("Invalid state transition to Connected state!"),
- };
- Ok(())
- }
-
- fn state_to_idle(&mut self) {
- self.state = match self.state {
- // Only either Connected or Error -> Idle is valid.
- // TODO: Error -> Idle is the re-probing case, add the relevant statistic.
- ConnectionState::Connected { .. } | ConnectionState::Error => ConnectionState::Idle,
- // The rest should fail.
- _ => panic!("Invalid state transition to Idle state!"),
- }
- }
-
- fn state_to_error(&mut self) {
- self.pending_queries.clear();
- self.state = ConnectionState::Error
- }
-
- fn is_reprobe_required(&self) -> bool {
- matches!(self.state, ConnectionState::Error)
- }
-
- fn has_not_handled_queries(&self) -> bool {
- match &self.state {
- ConnectionState::Connecting { .. } | ConnectionState::Idle => {
- !self.pending_queries.is_empty()
- }
- ConnectionState::Connected { query_map, .. } => {
- !query_map.is_empty() || !self.pending_queries.is_empty()
- }
- _ => false,
- }
- }
-
- fn handle_if_connection_expired(&mut self) {
- let expired_time = match &mut self.state {
- ConnectionState::Connecting { expired_time, .. } => expired_time,
- ConnectionState::Connected { expired_time, .. } => expired_time,
- // ignore
- _ => return,
- };
-
- if let Some(expired_time) = expired_time {
- if let Some(elapsed) = BootTime::now().checked_duration_since(*expired_time) {
- warn!(
- "Change the state to Idle due to connection timeout, {:?}, {}",
- elapsed, self.info.net_id
- );
- self.state_to_idle();
- }
- }
- }
-
- async fn probe(&mut self, t: Duration) -> Result<()> {
- match timeout(t, async {
- self.try_connect().await?;
- info!("probe start for {}", self.info.net_id);
- if let ConnectionState::Connected { quic_conn, udp_sk, h3_conn, expired_time, .. } =
- &mut self.state
- {
- let h3_conn = h3_conn.as_mut().ok_or_else(|| anyhow!("h3 conn isn't available"))?;
- let req = match encoding::probe_query() {
- Ok(q) => match encoding::dns_request(&q, &self.info.url) {
- Ok(req) => req,
- Err(e) => bail!(e),
- },
- Err(e) => bail!(e),
- };
- // Send the probe query.
- let req_id = h3_conn.send_request(quic_conn, &req, true /*fin*/)?;
- loop {
- flush_tx(quic_conn, udp_sk).await?;
- recv_rx(quic_conn, udp_sk, expired_time).await?;
- loop {
- match recv_h3(quic_conn, h3_conn) {
- Ok((stream_id, H3Result::Finished)) => {
- if stream_id == req_id {
- return Ok(());
- }
- }
- // TODO: Verify the answer
- Ok((_stream_id, H3Result::Data { .. })) => {}
- Ok((_stream_id, H3Result::Ignore)) => {}
- Err(_) => break,
- }
- }
- }
- } else {
- bail!("state error while performing probe()");
- }
- })
- .await
- {
- Ok(v) => match v {
- Ok(_) => Ok(()),
- Err(e) => {
- self.state_to_error();
- bail!(e);
- }
- },
- Err(e) => {
- self.state_to_error();
- bail!(e);
- }
- }
- }
-
- async fn try_connect(&mut self) -> Result<()> {
- if matches!(self.state, ConnectionState::Connected { .. }) {
- return Ok(());
- }
- self.state_to_connecting().await?;
- debug!("connecting to Network {}", self.info.net_id);
-
- let (quic_conn, udp_sk, expired_time) = match &mut self.state {
- ConnectionState::Connecting { quic_conn, udp_sk, expired_time, .. } => {
- if let (Some(quic_conn), Some(udp_sk)) = (quic_conn.as_mut(), udp_sk.as_mut()) {
- (quic_conn, udp_sk, expired_time)
- } else {
- bail!("unexpected error while performing connect()");
- }
- }
- _ => bail!("state error while performing try_connect()"),
- };
-
- while !quic_conn.is_established() {
- flush_tx(quic_conn, udp_sk).await?;
- recv_rx(quic_conn, udp_sk, expired_time).await?;
- }
- self.cached_session = quic_conn.session();
- self.state_to_connected()?;
- info!("connected to Network {}", self.info.net_id);
- Ok(())
- }
-
- async fn try_send_doh_query(
- &mut self,
- req: DnsRequest,
- resp: QueryResponder,
- expired_time: BootTime,
- ) -> Result<()> {
- self.handle_if_connection_expired();
- match &mut self.state {
- ConnectionState::Connected { quic_conn, udp_sk, h3_conn, query_map, .. } => {
- let h3_conn = h3_conn.as_mut().ok_or_else(|| anyhow!("h3 conn isn't available"))?;
- send_dns_query(
- quic_conn,
- udp_sk,
- h3_conn,
- query_map,
- &mut self.pending_queries,
- resp,
- expired_time,
- req,
- )
- .await?
- }
- ConnectionState::Connecting { .. } | ConnectionState::Idle => {
- self.pending_queries.push((req, resp, expired_time))
- }
- ConnectionState::Error => {
- error!(
- "state is error while performing try_send_doh_query(), network: {}",
- self.info.net_id
- );
- let _ = resp.send(Response::Error { error: QueryError::BrokenServer });
- }
- }
- Ok(())
- }
-
- async fn process_queries(&mut self) -> Result<()> {
- debug!("process_queries entry, Network {}", self.info.net_id);
- self.try_connect().await?;
- if let ConnectionState::Connected { quic_conn, udp_sk, h3_conn, query_map, expired_time } =
- &mut self.state
- {
- let h3_conn = h3_conn.as_mut().ok_or_else(|| anyhow!("h3 conn isn't available"))?;
- loop {
- while !self.pending_queries.is_empty() {
- if let Some((req, resp, exp_time)) = self.pending_queries.pop() {
- // Ignore the expired queries.
- if BootTime::now().checked_duration_since(exp_time).is_some() {
- warn!("Drop the obsolete query for network {}", self.info.net_id);
- continue;
- }
- send_dns_query(
- quic_conn,
- udp_sk,
- h3_conn,
- query_map,
- &mut self.pending_queries,
- resp,
- exp_time,
- req,
- )
- .await?;
- }
- }
- flush_tx(quic_conn, udp_sk).await?;
- recv_rx(quic_conn, udp_sk, expired_time).await?;
- loop {
- match recv_h3(quic_conn, h3_conn) {
- Ok((stream_id, H3Result::Data { mut data })) => {
- if let Some((answer, _)) = query_map.get_mut(&stream_id) {
- answer.append(&mut data);
- } else {
- // Should not happen
- warn!("No associated receiver found while receiving Data, Network {}, stream id: {}", self.info.net_id, stream_id);
- }
- }
- Ok((stream_id, H3Result::Finished)) => {
- if let Some((answer, resp)) = query_map.remove(&stream_id) {
- debug!(
- "sending answer back to resolv, Network {}, stream id: {}",
- self.info.net_id, stream_id
- );
- resp.send(Response::Success { answer }).unwrap_or_else(|e| {
- trace!(
- "the receiver dropped {:?}, stream id: {}",
- e,
- stream_id
- );
- });
- } else {
- // Should not happen
- warn!("No associated receiver found while receiving Finished, Network {}, stream id: {}", self.info.net_id, stream_id);
- }
- }
- Ok((_stream_id, H3Result::Ignore)) => {}
- Err(_) => break,
- }
- }
- if quic_conn.is_closed() || !quic_conn.is_established() {
- self.state_to_idle();
- bail!("connection become idle");
- }
- }
- } else {
- self.state_to_error();
- bail!("state error while performing process_queries(), network: {}", self.info.net_id);
- }
- }
-}
-
-fn recv_h3(
- quic_conn: &mut Pin<Box<quiche::Connection>>,
- h3_conn: &mut h3::Connection,
-) -> Result<(u64, H3Result)> {
- match h3_conn.poll(quic_conn) {
- // Process HTTP/3 events.
- Ok((stream_id, quiche::h3::Event::Data)) => {
- debug!("quiche::h3::Event::Data");
- let mut buf = vec![0; config::MAX_DATAGRAM_SIZE];
- match h3_conn.recv_body(quic_conn, stream_id, &mut buf) {
- Ok(read) => {
- trace!(
- "got {} bytes of response data on stream {}: {:x?}",
- read,
- stream_id,
- &buf[..read]
- );
- buf.truncate(read);
- Ok((stream_id, H3Result::Data { data: buf }))
- }
- Err(e) => {
- warn!("recv_h3::recv_body {:?}", e);
- bail!(e);
- }
- }
- }
- Ok((stream_id, quiche::h3::Event::Headers { list, has_body })) => {
- trace!(
- "got response headers {:?} on stream id {} has_body {}",
- list,
- stream_id,
- has_body
- );
- Ok((stream_id, H3Result::Ignore))
- }
- Ok((stream_id, quiche::h3::Event::Finished)) => {
- debug!("quiche::h3::Event::Finished on stream id {}", stream_id);
- Ok((stream_id, H3Result::Finished))
- }
- Ok((stream_id, quiche::h3::Event::Datagram)) => {
- debug!("quiche::h3::Event::Datagram on stream id {}", stream_id);
- Ok((stream_id, H3Result::Ignore))
- }
- // TODO: Check if it's necessary to handle GoAway event.
- Ok((stream_id, quiche::h3::Event::GoAway)) => {
- debug!("quiche::h3::Event::GoAway on stream id {}", stream_id);
- Ok((stream_id, H3Result::Ignore))
- }
- Err(e) => {
- debug!("recv_h3 {:?}", e);
- bail!(e);
- }
- }
-}
-
-#[allow(clippy::too_many_arguments)]
-async fn send_dns_query(
- quic_conn: &mut Pin<Box<quiche::Connection>>,
- udp_sk: &mut UdpSocket,
- h3_conn: &mut h3::Connection,
- query_map: &mut HashMap<u64, (Vec<u8>, QueryResponder)>,
- pending_queries: &mut Vec<(DnsRequest, QueryResponder, BootTime)>,
- resp: QueryResponder,
- expired_time: BootTime,
- req: DnsRequest,
-) -> Result<()> {
- if !quic_conn.is_established() {
- bail!("quic connection is not ready");
- }
- match h3_conn.send_request(quic_conn, &req, true /*fin*/) {
- Ok(stream_id) => {
- query_map.insert(stream_id, (Vec::new(), resp));
- flush_tx(quic_conn, udp_sk).await?;
- debug!("send dns query successfully stream id: {}", stream_id);
- Ok(())
- }
- Err(quiche::h3::Error::StreamBlocked) => {
- warn!("try to send query but error on StreamBlocked");
- pending_queries.push((req, resp, expired_time));
- Ok(())
- }
- Err(e) => {
- resp.send(Response::Error { error: QueryError::ConnectionError }).ok();
- bail!(e);
- }
- }
-}
-
-async fn recv_rx(
- quic_conn: &mut Pin<Box<quiche::Connection>>,
- udp_sk: &mut UdpSocket,
- expired_time: &mut Option<BootTime>,
-) -> Result<()> {
- // TODO: Evaluate if we could make the buffer smaller.
- let mut buf = [0; 65535];
- let quic_idle_timeout_ms = Duration::from_millis(config::QUICHE_IDLE_TIMEOUT_MS);
- let ts = quic_conn.timeout().unwrap_or(quic_idle_timeout_ms);
-
- if let Some(next_expired) = BootTime::now().checked_add(quic_idle_timeout_ms) {
- expired_time.replace(next_expired);
- } else {
- expired_time.take();
- }
- debug!("recv_rx entry next timeout {:?} {:?}", ts, expired_time);
- match timeout(ts, udp_sk.recv_from(&mut buf)).await {
- Ok(v) => match v {
- Ok((size, from)) => {
- let recv_info = quiche::RecvInfo { from };
- let processed = match quic_conn.recv(&mut buf[..size], recv_info) {
- Ok(l) => l,
- Err(e) => {
- debug!("recv_rx error {:?}", e);
- bail!("quic recv failed: {:?}", e);
- }
- };
- debug!("processed {} bytes", processed);
- Ok(())
- }
- Err(e) => bail!("socket recv failed: {:?}", e),
- },
- Err(_) => {
- warn!("timeout did not receive value within {:?}", ts);
- quic_conn.on_timeout();
- Ok(())
- }
- }
-}
-
-async fn flush_tx(
- quic_conn: &mut Pin<Box<quiche::Connection>>,
- udp_sk: &mut UdpSocket,
-) -> Result<()> {
- let mut out = [0; config::MAX_DATAGRAM_SIZE];
- loop {
- let (write, _) = match quic_conn.send(&mut out) {
- Ok(v) => v,
- Err(quiche::Error::Done) => {
- debug!("done writing");
- break;
- }
- Err(e) => {
- quic_conn.close(false, 0x1, b"fail").ok();
- bail!(e);
- }
- };
- udp_sk.send(&out[..write]).await?;
- debug!("written {}", write);
- }
- Ok(())
-}
-
-async fn handle_probe_result(
- result: (ServerInfo, Result<DohConnection, (anyhow::Error, DohConnection)>),
- doh_conn_map: &mut HashMap<u32, (ServerInfo, Option<DohConnection>)>,
- validation: &ValidationReporter,
-) {
- let (info, doh_conn) = match result {
- (info, Ok(doh_conn)) => {
- info!("probing_task success on net_id: {}", info.net_id);
- (info, doh_conn)
- }
- (info, Err((e, doh_conn))) => {
- error!("probe failed on network {}, {:?}", e, info.net_id);
- (info, doh_conn)
- // TODO: Retry probe?
- }
- };
- // If the network is removed or the server is replaced before probing,
- // ignore the probe result.
- match doh_conn_map.get(&info.net_id) {
- Some((server_info, _)) => {
- if *server_info != info {
- warn!(
- "The previous configuration for network {} was replaced before probe finished",
- info.net_id
- );
- return;
- }
- }
- _ => {
- warn!("network {} was removed before probe finished", info.net_id);
- return;
- }
- }
- validation(&info, doh_conn.state.is_connected()).await;
- doh_conn_map.insert(info.net_id, (info, Some(doh_conn)));
-}
-
-async fn probe_task(
- info: ServerInfo,
- mut doh: DohConnection,
- t: Duration,
-) -> (ServerInfo, Result<DohConnection, (anyhow::Error, DohConnection)>) {
- match doh.probe(t).await {
- Ok(_) => (info, Ok(doh)),
- Err(e) => (info, Err((anyhow!(e), doh))),
- }
-}
-
-fn make_connection_if_needed(
- info: &ServerInfo,
- doh_conn_map: &mut HashMap<u32, (ServerInfo, Option<DohConnection>)>,
- config_cache: &config::Cache,
- tag_socket: SocketTagger,
-) -> Result<Option<DohConnection>> {
- // Check if connection exists.
- match doh_conn_map.get(&info.net_id) {
- // The connection exists but has failed. Re-probe.
- Some((server_info, Some(doh))) if *server_info == *info && doh.is_reprobe_required() => {
- let (_, doh) = doh_conn_map
- .insert(info.net_id, (info.clone(), None))
- .ok_or_else(|| anyhow!("unexpected error, missing connection"))?;
- return Ok(doh);
- }
- // The connection exists or the connection is under probing, ignore.
- Some((server_info, _)) if *server_info == *info => return Ok(None),
- // TODO: change the inner connection instead of removing?
- _ => doh_conn_map.remove(&info.net_id),
- };
- let config = config_cache.from_cert_path(&info.cert_path)?;
- let doh = DohConnection::new(info, config, tag_socket)?;
- doh_conn_map.insert(info.net_id, (info.clone(), None));
- Ok(Some(doh))
-}
-
-async fn handle_query_cmd(
- net_id: u32,
- base64_query: Base64Query,
- expired_time: BootTime,
- resp: QueryResponder,
- doh_conn_map: &mut HashMap<u32, (ServerInfo, Option<DohConnection>)>,
-) {
- if let Some((info, quic_conn)) = doh_conn_map.get_mut(&net_id) {
- match (&info.domain, quic_conn) {
- // Connection is not ready, strict mode
- (Some(_), None) => {
- let _ = resp.send(Response::Error { error: QueryError::ServerNotReady });
- }
- // Connection is not ready, Opportunistic mode
- (None, None) => {
- let _ = resp.send(Response::Error { error: QueryError::ServerNotReady });
- }
- // Connection is ready
- (_, Some(quic_conn)) => {
- if let Ok(req) = encoding::dns_request(&base64_query, &info.url) {
- let _ = quic_conn.try_send_doh_query(req, resp, expired_time).await;
- } else {
- let _ = resp.send(Response::Error { error: QueryError::Unexpected });
- }
- }
- }
- } else {
- error!("No connection is associated with the given net id {}", net_id);
- let _ = resp.send(Response::Error { error: QueryError::ServerNotReady });
- }
-}
-fn need_process_queries(doh_conn_map: &HashMap<u32, (ServerInfo, Option<DohConnection>)>) -> bool {
- if doh_conn_map.is_empty() {
- return false;
- }
- for (_, doh_conn) in doh_conn_map.values() {
- if let Some(doh_conn) = doh_conn {
- if doh_conn.has_not_handled_queries() {
- return true;
- }
- }
- }
- false
-}
-
-async fn doh_handler(
- mut cmd_rx: CmdReceiver,
- validation: ValidationReporter,
- tag_socket: SocketTagger,
-) -> Result<()> {
- info!("doh_dispatcher entry");
- let config_cache = config::Cache::new();
-
- // Currently, only support 1 server per network.
- let mut doh_conn_map: HashMap<u32, (ServerInfo, Option<DohConnection>)> = HashMap::new();
- let mut probe_futures = FuturesUnordered::new();
- loop {
- tokio::select! {
- _ = async {
- let mut futures = vec![];
- for (_, doh_conn) in doh_conn_map.values_mut() {
- if let Some(doh_conn) = doh_conn {
- futures.push(doh_conn.process_queries());
- }
- }
- join_all(futures).await
- }, if need_process_queries(&doh_conn_map) => {},
- Some(result) = probe_futures.next() => {
- handle_probe_result(result, &mut doh_conn_map, &validation).await;
- info!("probe_futures remaining size: {}", probe_futures.len());
- },
- Some(cmd) = cmd_rx.recv() => {
- trace!("recv {:?}", cmd);
- match cmd {
- DohCommand::Probe { info, timeout: t } => {
- match make_connection_if_needed(&info, &mut doh_conn_map, &config_cache, tag_socket.clone()) {
- Ok(Some(doh)) => {
- // Create a new async task associated to the DoH connection.
- probe_futures.push(probe_task(info, doh, t));
- debug!("probe_futures size: {}", probe_futures.len());
- }
- Ok(None) => {
- // No further probe is needed.
- warn!("connection for network {} already exists", info.net_id);
- // TODO: Report the status again?
- }
- Err(e) => {
- error!("create connection for network {} error {:?}", info.net_id, e);
- validation(&info, false).await
- }
- }
- },
- DohCommand::Query { net_id, base64_query, expired_time, resp } => {
- handle_query_cmd(net_id, base64_query, expired_time, resp, &mut doh_conn_map).await;
- },
- DohCommand::Clear { net_id } => {
- doh_conn_map.remove(&net_id);
- info!("Doh Clear server for netid: {}", net_id);
- config_cache.garbage_collect();
- },
- DohCommand::Exit => return Ok(()),
- }
- }
- }
- }
-}
-
-fn make_doh_udp_socket(peer_addr: SocketAddr, mark: u32) -> Result<std::net::UdpSocket> {
- let bind_addr = match peer_addr {
- std::net::SocketAddr::V4(_) => "0.0.0.0:0",
- std::net::SocketAddr::V6(_) => "[::]:0",
- };
- let udp_sk = std::net::UdpSocket::bind(bind_addr)?;
- udp_sk.set_nonblocking(true)?;
- if mark_socket(udp_sk.as_raw_fd(), mark).is_err() {
- warn!("Mark socket failed, is it a test?");
- }
- udp_sk.connect(peer_addr)?;
-
- trace!("connecting to {:} from {:}", peer_addr, udp_sk.local_addr()?);
- Ok(udp_sk)
-}
-
-fn mark_socket(fd: RawFd, mark: u32) -> Result<()> {
- // libc::setsockopt is a wrapper function calling into bionic setsockopt.
- // Both fd and mark are valid, which makes the function call mostly safe.
- if unsafe {
- libc::setsockopt(
- fd,
- libc::SOL_SOCKET,
- libc::SO_MARK,
- &mark as *const _ as *const libc::c_void,
- std::mem::size_of::<u32>() as libc::socklen_t,
- )
- } == 0
- {
- Ok(())
- } else {
- Err(anyhow::Error::new(std::io::Error::last_os_error()))
- }
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
- use futures::FutureExt;
- use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
-
- const TEST_NET_ID: u32 = 50;
- const TEST_MARK: u32 = 0xD0033;
- const LOOPBACK_ADDR: &str = "127.0.0.1:443";
- const LOCALHOST_URL: &str = "https://mylocal.com/dns-query";
-
- // TODO: Make some tests for DohConnection.
-
- fn make_testing_variables(
- ) -> (ServerInfo, HashMap<u32, (ServerInfo, Option<DohConnection>)>, config::Cache, Runtime)
- {
- let test_map: HashMap<u32, (ServerInfo, Option<DohConnection>)> = HashMap::new();
- let info = ServerInfo {
- net_id: TEST_NET_ID,
- url: Url::parse(LOCALHOST_URL).unwrap(),
- peer_addr: LOOPBACK_ADDR.parse().unwrap(),
- domain: None,
- sk_mark: 0,
- cert_path: None,
- };
- let config_cache = config::Cache::new();
- let rt = Builder::new_current_thread()
- .thread_name("test-runtime")
- .enable_all()
- .build()
- .expect("Failed to create testing tokio runtime");
- (info, test_map, config_cache, rt)
- }
-
- fn build_socket_tagger() -> SocketTagger {
- Arc::new(|_| async {}.boxed())
- }
-
- #[test]
- fn make_connection_if_needed() {
- let (info, mut test_map, config_cache, rt) = make_testing_variables();
- rt.block_on(async {
- // Expect to make a new connection.
- let mut doh = super::make_connection_if_needed(
- &info,
- &mut test_map,
- &config_cache,
- build_socket_tagger(),
- )
- .unwrap()
- .unwrap();
- assert_eq!(doh.info.net_id, info.net_id);
- assert!(matches!(doh.state, ConnectionState::Idle));
- doh.state = ConnectionState::Error;
- test_map.insert(info.net_id, (info.clone(), Some(doh)));
- // Expect that we will get a connection with fail status that we added to the map before.
- let mut doh = super::make_connection_if_needed(
- &info,
- &mut test_map,
- &config_cache,
- build_socket_tagger(),
- )
- .unwrap()
- .unwrap();
- assert_eq!(doh.info.net_id, info.net_id);
- assert!(matches!(doh.state, ConnectionState::Error));
- doh.state = make_dummy_connected_state();
- test_map.insert(info.net_id, (info.clone(), Some(doh)));
- // Expect that we will get None because the map contains a connection with ready status.
- assert!(super::make_connection_if_needed(
- &info,
- &mut test_map,
- &config_cache,
- build_socket_tagger()
- )
- .unwrap()
- .is_none());
- });
- }
-
- #[test]
- fn handle_query_cmd() {
- let (info, mut test_map, config_cache, rt) = make_testing_variables();
- let t = Duration::from_millis(100);
- rt.block_on(async {
- // Test no available server cases.
- let (resp_tx, resp_rx) = oneshot::channel();
- let query = encoding::probe_query().unwrap();
- super::handle_query_cmd(
- info.net_id,
- query.clone(),
- BootTime::now().checked_add(t).unwrap(),
- resp_tx,
- &mut test_map,
- )
- .await;
- assert_eq!(
- timeout(t, resp_rx).await.unwrap().unwrap(),
- Response::Error { error: QueryError::ServerNotReady }
- );
-
- let (resp_tx, resp_rx) = oneshot::channel();
- test_map.insert(info.net_id, (info.clone(), None));
- super::handle_query_cmd(
- info.net_id,
- query.clone(),
- BootTime::now().checked_add(t).unwrap(),
- resp_tx,
- &mut test_map,
- )
- .await;
- assert_eq!(
- timeout(t, resp_rx).await.unwrap().unwrap(),
- Response::Error { error: QueryError::ServerNotReady }
- );
-
- // Test the connection broken case.
- test_map.clear();
- let (resp_tx, resp_rx) = oneshot::channel();
- let mut doh = super::make_connection_if_needed(
- &info,
- &mut test_map,
- &config_cache,
- build_socket_tagger(),
- )
- .unwrap()
- .unwrap();
- doh.state = ConnectionState::Error;
- test_map.insert(info.net_id, (info.clone(), Some(doh)));
- super::handle_query_cmd(
- info.net_id,
- query.clone(),
- BootTime::now().checked_add(t).unwrap(),
- resp_tx,
- &mut test_map,
- )
- .await;
- assert_eq!(
- timeout(t, resp_rx).await.unwrap().unwrap(),
- Response::Error { error: QueryError::BrokenServer }
- );
- });
- }
-
- fn make_testing_connection_variables() -> (Pin<Box<quiche::Connection>>, UdpSocket) {
- let sk = super::make_doh_udp_socket(LOOPBACK_ADDR.parse().unwrap(), TEST_MARK).unwrap();
- let udp_sk = UdpSocket::from_std(sk).unwrap();
- let mut scid = [0; quiche::MAX_CONN_ID_LEN];
- ring::rand::SystemRandom::new().fill(&mut scid).context("failed to generate scid").unwrap();
- let connid = quiche::ConnectionId::from_ref(&scid);
- let mut config = Config::from_cert_path(None).unwrap();
- let quic_conn =
- quiche::connect(None, &connid, LOOPBACK_ADDR.parse().unwrap(), &mut config.take())
- .unwrap();
- (quic_conn, udp_sk)
- }
-
- fn make_dummy_connected_state() -> super::ConnectionState {
- let (quic_conn, udp_sk) = make_testing_connection_variables();
- ConnectionState::Connected {
- quic_conn,
- udp_sk,
- h3_conn: None,
- query_map: HashMap::new(),
- expired_time: None,
- }
- }
-
- #[test]
- fn make_doh_udp_socket() {
- // Make a socket connecting to loopback with a test mark.
- let sk = super::make_doh_udp_socket(LOOPBACK_ADDR.parse().unwrap(), TEST_MARK).unwrap();
- // Check if the socket is connected to loopback.
- assert_eq!(
- sk.peer_addr().unwrap(),
- SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), DOH_PORT))
- );
-
- // Check if the socket mark is correct.
- let fd: RawFd = sk.as_raw_fd();
-
- let mut mark: u32 = 50;
- let mut size = std::mem::size_of::<u32>() as libc::socklen_t;
- unsafe {
- // Safety: It's fine since the fd belongs to this test.
- assert_eq!(
- libc::getsockopt(
- fd,
- libc::SOL_SOCKET,
- libc::SO_MARK,
- &mut mark as *mut _ as *mut libc::c_void,
- &mut size as *mut libc::socklen_t,
- ),
- 0
- );
- }
- assert_eq!(mark, TEST_MARK);
-
- // Check if the socket is non-blocking.
- unsafe {
- // Safety: It's fine since the fd belongs to this test.
- assert_eq!(libc::fcntl(fd, libc::F_GETFL, 0) & libc::O_NONBLOCK, libc::O_NONBLOCK);
- }
- }
-}
+mod network;
diff --git a/doh/ffi.rs b/doh/ffi.rs
index ce2f5a64..2df5be68 100644
--- a/doh/ffi.rs
+++ b/doh/ffi.rs
@@ -17,6 +17,8 @@
//! C API for the DoH backend for the Android DnsResolver module.
use crate::boot_time::{timeout, BootTime, Duration};
+use crate::dispatcher::{Command, Dispatcher, Response, ServerInfo};
+use crate::network::{SocketTagger, ValidationReporter};
use futures::FutureExt;
use libc::{c_char, int32_t, size_t, ssize_t, uint32_t, uint64_t};
use log::{error, warn};
@@ -25,22 +27,19 @@ use std::net::{IpAddr, SocketAddr};
use std::ops::DerefMut;
use std::os::unix::io::RawFd;
use std::str::FromStr;
-use std::sync::Mutex;
+use std::sync::{Arc, Mutex};
use std::{ptr, slice};
use tokio::runtime::Runtime;
use tokio::sync::oneshot;
use tokio::task;
use url::Url;
-use super::DohDispatcher as Dispatcher;
-use super::{DohCommand, Response, ServerInfo, SocketTagger, ValidationReporter, DOH_PORT};
-
pub type ValidationCallback =
extern "C" fn(net_id: uint32_t, success: bool, ip_addr: *const c_char, host: *const c_char);
pub type TagSocketCallback = extern "C" fn(sock: RawFd);
fn wrap_validation_callback(validation_fn: ValidationCallback) -> ValidationReporter {
- Box::new(move |info: &ServerInfo, success: bool| {
+ Arc::new(move |info: &ServerInfo, success: bool| {
async move {
let (ip_addr, domain) = match (
CString::new(info.peer_addr.ip().to_string()),
@@ -65,7 +64,6 @@ fn wrap_validation_callback(validation_fn: ValidationCallback) -> ValidationRepo
fn wrap_tag_socket_callback(tag_socket_fn: TagSocketCallback) -> SocketTagger {
use std::os::unix::io::AsRawFd;
- use std::sync::Arc;
Arc::new(move |udp_socket: &std::net::UdpSocket| {
let fd = udp_socket.as_raw_fd();
async move {
@@ -107,6 +105,8 @@ pub const DOH_LOG_LEVEL_DEBUG: u32 = 3;
/// The trace log level.
pub const DOH_LOG_LEVEL_TRACE: u32 = 4;
+const DOH_PORT: u16 = 443;
+
fn level_from_u32(level: u32) -> Option<log::Level> {
use log::Level::*;
match level {
@@ -216,7 +216,7 @@ pub unsafe extern "C" fn doh_net_new(
return -libc::EINVAL;
}
};
- let cmd = DohCommand::Probe {
+ let cmd = Command::Probe {
info: ServerInfo {
net_id,
url,
@@ -257,7 +257,7 @@ pub unsafe extern "C" fn doh_query(
let (resp_tx, resp_rx) = oneshot::channel();
let t = Duration::from_millis(timeout_ms);
if let Some(expired_time) = BootTime::now().checked_add(t) {
- let cmd = DohCommand::Query {
+ let cmd = Command::Query {
net_id,
base64_query: base64::encode_config(q, base64::URL_SAFE_NO_PAD),
expired_time,
@@ -309,7 +309,7 @@ pub unsafe extern "C" fn doh_query(
/// and not yet deleted by `doh_dispatcher_delete()`.
#[no_mangle]
pub extern "C" fn doh_net_delete(doh: &DohDispatcher, net_id: uint32_t) {
- if let Err(e) = doh.lock().send_cmd(DohCommand::Clear { net_id }) {
+ if let Err(e) = doh.lock().send_cmd(Command::Clear { net_id }) {
error!("Failed to send the query: {:?}", e);
}
}
diff --git a/doh/network/driver.rs b/doh/network/driver.rs
new file mode 100644
index 00000000..6c17f35e
--- /dev/null
+++ b/doh/network/driver.rs
@@ -0,0 +1,190 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+//! Provides a backing task to implement a network
+
+use crate::boot_time::{timeout, BootTime, Duration};
+use crate::config::Config;
+use crate::connection::Connection;
+use crate::dispatcher::{QueryError, Response};
+use crate::encoding;
+use anyhow::{anyhow, Result};
+use std::sync::Arc;
+use tokio::sync::{mpsc, watch};
+use tokio::task;
+
+use super::{Query, ServerInfo, SocketTagger, ValidationReporter};
+
+pub struct Driver {
+ info: ServerInfo,
+ config: Config,
+ connection: Connection,
+ command_rx: mpsc::Receiver<Command>,
+ status_tx: watch::Sender<Status>,
+ validation: ValidationReporter,
+ tag_socket: SocketTagger,
+}
+
+#[derive(Debug)]
+/// Requests the network can handle
+pub enum Command {
+ /// Send a DNS query to the network
+ Query(Query),
+ /// Run a probe to check the health of the network. Argument is timeout.
+ Probe(Duration),
+}
+
+#[derive(Clone, Debug)]
+/// Current Network Status
+///
+/// (Unprobed or Failed) can go to (Live or Failed) via Probe.
+/// Currently, there is no way to go from Live to Failed - probing a live network will short-circuit to returning valid, and query failures do not declare the network failed.
+pub enum Status {
+ /// Network has not been probed, it may or may not work
+ Unprobed,
+ /// Network is believed to be working
+ Live,
+ /// Network is broken, reason as argument
+ Failed(Arc<anyhow::Error>),
+}
+
+impl Status {
+ pub fn is_live(&self) -> bool {
+ matches!(self, Self::Live)
+ }
+ pub fn is_failed(&self) -> bool {
+ matches!(self, Self::Failed(_))
+ }
+}
+
+async fn build_connection(
+ info: &ServerInfo,
+ tag_socket: &SocketTagger,
+ config: &mut Config,
+) -> Result<Connection> {
+ use std::ops::DerefMut;
+ Ok(Connection::new(
+ info.domain.as_deref(),
+ info.peer_addr,
+ info.sk_mark,
+ tag_socket,
+ config.take().await.deref_mut(),
+ )
+ .await?)
+}
+
+impl Driver {
+ const MAX_BUFFERED_COMMANDS: usize = 10;
+
+ pub async fn new(
+ info: ServerInfo,
+ mut config: Config,
+ validation: ValidationReporter,
+ tag_socket: SocketTagger,
+ ) -> Result<(Self, mpsc::Sender<Command>, watch::Receiver<Status>)> {
+ let (command_tx, command_rx) = mpsc::channel(Self::MAX_BUFFERED_COMMANDS);
+ let (status_tx, status_rx) = watch::channel(Status::Unprobed);
+ let connection = build_connection(&info, &tag_socket, &mut config).await?;
+ Ok((
+ Self { info, config, connection, status_tx, command_rx, validation, tag_socket },
+ command_tx,
+ status_rx,
+ ))
+ }
+
+ pub async fn drive(mut self) -> Result<()> {
+ while let Some(cmd) = self.command_rx.recv().await {
+ if let Err(e) = match cmd {
+ Command::Probe(duration) => self.probe(duration).await,
+ Command::Query(query) => self.send_query(query).await,
+ } {
+ self.status_tx.send(Status::Failed(Arc::new(e)))?
+ };
+ }
+ Ok(())
+ }
+
+ async fn probe(&mut self, probe_timeout: Duration) -> Result<()> {
+ if self.status_tx.borrow().is_failed() {
+ // If our network is currently failed, it may be due to issues with the connection.
+ // Re-establish before re-probing
+ self.connection =
+ build_connection(&self.info, &self.tag_socket, &mut self.config).await?;
+ self.status_tx.send(Status::Unprobed)?;
+ }
+ if self.status_tx.borrow().is_live() {
+ // If we're already validated, short circuit
+ (self.validation)(&self.info, true).await;
+ return Ok(());
+ }
+ self.force_probe(probe_timeout).await
+ }
+
+ async fn force_probe(&mut self, probe_timeout: Duration) -> Result<()> {
+ let probe = encoding::probe_query()?;
+ let dns_request = encoding::dns_request(&probe, &self.info.url)?;
+ let expiry = BootTime::now().checked_add(probe_timeout);
+ let request = async {
+ match self.connection.query(dns_request, expiry).await {
+ Err(e) => self.status_tx.send(Status::Failed(Arc::new(anyhow!(e)))),
+ Ok(rsp) => {
+ if let Some(_stream) = rsp.await {
+ // TODO verify stream contents
+ self.status_tx.send(Status::Live)
+ } else {
+ self.status_tx.send(Status::Failed(Arc::new(anyhow!("Empty response"))))
+ }
+ }
+ }
+ };
+ match timeout(probe_timeout, request).await {
+ // Timed out
+ Err(time) => self.status_tx.send(Status::Failed(Arc::new(anyhow!(
+ "Probe timed out after {:?} (timeout={:?})",
+ time,
+ probe_timeout
+ )))),
+ // Query completed
+ Ok(r) => r,
+ }?;
+ let valid = self.status_tx.borrow().is_live();
+ (self.validation)(&self.info, valid).await;
+ Ok(())
+ }
+
+ async fn send_query(&mut self, query: Query) -> Result<()> {
+ let request = encoding::dns_request(&query.query, &self.info.url)?;
+ let stream_fut = self.connection.query(request, Some(query.expiry)).await?;
+ task::spawn(async move {
+ let stream = match stream_fut.await {
+ Some(stream) => stream,
+ None => {
+ // We don't care if the response is gone
+ let _ =
+ query.response.send(Response::Error { error: QueryError::ConnectionError });
+ return;
+ }
+ };
+ // We don't care if the response is gone.
+ let _ = if let Some(err) = stream.error {
+ query.response.send(Response::Error { error: QueryError::Reset(err) })
+ } else {
+ query.response.send(Response::Success { answer: stream.data })
+ };
+ });
+ Ok(())
+ }
+}
diff --git a/doh/network/mod.rs b/doh/network/mod.rs
new file mode 100644
index 00000000..5d26688c
--- /dev/null
+++ b/doh/network/mod.rs
@@ -0,0 +1,112 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+//! Provides the ability to query DNS for a specific network configuration
+
+use crate::boot_time::{BootTime, Duration};
+use crate::config::Config;
+use crate::dispatcher::{QueryError, Response};
+use anyhow::Result;
+use futures::future::BoxFuture;
+use log::warn;
+use std::net::SocketAddr;
+use std::sync::Arc;
+use tokio::sync::{mpsc, oneshot, watch};
+use tokio::task;
+use url::Url;
+
+mod driver;
+
+use driver::{Command, Driver};
+
+pub use driver::Status;
+
+/// Closure to signal validation status to outside world
+pub type ValidationReporter = Arc<dyn Fn(&ServerInfo, bool) -> BoxFuture<()> + Send + Sync>;
+/// Closure to tag socket during connection construction
+pub type SocketTagger = Arc<dyn Fn(&std::net::UdpSocket) -> BoxFuture<()> + Send + Sync>;
+
+#[derive(Eq, PartialEq, Debug, Clone)]
+pub struct ServerInfo {
+ pub net_id: u32,
+ pub url: Url,
+ pub peer_addr: SocketAddr,
+ pub domain: Option<String>,
+ pub sk_mark: u32,
+ pub cert_path: Option<String>,
+}
+
+#[derive(Debug)]
+/// DNS resolution query
+pub struct Query {
+ /// Raw DNS query, base64 encoded
+ pub query: String,
+ /// Place to send the answer
+ pub response: oneshot::Sender<Response>,
+ /// When this request is considered stale (will be ignored if not serviced by that point)
+ pub expiry: BootTime,
+}
+
+/// Handle to a particular network's DNS resolution
+pub struct Network {
+ info: ServerInfo,
+ status_rx: watch::Receiver<Status>,
+ command_tx: mpsc::Sender<Command>,
+}
+
+impl Network {
+ pub async fn new(
+ info: ServerInfo,
+ config: Config,
+ validation: ValidationReporter,
+ tagger: SocketTagger,
+ ) -> Result<Network> {
+ let (driver, command_tx, status_rx) =
+ Driver::new(info.clone(), config, validation, tagger).await?;
+ task::spawn(driver.drive());
+ Ok(Network { info, command_tx, status_rx })
+ }
+
+ pub async fn probe(&mut self, timeout: Duration) -> Result<()> {
+ self.command_tx.send(Command::Probe(timeout)).await?;
+ Ok(())
+ }
+
+ pub async fn query(&mut self, query: Query) -> Result<()> {
+ // The clone is used to prevent status_rx from being held across an await
+ let status: Status = self.status_rx.borrow().clone();
+ match status {
+ Status::Failed(_) => query
+ .response
+ .send(Response::Error { error: QueryError::BrokenServer })
+ .unwrap_or_else(|_| {
+ warn!("Query result listener went away before receiving a response")
+ }),
+ Status::Unprobed => query
+ .response
+ .send(Response::Error { error: QueryError::ServerNotReady })
+ .unwrap_or_else(|_| {
+ warn!("Query result listener went away before receiving a response")
+ }),
+ Status::Live => self.command_tx.send(Command::Query(query)).await?,
+ }
+ Ok(())
+ }
+
+ pub fn get_info(&self) -> &ServerInfo {
+ &self.info
+ }
+}