diff options
Diffstat (limited to 'doh/ffi.rs')
-rw-r--r-- | doh/ffi.rs | 54 |
1 files changed, 38 insertions, 16 deletions
@@ -35,8 +35,12 @@ use tokio::sync::oneshot; use tokio::task; use url::Url; -pub type ValidationCallback = - extern "C" fn(net_id: uint32_t, success: bool, ip_addr: *const c_char, host: *const c_char); +pub type ValidationCallback = unsafe extern "C" fn( + net_id: uint32_t, + success: bool, + ip_addr: *const c_char, + host: *const c_char, +); pub type TagSocketCallback = extern "C" fn(sock: RawFd); #[repr(C)] @@ -61,7 +65,9 @@ fn wrap_validation_callback(validation_fn: ValidationCallback) -> ValidationRepo } }; let netd_id = info.net_id; - task::spawn_blocking(move || { + // SAFETY: The string pointers are obtained from `CString`, so they must be valid C + // strings. + task::spawn_blocking(move || unsafe { validation_fn(netd_id, success, ip_addr.as_ptr(), domain.as_ptr()) }) .await @@ -167,12 +173,16 @@ pub extern "C" fn doh_dispatcher_new( } /// Deletes a DoH engine created by doh_dispatcher_new(). +/// /// # Safety +/// /// `doh` must be a non-null pointer previously created by `doh_dispatcher_new()` /// and not yet deleted by `doh_dispatcher_delete()`. #[no_mangle] pub unsafe extern "C" fn doh_dispatcher_delete(doh: *mut DohDispatcher) { - Box::from_raw(doh).lock().exit_handler() + // SAFETY: The caller guarantees that `doh` was created by `doh_dispatcher_new` (which does so + // using `Box::into_raw`), and that it hasn't yet been deleted by this function. + unsafe { Box::from_raw(doh) }.lock().exit_handler() } /// Probes and stores the DoH server with the given configurations. @@ -194,12 +204,15 @@ pub unsafe extern "C" fn doh_net_new( network_type: uint32_t, private_dns_mode: uint32_t, ) -> int32_t { - let (url, domain, ip_addr, cert_path) = match ( - std::ffi::CStr::from_ptr(url).to_str(), - std::ffi::CStr::from_ptr(domain).to_str(), - std::ffi::CStr::from_ptr(ip_addr).to_str(), - std::ffi::CStr::from_ptr(cert_path).to_str(), - ) { + // SAFETY: The caller guarantees that these are all valid nul-terminated C strings. + let (url, domain, ip_addr, cert_path) = match unsafe { + ( + std::ffi::CStr::from_ptr(url).to_str(), + std::ffi::CStr::from_ptr(domain).to_str(), + std::ffi::CStr::from_ptr(ip_addr).to_str(), + std::ffi::CStr::from_ptr(cert_path).to_str(), + ) + } { (Ok(url), Ok(domain), Ok(ip_addr), Ok(cert_path)) => { if domain.is_empty() { (url, None, ip_addr.to_string(), None) @@ -268,7 +281,9 @@ pub unsafe extern "C" fn doh_query( response_len: size_t, timeout_ms: uint64_t, ) -> ssize_t { - let q = slice::from_raw_parts_mut(dns_query, dns_query_len); + // SAFETY: The caller guarantees that `dns_query` is a valid pointer to a buffer of at least + // `dns_query_len` items. + let q = unsafe { slice::from_raw_parts_mut(dns_query, dns_query_len) }; let (resp_tx, resp_rx) = oneshot::channel(); let t = Duration::from_millis(timeout_ms); @@ -298,7 +313,10 @@ pub unsafe extern "C" fn doh_query( if answer.len() > response_len || answer.len() > isize::MAX as usize { return DOH_RESULT_INTERNAL_ERROR; } - let response = slice::from_raw_parts_mut(response, answer.len()); + // SAFETY: The caller guarantees that response points to a valid buffer at + // least `response_len` long, and we just checked that `answer.len()` is no + // longer than `response_len`. + let response = unsafe { slice::from_raw_parts_mut(response, answer.len()) }; response.copy_from_slice(&answer); answer.len() as ssize_t } @@ -341,25 +359,27 @@ mod tests { const LOOPBACK_ADDR: &str = "127.0.0.1:443"; const LOCALHOST_URL: &str = "https://mylocal.com/dns-query"; - extern "C" fn success_cb( + unsafe extern "C" fn success_cb( net_id: uint32_t, success: bool, ip_addr: *const c_char, host: *const c_char, ) { assert!(success); + // SAFETY: The caller guarantees that ip_addr and host are valid nul-terminated C strings. unsafe { assert_validation_info(net_id, ip_addr, host); } } - extern "C" fn fail_cb( + unsafe extern "C" fn fail_cb( net_id: uint32_t, success: bool, ip_addr: *const c_char, host: *const c_char, ) { assert!(!success); + // SAFETY: The caller guarantees that ip_addr and host are valid nul-terminated C strings. unsafe { assert_validation_info(net_id, ip_addr, host); } @@ -373,10 +393,12 @@ mod tests { host: *const c_char, ) { assert_eq!(net_id, TEST_NET_ID); - let ip_addr = std::ffi::CStr::from_ptr(ip_addr).to_str().unwrap(); + // SAFETY: The caller guarantees that `ip_addr` is a valid nul-terminated C string. + let ip_addr = unsafe { std::ffi::CStr::from_ptr(ip_addr) }.to_str().unwrap(); let expected_addr: SocketAddr = LOOPBACK_ADDR.parse().unwrap(); assert_eq!(ip_addr, expected_addr.ip().to_string()); - let host = std::ffi::CStr::from_ptr(host).to_str().unwrap(); + // SAFETY: The caller guarantees that `host` is a valid nul-terminated C string. + let host = unsafe { std::ffi::CStr::from_ptr(host) }.to_str().unwrap(); assert_eq!(host, ""); } |