aboutsummaryrefslogtreecommitdiff
path: root/doh/ffi.rs
diff options
context:
space:
mode:
Diffstat (limited to 'doh/ffi.rs')
-rw-r--r--doh/ffi.rs54
1 files changed, 38 insertions, 16 deletions
diff --git a/doh/ffi.rs b/doh/ffi.rs
index 2276f654..63b98cc8 100644
--- a/doh/ffi.rs
+++ b/doh/ffi.rs
@@ -35,8 +35,12 @@ use tokio::sync::oneshot;
use tokio::task;
use url::Url;
-pub type ValidationCallback =
- extern "C" fn(net_id: uint32_t, success: bool, ip_addr: *const c_char, host: *const c_char);
+pub type ValidationCallback = unsafe extern "C" fn(
+ net_id: uint32_t,
+ success: bool,
+ ip_addr: *const c_char,
+ host: *const c_char,
+);
pub type TagSocketCallback = extern "C" fn(sock: RawFd);
#[repr(C)]
@@ -61,7 +65,9 @@ fn wrap_validation_callback(validation_fn: ValidationCallback) -> ValidationRepo
}
};
let netd_id = info.net_id;
- task::spawn_blocking(move || {
+ // SAFETY: The string pointers are obtained from `CString`, so they must be valid C
+ // strings.
+ task::spawn_blocking(move || unsafe {
validation_fn(netd_id, success, ip_addr.as_ptr(), domain.as_ptr())
})
.await
@@ -167,12 +173,16 @@ pub extern "C" fn doh_dispatcher_new(
}
/// Deletes a DoH engine created by doh_dispatcher_new().
+///
/// # Safety
+///
/// `doh` must be a non-null pointer previously created by `doh_dispatcher_new()`
/// and not yet deleted by `doh_dispatcher_delete()`.
#[no_mangle]
pub unsafe extern "C" fn doh_dispatcher_delete(doh: *mut DohDispatcher) {
- Box::from_raw(doh).lock().exit_handler()
+ // SAFETY: The caller guarantees that `doh` was created by `doh_dispatcher_new` (which does so
+ // using `Box::into_raw`), and that it hasn't yet been deleted by this function.
+ unsafe { Box::from_raw(doh) }.lock().exit_handler()
}
/// Probes and stores the DoH server with the given configurations.
@@ -194,12 +204,15 @@ pub unsafe extern "C" fn doh_net_new(
network_type: uint32_t,
private_dns_mode: uint32_t,
) -> int32_t {
- let (url, domain, ip_addr, cert_path) = match (
- std::ffi::CStr::from_ptr(url).to_str(),
- std::ffi::CStr::from_ptr(domain).to_str(),
- std::ffi::CStr::from_ptr(ip_addr).to_str(),
- std::ffi::CStr::from_ptr(cert_path).to_str(),
- ) {
+ // SAFETY: The caller guarantees that these are all valid nul-terminated C strings.
+ let (url, domain, ip_addr, cert_path) = match unsafe {
+ (
+ std::ffi::CStr::from_ptr(url).to_str(),
+ std::ffi::CStr::from_ptr(domain).to_str(),
+ std::ffi::CStr::from_ptr(ip_addr).to_str(),
+ std::ffi::CStr::from_ptr(cert_path).to_str(),
+ )
+ } {
(Ok(url), Ok(domain), Ok(ip_addr), Ok(cert_path)) => {
if domain.is_empty() {
(url, None, ip_addr.to_string(), None)
@@ -268,7 +281,9 @@ pub unsafe extern "C" fn doh_query(
response_len: size_t,
timeout_ms: uint64_t,
) -> ssize_t {
- let q = slice::from_raw_parts_mut(dns_query, dns_query_len);
+ // SAFETY: The caller guarantees that `dns_query` is a valid pointer to a buffer of at least
+ // `dns_query_len` items.
+ let q = unsafe { slice::from_raw_parts_mut(dns_query, dns_query_len) };
let (resp_tx, resp_rx) = oneshot::channel();
let t = Duration::from_millis(timeout_ms);
@@ -298,7 +313,10 @@ pub unsafe extern "C" fn doh_query(
if answer.len() > response_len || answer.len() > isize::MAX as usize {
return DOH_RESULT_INTERNAL_ERROR;
}
- let response = slice::from_raw_parts_mut(response, answer.len());
+ // SAFETY: The caller guarantees that response points to a valid buffer at
+ // least `response_len` long, and we just checked that `answer.len()` is no
+ // longer than `response_len`.
+ let response = unsafe { slice::from_raw_parts_mut(response, answer.len()) };
response.copy_from_slice(&answer);
answer.len() as ssize_t
}
@@ -341,25 +359,27 @@ mod tests {
const LOOPBACK_ADDR: &str = "127.0.0.1:443";
const LOCALHOST_URL: &str = "https://mylocal.com/dns-query";
- extern "C" fn success_cb(
+ unsafe extern "C" fn success_cb(
net_id: uint32_t,
success: bool,
ip_addr: *const c_char,
host: *const c_char,
) {
assert!(success);
+ // SAFETY: The caller guarantees that ip_addr and host are valid nul-terminated C strings.
unsafe {
assert_validation_info(net_id, ip_addr, host);
}
}
- extern "C" fn fail_cb(
+ unsafe extern "C" fn fail_cb(
net_id: uint32_t,
success: bool,
ip_addr: *const c_char,
host: *const c_char,
) {
assert!(!success);
+ // SAFETY: The caller guarantees that ip_addr and host are valid nul-terminated C strings.
unsafe {
assert_validation_info(net_id, ip_addr, host);
}
@@ -373,10 +393,12 @@ mod tests {
host: *const c_char,
) {
assert_eq!(net_id, TEST_NET_ID);
- let ip_addr = std::ffi::CStr::from_ptr(ip_addr).to_str().unwrap();
+ // SAFETY: The caller guarantees that `ip_addr` is a valid nul-terminated C string.
+ let ip_addr = unsafe { std::ffi::CStr::from_ptr(ip_addr) }.to_str().unwrap();
let expected_addr: SocketAddr = LOOPBACK_ADDR.parse().unwrap();
assert_eq!(ip_addr, expected_addr.ip().to_string());
- let host = std::ffi::CStr::from_ptr(host).to_str().unwrap();
+ // SAFETY: The caller guarantees that `host` is a valid nul-terminated C string.
+ let host = unsafe { std::ffi::CStr::from_ptr(host) }.to_str().unwrap();
assert_eq!(host, "");
}