aboutsummaryrefslogtreecommitdiff
path: root/doh
diff options
context:
space:
mode:
authorMatthew Maurer <mmaurer@google.com>2021-09-16 16:16:58 -0700
committerMatthew Maurer <mmaurer@google.com>2021-10-04 20:53:20 -0700
commitaa0dac6f1b9b2a09bc3c39688521f9dc1ec8ac1f (patch)
tree6e37978dac4c343cf63cab4ab5c6ffab082aaa17 /doh
parentae5fe72c344017e4fecfef661ba4391dfde89e97 (diff)
downloadDnsResolver-aa0dac6f1b9b2a09bc3c39688521f9dc1ec8ac1f.tar.gz
DoH: Split out FFI logic to separate module
Bug: 202081046 Change-Id: Ie9093ab14a4eb4fc17381ad81f362dd209038d4d
Diffstat (limited to 'doh')
-rw-r--r--doh/doh.rs1366
-rw-r--r--doh/ffi.rs249
2 files changed, 1615 insertions, 0 deletions
diff --git a/doh/doh.rs b/doh/doh.rs
new file mode 100644
index 00000000..95daf648
--- /dev/null
+++ b/doh/doh.rs
@@ -0,0 +1,1366 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+//! DoH backend for the Android DnsResolver module.
+
+use anyhow::{anyhow, bail, Context, Result};
+use futures::future::join_all;
+use futures::stream::FuturesUnordered;
+use futures::StreamExt;
+use libc::{c_char, int32_t, uint32_t};
+use log::{debug, error, info, trace, warn};
+use quiche::h3;
+use ring::rand::SecureRandom;
+use std::collections::HashMap;
+use std::ffi::CString;
+use std::net::SocketAddr;
+use std::ops::Deref;
+use std::os::unix::io::{AsRawFd, RawFd};
+use std::pin::Pin;
+use std::sync::{Arc, Mutex};
+use tokio::net::UdpSocket;
+use tokio::runtime::{Builder, Runtime};
+use tokio::sync::{mpsc, oneshot};
+use tokio::task;
+use tokio::time::{timeout, Duration, Instant};
+use url::Url;
+
+mod ffi;
+
+const MAX_BUFFERED_CMD_SIZE: usize = 400;
+const MAX_INCOMING_BUFFER_SIZE_WHOLE: u64 = 10000000;
+const MAX_INCOMING_BUFFER_SIZE_EACH: u64 = 1000000;
+const MAX_CONCURRENT_STREAM_SIZE: u64 = 100;
+const MAX_DATAGRAM_SIZE: usize = 1350;
+const DOH_PORT: u16 = 443;
+const QUICHE_IDLE_TIMEOUT_MS: u64 = 180000;
+const NS_T_AAAA: u8 = 28;
+const NS_C_IN: u8 = 1;
+// Used to randomly generate query prefix and query id.
+const CHARSET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ\
+ abcdefghijklmnopqrstuvwxyz\
+ 0123456789";
+
+type SCID = [u8; quiche::MAX_CONN_ID_LEN];
+type Base64Query = String;
+type CmdSender = mpsc::Sender<DohCommand>;
+type CmdReceiver = mpsc::Receiver<DohCommand>;
+type QueryResponder = oneshot::Sender<Response>;
+type DnsRequest = Vec<quiche::h3::Header>;
+type ValidationCallback =
+ extern "C" fn(net_id: uint32_t, success: bool, ip_addr: *const c_char, host: *const c_char);
+type TagSocketCallback = extern "C" fn(sock: int32_t);
+
+#[derive(Eq, PartialEq, Debug)]
+enum QueryError {
+ BrokenServer,
+ ConnectionError,
+ ServerNotReady,
+ Unexpected,
+}
+
+#[derive(Eq, PartialEq, Debug, Clone)]
+struct ServerInfo {
+ net_id: u32,
+ url: Url,
+ peer_addr: SocketAddr,
+ domain: Option<String>,
+ sk_mark: u32,
+ cert_path: Option<String>,
+}
+
+#[derive(Eq, PartialEq, Debug)]
+enum Response {
+ Error { error: QueryError },
+ Success { answer: Vec<u8> },
+}
+
+#[derive(Debug)]
+enum DohCommand {
+ Probe { info: ServerInfo, timeout: Duration },
+ Query { net_id: u32, base64_query: Base64Query, expired_time: Instant, resp: QueryResponder },
+ Clear { net_id: u32 },
+ Exit,
+}
+
+#[allow(clippy::large_enum_variant)]
+enum ConnectionState {
+ Idle,
+ Connecting {
+ quic_conn: Option<Pin<Box<quiche::Connection>>>,
+ udp_sk: Option<UdpSocket>,
+ expired_time: Option<BootTime>,
+ },
+ Connected {
+ quic_conn: Pin<Box<quiche::Connection>>,
+ udp_sk: UdpSocket,
+ h3_conn: Option<h3::Connection>,
+ query_map: HashMap<u64, (Vec<u8>, QueryResponder)>,
+ expired_time: Option<BootTime>,
+ },
+ /// Indicate that the Connection can't be used due to
+ /// network or unexpected reasons.
+ Error,
+}
+
+enum H3Result {
+ Data { data: Vec<u8> },
+ Finished,
+ Ignore,
+}
+
+trait OptionDeref<T: Deref> {
+ fn as_deref(&self) -> Option<&T::Target>;
+}
+
+impl<T: Deref> OptionDeref<T> for Option<T> {
+ fn as_deref(&self) -> Option<&T::Target> {
+ self.as_ref().map(Deref::deref)
+ }
+}
+
+#[derive(Copy, Clone, Debug)]
+struct BootTime {
+ d: Duration,
+}
+
+impl BootTime {
+ fn now() -> BootTime {
+ unsafe {
+ let mut t = libc::timespec { tv_sec: 0, tv_nsec: 0 };
+ if libc::clock_gettime(libc::CLOCK_BOOTTIME, &mut t as *mut libc::timespec) != 0 {
+ panic!("get boot time failed: {:?}", std::io::Error::last_os_error());
+ }
+ BootTime { d: Duration::new(t.tv_sec as u64, t.tv_nsec as u32) }
+ }
+ }
+
+ fn elapsed(&self) -> Option<Duration> {
+ BootTime::now().duration_since(*self)
+ }
+
+ fn checked_add(&self, duration: Duration) -> Option<BootTime> {
+ Some(BootTime { d: self.d.checked_add(duration)? })
+ }
+
+ fn duration_since(&self, earlier: BootTime) -> Option<Duration> {
+ self.d.checked_sub(earlier.d)
+ }
+}
+
+/// Context for a running DoH engine.
+pub struct DohDispatcher {
+ /// Used to submit cmds to the I/O task.
+ cmd_sender: CmdSender,
+ join_handle: task::JoinHandle<Result<()>>,
+ runtime: Arc<Runtime>,
+}
+
+// DoH dispatcher
+impl DohDispatcher {
+ fn new(
+ validation_fn: ValidationCallback,
+ tag_socket_fn: TagSocketCallback,
+ ) -> Result<Box<DohDispatcher>> {
+ let (cmd_sender, cmd_receiver) = mpsc::channel::<DohCommand>(MAX_BUFFERED_CMD_SIZE);
+ let runtime = Arc::new(
+ Builder::new_multi_thread()
+ .worker_threads(2)
+ .enable_all()
+ .thread_name("doh-handler")
+ .build()
+ .expect("Failed to create tokio runtime"),
+ );
+ let join_handle =
+ runtime.spawn(doh_handler(cmd_receiver, runtime.clone(), validation_fn, tag_socket_fn));
+ Ok(Box::new(DohDispatcher { cmd_sender, join_handle, runtime }))
+ }
+
+ fn send_cmd(&self, cmd: DohCommand) -> Result<()> {
+ self.cmd_sender.blocking_send(cmd)?;
+ Ok(())
+ }
+
+ fn exit_handler(&mut self) {
+ if self.cmd_sender.blocking_send(DohCommand::Exit).is_err() {
+ return;
+ }
+ let _ = self.runtime.block_on(&mut self.join_handle);
+ }
+}
+
+struct DohConnection {
+ info: ServerInfo,
+ shared_config: Arc<Mutex<QuicheConfigCache>>,
+ scid: SCID,
+ state: ConnectionState,
+ pending_queries: Vec<(DnsRequest, QueryResponder, Instant)>,
+ cached_session: Option<Vec<u8>>,
+ tag_socket_fn: TagSocketCallback,
+}
+
+impl DohConnection {
+ fn new(
+ info: &ServerInfo,
+ shared_config: Arc<Mutex<QuicheConfigCache>>,
+ tag_socket_fn: TagSocketCallback,
+ ) -> Result<DohConnection> {
+ let mut scid = [0; quiche::MAX_CONN_ID_LEN];
+ ring::rand::SystemRandom::new().fill(&mut scid).context("failed to generate scid")?;
+ Ok(DohConnection {
+ info: info.clone(),
+ shared_config,
+ scid,
+ state: ConnectionState::Idle,
+ pending_queries: Vec::new(),
+ cached_session: None,
+ tag_socket_fn,
+ })
+ }
+
+ fn state_to_connecting(&mut self) -> Result<()> {
+ self.state = match self.state {
+ ConnectionState::Idle => {
+ let udp_sk_std = make_doh_udp_socket(self.info.peer_addr, self.info.sk_mark)?;
+ (self.tag_socket_fn)(udp_sk_std.as_raw_fd());
+ let udp_sk = UdpSocket::from_std(udp_sk_std)?;
+ let connid = quiche::ConnectionId::from_ref(&self.scid);
+ let mut cache = self.shared_config.lock().unwrap();
+ let config =
+ cache.get(&self.info.cert_path)?.ok_or_else(|| anyhow!("no quiche config"))?;
+ debug!("init the connection for Network {}", self.info.net_id);
+ let mut quic_conn = quiche::connect(
+ self.info.domain.as_deref(),
+ &connid,
+ self.info.peer_addr,
+ config,
+ )?;
+ if let Some(session) = &self.cached_session {
+ if quic_conn.set_session(session).is_err() {
+ warn!("can't restore session for network {}", self.info.net_id);
+ }
+ }
+ ConnectionState::Connecting {
+ quic_conn: Some(quic_conn),
+ udp_sk: Some(udp_sk),
+ expired_time: None,
+ }
+ }
+ ConnectionState::Error => {
+ self.state_to_idle();
+ return self.state_to_connecting();
+ }
+ ConnectionState::Connecting { .. } => return Ok(()),
+ ConnectionState::Connected { .. } => {
+ panic!("Invalid state transition to Connecting state!")
+ }
+ };
+ Ok(())
+ }
+
+ fn state_to_connected(&mut self) -> Result<()> {
+ self.state = match &mut self.state {
+ // Only Connecting -> Connected is valid.
+ ConnectionState::Connecting { quic_conn, udp_sk, .. } => {
+ if let (Some(mut quic_conn), Some(udp_sk)) = (quic_conn.take(), udp_sk.take()) {
+ let h3_config = h3::Config::new()?;
+ let h3_conn =
+ quiche::h3::Connection::with_transport(&mut quic_conn, &h3_config)?;
+ ConnectionState::Connected {
+ quic_conn,
+ udp_sk,
+ h3_conn: Some(h3_conn),
+ query_map: HashMap::new(),
+ expired_time: None,
+ }
+ } else {
+ bail!("state transition fail!");
+ }
+ }
+ // The rest should fail.
+ _ => panic!("Invalid state transition to Connected state!"),
+ };
+ Ok(())
+ }
+
+ fn state_to_idle(&mut self) {
+ self.state = match self.state {
+ // Only either Connected or Error -> Idle is valid.
+ // TODO: Error -> Idle is the re-probing case, add the relevant statistic.
+ ConnectionState::Connected { .. } | ConnectionState::Error => ConnectionState::Idle,
+ // The rest should fail.
+ _ => panic!("Invalid state transition to Idle state!"),
+ }
+ }
+
+ fn state_to_error(&mut self) {
+ self.pending_queries.clear();
+ self.state = ConnectionState::Error
+ }
+
+ fn is_reprobe_required(&self) -> bool {
+ matches!(self.state, ConnectionState::Error)
+ }
+
+ fn has_not_handled_queries(&self) -> bool {
+ match &self.state {
+ ConnectionState::Connecting { .. } | ConnectionState::Idle => {
+ !self.pending_queries.is_empty()
+ }
+ ConnectionState::Connected { query_map, .. } => {
+ !query_map.is_empty() || !self.pending_queries.is_empty()
+ }
+ _ => false,
+ }
+ }
+
+ fn handle_if_connection_expired(&mut self) {
+ let expired_time = match &mut self.state {
+ ConnectionState::Connecting { expired_time, .. } => expired_time,
+ ConnectionState::Connected { expired_time, .. } => expired_time,
+ // ignore
+ _ => return,
+ };
+
+ if let Some(expired_time) = expired_time {
+ if let Some(elapsed) = expired_time.elapsed() {
+ warn!(
+ "Change the state to Idle due to connection timeout, {:?}, {}",
+ elapsed, self.info.net_id
+ );
+ self.state_to_idle();
+ }
+ }
+ }
+
+ async fn probe(&mut self, t: Duration) -> Result<()> {
+ match timeout(t, async {
+ self.try_connect().await?;
+ info!("probe start for {}", self.info.net_id);
+ if let ConnectionState::Connected { quic_conn, udp_sk, h3_conn, expired_time, .. } =
+ &mut self.state
+ {
+ let h3_conn = h3_conn.as_mut().ok_or_else(|| anyhow!("h3 conn isn't available"))?;
+ let req = match make_probe_query() {
+ Ok(q) => match make_dns_request(&q, &self.info.url) {
+ Ok(req) => req,
+ Err(e) => bail!(e),
+ },
+ Err(e) => bail!(e),
+ };
+ // Send the probe query.
+ let req_id = h3_conn.send_request(quic_conn, &req, true /*fin*/)?;
+ loop {
+ flush_tx(quic_conn, udp_sk).await?;
+ recv_rx(quic_conn, udp_sk, expired_time).await?;
+ loop {
+ match recv_h3(quic_conn, h3_conn) {
+ Ok((stream_id, H3Result::Finished)) => {
+ if stream_id == req_id {
+ return Ok(());
+ }
+ }
+ // TODO: Verify the answer
+ Ok((_stream_id, H3Result::Data { .. })) => {}
+ Ok((_stream_id, H3Result::Ignore)) => {}
+ Err(_) => break,
+ }
+ }
+ }
+ } else {
+ bail!("state error while performing probe()");
+ }
+ })
+ .await
+ {
+ Ok(v) => match v {
+ Ok(_) => Ok(()),
+ Err(e) => {
+ self.state_to_error();
+ bail!(e);
+ }
+ },
+ Err(e) => {
+ self.state_to_error();
+ bail!(e);
+ }
+ }
+ }
+
+ async fn try_connect(&mut self) -> Result<()> {
+ if matches!(self.state, ConnectionState::Connected { .. }) {
+ return Ok(());
+ }
+ self.state_to_connecting()?;
+ debug!("connecting to Network {}", self.info.net_id);
+
+ let (quic_conn, udp_sk, expired_time) = match &mut self.state {
+ ConnectionState::Connecting { quic_conn, udp_sk, expired_time, .. } => {
+ if let (Some(quic_conn), Some(udp_sk)) = (quic_conn.as_mut(), udp_sk.as_mut()) {
+ (quic_conn, udp_sk, expired_time)
+ } else {
+ bail!("unexpected error while performing connect()");
+ }
+ }
+ _ => bail!("state error while performing try_connect()"),
+ };
+
+ while !quic_conn.is_established() {
+ flush_tx(quic_conn, udp_sk).await?;
+ recv_rx(quic_conn, udp_sk, expired_time).await?;
+ }
+ self.cached_session = quic_conn.session();
+ self.state_to_connected()?;
+ info!("connected to Network {}", self.info.net_id);
+ Ok(())
+ }
+
+ async fn try_send_doh_query(
+ &mut self,
+ req: DnsRequest,
+ resp: QueryResponder,
+ expired_time: Instant,
+ ) -> Result<()> {
+ self.handle_if_connection_expired();
+ match &mut self.state {
+ ConnectionState::Connected { quic_conn, udp_sk, h3_conn, query_map, .. } => {
+ let h3_conn = h3_conn.as_mut().ok_or_else(|| anyhow!("h3 conn isn't available"))?;
+ send_dns_query(
+ quic_conn,
+ udp_sk,
+ h3_conn,
+ query_map,
+ &mut self.pending_queries,
+ resp,
+ expired_time,
+ req,
+ )
+ .await?
+ }
+ ConnectionState::Connecting { .. } | ConnectionState::Idle => {
+ self.pending_queries.push((req, resp, expired_time))
+ }
+ ConnectionState::Error => {
+ error!(
+ "state is error while performing try_send_doh_query(), network: {}",
+ self.info.net_id
+ );
+ let _ = resp.send(Response::Error { error: QueryError::BrokenServer });
+ }
+ }
+ Ok(())
+ }
+
+ async fn process_queries(&mut self) -> Result<()> {
+ debug!("process_queries entry, Network {}", self.info.net_id);
+ self.try_connect().await?;
+ if let ConnectionState::Connected { quic_conn, udp_sk, h3_conn, query_map, expired_time } =
+ &mut self.state
+ {
+ let h3_conn = h3_conn.as_mut().ok_or_else(|| anyhow!("h3 conn isn't available"))?;
+ loop {
+ while !self.pending_queries.is_empty() {
+ if let Some((req, resp, exp_time)) = self.pending_queries.pop() {
+ // Ignore the expired queries.
+ if Instant::now().checked_duration_since(exp_time).is_some() {
+ warn!("Drop the obsolete query for network {}", self.info.net_id);
+ continue;
+ }
+ send_dns_query(
+ quic_conn,
+ udp_sk,
+ h3_conn,
+ query_map,
+ &mut self.pending_queries,
+ resp,
+ exp_time,
+ req,
+ )
+ .await?;
+ }
+ }
+ flush_tx(quic_conn, udp_sk).await?;
+ recv_rx(quic_conn, udp_sk, expired_time).await?;
+ loop {
+ match recv_h3(quic_conn, h3_conn) {
+ Ok((stream_id, H3Result::Data { mut data })) => {
+ if let Some((answer, _)) = query_map.get_mut(&stream_id) {
+ answer.append(&mut data);
+ } else {
+ // Should not happen
+ warn!("No associated receiver found while receiving Data, Network {}, stream id: {}", self.info.net_id, stream_id);
+ }
+ }
+ Ok((stream_id, H3Result::Finished)) => {
+ if let Some((answer, resp)) = query_map.remove(&stream_id) {
+ debug!(
+ "sending answer back to resolv, Network {}, stream id: {}",
+ self.info.net_id, stream_id
+ );
+ resp.send(Response::Success { answer }).unwrap_or_else(|e| {
+ trace!(
+ "the receiver dropped {:?}, stream id: {}",
+ e,
+ stream_id
+ );
+ });
+ } else {
+ // Should not happen
+ warn!("No associated receiver found while receiving Finished, Network {}, stream id: {}", self.info.net_id, stream_id);
+ }
+ }
+ Ok((_stream_id, H3Result::Ignore)) => {}
+ Err(_) => break,
+ }
+ }
+ if quic_conn.is_closed() || !quic_conn.is_established() {
+ self.state_to_idle();
+ bail!("connection become idle");
+ }
+ }
+ } else {
+ self.state_to_error();
+ bail!("state error while performing process_queries(), network: {}", self.info.net_id);
+ }
+ }
+}
+
+fn recv_h3(
+ quic_conn: &mut Pin<Box<quiche::Connection>>,
+ h3_conn: &mut h3::Connection,
+) -> Result<(u64, H3Result)> {
+ match h3_conn.poll(quic_conn) {
+ // Process HTTP/3 events.
+ Ok((stream_id, quiche::h3::Event::Data)) => {
+ debug!("quiche::h3::Event::Data");
+ let mut buf = vec![0; MAX_DATAGRAM_SIZE];
+ match h3_conn.recv_body(quic_conn, stream_id, &mut buf) {
+ Ok(read) => {
+ trace!(
+ "got {} bytes of response data on stream {}: {:x?}",
+ read,
+ stream_id,
+ &buf[..read]
+ );
+ buf.truncate(read);
+ Ok((stream_id, H3Result::Data { data: buf }))
+ }
+ Err(e) => {
+ warn!("recv_h3::recv_body {:?}", e);
+ bail!(e);
+ }
+ }
+ }
+ Ok((stream_id, quiche::h3::Event::Headers { list, has_body })) => {
+ trace!(
+ "got response headers {:?} on stream id {} has_body {}",
+ list,
+ stream_id,
+ has_body
+ );
+ Ok((stream_id, H3Result::Ignore))
+ }
+ Ok((stream_id, quiche::h3::Event::Finished)) => {
+ debug!("quiche::h3::Event::Finished on stream id {}", stream_id);
+ Ok((stream_id, H3Result::Finished))
+ }
+ Ok((stream_id, quiche::h3::Event::Datagram)) => {
+ debug!("quiche::h3::Event::Datagram on stream id {}", stream_id);
+ Ok((stream_id, H3Result::Ignore))
+ }
+ // TODO: Check if it's necessary to handle GoAway event.
+ Ok((stream_id, quiche::h3::Event::GoAway)) => {
+ debug!("quiche::h3::Event::GoAway on stream id {}", stream_id);
+ Ok((stream_id, H3Result::Ignore))
+ }
+ Err(e) => {
+ debug!("recv_h3 {:?}", e);
+ bail!(e);
+ }
+ }
+}
+
+#[allow(clippy::too_many_arguments)]
+async fn send_dns_query(
+ quic_conn: &mut Pin<Box<quiche::Connection>>,
+ udp_sk: &mut UdpSocket,
+ h3_conn: &mut h3::Connection,
+ query_map: &mut HashMap<u64, (Vec<u8>, QueryResponder)>,
+ pending_queries: &mut Vec<(DnsRequest, QueryResponder, Instant)>,
+ resp: QueryResponder,
+ expired_time: Instant,
+ req: DnsRequest,
+) -> Result<()> {
+ if !quic_conn.is_established() {
+ bail!("quic connection is not ready");
+ }
+ match h3_conn.send_request(quic_conn, &req, true /*fin*/) {
+ Ok(stream_id) => {
+ query_map.insert(stream_id, (Vec::new(), resp));
+ flush_tx(quic_conn, udp_sk).await?;
+ debug!("send dns query successfully stream id: {}", stream_id);
+ Ok(())
+ }
+ Err(quiche::h3::Error::StreamBlocked) => {
+ warn!("try to send query but error on StreamBlocked");
+ pending_queries.push((req, resp, expired_time));
+ Ok(())
+ }
+ Err(e) => {
+ resp.send(Response::Error { error: QueryError::ConnectionError }).ok();
+ bail!(e);
+ }
+ }
+}
+
+async fn recv_rx(
+ quic_conn: &mut Pin<Box<quiche::Connection>>,
+ udp_sk: &mut UdpSocket,
+ expired_time: &mut Option<BootTime>,
+) -> Result<()> {
+ // TODO: Evaluate if we could make the buffer smaller.
+ let mut buf = [0; 65535];
+ let quic_idle_timeout_ms = Duration::from_millis(QUICHE_IDLE_TIMEOUT_MS);
+ let ts = quic_conn.timeout().unwrap_or(quic_idle_timeout_ms);
+
+ if let Some(next_expired) = BootTime::now().checked_add(quic_idle_timeout_ms) {
+ expired_time.replace(next_expired);
+ } else {
+ expired_time.take();
+ }
+ debug!("recv_rx entry next timeout {:?} {:?}", ts, expired_time);
+ match timeout(ts, udp_sk.recv_from(&mut buf)).await {
+ Ok(v) => match v {
+ Ok((size, from)) => {
+ let recv_info = quiche::RecvInfo { from };
+ let processed = match quic_conn.recv(&mut buf[..size], recv_info) {
+ Ok(l) => l,
+ Err(e) => {
+ debug!("recv_rx error {:?}", e);
+ bail!("quic recv failed: {:?}", e);
+ }
+ };
+ debug!("processed {} bytes", processed);
+ Ok(())
+ }
+ Err(e) => bail!("socket recv failed: {:?}", e),
+ },
+ Err(_) => {
+ warn!("timeout did not receive value within {:?}", ts);
+ quic_conn.on_timeout();
+ Ok(())
+ }
+ }
+}
+
+async fn flush_tx(
+ quic_conn: &mut Pin<Box<quiche::Connection>>,
+ udp_sk: &mut UdpSocket,
+) -> Result<()> {
+ let mut out = [0; MAX_DATAGRAM_SIZE];
+ loop {
+ let (write, _) = match quic_conn.send(&mut out) {
+ Ok(v) => v,
+ Err(quiche::Error::Done) => {
+ debug!("done writing");
+ break;
+ }
+ Err(e) => {
+ quic_conn.close(false, 0x1, b"fail").ok();
+ bail!(e);
+ }
+ };
+ udp_sk.send(&out[..write]).await?;
+ debug!("written {}", write);
+ }
+ Ok(())
+}
+
+fn report_private_dns_validation(
+ info: &ServerInfo,
+ state: &ConnectionState,
+ runtime: Arc<Runtime>,
+ validation_fn: ValidationCallback,
+) {
+ let (ip_addr, domain) = match (
+ CString::new(info.peer_addr.ip().to_string()),
+ CString::new(info.domain.clone().unwrap_or_default()),
+ ) {
+ (Ok(ip_addr), Ok(domain)) => (ip_addr, domain),
+ _ => {
+ error!("report_private_dns_validation bad input");
+ return;
+ }
+ };
+ let netd_id = info.net_id;
+ let success = matches!(state, ConnectionState::Connected { .. });
+ runtime
+ .spawn_blocking(move || validation_fn(netd_id, success, ip_addr.as_ptr(), domain.as_ptr()));
+}
+
+fn handle_probe_result(
+ result: (ServerInfo, Result<DohConnection, (anyhow::Error, DohConnection)>),
+ doh_conn_map: &mut HashMap<u32, (ServerInfo, Option<DohConnection>)>,
+ runtime: Arc<Runtime>,
+ validation_fn: ValidationCallback,
+) {
+ let (info, doh_conn) = match result {
+ (info, Ok(doh_conn)) => {
+ info!("probing_task success on net_id: {}", info.net_id);
+ (info, doh_conn)
+ }
+ (info, Err((e, doh_conn))) => {
+ error!("probe failed on network {}, {:?}", e, info.net_id);
+ (info, doh_conn)
+ // TODO: Retry probe?
+ }
+ };
+ // If the network is removed or the server is replaced before probing,
+ // ignore the probe result.
+ match doh_conn_map.get(&info.net_id) {
+ Some((server_info, _)) => {
+ if *server_info != info {
+ warn!(
+ "The previous configuration for network {} was replaced before probe finished",
+ info.net_id
+ );
+ return;
+ }
+ }
+ _ => {
+ warn!("network {} was removed before probe finished", info.net_id);
+ return;
+ }
+ }
+ report_private_dns_validation(&info, &doh_conn.state, runtime, validation_fn);
+ doh_conn_map.insert(info.net_id, (info, Some(doh_conn)));
+}
+
+async fn probe_task(
+ info: ServerInfo,
+ mut doh: DohConnection,
+ t: Duration,
+) -> (ServerInfo, Result<DohConnection, (anyhow::Error, DohConnection)>) {
+ match doh.probe(t).await {
+ Ok(_) => (info, Ok(doh)),
+ Err(e) => (info, Err((anyhow!(e), doh))),
+ }
+}
+
+fn make_connection_if_needed(
+ info: &ServerInfo,
+ doh_conn_map: &mut HashMap<u32, (ServerInfo, Option<DohConnection>)>,
+ shared_config: Arc<Mutex<QuicheConfigCache>>,
+ tag_socket_fn: TagSocketCallback,
+) -> Result<Option<DohConnection>> {
+ // Check if connection exists.
+ match doh_conn_map.get(&info.net_id) {
+ // The connection exists but has failed. Re-probe.
+ Some((server_info, Some(doh))) if *server_info == *info && doh.is_reprobe_required() => {
+ let (_, doh) = doh_conn_map
+ .insert(info.net_id, (info.clone(), None))
+ .ok_or_else(|| anyhow!("unexpected error, missing connection"))?;
+ return Ok(doh);
+ }
+ // The connection exists or the connection is under probing, ignore.
+ Some((server_info, _)) if *server_info == *info => return Ok(None),
+ // TODO: change the inner connection instead of removing?
+ _ => doh_conn_map.remove(&info.net_id),
+ };
+ let doh = DohConnection::new(info, shared_config, tag_socket_fn)?;
+ doh_conn_map.insert(info.net_id, (info.clone(), None));
+ Ok(Some(doh))
+}
+
+struct QuicheConfigCache {
+ cert_path: Option<String>,
+ config: Option<quiche::Config>,
+}
+
+impl QuicheConfigCache {
+ fn get(&mut self, cert_path: &Option<String>) -> Result<Option<&mut quiche::Config>> {
+ // No config is cached or the cached config isn't matched with the input cert_path
+ // Create it with the input cert_path.
+ if self.config.is_none() || self.cert_path != *cert_path {
+ self.config = Some(create_quiche_config(cert_path.as_deref())?);
+ self.cert_path = cert_path.clone();
+ }
+ return Ok(self.config.as_mut());
+ }
+}
+
+async fn handle_query_cmd(
+ net_id: u32,
+ base64_query: Base64Query,
+ expired_time: Instant,
+ resp: QueryResponder,
+ doh_conn_map: &mut HashMap<u32, (ServerInfo, Option<DohConnection>)>,
+) {
+ if let Some((info, quic_conn)) = doh_conn_map.get_mut(&net_id) {
+ match (&info.domain, quic_conn) {
+ // Connection is not ready, strict mode
+ (Some(_), None) => {
+ let _ = resp.send(Response::Error { error: QueryError::ServerNotReady });
+ }
+ // Connection is not ready, Opportunistic mode
+ (None, None) => {
+ let _ = resp.send(Response::Error { error: QueryError::ServerNotReady });
+ }
+ // Connection is ready
+ (_, Some(quic_conn)) => {
+ if let Ok(req) = make_dns_request(&base64_query, &info.url) {
+ let _ = quic_conn.try_send_doh_query(req, resp, expired_time).await;
+ } else {
+ let _ = resp.send(Response::Error { error: QueryError::Unexpected });
+ }
+ }
+ }
+ } else {
+ error!("No connection is associated with the given net id {}", net_id);
+ let _ = resp.send(Response::Error { error: QueryError::ServerNotReady });
+ }
+}
+fn need_process_queries(doh_conn_map: &HashMap<u32, (ServerInfo, Option<DohConnection>)>) -> bool {
+ if doh_conn_map.is_empty() {
+ return false;
+ }
+ for (_, doh_conn) in doh_conn_map.values() {
+ if let Some(doh_conn) = doh_conn {
+ if doh_conn.has_not_handled_queries() {
+ return true;
+ }
+ }
+ }
+ false
+}
+
+async fn doh_handler(
+ mut cmd_rx: CmdReceiver,
+ runtime: Arc<Runtime>,
+ validation_fn: ValidationCallback,
+ tag_socket_fn: TagSocketCallback,
+) -> Result<()> {
+ info!("doh_dispatcher entry");
+ let config_cache = Arc::new(Mutex::new(QuicheConfigCache { cert_path: None, config: None }));
+
+ // Currently, only support 1 server per network.
+ let mut doh_conn_map: HashMap<u32, (ServerInfo, Option<DohConnection>)> = HashMap::new();
+ let mut probe_futures = FuturesUnordered::new();
+ loop {
+ tokio::select! {
+ _ = async {
+ let mut futures = vec![];
+ for (_, doh_conn) in doh_conn_map.values_mut() {
+ if let Some(doh_conn) = doh_conn {
+ futures.push(doh_conn.process_queries());
+ }
+ }
+ join_all(futures).await
+ }, if need_process_queries(&doh_conn_map) => {},
+ Some(result) = probe_futures.next() => {
+ let runtime_clone = runtime.clone();
+ handle_probe_result(result, &mut doh_conn_map, runtime_clone, validation_fn);
+ info!("probe_futures remaining size: {}", probe_futures.len());
+ },
+ Some(cmd) = cmd_rx.recv() => {
+ trace!("recv {:?}", cmd);
+ match cmd {
+ DohCommand::Probe { info, timeout: t } => {
+ match make_connection_if_needed(&info, &mut doh_conn_map, config_cache.clone(), tag_socket_fn) {
+ Ok(Some(doh)) => {
+ // Create a new async task associated to the DoH connection.
+ probe_futures.push(probe_task(info, doh, t));
+ debug!("probe_futures size: {}", probe_futures.len());
+ }
+ Ok(None) => {
+ // No further probe is needed.
+ warn!("connection for network {} already exists", info.net_id);
+ // TODO: Report the status again?
+ }
+ Err(e) => {
+ error!("create connection for network {} error {:?}", info.net_id, e);
+ report_private_dns_validation(&info, &ConnectionState::Error, runtime.clone(), validation_fn);
+ }
+ }
+ },
+ DohCommand::Query { net_id, base64_query, expired_time, resp } => {
+ handle_query_cmd(net_id, base64_query, expired_time, resp, &mut doh_conn_map).await;
+ },
+ DohCommand::Clear { net_id } => {
+ doh_conn_map.remove(&net_id);
+ info!("Doh Clear server for netid: {}", net_id);
+ },
+ DohCommand::Exit => return Ok(()),
+ }
+ }
+ }
+ }
+}
+
+fn make_dns_request(base64_query: &str, url: &url::Url) -> Result<DnsRequest> {
+ let mut path = String::from(url.path());
+ path.push_str("?dns=");
+ path.push_str(base64_query);
+ let req = vec![
+ quiche::h3::Header::new(b":method", b"GET"),
+ quiche::h3::Header::new(b":scheme", b"https"),
+ quiche::h3::Header::new(
+ b":authority",
+ url.host_str().ok_or_else(|| anyhow!("failed to get host"))?.as_bytes(),
+ ),
+ quiche::h3::Header::new(b":path", path.as_bytes()),
+ quiche::h3::Header::new(b"user-agent", b"quiche"),
+ quiche::h3::Header::new(b"accept", b"application/dns-message"),
+ // TODO: is content-length required?
+ ];
+
+ Ok(req)
+}
+
+fn make_doh_udp_socket(peer_addr: SocketAddr, mark: u32) -> Result<std::net::UdpSocket> {
+ let bind_addr = match peer_addr {
+ std::net::SocketAddr::V4(_) => "0.0.0.0:0",
+ std::net::SocketAddr::V6(_) => "[::]:0",
+ };
+ let udp_sk = std::net::UdpSocket::bind(bind_addr)?;
+ udp_sk.set_nonblocking(true)?;
+ if mark_socket(udp_sk.as_raw_fd(), mark).is_err() {
+ warn!("Mark socket failed, is it a test?");
+ }
+ udp_sk.connect(peer_addr)?;
+
+ trace!("connecting to {:} from {:}", peer_addr, udp_sk.local_addr()?);
+ Ok(udp_sk)
+}
+
+fn create_quiche_config(cert_path: Option<&str>) -> Result<quiche::Config> {
+ let mut config = quiche::Config::new(quiche::PROTOCOL_VERSION)?;
+ config.set_application_protos(h3::APPLICATION_PROTOCOL)?;
+ match cert_path {
+ Some(path) => {
+ config.verify_peer(true);
+ config.load_verify_locations_from_directory(path)?;
+ }
+ None => config.verify_peer(false),
+ }
+
+ // Some of these configs are necessary, or the server can't respond the HTTP/3 request.
+ config.set_max_idle_timeout(QUICHE_IDLE_TIMEOUT_MS);
+ config.set_max_recv_udp_payload_size(MAX_DATAGRAM_SIZE);
+ config.set_initial_max_data(MAX_INCOMING_BUFFER_SIZE_WHOLE);
+ config.set_initial_max_stream_data_bidi_local(MAX_INCOMING_BUFFER_SIZE_EACH);
+ config.set_initial_max_stream_data_bidi_remote(MAX_INCOMING_BUFFER_SIZE_EACH);
+ config.set_initial_max_stream_data_uni(MAX_INCOMING_BUFFER_SIZE_EACH);
+ config.set_initial_max_streams_bidi(MAX_CONCURRENT_STREAM_SIZE);
+ config.set_initial_max_streams_uni(MAX_CONCURRENT_STREAM_SIZE);
+ config.set_disable_active_migration(true);
+ Ok(config)
+}
+
+fn mark_socket(fd: RawFd, mark: u32) -> Result<()> {
+ // libc::setsockopt is a wrapper function calling into bionic setsockopt.
+ // Both fd and mark are valid, which makes the function call mostly safe.
+ if unsafe {
+ libc::setsockopt(
+ fd,
+ libc::SOL_SOCKET,
+ libc::SO_MARK,
+ &mark as *const _ as *const libc::c_void,
+ std::mem::size_of::<u32>() as libc::socklen_t,
+ )
+ } == 0
+ {
+ Ok(())
+ } else {
+ Err(anyhow::Error::new(std::io::Error::last_os_error()))
+ }
+}
+
+#[rustfmt::skip]
+fn make_probe_query() -> Result<String> {
+ let mut rnd = [0; 8];
+ ring::rand::SystemRandom::new().fill(&mut rnd).context("failed to generate probe rnd")?;
+ let c = |byte| CHARSET[(byte as usize) % CHARSET.len()];
+ let query = vec![
+ rnd[6], rnd[7], // [0-1] query ID
+ 1, 0, // [2-3] flags; query[2] = 1 for recursion desired (RD).
+ 0, 1, // [4-5] QDCOUNT (number of queries)
+ 0, 0, // [6-7] ANCOUNT (number of answers)
+ 0, 0, // [8-9] NSCOUNT (number of name server records)
+ 0, 0, // [10-11] ARCOUNT (number of additional records)
+ 19, c(rnd[0]), c(rnd[1]), c(rnd[2]), c(rnd[3]), c(rnd[4]), c(rnd[5]), b'-', b'd', b'n',
+ b's', b'o', b'h', b't', b't', b'p', b's', b'-', b'd', b's',
+ 6, b'm', b'e', b't', b'r', b'i', b'c', 7, b'g', b's',
+ b't', b'a', b't', b'i', b'c', 3, b'c', b'o', b'm',
+ 0, // null terminator of FQDN (root TLD)
+ 0, NS_T_AAAA, // QTYPE
+ 0, NS_C_IN // QCLASS
+ ];
+ Ok(base64::encode_config(query, base64::URL_SAFE_NO_PAD))
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use quiche::h3::NameValue;
+ use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
+
+ const TEST_NET_ID: u32 = 50;
+ const PROBE_QUERY_SIZE: usize = 56;
+ const H3_DNS_REQUEST_HEADER_SIZE: usize = 6;
+ const TEST_MARK: u32 = 0xD0033;
+ const LOOPBACK_ADDR: &str = "127.0.0.1:443";
+ const LOCALHOST_URL: &str = "https://mylocal.com/dns-query";
+
+ // TODO: Make some tests for DohConnection and QuicheConfigCache.
+
+ fn make_testing_variables() -> (
+ ServerInfo,
+ HashMap<u32, (ServerInfo, Option<DohConnection>)>,
+ Arc<Mutex<QuicheConfigCache>>,
+ Arc<Runtime>,
+ ) {
+ let test_map: HashMap<u32, (ServerInfo, Option<DohConnection>)> = HashMap::new();
+ let info = ServerInfo {
+ net_id: TEST_NET_ID,
+ url: Url::parse(LOCALHOST_URL).unwrap(),
+ peer_addr: LOOPBACK_ADDR.parse().unwrap(),
+ domain: None,
+ sk_mark: 0,
+ cert_path: None,
+ };
+ let config_cache =
+ Arc::new(Mutex::new(QuicheConfigCache { cert_path: None, config: None }));
+
+ let rt = Arc::new(
+ Builder::new_current_thread()
+ .thread_name("test-runtime")
+ .enable_all()
+ .build()
+ .expect("Failed to create testing tokio runtime"),
+ );
+ (info, test_map, config_cache, rt)
+ }
+
+ extern "C" fn tag_socket_cb(sock: int32_t) {
+ assert!(sock >= 0);
+ }
+
+ #[test]
+ fn make_connection_if_needed() {
+ let (info, mut test_map, config, rt) = make_testing_variables();
+ rt.block_on(async {
+ // Expect to make a new connection.
+ let mut doh = super::make_connection_if_needed(
+ &info,
+ &mut test_map,
+ config.clone(),
+ tag_socket_cb,
+ )
+ .unwrap()
+ .unwrap();
+ assert_eq!(doh.info.net_id, info.net_id);
+ assert!(matches!(doh.state, ConnectionState::Idle));
+ doh.state = ConnectionState::Error;
+ test_map.insert(info.net_id, (info.clone(), Some(doh)));
+ // Expect that we will get a connection with fail status that we added to the map before.
+ let mut doh = super::make_connection_if_needed(
+ &info,
+ &mut test_map,
+ config.clone(),
+ tag_socket_cb,
+ )
+ .unwrap()
+ .unwrap();
+ assert_eq!(doh.info.net_id, info.net_id);
+ assert!(matches!(doh.state, ConnectionState::Error));
+ doh.state = make_dummy_connected_state();
+ test_map.insert(info.net_id, (info.clone(), Some(doh)));
+ // Expect that we will get None because the map contains a connection with ready status.
+ assert!(super::make_connection_if_needed(
+ &info,
+ &mut test_map,
+ config.clone(),
+ tag_socket_cb
+ )
+ .unwrap()
+ .is_none());
+ });
+ }
+
+ #[test]
+ fn handle_query_cmd() {
+ let (info, mut test_map, config, rt) = make_testing_variables();
+ let t = Duration::from_millis(100);
+
+ rt.block_on(async {
+ // Test no available server cases.
+ let (resp_tx, resp_rx) = oneshot::channel();
+ let query = super::make_probe_query().unwrap();
+ super::handle_query_cmd(
+ info.net_id,
+ query.clone(),
+ Instant::now().checked_add(t).unwrap(),
+ resp_tx,
+ &mut test_map,
+ )
+ .await;
+ assert_eq!(
+ timeout(t, resp_rx).await.unwrap().unwrap(),
+ Response::Error { error: QueryError::ServerNotReady }
+ );
+
+ let (resp_tx, resp_rx) = oneshot::channel();
+ test_map.insert(info.net_id, (info.clone(), None));
+ super::handle_query_cmd(
+ info.net_id,
+ query.clone(),
+ Instant::now().checked_add(t).unwrap(),
+ resp_tx,
+ &mut test_map,
+ )
+ .await;
+ assert_eq!(
+ timeout(t, resp_rx).await.unwrap().unwrap(),
+ Response::Error { error: QueryError::ServerNotReady }
+ );
+
+ // Test the connection broken case.
+ test_map.clear();
+ let (resp_tx, resp_rx) = oneshot::channel();
+ let mut doh = super::make_connection_if_needed(
+ &info,
+ &mut test_map,
+ config.clone(),
+ tag_socket_cb,
+ )
+ .unwrap()
+ .unwrap();
+ doh.state = ConnectionState::Error;
+ test_map.insert(info.net_id, (info.clone(), Some(doh)));
+ super::handle_query_cmd(
+ info.net_id,
+ query.clone(),
+ Instant::now().checked_add(t).unwrap(),
+ resp_tx,
+ &mut test_map,
+ )
+ .await;
+ assert_eq!(
+ timeout(t, resp_rx).await.unwrap().unwrap(),
+ Response::Error { error: QueryError::BrokenServer }
+ );
+ });
+ }
+
+ extern "C" fn success_cb(
+ net_id: uint32_t,
+ success: bool,
+ ip_addr: *const c_char,
+ host: *const c_char,
+ ) {
+ assert!(success);
+ unsafe {
+ assert_validation_info(net_id, ip_addr, host);
+ }
+ }
+
+ extern "C" fn fail_cb(
+ net_id: uint32_t,
+ success: bool,
+ ip_addr: *const c_char,
+ host: *const c_char,
+ ) {
+ assert!(!success);
+ unsafe {
+ assert_validation_info(net_id, ip_addr, host);
+ }
+ }
+
+ // # Safety
+ // `ip_addr`, `host` are null terminated strings
+ unsafe fn assert_validation_info(
+ net_id: uint32_t,
+ ip_addr: *const c_char,
+ host: *const c_char,
+ ) {
+ assert_eq!(net_id, TEST_NET_ID);
+ let ip_addr = std::ffi::CStr::from_ptr(ip_addr).to_str().unwrap();
+ let expected_addr: SocketAddr = LOOPBACK_ADDR.parse().unwrap();
+ assert_eq!(ip_addr, expected_addr.ip().to_string());
+ let host = std::ffi::CStr::from_ptr(host).to_str().unwrap();
+ assert_eq!(host, "");
+ }
+
+ fn make_testing_connection_variables() -> (Pin<Box<quiche::Connection>>, UdpSocket) {
+ let sk = super::make_doh_udp_socket(LOOPBACK_ADDR.parse().unwrap(), TEST_MARK).unwrap();
+ let udp_sk = UdpSocket::from_std(sk).unwrap();
+ let mut scid = [0; quiche::MAX_CONN_ID_LEN];
+ ring::rand::SystemRandom::new().fill(&mut scid).context("failed to generate scid").unwrap();
+ let connid = quiche::ConnectionId::from_ref(&scid);
+ let mut config = super::create_quiche_config(None).unwrap();
+ let quic_conn =
+ quiche::connect(None, &connid, LOOPBACK_ADDR.parse().unwrap(), &mut config).unwrap();
+ (quic_conn, udp_sk)
+ }
+
+ fn make_dummy_connected_state() -> super::ConnectionState {
+ let (quic_conn, udp_sk) = make_testing_connection_variables();
+ ConnectionState::Connected {
+ quic_conn,
+ udp_sk,
+ h3_conn: None,
+ query_map: HashMap::new(),
+ expired_time: None,
+ }
+ }
+
+ fn make_dummy_connecting_state() -> super::ConnectionState {
+ let (quic_conn, udp_sk) = make_testing_connection_variables();
+ ConnectionState::Connecting {
+ quic_conn: Some(quic_conn),
+ udp_sk: Some(udp_sk),
+ expired_time: None,
+ }
+ }
+
+ #[test]
+ fn report_private_dns_validation() {
+ let info = ServerInfo {
+ net_id: TEST_NET_ID,
+ url: Url::parse(LOCALHOST_URL).unwrap(),
+ peer_addr: LOOPBACK_ADDR.parse().unwrap(),
+ domain: None,
+ sk_mark: 0,
+ cert_path: None,
+ };
+ let rt = Arc::new(
+ Builder::new_current_thread()
+ .thread_name("test-runtime")
+ .enable_io()
+ .build()
+ .expect("Failed to create testing tokio runtime"),
+ );
+ let default_panic = std::panic::take_hook();
+ // Exit the test if the worker inside tokio runtime panicked.
+ std::panic::set_hook(Box::new(move |info| {
+ default_panic(info);
+ std::process::exit(1);
+ }));
+ rt.block_on(async {
+ super::report_private_dns_validation(
+ &info,
+ &make_dummy_connected_state(),
+ rt.clone(),
+ success_cb,
+ );
+ super::report_private_dns_validation(
+ &info,
+ &ConnectionState::Error,
+ rt.clone(),
+ fail_cb,
+ );
+ super::report_private_dns_validation(
+ &info,
+ &make_dummy_connecting_state(),
+ rt.clone(),
+ fail_cb,
+ );
+ super::report_private_dns_validation(
+ &info,
+ &ConnectionState::Idle,
+ rt.clone(),
+ fail_cb,
+ );
+ });
+ }
+
+ #[test]
+ fn make_probe_query_and_request() {
+ let probe_query = super::make_probe_query().unwrap();
+ let url = Url::parse(LOCALHOST_URL).unwrap();
+ let request = make_dns_request(&probe_query, &url).unwrap();
+ // Verify H3 DNS request.
+ assert_eq!(request.len(), H3_DNS_REQUEST_HEADER_SIZE);
+ assert_eq!(request[0].name(), b":method");
+ assert_eq!(request[0].value(), b"GET");
+ assert_eq!(request[1].name(), b":scheme");
+ assert_eq!(request[1].value(), b"https");
+ assert_eq!(request[2].name(), b":authority");
+ assert_eq!(request[2].value(), url.host_str().unwrap().as_bytes());
+ assert_eq!(request[3].name(), b":path");
+ let mut path = String::from(url.path());
+ path.push_str("?dns=");
+ path.push_str(&probe_query);
+ assert_eq!(request[3].value(), path.as_bytes());
+ assert_eq!(request[5].name(), b"accept");
+ assert_eq!(request[5].value(), b"application/dns-message");
+
+ // Verify DNS probe packet.
+ let bytes = base64::decode_config(probe_query, base64::URL_SAFE_NO_PAD).unwrap();
+ assert_eq!(bytes.len(), PROBE_QUERY_SIZE);
+ // TODO: Parse the result to ensure it's a valid DNS packet.
+ }
+
+ #[test]
+ fn create_quiche_config() {
+ assert!(
+ super::create_quiche_config(None).is_ok(),
+ "quiche config without cert creating failed"
+ );
+ assert!(
+ super::create_quiche_config(Some("data/local/tmp/")).is_ok(),
+ "quiche config with cert creating failed"
+ );
+ }
+
+ #[test]
+ fn make_doh_udp_socket() {
+ // Make a socket connecting to loopback with a test mark.
+ let sk = super::make_doh_udp_socket(LOOPBACK_ADDR.parse().unwrap(), TEST_MARK).unwrap();
+ // Check if the socket is connected to loopback.
+ assert_eq!(
+ sk.peer_addr().unwrap(),
+ SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), DOH_PORT))
+ );
+
+ // Check if the socket mark is correct.
+ let fd: RawFd = sk.as_raw_fd();
+
+ let mut mark: u32 = 50;
+ let mut size = std::mem::size_of::<u32>() as libc::socklen_t;
+ unsafe {
+ // Safety: It's fine since the fd belongs to this test.
+ assert_eq!(
+ libc::getsockopt(
+ fd,
+ libc::SOL_SOCKET,
+ libc::SO_MARK,
+ &mut mark as *mut _ as *mut libc::c_void,
+ &mut size as *mut libc::socklen_t,
+ ),
+ 0
+ );
+ }
+ assert_eq!(mark, TEST_MARK);
+
+ // Check if the socket is non-blocking.
+ unsafe {
+ // Safety: It's fine since the fd belongs to this test.
+ assert_eq!(libc::fcntl(fd, libc::F_GETFL, 0) & libc::O_NONBLOCK, libc::O_NONBLOCK);
+ }
+ }
+}
diff --git a/doh/ffi.rs b/doh/ffi.rs
new file mode 100644
index 00000000..1202337d
--- /dev/null
+++ b/doh/ffi.rs
@@ -0,0 +1,249 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+//! C API for the DoH backend for the Android DnsResolver module.
+
+use libc::{c_char, int32_t, size_t, ssize_t, uint32_t, uint64_t};
+use log::error;
+use std::net::{IpAddr, SocketAddr};
+use std::str::FromStr;
+use std::{ptr, slice};
+use tokio::runtime::Runtime;
+use tokio::sync::oneshot;
+use tokio::task;
+use tokio::time::{timeout, Duration, Instant};
+
+use super::{
+ DohCommand, DohDispatcher, Response, ServerInfo, TagSocketCallback, ValidationCallback,
+ DOH_PORT,
+};
+
+const SYSTEM_CERT_PATH: &str = "/system/etc/security/cacerts";
+
+/// The return code of doh_query means that there is no answer.
+pub const RESULT_INTERNAL_ERROR: ssize_t = -1;
+/// The return code of doh_query means that query can't be sent.
+pub const RESULT_CAN_NOT_SEND: ssize_t = -2;
+/// The return code of doh_query to indicate that the query timed out.
+pub const RESULT_TIMEOUT: ssize_t = -255;
+/// The error log level.
+pub const LOG_LEVEL_ERROR: u32 = 0;
+/// The warning log level.
+pub const LOG_LEVEL_WARN: u32 = 1;
+/// The info log level.
+pub const LOG_LEVEL_INFO: u32 = 2;
+/// The debug log level.
+pub const LOG_LEVEL_DEBUG: u32 = 3;
+/// The trace log level.
+pub const LOG_LEVEL_TRACE: u32 = 4;
+
+/// Performs static initialization for android logger.
+#[no_mangle]
+pub extern "C" fn doh_init_logger(level: u32) {
+ let level = match level {
+ LOG_LEVEL_WARN => log::Level::Warn,
+ LOG_LEVEL_DEBUG => log::Level::Debug,
+ _ => log::Level::Error,
+ };
+ android_logger::init_once(android_logger::Config::default().with_min_level(level));
+}
+
+/// Set the log level.
+#[no_mangle]
+pub extern "C" fn doh_set_log_level(level: u32) {
+ let level = match level {
+ LOG_LEVEL_ERROR => log::LevelFilter::Error,
+ LOG_LEVEL_WARN => log::LevelFilter::Warn,
+ LOG_LEVEL_INFO => log::LevelFilter::Info,
+ LOG_LEVEL_DEBUG => log::LevelFilter::Debug,
+ LOG_LEVEL_TRACE => log::LevelFilter::Trace,
+ _ => log::LevelFilter::Off,
+ };
+ log::set_max_level(level);
+}
+
+/// Performs the initialization for the DoH engine.
+/// Creates and returns a DoH engine instance.
+#[no_mangle]
+pub extern "C" fn doh_dispatcher_new(
+ validation_fn: ValidationCallback,
+ tag_socket_fn: TagSocketCallback,
+) -> *mut DohDispatcher {
+ match DohDispatcher::new(validation_fn, tag_socket_fn) {
+ Ok(c) => Box::into_raw(c),
+ Err(e) => {
+ error!("doh_dispatcher_new: failed: {:?}", e);
+ ptr::null_mut()
+ }
+ }
+}
+
+/// Deletes a DoH engine created by doh_dispatcher_new().
+/// # Safety
+/// `doh` must be a non-null pointer previously created by `doh_dispatcher_new()`
+/// and not yet deleted by `doh_dispatcher_delete()`.
+#[no_mangle]
+pub unsafe extern "C" fn doh_dispatcher_delete(doh: *mut DohDispatcher) {
+ Box::from_raw(doh).exit_handler()
+}
+
+/// Probes and stores the DoH server with the given configurations.
+/// Use the negative errno-style codes as the return value to represent the result.
+/// # Safety
+/// `doh` must be a non-null pointer previously created by `doh_dispatcher_new()`
+/// and not yet deleted by `doh_dispatcher_delete()`.
+/// `url`, `domain`, `ip_addr`, `cert_path` are null terminated strings.
+#[no_mangle]
+pub unsafe extern "C" fn doh_net_new(
+ doh: &mut DohDispatcher,
+ net_id: uint32_t,
+ url: *const c_char,
+ domain: *const c_char,
+ ip_addr: *const c_char,
+ sk_mark: libc::uint32_t,
+ cert_path: *const c_char,
+ timeout_ms: libc::uint64_t,
+) -> int32_t {
+ let (url, domain, ip_addr, cert_path) = match (
+ std::ffi::CStr::from_ptr(url).to_str(),
+ std::ffi::CStr::from_ptr(domain).to_str(),
+ std::ffi::CStr::from_ptr(ip_addr).to_str(),
+ std::ffi::CStr::from_ptr(cert_path).to_str(),
+ ) {
+ (Ok(url), Ok(domain), Ok(ip_addr), Ok(cert_path)) => {
+ if domain.is_empty() {
+ (url, None, ip_addr.to_string(), None)
+ } else if !cert_path.is_empty() {
+ (url, Some(domain.to_string()), ip_addr.to_string(), Some(cert_path.to_string()))
+ } else {
+ (
+ url,
+ Some(domain.to_string()),
+ ip_addr.to_string(),
+ Some(SYSTEM_CERT_PATH.to_string()),
+ )
+ }
+ }
+ _ => {
+ error!("bad input"); // Should not happen
+ return -libc::EINVAL;
+ }
+ };
+
+ let (url, ip_addr) = match (url::Url::parse(url), IpAddr::from_str(&ip_addr)) {
+ (Ok(url), Ok(ip_addr)) => (url, ip_addr),
+ _ => {
+ error!("bad ip or url"); // Should not happen
+ return -libc::EINVAL;
+ }
+ };
+ let cmd = DohCommand::Probe {
+ info: ServerInfo {
+ net_id,
+ url,
+ peer_addr: SocketAddr::new(ip_addr, DOH_PORT),
+ domain,
+ sk_mark,
+ cert_path,
+ },
+ timeout: Duration::from_millis(timeout_ms),
+ };
+ if let Err(e) = doh.send_cmd(cmd) {
+ error!("Failed to send the probe: {:?}", e);
+ return -libc::EPIPE;
+ }
+ 0
+}
+
+/// Sends a DNS query via the network associated to the given |net_id| and waits for the response.
+/// The return code should be either one of the public constant RESULT_* to indicate the error or
+/// the size of the answer.
+/// # Safety
+/// `doh` must be a non-null pointer previously created by `doh_dispatcher_new()`
+/// and not yet deleted by `doh_dispatcher_delete()`.
+/// `dns_query` must point to a buffer at least `dns_query_len` in size.
+/// `response` must point to a buffer at least `response_len` in size.
+#[no_mangle]
+pub unsafe extern "C" fn doh_query(
+ doh: &mut DohDispatcher,
+ net_id: uint32_t,
+ dns_query: *mut u8,
+ dns_query_len: size_t,
+ response: *mut u8,
+ response_len: size_t,
+ timeout_ms: uint64_t,
+) -> ssize_t {
+ let q = slice::from_raw_parts_mut(dns_query, dns_query_len);
+
+ let (resp_tx, resp_rx) = oneshot::channel();
+ let t = Duration::from_millis(timeout_ms);
+ if let Some(expired_time) = Instant::now().checked_add(t) {
+ let cmd = DohCommand::Query {
+ net_id,
+ base64_query: base64::encode_config(q, base64::URL_SAFE_NO_PAD),
+ expired_time,
+ resp: resp_tx,
+ };
+
+ if let Err(e) = doh.send_cmd(cmd) {
+ error!("Failed to send the query: {:?}", e);
+ return RESULT_CAN_NOT_SEND;
+ }
+ } else {
+ error!("Bad timeout parameter: {}", timeout_ms);
+ return RESULT_CAN_NOT_SEND;
+ }
+
+ if let Ok(rt) = Runtime::new() {
+ let local = task::LocalSet::new();
+ match local.block_on(&rt, async { timeout(t, resp_rx).await }) {
+ Ok(v) => match v {
+ Ok(v) => match v {
+ Response::Success { answer } => {
+ if answer.len() > response_len || answer.len() > isize::MAX as usize {
+ return RESULT_INTERNAL_ERROR;
+ }
+ let response = slice::from_raw_parts_mut(response, answer.len());
+ response.copy_from_slice(&answer);
+ answer.len() as ssize_t
+ }
+ _ => RESULT_CAN_NOT_SEND,
+ },
+ Err(e) => {
+ error!("no result {}", e);
+ RESULT_CAN_NOT_SEND
+ }
+ },
+ Err(e) => {
+ error!("timeout: {}", e);
+ RESULT_TIMEOUT
+ }
+ }
+ } else {
+ RESULT_CAN_NOT_SEND
+ }
+}
+
+/// Clears the DoH servers associated with the given |netid|.
+/// # Safety
+/// `doh` must be a non-null pointer previously created by `doh_dispatcher_new()`
+/// and not yet deleted by `doh_dispatcher_delete()`.
+#[no_mangle]
+pub extern "C" fn doh_net_delete(doh: &mut DohDispatcher, net_id: uint32_t) {
+ if let Err(e) = doh.send_cmd(DohCommand::Clear { net_id }) {
+ error!("Failed to send the query: {:?}", e);
+ }
+}