diff options
author | Matthew Maurer <mmaurer@google.com> | 2021-09-16 16:16:58 -0700 |
---|---|---|
committer | Matthew Maurer <mmaurer@google.com> | 2021-10-04 20:53:20 -0700 |
commit | aa0dac6f1b9b2a09bc3c39688521f9dc1ec8ac1f (patch) | |
tree | 6e37978dac4c343cf63cab4ab5c6ffab082aaa17 /doh | |
parent | ae5fe72c344017e4fecfef661ba4391dfde89e97 (diff) | |
download | DnsResolver-aa0dac6f1b9b2a09bc3c39688521f9dc1ec8ac1f.tar.gz |
DoH: Split out FFI logic to separate module
Bug: 202081046
Change-Id: Ie9093ab14a4eb4fc17381ad81f362dd209038d4d
Diffstat (limited to 'doh')
-rw-r--r-- | doh/doh.rs | 1366 | ||||
-rw-r--r-- | doh/ffi.rs | 249 |
2 files changed, 1615 insertions, 0 deletions
diff --git a/doh/doh.rs b/doh/doh.rs new file mode 100644 index 00000000..95daf648 --- /dev/null +++ b/doh/doh.rs @@ -0,0 +1,1366 @@ +/* + * Copyright (C) 2021 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +//! DoH backend for the Android DnsResolver module. + +use anyhow::{anyhow, bail, Context, Result}; +use futures::future::join_all; +use futures::stream::FuturesUnordered; +use futures::StreamExt; +use libc::{c_char, int32_t, uint32_t}; +use log::{debug, error, info, trace, warn}; +use quiche::h3; +use ring::rand::SecureRandom; +use std::collections::HashMap; +use std::ffi::CString; +use std::net::SocketAddr; +use std::ops::Deref; +use std::os::unix::io::{AsRawFd, RawFd}; +use std::pin::Pin; +use std::sync::{Arc, Mutex}; +use tokio::net::UdpSocket; +use tokio::runtime::{Builder, Runtime}; +use tokio::sync::{mpsc, oneshot}; +use tokio::task; +use tokio::time::{timeout, Duration, Instant}; +use url::Url; + +mod ffi; + +const MAX_BUFFERED_CMD_SIZE: usize = 400; +const MAX_INCOMING_BUFFER_SIZE_WHOLE: u64 = 10000000; +const MAX_INCOMING_BUFFER_SIZE_EACH: u64 = 1000000; +const MAX_CONCURRENT_STREAM_SIZE: u64 = 100; +const MAX_DATAGRAM_SIZE: usize = 1350; +const DOH_PORT: u16 = 443; +const QUICHE_IDLE_TIMEOUT_MS: u64 = 180000; +const NS_T_AAAA: u8 = 28; +const NS_C_IN: u8 = 1; +// Used to randomly generate query prefix and query id. +const CHARSET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ\ + abcdefghijklmnopqrstuvwxyz\ + 0123456789"; + +type SCID = [u8; quiche::MAX_CONN_ID_LEN]; +type Base64Query = String; +type CmdSender = mpsc::Sender<DohCommand>; +type CmdReceiver = mpsc::Receiver<DohCommand>; +type QueryResponder = oneshot::Sender<Response>; +type DnsRequest = Vec<quiche::h3::Header>; +type ValidationCallback = + extern "C" fn(net_id: uint32_t, success: bool, ip_addr: *const c_char, host: *const c_char); +type TagSocketCallback = extern "C" fn(sock: int32_t); + +#[derive(Eq, PartialEq, Debug)] +enum QueryError { + BrokenServer, + ConnectionError, + ServerNotReady, + Unexpected, +} + +#[derive(Eq, PartialEq, Debug, Clone)] +struct ServerInfo { + net_id: u32, + url: Url, + peer_addr: SocketAddr, + domain: Option<String>, + sk_mark: u32, + cert_path: Option<String>, +} + +#[derive(Eq, PartialEq, Debug)] +enum Response { + Error { error: QueryError }, + Success { answer: Vec<u8> }, +} + +#[derive(Debug)] +enum DohCommand { + Probe { info: ServerInfo, timeout: Duration }, + Query { net_id: u32, base64_query: Base64Query, expired_time: Instant, resp: QueryResponder }, + Clear { net_id: u32 }, + Exit, +} + +#[allow(clippy::large_enum_variant)] +enum ConnectionState { + Idle, + Connecting { + quic_conn: Option<Pin<Box<quiche::Connection>>>, + udp_sk: Option<UdpSocket>, + expired_time: Option<BootTime>, + }, + Connected { + quic_conn: Pin<Box<quiche::Connection>>, + udp_sk: UdpSocket, + h3_conn: Option<h3::Connection>, + query_map: HashMap<u64, (Vec<u8>, QueryResponder)>, + expired_time: Option<BootTime>, + }, + /// Indicate that the Connection can't be used due to + /// network or unexpected reasons. + Error, +} + +enum H3Result { + Data { data: Vec<u8> }, + Finished, + Ignore, +} + +trait OptionDeref<T: Deref> { + fn as_deref(&self) -> Option<&T::Target>; +} + +impl<T: Deref> OptionDeref<T> for Option<T> { + fn as_deref(&self) -> Option<&T::Target> { + self.as_ref().map(Deref::deref) + } +} + +#[derive(Copy, Clone, Debug)] +struct BootTime { + d: Duration, +} + +impl BootTime { + fn now() -> BootTime { + unsafe { + let mut t = libc::timespec { tv_sec: 0, tv_nsec: 0 }; + if libc::clock_gettime(libc::CLOCK_BOOTTIME, &mut t as *mut libc::timespec) != 0 { + panic!("get boot time failed: {:?}", std::io::Error::last_os_error()); + } + BootTime { d: Duration::new(t.tv_sec as u64, t.tv_nsec as u32) } + } + } + + fn elapsed(&self) -> Option<Duration> { + BootTime::now().duration_since(*self) + } + + fn checked_add(&self, duration: Duration) -> Option<BootTime> { + Some(BootTime { d: self.d.checked_add(duration)? }) + } + + fn duration_since(&self, earlier: BootTime) -> Option<Duration> { + self.d.checked_sub(earlier.d) + } +} + +/// Context for a running DoH engine. +pub struct DohDispatcher { + /// Used to submit cmds to the I/O task. + cmd_sender: CmdSender, + join_handle: task::JoinHandle<Result<()>>, + runtime: Arc<Runtime>, +} + +// DoH dispatcher +impl DohDispatcher { + fn new( + validation_fn: ValidationCallback, + tag_socket_fn: TagSocketCallback, + ) -> Result<Box<DohDispatcher>> { + let (cmd_sender, cmd_receiver) = mpsc::channel::<DohCommand>(MAX_BUFFERED_CMD_SIZE); + let runtime = Arc::new( + Builder::new_multi_thread() + .worker_threads(2) + .enable_all() + .thread_name("doh-handler") + .build() + .expect("Failed to create tokio runtime"), + ); + let join_handle = + runtime.spawn(doh_handler(cmd_receiver, runtime.clone(), validation_fn, tag_socket_fn)); + Ok(Box::new(DohDispatcher { cmd_sender, join_handle, runtime })) + } + + fn send_cmd(&self, cmd: DohCommand) -> Result<()> { + self.cmd_sender.blocking_send(cmd)?; + Ok(()) + } + + fn exit_handler(&mut self) { + if self.cmd_sender.blocking_send(DohCommand::Exit).is_err() { + return; + } + let _ = self.runtime.block_on(&mut self.join_handle); + } +} + +struct DohConnection { + info: ServerInfo, + shared_config: Arc<Mutex<QuicheConfigCache>>, + scid: SCID, + state: ConnectionState, + pending_queries: Vec<(DnsRequest, QueryResponder, Instant)>, + cached_session: Option<Vec<u8>>, + tag_socket_fn: TagSocketCallback, +} + +impl DohConnection { + fn new( + info: &ServerInfo, + shared_config: Arc<Mutex<QuicheConfigCache>>, + tag_socket_fn: TagSocketCallback, + ) -> Result<DohConnection> { + let mut scid = [0; quiche::MAX_CONN_ID_LEN]; + ring::rand::SystemRandom::new().fill(&mut scid).context("failed to generate scid")?; + Ok(DohConnection { + info: info.clone(), + shared_config, + scid, + state: ConnectionState::Idle, + pending_queries: Vec::new(), + cached_session: None, + tag_socket_fn, + }) + } + + fn state_to_connecting(&mut self) -> Result<()> { + self.state = match self.state { + ConnectionState::Idle => { + let udp_sk_std = make_doh_udp_socket(self.info.peer_addr, self.info.sk_mark)?; + (self.tag_socket_fn)(udp_sk_std.as_raw_fd()); + let udp_sk = UdpSocket::from_std(udp_sk_std)?; + let connid = quiche::ConnectionId::from_ref(&self.scid); + let mut cache = self.shared_config.lock().unwrap(); + let config = + cache.get(&self.info.cert_path)?.ok_or_else(|| anyhow!("no quiche config"))?; + debug!("init the connection for Network {}", self.info.net_id); + let mut quic_conn = quiche::connect( + self.info.domain.as_deref(), + &connid, + self.info.peer_addr, + config, + )?; + if let Some(session) = &self.cached_session { + if quic_conn.set_session(session).is_err() { + warn!("can't restore session for network {}", self.info.net_id); + } + } + ConnectionState::Connecting { + quic_conn: Some(quic_conn), + udp_sk: Some(udp_sk), + expired_time: None, + } + } + ConnectionState::Error => { + self.state_to_idle(); + return self.state_to_connecting(); + } + ConnectionState::Connecting { .. } => return Ok(()), + ConnectionState::Connected { .. } => { + panic!("Invalid state transition to Connecting state!") + } + }; + Ok(()) + } + + fn state_to_connected(&mut self) -> Result<()> { + self.state = match &mut self.state { + // Only Connecting -> Connected is valid. + ConnectionState::Connecting { quic_conn, udp_sk, .. } => { + if let (Some(mut quic_conn), Some(udp_sk)) = (quic_conn.take(), udp_sk.take()) { + let h3_config = h3::Config::new()?; + let h3_conn = + quiche::h3::Connection::with_transport(&mut quic_conn, &h3_config)?; + ConnectionState::Connected { + quic_conn, + udp_sk, + h3_conn: Some(h3_conn), + query_map: HashMap::new(), + expired_time: None, + } + } else { + bail!("state transition fail!"); + } + } + // The rest should fail. + _ => panic!("Invalid state transition to Connected state!"), + }; + Ok(()) + } + + fn state_to_idle(&mut self) { + self.state = match self.state { + // Only either Connected or Error -> Idle is valid. + // TODO: Error -> Idle is the re-probing case, add the relevant statistic. + ConnectionState::Connected { .. } | ConnectionState::Error => ConnectionState::Idle, + // The rest should fail. + _ => panic!("Invalid state transition to Idle state!"), + } + } + + fn state_to_error(&mut self) { + self.pending_queries.clear(); + self.state = ConnectionState::Error + } + + fn is_reprobe_required(&self) -> bool { + matches!(self.state, ConnectionState::Error) + } + + fn has_not_handled_queries(&self) -> bool { + match &self.state { + ConnectionState::Connecting { .. } | ConnectionState::Idle => { + !self.pending_queries.is_empty() + } + ConnectionState::Connected { query_map, .. } => { + !query_map.is_empty() || !self.pending_queries.is_empty() + } + _ => false, + } + } + + fn handle_if_connection_expired(&mut self) { + let expired_time = match &mut self.state { + ConnectionState::Connecting { expired_time, .. } => expired_time, + ConnectionState::Connected { expired_time, .. } => expired_time, + // ignore + _ => return, + }; + + if let Some(expired_time) = expired_time { + if let Some(elapsed) = expired_time.elapsed() { + warn!( + "Change the state to Idle due to connection timeout, {:?}, {}", + elapsed, self.info.net_id + ); + self.state_to_idle(); + } + } + } + + async fn probe(&mut self, t: Duration) -> Result<()> { + match timeout(t, async { + self.try_connect().await?; + info!("probe start for {}", self.info.net_id); + if let ConnectionState::Connected { quic_conn, udp_sk, h3_conn, expired_time, .. } = + &mut self.state + { + let h3_conn = h3_conn.as_mut().ok_or_else(|| anyhow!("h3 conn isn't available"))?; + let req = match make_probe_query() { + Ok(q) => match make_dns_request(&q, &self.info.url) { + Ok(req) => req, + Err(e) => bail!(e), + }, + Err(e) => bail!(e), + }; + // Send the probe query. + let req_id = h3_conn.send_request(quic_conn, &req, true /*fin*/)?; + loop { + flush_tx(quic_conn, udp_sk).await?; + recv_rx(quic_conn, udp_sk, expired_time).await?; + loop { + match recv_h3(quic_conn, h3_conn) { + Ok((stream_id, H3Result::Finished)) => { + if stream_id == req_id { + return Ok(()); + } + } + // TODO: Verify the answer + Ok((_stream_id, H3Result::Data { .. })) => {} + Ok((_stream_id, H3Result::Ignore)) => {} + Err(_) => break, + } + } + } + } else { + bail!("state error while performing probe()"); + } + }) + .await + { + Ok(v) => match v { + Ok(_) => Ok(()), + Err(e) => { + self.state_to_error(); + bail!(e); + } + }, + Err(e) => { + self.state_to_error(); + bail!(e); + } + } + } + + async fn try_connect(&mut self) -> Result<()> { + if matches!(self.state, ConnectionState::Connected { .. }) { + return Ok(()); + } + self.state_to_connecting()?; + debug!("connecting to Network {}", self.info.net_id); + + let (quic_conn, udp_sk, expired_time) = match &mut self.state { + ConnectionState::Connecting { quic_conn, udp_sk, expired_time, .. } => { + if let (Some(quic_conn), Some(udp_sk)) = (quic_conn.as_mut(), udp_sk.as_mut()) { + (quic_conn, udp_sk, expired_time) + } else { + bail!("unexpected error while performing connect()"); + } + } + _ => bail!("state error while performing try_connect()"), + }; + + while !quic_conn.is_established() { + flush_tx(quic_conn, udp_sk).await?; + recv_rx(quic_conn, udp_sk, expired_time).await?; + } + self.cached_session = quic_conn.session(); + self.state_to_connected()?; + info!("connected to Network {}", self.info.net_id); + Ok(()) + } + + async fn try_send_doh_query( + &mut self, + req: DnsRequest, + resp: QueryResponder, + expired_time: Instant, + ) -> Result<()> { + self.handle_if_connection_expired(); + match &mut self.state { + ConnectionState::Connected { quic_conn, udp_sk, h3_conn, query_map, .. } => { + let h3_conn = h3_conn.as_mut().ok_or_else(|| anyhow!("h3 conn isn't available"))?; + send_dns_query( + quic_conn, + udp_sk, + h3_conn, + query_map, + &mut self.pending_queries, + resp, + expired_time, + req, + ) + .await? + } + ConnectionState::Connecting { .. } | ConnectionState::Idle => { + self.pending_queries.push((req, resp, expired_time)) + } + ConnectionState::Error => { + error!( + "state is error while performing try_send_doh_query(), network: {}", + self.info.net_id + ); + let _ = resp.send(Response::Error { error: QueryError::BrokenServer }); + } + } + Ok(()) + } + + async fn process_queries(&mut self) -> Result<()> { + debug!("process_queries entry, Network {}", self.info.net_id); + self.try_connect().await?; + if let ConnectionState::Connected { quic_conn, udp_sk, h3_conn, query_map, expired_time } = + &mut self.state + { + let h3_conn = h3_conn.as_mut().ok_or_else(|| anyhow!("h3 conn isn't available"))?; + loop { + while !self.pending_queries.is_empty() { + if let Some((req, resp, exp_time)) = self.pending_queries.pop() { + // Ignore the expired queries. + if Instant::now().checked_duration_since(exp_time).is_some() { + warn!("Drop the obsolete query for network {}", self.info.net_id); + continue; + } + send_dns_query( + quic_conn, + udp_sk, + h3_conn, + query_map, + &mut self.pending_queries, + resp, + exp_time, + req, + ) + .await?; + } + } + flush_tx(quic_conn, udp_sk).await?; + recv_rx(quic_conn, udp_sk, expired_time).await?; + loop { + match recv_h3(quic_conn, h3_conn) { + Ok((stream_id, H3Result::Data { mut data })) => { + if let Some((answer, _)) = query_map.get_mut(&stream_id) { + answer.append(&mut data); + } else { + // Should not happen + warn!("No associated receiver found while receiving Data, Network {}, stream id: {}", self.info.net_id, stream_id); + } + } + Ok((stream_id, H3Result::Finished)) => { + if let Some((answer, resp)) = query_map.remove(&stream_id) { + debug!( + "sending answer back to resolv, Network {}, stream id: {}", + self.info.net_id, stream_id + ); + resp.send(Response::Success { answer }).unwrap_or_else(|e| { + trace!( + "the receiver dropped {:?}, stream id: {}", + e, + stream_id + ); + }); + } else { + // Should not happen + warn!("No associated receiver found while receiving Finished, Network {}, stream id: {}", self.info.net_id, stream_id); + } + } + Ok((_stream_id, H3Result::Ignore)) => {} + Err(_) => break, + } + } + if quic_conn.is_closed() || !quic_conn.is_established() { + self.state_to_idle(); + bail!("connection become idle"); + } + } + } else { + self.state_to_error(); + bail!("state error while performing process_queries(), network: {}", self.info.net_id); + } + } +} + +fn recv_h3( + quic_conn: &mut Pin<Box<quiche::Connection>>, + h3_conn: &mut h3::Connection, +) -> Result<(u64, H3Result)> { + match h3_conn.poll(quic_conn) { + // Process HTTP/3 events. + Ok((stream_id, quiche::h3::Event::Data)) => { + debug!("quiche::h3::Event::Data"); + let mut buf = vec![0; MAX_DATAGRAM_SIZE]; + match h3_conn.recv_body(quic_conn, stream_id, &mut buf) { + Ok(read) => { + trace!( + "got {} bytes of response data on stream {}: {:x?}", + read, + stream_id, + &buf[..read] + ); + buf.truncate(read); + Ok((stream_id, H3Result::Data { data: buf })) + } + Err(e) => { + warn!("recv_h3::recv_body {:?}", e); + bail!(e); + } + } + } + Ok((stream_id, quiche::h3::Event::Headers { list, has_body })) => { + trace!( + "got response headers {:?} on stream id {} has_body {}", + list, + stream_id, + has_body + ); + Ok((stream_id, H3Result::Ignore)) + } + Ok((stream_id, quiche::h3::Event::Finished)) => { + debug!("quiche::h3::Event::Finished on stream id {}", stream_id); + Ok((stream_id, H3Result::Finished)) + } + Ok((stream_id, quiche::h3::Event::Datagram)) => { + debug!("quiche::h3::Event::Datagram on stream id {}", stream_id); + Ok((stream_id, H3Result::Ignore)) + } + // TODO: Check if it's necessary to handle GoAway event. + Ok((stream_id, quiche::h3::Event::GoAway)) => { + debug!("quiche::h3::Event::GoAway on stream id {}", stream_id); + Ok((stream_id, H3Result::Ignore)) + } + Err(e) => { + debug!("recv_h3 {:?}", e); + bail!(e); + } + } +} + +#[allow(clippy::too_many_arguments)] +async fn send_dns_query( + quic_conn: &mut Pin<Box<quiche::Connection>>, + udp_sk: &mut UdpSocket, + h3_conn: &mut h3::Connection, + query_map: &mut HashMap<u64, (Vec<u8>, QueryResponder)>, + pending_queries: &mut Vec<(DnsRequest, QueryResponder, Instant)>, + resp: QueryResponder, + expired_time: Instant, + req: DnsRequest, +) -> Result<()> { + if !quic_conn.is_established() { + bail!("quic connection is not ready"); + } + match h3_conn.send_request(quic_conn, &req, true /*fin*/) { + Ok(stream_id) => { + query_map.insert(stream_id, (Vec::new(), resp)); + flush_tx(quic_conn, udp_sk).await?; + debug!("send dns query successfully stream id: {}", stream_id); + Ok(()) + } + Err(quiche::h3::Error::StreamBlocked) => { + warn!("try to send query but error on StreamBlocked"); + pending_queries.push((req, resp, expired_time)); + Ok(()) + } + Err(e) => { + resp.send(Response::Error { error: QueryError::ConnectionError }).ok(); + bail!(e); + } + } +} + +async fn recv_rx( + quic_conn: &mut Pin<Box<quiche::Connection>>, + udp_sk: &mut UdpSocket, + expired_time: &mut Option<BootTime>, +) -> Result<()> { + // TODO: Evaluate if we could make the buffer smaller. + let mut buf = [0; 65535]; + let quic_idle_timeout_ms = Duration::from_millis(QUICHE_IDLE_TIMEOUT_MS); + let ts = quic_conn.timeout().unwrap_or(quic_idle_timeout_ms); + + if let Some(next_expired) = BootTime::now().checked_add(quic_idle_timeout_ms) { + expired_time.replace(next_expired); + } else { + expired_time.take(); + } + debug!("recv_rx entry next timeout {:?} {:?}", ts, expired_time); + match timeout(ts, udp_sk.recv_from(&mut buf)).await { + Ok(v) => match v { + Ok((size, from)) => { + let recv_info = quiche::RecvInfo { from }; + let processed = match quic_conn.recv(&mut buf[..size], recv_info) { + Ok(l) => l, + Err(e) => { + debug!("recv_rx error {:?}", e); + bail!("quic recv failed: {:?}", e); + } + }; + debug!("processed {} bytes", processed); + Ok(()) + } + Err(e) => bail!("socket recv failed: {:?}", e), + }, + Err(_) => { + warn!("timeout did not receive value within {:?}", ts); + quic_conn.on_timeout(); + Ok(()) + } + } +} + +async fn flush_tx( + quic_conn: &mut Pin<Box<quiche::Connection>>, + udp_sk: &mut UdpSocket, +) -> Result<()> { + let mut out = [0; MAX_DATAGRAM_SIZE]; + loop { + let (write, _) = match quic_conn.send(&mut out) { + Ok(v) => v, + Err(quiche::Error::Done) => { + debug!("done writing"); + break; + } + Err(e) => { + quic_conn.close(false, 0x1, b"fail").ok(); + bail!(e); + } + }; + udp_sk.send(&out[..write]).await?; + debug!("written {}", write); + } + Ok(()) +} + +fn report_private_dns_validation( + info: &ServerInfo, + state: &ConnectionState, + runtime: Arc<Runtime>, + validation_fn: ValidationCallback, +) { + let (ip_addr, domain) = match ( + CString::new(info.peer_addr.ip().to_string()), + CString::new(info.domain.clone().unwrap_or_default()), + ) { + (Ok(ip_addr), Ok(domain)) => (ip_addr, domain), + _ => { + error!("report_private_dns_validation bad input"); + return; + } + }; + let netd_id = info.net_id; + let success = matches!(state, ConnectionState::Connected { .. }); + runtime + .spawn_blocking(move || validation_fn(netd_id, success, ip_addr.as_ptr(), domain.as_ptr())); +} + +fn handle_probe_result( + result: (ServerInfo, Result<DohConnection, (anyhow::Error, DohConnection)>), + doh_conn_map: &mut HashMap<u32, (ServerInfo, Option<DohConnection>)>, + runtime: Arc<Runtime>, + validation_fn: ValidationCallback, +) { + let (info, doh_conn) = match result { + (info, Ok(doh_conn)) => { + info!("probing_task success on net_id: {}", info.net_id); + (info, doh_conn) + } + (info, Err((e, doh_conn))) => { + error!("probe failed on network {}, {:?}", e, info.net_id); + (info, doh_conn) + // TODO: Retry probe? + } + }; + // If the network is removed or the server is replaced before probing, + // ignore the probe result. + match doh_conn_map.get(&info.net_id) { + Some((server_info, _)) => { + if *server_info != info { + warn!( + "The previous configuration for network {} was replaced before probe finished", + info.net_id + ); + return; + } + } + _ => { + warn!("network {} was removed before probe finished", info.net_id); + return; + } + } + report_private_dns_validation(&info, &doh_conn.state, runtime, validation_fn); + doh_conn_map.insert(info.net_id, (info, Some(doh_conn))); +} + +async fn probe_task( + info: ServerInfo, + mut doh: DohConnection, + t: Duration, +) -> (ServerInfo, Result<DohConnection, (anyhow::Error, DohConnection)>) { + match doh.probe(t).await { + Ok(_) => (info, Ok(doh)), + Err(e) => (info, Err((anyhow!(e), doh))), + } +} + +fn make_connection_if_needed( + info: &ServerInfo, + doh_conn_map: &mut HashMap<u32, (ServerInfo, Option<DohConnection>)>, + shared_config: Arc<Mutex<QuicheConfigCache>>, + tag_socket_fn: TagSocketCallback, +) -> Result<Option<DohConnection>> { + // Check if connection exists. + match doh_conn_map.get(&info.net_id) { + // The connection exists but has failed. Re-probe. + Some((server_info, Some(doh))) if *server_info == *info && doh.is_reprobe_required() => { + let (_, doh) = doh_conn_map + .insert(info.net_id, (info.clone(), None)) + .ok_or_else(|| anyhow!("unexpected error, missing connection"))?; + return Ok(doh); + } + // The connection exists or the connection is under probing, ignore. + Some((server_info, _)) if *server_info == *info => return Ok(None), + // TODO: change the inner connection instead of removing? + _ => doh_conn_map.remove(&info.net_id), + }; + let doh = DohConnection::new(info, shared_config, tag_socket_fn)?; + doh_conn_map.insert(info.net_id, (info.clone(), None)); + Ok(Some(doh)) +} + +struct QuicheConfigCache { + cert_path: Option<String>, + config: Option<quiche::Config>, +} + +impl QuicheConfigCache { + fn get(&mut self, cert_path: &Option<String>) -> Result<Option<&mut quiche::Config>> { + // No config is cached or the cached config isn't matched with the input cert_path + // Create it with the input cert_path. + if self.config.is_none() || self.cert_path != *cert_path { + self.config = Some(create_quiche_config(cert_path.as_deref())?); + self.cert_path = cert_path.clone(); + } + return Ok(self.config.as_mut()); + } +} + +async fn handle_query_cmd( + net_id: u32, + base64_query: Base64Query, + expired_time: Instant, + resp: QueryResponder, + doh_conn_map: &mut HashMap<u32, (ServerInfo, Option<DohConnection>)>, +) { + if let Some((info, quic_conn)) = doh_conn_map.get_mut(&net_id) { + match (&info.domain, quic_conn) { + // Connection is not ready, strict mode + (Some(_), None) => { + let _ = resp.send(Response::Error { error: QueryError::ServerNotReady }); + } + // Connection is not ready, Opportunistic mode + (None, None) => { + let _ = resp.send(Response::Error { error: QueryError::ServerNotReady }); + } + // Connection is ready + (_, Some(quic_conn)) => { + if let Ok(req) = make_dns_request(&base64_query, &info.url) { + let _ = quic_conn.try_send_doh_query(req, resp, expired_time).await; + } else { + let _ = resp.send(Response::Error { error: QueryError::Unexpected }); + } + } + } + } else { + error!("No connection is associated with the given net id {}", net_id); + let _ = resp.send(Response::Error { error: QueryError::ServerNotReady }); + } +} +fn need_process_queries(doh_conn_map: &HashMap<u32, (ServerInfo, Option<DohConnection>)>) -> bool { + if doh_conn_map.is_empty() { + return false; + } + for (_, doh_conn) in doh_conn_map.values() { + if let Some(doh_conn) = doh_conn { + if doh_conn.has_not_handled_queries() { + return true; + } + } + } + false +} + +async fn doh_handler( + mut cmd_rx: CmdReceiver, + runtime: Arc<Runtime>, + validation_fn: ValidationCallback, + tag_socket_fn: TagSocketCallback, +) -> Result<()> { + info!("doh_dispatcher entry"); + let config_cache = Arc::new(Mutex::new(QuicheConfigCache { cert_path: None, config: None })); + + // Currently, only support 1 server per network. + let mut doh_conn_map: HashMap<u32, (ServerInfo, Option<DohConnection>)> = HashMap::new(); + let mut probe_futures = FuturesUnordered::new(); + loop { + tokio::select! { + _ = async { + let mut futures = vec![]; + for (_, doh_conn) in doh_conn_map.values_mut() { + if let Some(doh_conn) = doh_conn { + futures.push(doh_conn.process_queries()); + } + } + join_all(futures).await + }, if need_process_queries(&doh_conn_map) => {}, + Some(result) = probe_futures.next() => { + let runtime_clone = runtime.clone(); + handle_probe_result(result, &mut doh_conn_map, runtime_clone, validation_fn); + info!("probe_futures remaining size: {}", probe_futures.len()); + }, + Some(cmd) = cmd_rx.recv() => { + trace!("recv {:?}", cmd); + match cmd { + DohCommand::Probe { info, timeout: t } => { + match make_connection_if_needed(&info, &mut doh_conn_map, config_cache.clone(), tag_socket_fn) { + Ok(Some(doh)) => { + // Create a new async task associated to the DoH connection. + probe_futures.push(probe_task(info, doh, t)); + debug!("probe_futures size: {}", probe_futures.len()); + } + Ok(None) => { + // No further probe is needed. + warn!("connection for network {} already exists", info.net_id); + // TODO: Report the status again? + } + Err(e) => { + error!("create connection for network {} error {:?}", info.net_id, e); + report_private_dns_validation(&info, &ConnectionState::Error, runtime.clone(), validation_fn); + } + } + }, + DohCommand::Query { net_id, base64_query, expired_time, resp } => { + handle_query_cmd(net_id, base64_query, expired_time, resp, &mut doh_conn_map).await; + }, + DohCommand::Clear { net_id } => { + doh_conn_map.remove(&net_id); + info!("Doh Clear server for netid: {}", net_id); + }, + DohCommand::Exit => return Ok(()), + } + } + } + } +} + +fn make_dns_request(base64_query: &str, url: &url::Url) -> Result<DnsRequest> { + let mut path = String::from(url.path()); + path.push_str("?dns="); + path.push_str(base64_query); + let req = vec![ + quiche::h3::Header::new(b":method", b"GET"), + quiche::h3::Header::new(b":scheme", b"https"), + quiche::h3::Header::new( + b":authority", + url.host_str().ok_or_else(|| anyhow!("failed to get host"))?.as_bytes(), + ), + quiche::h3::Header::new(b":path", path.as_bytes()), + quiche::h3::Header::new(b"user-agent", b"quiche"), + quiche::h3::Header::new(b"accept", b"application/dns-message"), + // TODO: is content-length required? + ]; + + Ok(req) +} + +fn make_doh_udp_socket(peer_addr: SocketAddr, mark: u32) -> Result<std::net::UdpSocket> { + let bind_addr = match peer_addr { + std::net::SocketAddr::V4(_) => "0.0.0.0:0", + std::net::SocketAddr::V6(_) => "[::]:0", + }; + let udp_sk = std::net::UdpSocket::bind(bind_addr)?; + udp_sk.set_nonblocking(true)?; + if mark_socket(udp_sk.as_raw_fd(), mark).is_err() { + warn!("Mark socket failed, is it a test?"); + } + udp_sk.connect(peer_addr)?; + + trace!("connecting to {:} from {:}", peer_addr, udp_sk.local_addr()?); + Ok(udp_sk) +} + +fn create_quiche_config(cert_path: Option<&str>) -> Result<quiche::Config> { + let mut config = quiche::Config::new(quiche::PROTOCOL_VERSION)?; + config.set_application_protos(h3::APPLICATION_PROTOCOL)?; + match cert_path { + Some(path) => { + config.verify_peer(true); + config.load_verify_locations_from_directory(path)?; + } + None => config.verify_peer(false), + } + + // Some of these configs are necessary, or the server can't respond the HTTP/3 request. + config.set_max_idle_timeout(QUICHE_IDLE_TIMEOUT_MS); + config.set_max_recv_udp_payload_size(MAX_DATAGRAM_SIZE); + config.set_initial_max_data(MAX_INCOMING_BUFFER_SIZE_WHOLE); + config.set_initial_max_stream_data_bidi_local(MAX_INCOMING_BUFFER_SIZE_EACH); + config.set_initial_max_stream_data_bidi_remote(MAX_INCOMING_BUFFER_SIZE_EACH); + config.set_initial_max_stream_data_uni(MAX_INCOMING_BUFFER_SIZE_EACH); + config.set_initial_max_streams_bidi(MAX_CONCURRENT_STREAM_SIZE); + config.set_initial_max_streams_uni(MAX_CONCURRENT_STREAM_SIZE); + config.set_disable_active_migration(true); + Ok(config) +} + +fn mark_socket(fd: RawFd, mark: u32) -> Result<()> { + // libc::setsockopt is a wrapper function calling into bionic setsockopt. + // Both fd and mark are valid, which makes the function call mostly safe. + if unsafe { + libc::setsockopt( + fd, + libc::SOL_SOCKET, + libc::SO_MARK, + &mark as *const _ as *const libc::c_void, + std::mem::size_of::<u32>() as libc::socklen_t, + ) + } == 0 + { + Ok(()) + } else { + Err(anyhow::Error::new(std::io::Error::last_os_error())) + } +} + +#[rustfmt::skip] +fn make_probe_query() -> Result<String> { + let mut rnd = [0; 8]; + ring::rand::SystemRandom::new().fill(&mut rnd).context("failed to generate probe rnd")?; + let c = |byte| CHARSET[(byte as usize) % CHARSET.len()]; + let query = vec![ + rnd[6], rnd[7], // [0-1] query ID + 1, 0, // [2-3] flags; query[2] = 1 for recursion desired (RD). + 0, 1, // [4-5] QDCOUNT (number of queries) + 0, 0, // [6-7] ANCOUNT (number of answers) + 0, 0, // [8-9] NSCOUNT (number of name server records) + 0, 0, // [10-11] ARCOUNT (number of additional records) + 19, c(rnd[0]), c(rnd[1]), c(rnd[2]), c(rnd[3]), c(rnd[4]), c(rnd[5]), b'-', b'd', b'n', + b's', b'o', b'h', b't', b't', b'p', b's', b'-', b'd', b's', + 6, b'm', b'e', b't', b'r', b'i', b'c', 7, b'g', b's', + b't', b'a', b't', b'i', b'c', 3, b'c', b'o', b'm', + 0, // null terminator of FQDN (root TLD) + 0, NS_T_AAAA, // QTYPE + 0, NS_C_IN // QCLASS + ]; + Ok(base64::encode_config(query, base64::URL_SAFE_NO_PAD)) +} + +#[cfg(test)] +mod tests { + use super::*; + use quiche::h3::NameValue; + use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; + + const TEST_NET_ID: u32 = 50; + const PROBE_QUERY_SIZE: usize = 56; + const H3_DNS_REQUEST_HEADER_SIZE: usize = 6; + const TEST_MARK: u32 = 0xD0033; + const LOOPBACK_ADDR: &str = "127.0.0.1:443"; + const LOCALHOST_URL: &str = "https://mylocal.com/dns-query"; + + // TODO: Make some tests for DohConnection and QuicheConfigCache. + + fn make_testing_variables() -> ( + ServerInfo, + HashMap<u32, (ServerInfo, Option<DohConnection>)>, + Arc<Mutex<QuicheConfigCache>>, + Arc<Runtime>, + ) { + let test_map: HashMap<u32, (ServerInfo, Option<DohConnection>)> = HashMap::new(); + let info = ServerInfo { + net_id: TEST_NET_ID, + url: Url::parse(LOCALHOST_URL).unwrap(), + peer_addr: LOOPBACK_ADDR.parse().unwrap(), + domain: None, + sk_mark: 0, + cert_path: None, + }; + let config_cache = + Arc::new(Mutex::new(QuicheConfigCache { cert_path: None, config: None })); + + let rt = Arc::new( + Builder::new_current_thread() + .thread_name("test-runtime") + .enable_all() + .build() + .expect("Failed to create testing tokio runtime"), + ); + (info, test_map, config_cache, rt) + } + + extern "C" fn tag_socket_cb(sock: int32_t) { + assert!(sock >= 0); + } + + #[test] + fn make_connection_if_needed() { + let (info, mut test_map, config, rt) = make_testing_variables(); + rt.block_on(async { + // Expect to make a new connection. + let mut doh = super::make_connection_if_needed( + &info, + &mut test_map, + config.clone(), + tag_socket_cb, + ) + .unwrap() + .unwrap(); + assert_eq!(doh.info.net_id, info.net_id); + assert!(matches!(doh.state, ConnectionState::Idle)); + doh.state = ConnectionState::Error; + test_map.insert(info.net_id, (info.clone(), Some(doh))); + // Expect that we will get a connection with fail status that we added to the map before. + let mut doh = super::make_connection_if_needed( + &info, + &mut test_map, + config.clone(), + tag_socket_cb, + ) + .unwrap() + .unwrap(); + assert_eq!(doh.info.net_id, info.net_id); + assert!(matches!(doh.state, ConnectionState::Error)); + doh.state = make_dummy_connected_state(); + test_map.insert(info.net_id, (info.clone(), Some(doh))); + // Expect that we will get None because the map contains a connection with ready status. + assert!(super::make_connection_if_needed( + &info, + &mut test_map, + config.clone(), + tag_socket_cb + ) + .unwrap() + .is_none()); + }); + } + + #[test] + fn handle_query_cmd() { + let (info, mut test_map, config, rt) = make_testing_variables(); + let t = Duration::from_millis(100); + + rt.block_on(async { + // Test no available server cases. + let (resp_tx, resp_rx) = oneshot::channel(); + let query = super::make_probe_query().unwrap(); + super::handle_query_cmd( + info.net_id, + query.clone(), + Instant::now().checked_add(t).unwrap(), + resp_tx, + &mut test_map, + ) + .await; + assert_eq!( + timeout(t, resp_rx).await.unwrap().unwrap(), + Response::Error { error: QueryError::ServerNotReady } + ); + + let (resp_tx, resp_rx) = oneshot::channel(); + test_map.insert(info.net_id, (info.clone(), None)); + super::handle_query_cmd( + info.net_id, + query.clone(), + Instant::now().checked_add(t).unwrap(), + resp_tx, + &mut test_map, + ) + .await; + assert_eq!( + timeout(t, resp_rx).await.unwrap().unwrap(), + Response::Error { error: QueryError::ServerNotReady } + ); + + // Test the connection broken case. + test_map.clear(); + let (resp_tx, resp_rx) = oneshot::channel(); + let mut doh = super::make_connection_if_needed( + &info, + &mut test_map, + config.clone(), + tag_socket_cb, + ) + .unwrap() + .unwrap(); + doh.state = ConnectionState::Error; + test_map.insert(info.net_id, (info.clone(), Some(doh))); + super::handle_query_cmd( + info.net_id, + query.clone(), + Instant::now().checked_add(t).unwrap(), + resp_tx, + &mut test_map, + ) + .await; + assert_eq!( + timeout(t, resp_rx).await.unwrap().unwrap(), + Response::Error { error: QueryError::BrokenServer } + ); + }); + } + + extern "C" fn success_cb( + net_id: uint32_t, + success: bool, + ip_addr: *const c_char, + host: *const c_char, + ) { + assert!(success); + unsafe { + assert_validation_info(net_id, ip_addr, host); + } + } + + extern "C" fn fail_cb( + net_id: uint32_t, + success: bool, + ip_addr: *const c_char, + host: *const c_char, + ) { + assert!(!success); + unsafe { + assert_validation_info(net_id, ip_addr, host); + } + } + + // # Safety + // `ip_addr`, `host` are null terminated strings + unsafe fn assert_validation_info( + net_id: uint32_t, + ip_addr: *const c_char, + host: *const c_char, + ) { + assert_eq!(net_id, TEST_NET_ID); + let ip_addr = std::ffi::CStr::from_ptr(ip_addr).to_str().unwrap(); + let expected_addr: SocketAddr = LOOPBACK_ADDR.parse().unwrap(); + assert_eq!(ip_addr, expected_addr.ip().to_string()); + let host = std::ffi::CStr::from_ptr(host).to_str().unwrap(); + assert_eq!(host, ""); + } + + fn make_testing_connection_variables() -> (Pin<Box<quiche::Connection>>, UdpSocket) { + let sk = super::make_doh_udp_socket(LOOPBACK_ADDR.parse().unwrap(), TEST_MARK).unwrap(); + let udp_sk = UdpSocket::from_std(sk).unwrap(); + let mut scid = [0; quiche::MAX_CONN_ID_LEN]; + ring::rand::SystemRandom::new().fill(&mut scid).context("failed to generate scid").unwrap(); + let connid = quiche::ConnectionId::from_ref(&scid); + let mut config = super::create_quiche_config(None).unwrap(); + let quic_conn = + quiche::connect(None, &connid, LOOPBACK_ADDR.parse().unwrap(), &mut config).unwrap(); + (quic_conn, udp_sk) + } + + fn make_dummy_connected_state() -> super::ConnectionState { + let (quic_conn, udp_sk) = make_testing_connection_variables(); + ConnectionState::Connected { + quic_conn, + udp_sk, + h3_conn: None, + query_map: HashMap::new(), + expired_time: None, + } + } + + fn make_dummy_connecting_state() -> super::ConnectionState { + let (quic_conn, udp_sk) = make_testing_connection_variables(); + ConnectionState::Connecting { + quic_conn: Some(quic_conn), + udp_sk: Some(udp_sk), + expired_time: None, + } + } + + #[test] + fn report_private_dns_validation() { + let info = ServerInfo { + net_id: TEST_NET_ID, + url: Url::parse(LOCALHOST_URL).unwrap(), + peer_addr: LOOPBACK_ADDR.parse().unwrap(), + domain: None, + sk_mark: 0, + cert_path: None, + }; + let rt = Arc::new( + Builder::new_current_thread() + .thread_name("test-runtime") + .enable_io() + .build() + .expect("Failed to create testing tokio runtime"), + ); + let default_panic = std::panic::take_hook(); + // Exit the test if the worker inside tokio runtime panicked. + std::panic::set_hook(Box::new(move |info| { + default_panic(info); + std::process::exit(1); + })); + rt.block_on(async { + super::report_private_dns_validation( + &info, + &make_dummy_connected_state(), + rt.clone(), + success_cb, + ); + super::report_private_dns_validation( + &info, + &ConnectionState::Error, + rt.clone(), + fail_cb, + ); + super::report_private_dns_validation( + &info, + &make_dummy_connecting_state(), + rt.clone(), + fail_cb, + ); + super::report_private_dns_validation( + &info, + &ConnectionState::Idle, + rt.clone(), + fail_cb, + ); + }); + } + + #[test] + fn make_probe_query_and_request() { + let probe_query = super::make_probe_query().unwrap(); + let url = Url::parse(LOCALHOST_URL).unwrap(); + let request = make_dns_request(&probe_query, &url).unwrap(); + // Verify H3 DNS request. + assert_eq!(request.len(), H3_DNS_REQUEST_HEADER_SIZE); + assert_eq!(request[0].name(), b":method"); + assert_eq!(request[0].value(), b"GET"); + assert_eq!(request[1].name(), b":scheme"); + assert_eq!(request[1].value(), b"https"); + assert_eq!(request[2].name(), b":authority"); + assert_eq!(request[2].value(), url.host_str().unwrap().as_bytes()); + assert_eq!(request[3].name(), b":path"); + let mut path = String::from(url.path()); + path.push_str("?dns="); + path.push_str(&probe_query); + assert_eq!(request[3].value(), path.as_bytes()); + assert_eq!(request[5].name(), b"accept"); + assert_eq!(request[5].value(), b"application/dns-message"); + + // Verify DNS probe packet. + let bytes = base64::decode_config(probe_query, base64::URL_SAFE_NO_PAD).unwrap(); + assert_eq!(bytes.len(), PROBE_QUERY_SIZE); + // TODO: Parse the result to ensure it's a valid DNS packet. + } + + #[test] + fn create_quiche_config() { + assert!( + super::create_quiche_config(None).is_ok(), + "quiche config without cert creating failed" + ); + assert!( + super::create_quiche_config(Some("data/local/tmp/")).is_ok(), + "quiche config with cert creating failed" + ); + } + + #[test] + fn make_doh_udp_socket() { + // Make a socket connecting to loopback with a test mark. + let sk = super::make_doh_udp_socket(LOOPBACK_ADDR.parse().unwrap(), TEST_MARK).unwrap(); + // Check if the socket is connected to loopback. + assert_eq!( + sk.peer_addr().unwrap(), + SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), DOH_PORT)) + ); + + // Check if the socket mark is correct. + let fd: RawFd = sk.as_raw_fd(); + + let mut mark: u32 = 50; + let mut size = std::mem::size_of::<u32>() as libc::socklen_t; + unsafe { + // Safety: It's fine since the fd belongs to this test. + assert_eq!( + libc::getsockopt( + fd, + libc::SOL_SOCKET, + libc::SO_MARK, + &mut mark as *mut _ as *mut libc::c_void, + &mut size as *mut libc::socklen_t, + ), + 0 + ); + } + assert_eq!(mark, TEST_MARK); + + // Check if the socket is non-blocking. + unsafe { + // Safety: It's fine since the fd belongs to this test. + assert_eq!(libc::fcntl(fd, libc::F_GETFL, 0) & libc::O_NONBLOCK, libc::O_NONBLOCK); + } + } +} diff --git a/doh/ffi.rs b/doh/ffi.rs new file mode 100644 index 00000000..1202337d --- /dev/null +++ b/doh/ffi.rs @@ -0,0 +1,249 @@ +/* + * Copyright (C) 2021 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +//! C API for the DoH backend for the Android DnsResolver module. + +use libc::{c_char, int32_t, size_t, ssize_t, uint32_t, uint64_t}; +use log::error; +use std::net::{IpAddr, SocketAddr}; +use std::str::FromStr; +use std::{ptr, slice}; +use tokio::runtime::Runtime; +use tokio::sync::oneshot; +use tokio::task; +use tokio::time::{timeout, Duration, Instant}; + +use super::{ + DohCommand, DohDispatcher, Response, ServerInfo, TagSocketCallback, ValidationCallback, + DOH_PORT, +}; + +const SYSTEM_CERT_PATH: &str = "/system/etc/security/cacerts"; + +/// The return code of doh_query means that there is no answer. +pub const RESULT_INTERNAL_ERROR: ssize_t = -1; +/// The return code of doh_query means that query can't be sent. +pub const RESULT_CAN_NOT_SEND: ssize_t = -2; +/// The return code of doh_query to indicate that the query timed out. +pub const RESULT_TIMEOUT: ssize_t = -255; +/// The error log level. +pub const LOG_LEVEL_ERROR: u32 = 0; +/// The warning log level. +pub const LOG_LEVEL_WARN: u32 = 1; +/// The info log level. +pub const LOG_LEVEL_INFO: u32 = 2; +/// The debug log level. +pub const LOG_LEVEL_DEBUG: u32 = 3; +/// The trace log level. +pub const LOG_LEVEL_TRACE: u32 = 4; + +/// Performs static initialization for android logger. +#[no_mangle] +pub extern "C" fn doh_init_logger(level: u32) { + let level = match level { + LOG_LEVEL_WARN => log::Level::Warn, + LOG_LEVEL_DEBUG => log::Level::Debug, + _ => log::Level::Error, + }; + android_logger::init_once(android_logger::Config::default().with_min_level(level)); +} + +/// Set the log level. +#[no_mangle] +pub extern "C" fn doh_set_log_level(level: u32) { + let level = match level { + LOG_LEVEL_ERROR => log::LevelFilter::Error, + LOG_LEVEL_WARN => log::LevelFilter::Warn, + LOG_LEVEL_INFO => log::LevelFilter::Info, + LOG_LEVEL_DEBUG => log::LevelFilter::Debug, + LOG_LEVEL_TRACE => log::LevelFilter::Trace, + _ => log::LevelFilter::Off, + }; + log::set_max_level(level); +} + +/// Performs the initialization for the DoH engine. +/// Creates and returns a DoH engine instance. +#[no_mangle] +pub extern "C" fn doh_dispatcher_new( + validation_fn: ValidationCallback, + tag_socket_fn: TagSocketCallback, +) -> *mut DohDispatcher { + match DohDispatcher::new(validation_fn, tag_socket_fn) { + Ok(c) => Box::into_raw(c), + Err(e) => { + error!("doh_dispatcher_new: failed: {:?}", e); + ptr::null_mut() + } + } +} + +/// Deletes a DoH engine created by doh_dispatcher_new(). +/// # Safety +/// `doh` must be a non-null pointer previously created by `doh_dispatcher_new()` +/// and not yet deleted by `doh_dispatcher_delete()`. +#[no_mangle] +pub unsafe extern "C" fn doh_dispatcher_delete(doh: *mut DohDispatcher) { + Box::from_raw(doh).exit_handler() +} + +/// Probes and stores the DoH server with the given configurations. +/// Use the negative errno-style codes as the return value to represent the result. +/// # Safety +/// `doh` must be a non-null pointer previously created by `doh_dispatcher_new()` +/// and not yet deleted by `doh_dispatcher_delete()`. +/// `url`, `domain`, `ip_addr`, `cert_path` are null terminated strings. +#[no_mangle] +pub unsafe extern "C" fn doh_net_new( + doh: &mut DohDispatcher, + net_id: uint32_t, + url: *const c_char, + domain: *const c_char, + ip_addr: *const c_char, + sk_mark: libc::uint32_t, + cert_path: *const c_char, + timeout_ms: libc::uint64_t, +) -> int32_t { + let (url, domain, ip_addr, cert_path) = match ( + std::ffi::CStr::from_ptr(url).to_str(), + std::ffi::CStr::from_ptr(domain).to_str(), + std::ffi::CStr::from_ptr(ip_addr).to_str(), + std::ffi::CStr::from_ptr(cert_path).to_str(), + ) { + (Ok(url), Ok(domain), Ok(ip_addr), Ok(cert_path)) => { + if domain.is_empty() { + (url, None, ip_addr.to_string(), None) + } else if !cert_path.is_empty() { + (url, Some(domain.to_string()), ip_addr.to_string(), Some(cert_path.to_string())) + } else { + ( + url, + Some(domain.to_string()), + ip_addr.to_string(), + Some(SYSTEM_CERT_PATH.to_string()), + ) + } + } + _ => { + error!("bad input"); // Should not happen + return -libc::EINVAL; + } + }; + + let (url, ip_addr) = match (url::Url::parse(url), IpAddr::from_str(&ip_addr)) { + (Ok(url), Ok(ip_addr)) => (url, ip_addr), + _ => { + error!("bad ip or url"); // Should not happen + return -libc::EINVAL; + } + }; + let cmd = DohCommand::Probe { + info: ServerInfo { + net_id, + url, + peer_addr: SocketAddr::new(ip_addr, DOH_PORT), + domain, + sk_mark, + cert_path, + }, + timeout: Duration::from_millis(timeout_ms), + }; + if let Err(e) = doh.send_cmd(cmd) { + error!("Failed to send the probe: {:?}", e); + return -libc::EPIPE; + } + 0 +} + +/// Sends a DNS query via the network associated to the given |net_id| and waits for the response. +/// The return code should be either one of the public constant RESULT_* to indicate the error or +/// the size of the answer. +/// # Safety +/// `doh` must be a non-null pointer previously created by `doh_dispatcher_new()` +/// and not yet deleted by `doh_dispatcher_delete()`. +/// `dns_query` must point to a buffer at least `dns_query_len` in size. +/// `response` must point to a buffer at least `response_len` in size. +#[no_mangle] +pub unsafe extern "C" fn doh_query( + doh: &mut DohDispatcher, + net_id: uint32_t, + dns_query: *mut u8, + dns_query_len: size_t, + response: *mut u8, + response_len: size_t, + timeout_ms: uint64_t, +) -> ssize_t { + let q = slice::from_raw_parts_mut(dns_query, dns_query_len); + + let (resp_tx, resp_rx) = oneshot::channel(); + let t = Duration::from_millis(timeout_ms); + if let Some(expired_time) = Instant::now().checked_add(t) { + let cmd = DohCommand::Query { + net_id, + base64_query: base64::encode_config(q, base64::URL_SAFE_NO_PAD), + expired_time, + resp: resp_tx, + }; + + if let Err(e) = doh.send_cmd(cmd) { + error!("Failed to send the query: {:?}", e); + return RESULT_CAN_NOT_SEND; + } + } else { + error!("Bad timeout parameter: {}", timeout_ms); + return RESULT_CAN_NOT_SEND; + } + + if let Ok(rt) = Runtime::new() { + let local = task::LocalSet::new(); + match local.block_on(&rt, async { timeout(t, resp_rx).await }) { + Ok(v) => match v { + Ok(v) => match v { + Response::Success { answer } => { + if answer.len() > response_len || answer.len() > isize::MAX as usize { + return RESULT_INTERNAL_ERROR; + } + let response = slice::from_raw_parts_mut(response, answer.len()); + response.copy_from_slice(&answer); + answer.len() as ssize_t + } + _ => RESULT_CAN_NOT_SEND, + }, + Err(e) => { + error!("no result {}", e); + RESULT_CAN_NOT_SEND + } + }, + Err(e) => { + error!("timeout: {}", e); + RESULT_TIMEOUT + } + } + } else { + RESULT_CAN_NOT_SEND + } +} + +/// Clears the DoH servers associated with the given |netid|. +/// # Safety +/// `doh` must be a non-null pointer previously created by `doh_dispatcher_new()` +/// and not yet deleted by `doh_dispatcher_delete()`. +#[no_mangle] +pub extern "C" fn doh_net_delete(doh: &mut DohDispatcher, net_id: uint32_t) { + if let Err(e) = doh.send_cmd(DohCommand::Clear { net_id }) { + error!("Failed to send the query: {:?}", e); + } +} |