diff options
author | Matthew Maurer <mmaurer@google.com> | 2021-10-05 16:13:18 -0700 |
---|---|---|
committer | Matthew Maurer <mmaurer@google.com> | 2021-10-07 12:36:17 -0700 |
commit | 373b7f6235ef82380fd2a6bddcacf39cc14316c5 (patch) | |
tree | ff9543705b90280a4bd49c27bb826c4ed34384fb /doh | |
parent | 6b17842f445b68b075f247b841285a36c502adda (diff) | |
download | DnsResolver-373b7f6235ef82380fd2a6bddcacf39cc14316c5.tar.gz |
DoH: Factor out BootTime
* Move `BootTime` into the `boot_time` module
* Change `elapsed()` to match `Instant` API
* Add timeout + sleep functions, relative to `CLOCK_BOOTTIME`
* Change everywhere in DoH to use `boot_time` instead of `time`
* Add tests for boot_time module
BYPASS_INCLUSIVE_LANGUAGE_REASON="man is referring to the unix manual command, not a person"
Bug: 202081046
Bug: 200694560
Change-Id: I719965ff75abb0223ba20829ca0a3a4be1d07f40
Diffstat (limited to 'doh')
-rw-r--r-- | doh/boot_time.rs | 206 | ||||
-rw-r--r-- | doh/doh.rs | 55 | ||||
-rw-r--r-- | doh/ffi.rs | 4 |
3 files changed, 222 insertions, 43 deletions
diff --git a/doh/boot_time.rs b/doh/boot_time.rs new file mode 100644 index 00000000..1f6f97df --- /dev/null +++ b/doh/boot_time.rs @@ -0,0 +1,206 @@ +/* + * Copyright (C) 2021 The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +//! This module provides a time hack to work around the broken `Instant` type in the standard +//! library. +//! +//! `BootTime` looks like `Instant`, but represents `CLOCK_BOOTTIME` instead of `CLOCK_MONOTONIC`. +//! This means the clock increments correctly during suspend. + +pub use std::time::Duration; + +use std::io; + +use futures::future::pending; +use std::convert::TryInto; +use std::fmt; +use std::future::Future; +use std::os::unix::io::{AsRawFd, RawFd}; +use tokio::io::unix::AsyncFd; +use tokio::select; + +/// Represents a moment in time, with differences including time spent in suspend. Only valid for +/// a single boot - numbers from different boots are incomparable. +#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct BootTime { + d: Duration, +} + +// Return an error with the same structure as tokio::time::timeout to facilitate migration off it, +// and hopefully some day back to it. +/// Error returned by timeout +#[derive(Debug, PartialEq)] +pub struct Elapsed(()); + +impl fmt::Display for Elapsed { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + "deadline has elapsed".fmt(fmt) + } +} + +impl std::error::Error for Elapsed {} + +impl BootTime { + /// Gets a `BootTime` representing the current moment in time. + pub fn now() -> BootTime { + let mut t = libc::timespec { tv_sec: 0, tv_nsec: 0 }; + // # Safety + // clock_gettime's only action will be to possibly write to the pointer provided, + // and no borrows exist from that object other than the &mut used to construct the pointer + // itself. + if unsafe { libc::clock_gettime(libc::CLOCK_BOOTTIME, &mut t as *mut libc::timespec) } != 0 + { + panic!( + "libc::clock_gettime(libc::CLOCK_BOOTTIME) failed: {:?}", + io::Error::last_os_error() + ); + } + BootTime { d: Duration::new(t.tv_sec as u64, t.tv_nsec as u32) } + } + + /// Determines how long has elapsed since the provided `BootTime`. + pub fn elapsed(&self) -> Duration { + BootTime::now().checked_duration_since(*self).unwrap() + } + + /// Add a specified time delta to a moment in time. If this would overflow the representation, + /// returns `None`. + pub fn checked_add(&self, duration: Duration) -> Option<BootTime> { + Some(BootTime { d: self.d.checked_add(duration)? }) + } + + /// Finds the difference from an earlier point in time. If the provided time is later, returns + /// `None`. + pub fn checked_duration_since(&self, earlier: BootTime) -> Option<Duration> { + self.d.checked_sub(earlier.d) + } +} + +struct TimerFd(RawFd); + +impl Drop for TimerFd { + fn drop(&mut self) { + // # Safety + // The fd is owned by the TimerFd struct, and no memory access occurs as a result of this + // call. + unsafe { + libc::close(self.0); + } + } +} + +impl AsRawFd for TimerFd { + fn as_raw_fd(&self) -> RawFd { + self.0 + } +} + +impl TimerFd { + fn create() -> io::Result<Self> { + // # Unsafe + // This libc call will either give us back a file descriptor or fail, it does not act on + // memory or resources. + let raw = unsafe { + libc::timerfd_create(libc::CLOCK_BOOTTIME, libc::TFD_NONBLOCK | libc::TFD_CLOEXEC) + }; + if raw < 0 { + return Err(io::Error::last_os_error()); + } + Ok(Self(raw)) + } + + fn set(&self, duration: Duration) { + let timer = libc::itimerspec { + it_interval: libc::timespec { tv_sec: 0, tv_nsec: 0 }, + it_value: libc::timespec { + tv_sec: duration.as_secs().try_into().unwrap(), + tv_nsec: duration.subsec_nanos().try_into().unwrap(), + }, + }; + // # Unsafe + // We own `timer` and there are no borrows to it other than the pointer we pass to + // timerfd_settime. timerfd_settime is explicitly documented to handle a null output + // parameter for its fourth argument by not filling out the output. The fd passed in at + // self.0 is owned by the `TimerFd` struct, so we aren't breaking anyone else's invariants. + if unsafe { libc::timerfd_settime(self.0, 0, &timer, std::ptr::null_mut()) } != 0 { + panic!("timerfd_settime failed: {:?}", io::Error::last_os_error()); + } + } +} + +/// Runs the provided future until completion or `duration` has passed on the `CLOCK_BOOTTIME` +/// clock. In the event of a timeout, returns the elapsed time as an error. +pub async fn timeout<T>(duration: Duration, future: impl Future<Output = T>) -> Result<T, Elapsed> { + // Ideally, all timeouts in a runtime would share a timerfd. That will be much more + // straightforwards to implement when moving this functionality into `tokio`. + + // The failure conditions for this are rare (see `man 2 timerfd_create`) and the caller would + // not be able to do much in response to them. When integrated into tokio, this would be called + // during runtime setup. + let timer_fd = TimerFd::create().unwrap(); + timer_fd.set(duration); + let async_fd = AsyncFd::new(timer_fd).unwrap(); + select! { + v = future => Ok(v), + _ = async_fd.readable() => Err(Elapsed(())), + } +} + +/// Provides a future which will complete once the provided duration has passed, as measured by the +/// `CLOCK_BOOTTIME` clock. +pub async fn sleep(duration: Duration) { + assert!(timeout(duration, pending::<()>()).await.is_err()); +} + +#[test] +fn monotonic_smoke() { + for _ in 0..1000 { + // If BootTime is not monotonic, .elapsed() will panic on the unwrap. + BootTime::now().elapsed(); + } +} + +#[test] +fn round_trip() { + use std::thread::sleep; + for _ in 0..10 { + let start = BootTime::now(); + sleep(Duration::from_millis(1)); + let end = BootTime::now(); + let delta = end.checked_duration_since(start).unwrap(); + assert_eq!(start.checked_add(delta).unwrap(), end); + } +} + +#[tokio::test] +async fn timeout_drift() { + let delta = Duration::from_millis(20); + for _ in 0..10 { + let start = BootTime::now(); + assert!(timeout(delta, pending::<()>()).await.is_err()); + let taken = start.elapsed(); + let drift = if taken > delta { taken - delta } else { delta - taken }; + assert!(drift < Duration::from_millis(5)); + } + + for _ in 0..10 { + let start = BootTime::now(); + sleep(delta).await; + let taken = start.elapsed(); + let drift = if taken > delta { taken - delta } else { delta - taken }; + assert!(drift < Duration::from_millis(5)); + } +} @@ -35,11 +35,13 @@ use tokio::net::UdpSocket; use tokio::runtime::{Builder, Runtime}; use tokio::sync::{mpsc, oneshot}; use tokio::task; -use tokio::time::{timeout, Duration, Instant}; use url::Url; +pub mod boot_time; mod ffi; +use boot_time::{timeout, BootTime, Duration}; + const MAX_BUFFERED_CMD_SIZE: usize = 400; const MAX_INCOMING_BUFFER_SIZE_WHOLE: u64 = 10000000; const MAX_INCOMING_BUFFER_SIZE_EACH: u64 = 1000000; @@ -91,7 +93,7 @@ enum Response { #[derive(Debug)] enum DohCommand { Probe { info: ServerInfo, timeout: Duration }, - Query { net_id: u32, base64_query: Base64Query, expired_time: Instant, resp: QueryResponder }, + Query { net_id: u32, base64_query: Base64Query, expired_time: BootTime, resp: QueryResponder }, Clear { net_id: u32 }, Exit, } @@ -132,35 +134,6 @@ impl<T: Deref> OptionDeref<T> for Option<T> { } } -#[derive(Copy, Clone, Debug)] -struct BootTime { - d: Duration, -} - -impl BootTime { - fn now() -> BootTime { - unsafe { - let mut t = libc::timespec { tv_sec: 0, tv_nsec: 0 }; - if libc::clock_gettime(libc::CLOCK_BOOTTIME, &mut t as *mut libc::timespec) != 0 { - panic!("get boot time failed: {:?}", std::io::Error::last_os_error()); - } - BootTime { d: Duration::new(t.tv_sec as u64, t.tv_nsec as u32) } - } - } - - fn elapsed(&self) -> Option<Duration> { - BootTime::now().duration_since(*self) - } - - fn checked_add(&self, duration: Duration) -> Option<BootTime> { - Some(BootTime { d: self.d.checked_add(duration)? }) - } - - fn duration_since(&self, earlier: BootTime) -> Option<Duration> { - self.d.checked_sub(earlier.d) - } -} - /// Context for a running DoH engine. pub struct DohDispatcher { /// Used to submit cmds to the I/O task. @@ -204,7 +177,7 @@ struct DohConnection { shared_config: Arc<Mutex<QuicheConfigCache>>, scid: SCID, state: ConnectionState, - pending_queries: Vec<(DnsRequest, QueryResponder, Instant)>, + pending_queries: Vec<(DnsRequest, QueryResponder, BootTime)>, cached_session: Option<Vec<u8>>, tag_socket_fn: TagSocketCallback, } @@ -333,7 +306,7 @@ impl DohConnection { }; if let Some(expired_time) = expired_time { - if let Some(elapsed) = expired_time.elapsed() { + if let Some(elapsed) = BootTime::now().checked_duration_since(*expired_time) { warn!( "Change the state to Idle due to connection timeout, {:?}, {}", elapsed, self.info.net_id @@ -429,7 +402,7 @@ impl DohConnection { &mut self, req: DnsRequest, resp: QueryResponder, - expired_time: Instant, + expired_time: BootTime, ) -> Result<()> { self.handle_if_connection_expired(); match &mut self.state { @@ -472,7 +445,7 @@ impl DohConnection { while !self.pending_queries.is_empty() { if let Some((req, resp, exp_time)) = self.pending_queries.pop() { // Ignore the expired queries. - if Instant::now().checked_duration_since(exp_time).is_some() { + if BootTime::now().checked_duration_since(exp_time).is_some() { warn!("Drop the obsolete query for network {}", self.info.net_id); continue; } @@ -596,9 +569,9 @@ async fn send_dns_query( udp_sk: &mut UdpSocket, h3_conn: &mut h3::Connection, query_map: &mut HashMap<u64, (Vec<u8>, QueryResponder)>, - pending_queries: &mut Vec<(DnsRequest, QueryResponder, Instant)>, + pending_queries: &mut Vec<(DnsRequest, QueryResponder, BootTime)>, resp: QueryResponder, - expired_time: Instant, + expired_time: BootTime, req: DnsRequest, ) -> Result<()> { if !quic_conn.is_established() { @@ -803,7 +776,7 @@ impl QuicheConfigCache { async fn handle_query_cmd( net_id: u32, base64_query: Base64Query, - expired_time: Instant, + expired_time: BootTime, resp: QueryResponder, doh_conn_map: &mut HashMap<u32, (ServerInfo, Option<DohConnection>)>, ) { @@ -1107,7 +1080,7 @@ mod tests { super::handle_query_cmd( info.net_id, query.clone(), - Instant::now().checked_add(t).unwrap(), + BootTime::now().checked_add(t).unwrap(), resp_tx, &mut test_map, ) @@ -1122,7 +1095,7 @@ mod tests { super::handle_query_cmd( info.net_id, query.clone(), - Instant::now().checked_add(t).unwrap(), + BootTime::now().checked_add(t).unwrap(), resp_tx, &mut test_map, ) @@ -1148,7 +1121,7 @@ mod tests { super::handle_query_cmd( info.net_id, query.clone(), - Instant::now().checked_add(t).unwrap(), + BootTime::now().checked_add(t).unwrap(), resp_tx, &mut test_map, ) @@ -16,6 +16,7 @@ //! C API for the DoH backend for the Android DnsResolver module. +use crate::boot_time::{timeout, BootTime, Duration}; use libc::{c_char, int32_t, size_t, ssize_t, uint32_t, uint64_t}; use log::error; use std::net::{IpAddr, SocketAddr}; @@ -26,7 +27,6 @@ use std::{ptr, slice}; use tokio::runtime::Runtime; use tokio::sync::oneshot; use tokio::task; -use tokio::time::{timeout, Duration, Instant}; use super::DohDispatcher as Dispatcher; use super::{DohCommand, Response, ServerInfo, TagSocketCallback, ValidationCallback, DOH_PORT}; @@ -205,7 +205,7 @@ pub unsafe extern "C" fn doh_query( let (resp_tx, resp_rx) = oneshot::channel(); let t = Duration::from_millis(timeout_ms); - if let Some(expired_time) = Instant::now().checked_add(t) { + if let Some(expired_time) = BootTime::now().checked_add(t) { let cmd = DohCommand::Query { net_id, base64_query: base64::encode_config(q, base64::URL_SAFE_NO_PAD), |