diff options
Diffstat (limited to 'doh/network')
-rw-r--r-- | doh/network/driver.rs | 190 | ||||
-rw-r--r-- | doh/network/mod.rs | 112 |
2 files changed, 302 insertions, 0 deletions
diff --git a/doh/network/driver.rs b/doh/network/driver.rs new file mode 100644 index 00000000..6c17f35e --- /dev/null +++ b/doh/network/driver.rs @@ -0,0 +1,190 @@ +/* + * Copyright (C) 2021 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +//! Provides a backing task to implement a network + +use crate::boot_time::{timeout, BootTime, Duration}; +use crate::config::Config; +use crate::connection::Connection; +use crate::dispatcher::{QueryError, Response}; +use crate::encoding; +use anyhow::{anyhow, Result}; +use std::sync::Arc; +use tokio::sync::{mpsc, watch}; +use tokio::task; + +use super::{Query, ServerInfo, SocketTagger, ValidationReporter}; + +pub struct Driver { + info: ServerInfo, + config: Config, + connection: Connection, + command_rx: mpsc::Receiver<Command>, + status_tx: watch::Sender<Status>, + validation: ValidationReporter, + tag_socket: SocketTagger, +} + +#[derive(Debug)] +/// Requests the network can handle +pub enum Command { + /// Send a DNS query to the network + Query(Query), + /// Run a probe to check the health of the network. Argument is timeout. + Probe(Duration), +} + +#[derive(Clone, Debug)] +/// Current Network Status +/// +/// (Unprobed or Failed) can go to (Live or Failed) via Probe. +/// Currently, there is no way to go from Live to Failed - probing a live network will short-circuit to returning valid, and query failures do not declare the network failed. +pub enum Status { + /// Network has not been probed, it may or may not work + Unprobed, + /// Network is believed to be working + Live, + /// Network is broken, reason as argument + Failed(Arc<anyhow::Error>), +} + +impl Status { + pub fn is_live(&self) -> bool { + matches!(self, Self::Live) + } + pub fn is_failed(&self) -> bool { + matches!(self, Self::Failed(_)) + } +} + +async fn build_connection( + info: &ServerInfo, + tag_socket: &SocketTagger, + config: &mut Config, +) -> Result<Connection> { + use std::ops::DerefMut; + Ok(Connection::new( + info.domain.as_deref(), + info.peer_addr, + info.sk_mark, + tag_socket, + config.take().await.deref_mut(), + ) + .await?) +} + +impl Driver { + const MAX_BUFFERED_COMMANDS: usize = 10; + + pub async fn new( + info: ServerInfo, + mut config: Config, + validation: ValidationReporter, + tag_socket: SocketTagger, + ) -> Result<(Self, mpsc::Sender<Command>, watch::Receiver<Status>)> { + let (command_tx, command_rx) = mpsc::channel(Self::MAX_BUFFERED_COMMANDS); + let (status_tx, status_rx) = watch::channel(Status::Unprobed); + let connection = build_connection(&info, &tag_socket, &mut config).await?; + Ok(( + Self { info, config, connection, status_tx, command_rx, validation, tag_socket }, + command_tx, + status_rx, + )) + } + + pub async fn drive(mut self) -> Result<()> { + while let Some(cmd) = self.command_rx.recv().await { + if let Err(e) = match cmd { + Command::Probe(duration) => self.probe(duration).await, + Command::Query(query) => self.send_query(query).await, + } { + self.status_tx.send(Status::Failed(Arc::new(e)))? + }; + } + Ok(()) + } + + async fn probe(&mut self, probe_timeout: Duration) -> Result<()> { + if self.status_tx.borrow().is_failed() { + // If our network is currently failed, it may be due to issues with the connection. + // Re-establish before re-probing + self.connection = + build_connection(&self.info, &self.tag_socket, &mut self.config).await?; + self.status_tx.send(Status::Unprobed)?; + } + if self.status_tx.borrow().is_live() { + // If we're already validated, short circuit + (self.validation)(&self.info, true).await; + return Ok(()); + } + self.force_probe(probe_timeout).await + } + + async fn force_probe(&mut self, probe_timeout: Duration) -> Result<()> { + let probe = encoding::probe_query()?; + let dns_request = encoding::dns_request(&probe, &self.info.url)?; + let expiry = BootTime::now().checked_add(probe_timeout); + let request = async { + match self.connection.query(dns_request, expiry).await { + Err(e) => self.status_tx.send(Status::Failed(Arc::new(anyhow!(e)))), + Ok(rsp) => { + if let Some(_stream) = rsp.await { + // TODO verify stream contents + self.status_tx.send(Status::Live) + } else { + self.status_tx.send(Status::Failed(Arc::new(anyhow!("Empty response")))) + } + } + } + }; + match timeout(probe_timeout, request).await { + // Timed out + Err(time) => self.status_tx.send(Status::Failed(Arc::new(anyhow!( + "Probe timed out after {:?} (timeout={:?})", + time, + probe_timeout + )))), + // Query completed + Ok(r) => r, + }?; + let valid = self.status_tx.borrow().is_live(); + (self.validation)(&self.info, valid).await; + Ok(()) + } + + async fn send_query(&mut self, query: Query) -> Result<()> { + let request = encoding::dns_request(&query.query, &self.info.url)?; + let stream_fut = self.connection.query(request, Some(query.expiry)).await?; + task::spawn(async move { + let stream = match stream_fut.await { + Some(stream) => stream, + None => { + // We don't care if the response is gone + let _ = + query.response.send(Response::Error { error: QueryError::ConnectionError }); + return; + } + }; + // We don't care if the response is gone. + let _ = if let Some(err) = stream.error { + query.response.send(Response::Error { error: QueryError::Reset(err) }) + } else { + query.response.send(Response::Success { answer: stream.data }) + }; + }); + Ok(()) + } +} diff --git a/doh/network/mod.rs b/doh/network/mod.rs new file mode 100644 index 00000000..5d26688c --- /dev/null +++ b/doh/network/mod.rs @@ -0,0 +1,112 @@ +/* + * Copyright (C) 2021 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +//! Provides the ability to query DNS for a specific network configuration + +use crate::boot_time::{BootTime, Duration}; +use crate::config::Config; +use crate::dispatcher::{QueryError, Response}; +use anyhow::Result; +use futures::future::BoxFuture; +use log::warn; +use std::net::SocketAddr; +use std::sync::Arc; +use tokio::sync::{mpsc, oneshot, watch}; +use tokio::task; +use url::Url; + +mod driver; + +use driver::{Command, Driver}; + +pub use driver::Status; + +/// Closure to signal validation status to outside world +pub type ValidationReporter = Arc<dyn Fn(&ServerInfo, bool) -> BoxFuture<()> + Send + Sync>; +/// Closure to tag socket during connection construction +pub type SocketTagger = Arc<dyn Fn(&std::net::UdpSocket) -> BoxFuture<()> + Send + Sync>; + +#[derive(Eq, PartialEq, Debug, Clone)] +pub struct ServerInfo { + pub net_id: u32, + pub url: Url, + pub peer_addr: SocketAddr, + pub domain: Option<String>, + pub sk_mark: u32, + pub cert_path: Option<String>, +} + +#[derive(Debug)] +/// DNS resolution query +pub struct Query { + /// Raw DNS query, base64 encoded + pub query: String, + /// Place to send the answer + pub response: oneshot::Sender<Response>, + /// When this request is considered stale (will be ignored if not serviced by that point) + pub expiry: BootTime, +} + +/// Handle to a particular network's DNS resolution +pub struct Network { + info: ServerInfo, + status_rx: watch::Receiver<Status>, + command_tx: mpsc::Sender<Command>, +} + +impl Network { + pub async fn new( + info: ServerInfo, + config: Config, + validation: ValidationReporter, + tagger: SocketTagger, + ) -> Result<Network> { + let (driver, command_tx, status_rx) = + Driver::new(info.clone(), config, validation, tagger).await?; + task::spawn(driver.drive()); + Ok(Network { info, command_tx, status_rx }) + } + + pub async fn probe(&mut self, timeout: Duration) -> Result<()> { + self.command_tx.send(Command::Probe(timeout)).await?; + Ok(()) + } + + pub async fn query(&mut self, query: Query) -> Result<()> { + // The clone is used to prevent status_rx from being held across an await + let status: Status = self.status_rx.borrow().clone(); + match status { + Status::Failed(_) => query + .response + .send(Response::Error { error: QueryError::BrokenServer }) + .unwrap_or_else(|_| { + warn!("Query result listener went away before receiving a response") + }), + Status::Unprobed => query + .response + .send(Response::Error { error: QueryError::ServerNotReady }) + .unwrap_or_else(|_| { + warn!("Query result listener went away before receiving a response") + }), + Status::Live => self.command_tx.send(Command::Query(query)).await?, + } + Ok(()) + } + + pub fn get_info(&self) -> &ServerInfo { + &self.info + } +} |