aboutsummaryrefslogtreecommitdiff
path: root/doh/network
diff options
context:
space:
mode:
Diffstat (limited to 'doh/network')
-rw-r--r--doh/network/driver.rs190
-rw-r--r--doh/network/mod.rs112
2 files changed, 302 insertions, 0 deletions
diff --git a/doh/network/driver.rs b/doh/network/driver.rs
new file mode 100644
index 00000000..6c17f35e
--- /dev/null
+++ b/doh/network/driver.rs
@@ -0,0 +1,190 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+//! Provides a backing task to implement a network
+
+use crate::boot_time::{timeout, BootTime, Duration};
+use crate::config::Config;
+use crate::connection::Connection;
+use crate::dispatcher::{QueryError, Response};
+use crate::encoding;
+use anyhow::{anyhow, Result};
+use std::sync::Arc;
+use tokio::sync::{mpsc, watch};
+use tokio::task;
+
+use super::{Query, ServerInfo, SocketTagger, ValidationReporter};
+
+pub struct Driver {
+ info: ServerInfo,
+ config: Config,
+ connection: Connection,
+ command_rx: mpsc::Receiver<Command>,
+ status_tx: watch::Sender<Status>,
+ validation: ValidationReporter,
+ tag_socket: SocketTagger,
+}
+
+#[derive(Debug)]
+/// Requests the network can handle
+pub enum Command {
+ /// Send a DNS query to the network
+ Query(Query),
+ /// Run a probe to check the health of the network. Argument is timeout.
+ Probe(Duration),
+}
+
+#[derive(Clone, Debug)]
+/// Current Network Status
+///
+/// (Unprobed or Failed) can go to (Live or Failed) via Probe.
+/// Currently, there is no way to go from Live to Failed - probing a live network will short-circuit to returning valid, and query failures do not declare the network failed.
+pub enum Status {
+ /// Network has not been probed, it may or may not work
+ Unprobed,
+ /// Network is believed to be working
+ Live,
+ /// Network is broken, reason as argument
+ Failed(Arc<anyhow::Error>),
+}
+
+impl Status {
+ pub fn is_live(&self) -> bool {
+ matches!(self, Self::Live)
+ }
+ pub fn is_failed(&self) -> bool {
+ matches!(self, Self::Failed(_))
+ }
+}
+
+async fn build_connection(
+ info: &ServerInfo,
+ tag_socket: &SocketTagger,
+ config: &mut Config,
+) -> Result<Connection> {
+ use std::ops::DerefMut;
+ Ok(Connection::new(
+ info.domain.as_deref(),
+ info.peer_addr,
+ info.sk_mark,
+ tag_socket,
+ config.take().await.deref_mut(),
+ )
+ .await?)
+}
+
+impl Driver {
+ const MAX_BUFFERED_COMMANDS: usize = 10;
+
+ pub async fn new(
+ info: ServerInfo,
+ mut config: Config,
+ validation: ValidationReporter,
+ tag_socket: SocketTagger,
+ ) -> Result<(Self, mpsc::Sender<Command>, watch::Receiver<Status>)> {
+ let (command_tx, command_rx) = mpsc::channel(Self::MAX_BUFFERED_COMMANDS);
+ let (status_tx, status_rx) = watch::channel(Status::Unprobed);
+ let connection = build_connection(&info, &tag_socket, &mut config).await?;
+ Ok((
+ Self { info, config, connection, status_tx, command_rx, validation, tag_socket },
+ command_tx,
+ status_rx,
+ ))
+ }
+
+ pub async fn drive(mut self) -> Result<()> {
+ while let Some(cmd) = self.command_rx.recv().await {
+ if let Err(e) = match cmd {
+ Command::Probe(duration) => self.probe(duration).await,
+ Command::Query(query) => self.send_query(query).await,
+ } {
+ self.status_tx.send(Status::Failed(Arc::new(e)))?
+ };
+ }
+ Ok(())
+ }
+
+ async fn probe(&mut self, probe_timeout: Duration) -> Result<()> {
+ if self.status_tx.borrow().is_failed() {
+ // If our network is currently failed, it may be due to issues with the connection.
+ // Re-establish before re-probing
+ self.connection =
+ build_connection(&self.info, &self.tag_socket, &mut self.config).await?;
+ self.status_tx.send(Status::Unprobed)?;
+ }
+ if self.status_tx.borrow().is_live() {
+ // If we're already validated, short circuit
+ (self.validation)(&self.info, true).await;
+ return Ok(());
+ }
+ self.force_probe(probe_timeout).await
+ }
+
+ async fn force_probe(&mut self, probe_timeout: Duration) -> Result<()> {
+ let probe = encoding::probe_query()?;
+ let dns_request = encoding::dns_request(&probe, &self.info.url)?;
+ let expiry = BootTime::now().checked_add(probe_timeout);
+ let request = async {
+ match self.connection.query(dns_request, expiry).await {
+ Err(e) => self.status_tx.send(Status::Failed(Arc::new(anyhow!(e)))),
+ Ok(rsp) => {
+ if let Some(_stream) = rsp.await {
+ // TODO verify stream contents
+ self.status_tx.send(Status::Live)
+ } else {
+ self.status_tx.send(Status::Failed(Arc::new(anyhow!("Empty response"))))
+ }
+ }
+ }
+ };
+ match timeout(probe_timeout, request).await {
+ // Timed out
+ Err(time) => self.status_tx.send(Status::Failed(Arc::new(anyhow!(
+ "Probe timed out after {:?} (timeout={:?})",
+ time,
+ probe_timeout
+ )))),
+ // Query completed
+ Ok(r) => r,
+ }?;
+ let valid = self.status_tx.borrow().is_live();
+ (self.validation)(&self.info, valid).await;
+ Ok(())
+ }
+
+ async fn send_query(&mut self, query: Query) -> Result<()> {
+ let request = encoding::dns_request(&query.query, &self.info.url)?;
+ let stream_fut = self.connection.query(request, Some(query.expiry)).await?;
+ task::spawn(async move {
+ let stream = match stream_fut.await {
+ Some(stream) => stream,
+ None => {
+ // We don't care if the response is gone
+ let _ =
+ query.response.send(Response::Error { error: QueryError::ConnectionError });
+ return;
+ }
+ };
+ // We don't care if the response is gone.
+ let _ = if let Some(err) = stream.error {
+ query.response.send(Response::Error { error: QueryError::Reset(err) })
+ } else {
+ query.response.send(Response::Success { answer: stream.data })
+ };
+ });
+ Ok(())
+ }
+}
diff --git a/doh/network/mod.rs b/doh/network/mod.rs
new file mode 100644
index 00000000..5d26688c
--- /dev/null
+++ b/doh/network/mod.rs
@@ -0,0 +1,112 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+//! Provides the ability to query DNS for a specific network configuration
+
+use crate::boot_time::{BootTime, Duration};
+use crate::config::Config;
+use crate::dispatcher::{QueryError, Response};
+use anyhow::Result;
+use futures::future::BoxFuture;
+use log::warn;
+use std::net::SocketAddr;
+use std::sync::Arc;
+use tokio::sync::{mpsc, oneshot, watch};
+use tokio::task;
+use url::Url;
+
+mod driver;
+
+use driver::{Command, Driver};
+
+pub use driver::Status;
+
+/// Closure to signal validation status to outside world
+pub type ValidationReporter = Arc<dyn Fn(&ServerInfo, bool) -> BoxFuture<()> + Send + Sync>;
+/// Closure to tag socket during connection construction
+pub type SocketTagger = Arc<dyn Fn(&std::net::UdpSocket) -> BoxFuture<()> + Send + Sync>;
+
+#[derive(Eq, PartialEq, Debug, Clone)]
+pub struct ServerInfo {
+ pub net_id: u32,
+ pub url: Url,
+ pub peer_addr: SocketAddr,
+ pub domain: Option<String>,
+ pub sk_mark: u32,
+ pub cert_path: Option<String>,
+}
+
+#[derive(Debug)]
+/// DNS resolution query
+pub struct Query {
+ /// Raw DNS query, base64 encoded
+ pub query: String,
+ /// Place to send the answer
+ pub response: oneshot::Sender<Response>,
+ /// When this request is considered stale (will be ignored if not serviced by that point)
+ pub expiry: BootTime,
+}
+
+/// Handle to a particular network's DNS resolution
+pub struct Network {
+ info: ServerInfo,
+ status_rx: watch::Receiver<Status>,
+ command_tx: mpsc::Sender<Command>,
+}
+
+impl Network {
+ pub async fn new(
+ info: ServerInfo,
+ config: Config,
+ validation: ValidationReporter,
+ tagger: SocketTagger,
+ ) -> Result<Network> {
+ let (driver, command_tx, status_rx) =
+ Driver::new(info.clone(), config, validation, tagger).await?;
+ task::spawn(driver.drive());
+ Ok(Network { info, command_tx, status_rx })
+ }
+
+ pub async fn probe(&mut self, timeout: Duration) -> Result<()> {
+ self.command_tx.send(Command::Probe(timeout)).await?;
+ Ok(())
+ }
+
+ pub async fn query(&mut self, query: Query) -> Result<()> {
+ // The clone is used to prevent status_rx from being held across an await
+ let status: Status = self.status_rx.borrow().clone();
+ match status {
+ Status::Failed(_) => query
+ .response
+ .send(Response::Error { error: QueryError::BrokenServer })
+ .unwrap_or_else(|_| {
+ warn!("Query result listener went away before receiving a response")
+ }),
+ Status::Unprobed => query
+ .response
+ .send(Response::Error { error: QueryError::ServerNotReady })
+ .unwrap_or_else(|_| {
+ warn!("Query result listener went away before receiving a response")
+ }),
+ Status::Live => self.command_tx.send(Command::Query(query)).await?,
+ }
+ Ok(())
+ }
+
+ pub fn get_info(&self) -> &ServerInfo {
+ &self.info
+ }
+}