aboutsummaryrefslogtreecommitdiff
path: root/doh/network/mod.rs
diff options
context:
space:
mode:
Diffstat (limited to 'doh/network/mod.rs')
-rw-r--r--doh/network/mod.rs112
1 files changed, 112 insertions, 0 deletions
diff --git a/doh/network/mod.rs b/doh/network/mod.rs
new file mode 100644
index 00000000..5d26688c
--- /dev/null
+++ b/doh/network/mod.rs
@@ -0,0 +1,112 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+//! Provides the ability to query DNS for a specific network configuration
+
+use crate::boot_time::{BootTime, Duration};
+use crate::config::Config;
+use crate::dispatcher::{QueryError, Response};
+use anyhow::Result;
+use futures::future::BoxFuture;
+use log::warn;
+use std::net::SocketAddr;
+use std::sync::Arc;
+use tokio::sync::{mpsc, oneshot, watch};
+use tokio::task;
+use url::Url;
+
+mod driver;
+
+use driver::{Command, Driver};
+
+pub use driver::Status;
+
+/// Closure to signal validation status to outside world
+pub type ValidationReporter = Arc<dyn Fn(&ServerInfo, bool) -> BoxFuture<()> + Send + Sync>;
+/// Closure to tag socket during connection construction
+pub type SocketTagger = Arc<dyn Fn(&std::net::UdpSocket) -> BoxFuture<()> + Send + Sync>;
+
+#[derive(Eq, PartialEq, Debug, Clone)]
+pub struct ServerInfo {
+ pub net_id: u32,
+ pub url: Url,
+ pub peer_addr: SocketAddr,
+ pub domain: Option<String>,
+ pub sk_mark: u32,
+ pub cert_path: Option<String>,
+}
+
+#[derive(Debug)]
+/// DNS resolution query
+pub struct Query {
+ /// Raw DNS query, base64 encoded
+ pub query: String,
+ /// Place to send the answer
+ pub response: oneshot::Sender<Response>,
+ /// When this request is considered stale (will be ignored if not serviced by that point)
+ pub expiry: BootTime,
+}
+
+/// Handle to a particular network's DNS resolution
+pub struct Network {
+ info: ServerInfo,
+ status_rx: watch::Receiver<Status>,
+ command_tx: mpsc::Sender<Command>,
+}
+
+impl Network {
+ pub async fn new(
+ info: ServerInfo,
+ config: Config,
+ validation: ValidationReporter,
+ tagger: SocketTagger,
+ ) -> Result<Network> {
+ let (driver, command_tx, status_rx) =
+ Driver::new(info.clone(), config, validation, tagger).await?;
+ task::spawn(driver.drive());
+ Ok(Network { info, command_tx, status_rx })
+ }
+
+ pub async fn probe(&mut self, timeout: Duration) -> Result<()> {
+ self.command_tx.send(Command::Probe(timeout)).await?;
+ Ok(())
+ }
+
+ pub async fn query(&mut self, query: Query) -> Result<()> {
+ // The clone is used to prevent status_rx from being held across an await
+ let status: Status = self.status_rx.borrow().clone();
+ match status {
+ Status::Failed(_) => query
+ .response
+ .send(Response::Error { error: QueryError::BrokenServer })
+ .unwrap_or_else(|_| {
+ warn!("Query result listener went away before receiving a response")
+ }),
+ Status::Unprobed => query
+ .response
+ .send(Response::Error { error: QueryError::ServerNotReady })
+ .unwrap_or_else(|_| {
+ warn!("Query result listener went away before receiving a response")
+ }),
+ Status::Live => self.command_tx.send(Command::Query(query)).await?,
+ }
+ Ok(())
+ }
+
+ pub fn get_info(&self) -> &ServerInfo {
+ &self.info
+ }
+}