403 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Rust
		
	
	
	
			
		
		
	
	
			403 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Rust
		
	
	
	
| /*
 | |
|  * Copyright (C) 2021 The Android Open Source Project
 | |
|  *
 | |
|  * Licensed under the Apache License, Version 2.0 (the "License");
 | |
|  * you may not use this file except in compliance with the License.
 | |
|  * You may obtain a copy of the License at
 | |
|  *
 | |
|  *      http://www.apache.org/licenses/LICENSE-2.0
 | |
|  *
 | |
|  * Unless required by applicable law or agreed to in writing, software
 | |
|  * distributed under the License is distributed on an "AS IS" BASIS,
 | |
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
|  * See the License for the specific language governing permissions and
 | |
|  * limitations under the License.
 | |
|  */
 | |
| 
 | |
| //! C API for the DoH backend for the Android DnsResolver module.
 | |
| 
 | |
| use crate::boot_time::{timeout, BootTime, Duration};
 | |
| use crate::dispatcher::{Command, Dispatcher, Response, ServerInfo};
 | |
| use crate::network::{SocketTagger, ValidationReporter};
 | |
| use futures::FutureExt;
 | |
| use libc::{c_char, int32_t, size_t, ssize_t, uint32_t, uint64_t};
 | |
| use log::{error, warn};
 | |
| use std::ffi::CString;
 | |
| use std::net::{IpAddr, SocketAddr};
 | |
| use std::ops::DerefMut;
 | |
| use std::os::unix::io::RawFd;
 | |
| use std::str::FromStr;
 | |
| use std::sync::{Arc, Mutex};
 | |
| use std::{ptr, slice};
 | |
| use tokio::runtime::Builder;
 | |
| use tokio::sync::oneshot;
 | |
| use tokio::task;
 | |
| use url::Url;
 | |
| 
 | |
| pub type ValidationCallback =
 | |
|     extern "C" fn(net_id: uint32_t, success: bool, ip_addr: *const c_char, host: *const c_char);
 | |
| pub type TagSocketCallback = extern "C" fn(sock: RawFd);
 | |
| 
 | |
| #[repr(C)]
 | |
| pub struct FeatureFlags {
 | |
|     probe_timeout_ms: uint64_t,
 | |
|     idle_timeout_ms: uint64_t,
 | |
|     use_session_resumption: bool,
 | |
| }
 | |
| 
 | |
| fn wrap_validation_callback(validation_fn: ValidationCallback) -> ValidationReporter {
 | |
|     Arc::new(move |info: &ServerInfo, success: bool| {
 | |
|         async move {
 | |
|             let (ip_addr, domain) = match (
 | |
|                 CString::new(info.peer_addr.ip().to_string()),
 | |
|                 CString::new(info.domain.clone().unwrap_or_default()),
 | |
|             ) {
 | |
|                 (Ok(ip_addr), Ok(domain)) => (ip_addr, domain),
 | |
|                 _ => {
 | |
|                     error!("validation_callback bad input");
 | |
|                     return;
 | |
|                 }
 | |
|             };
 | |
|             let netd_id = info.net_id;
 | |
|             task::spawn_blocking(move || {
 | |
|                 validation_fn(netd_id, success, ip_addr.as_ptr(), domain.as_ptr())
 | |
|             })
 | |
|             .await
 | |
|             .unwrap_or_else(|e| warn!("Validation function task failed: {}", e))
 | |
|         }
 | |
|         .boxed()
 | |
|     })
 | |
| }
 | |
| 
 | |
| fn wrap_tag_socket_callback(tag_socket_fn: TagSocketCallback) -> SocketTagger {
 | |
|     use std::os::unix::io::AsRawFd;
 | |
|     Arc::new(move |udp_socket: &std::net::UdpSocket| {
 | |
|         let fd = udp_socket.as_raw_fd();
 | |
|         async move {
 | |
|             task::spawn_blocking(move || {
 | |
|                 tag_socket_fn(fd);
 | |
|             })
 | |
|             .await
 | |
|             .unwrap_or_else(|e| warn!("Socket tag function task failed: {}", e))
 | |
|         }
 | |
|         .boxed()
 | |
|     })
 | |
| }
 | |
| 
 | |
| pub struct DohDispatcher(Mutex<Dispatcher>);
 | |
| 
 | |
| impl DohDispatcher {
 | |
|     fn lock(&self) -> impl DerefMut<Target = Dispatcher> + '_ {
 | |
|         self.0.lock().unwrap()
 | |
|     }
 | |
| }
 | |
| 
 | |
| const SYSTEM_CERT_PATH: &str = "/system/etc/security/cacerts";
 | |
| 
 | |
| /// The return code of doh_query means that there is no answer.
 | |
| pub const DOH_RESULT_INTERNAL_ERROR: ssize_t = -1;
 | |
| /// The return code of doh_query means that query can't be sent.
 | |
| pub const DOH_RESULT_CAN_NOT_SEND: ssize_t = -2;
 | |
| /// The return code of doh_query to indicate that the query timed out.
 | |
| pub const DOH_RESULT_TIMEOUT: ssize_t = -255;
 | |
| 
 | |
| /// The error log level.
 | |
| pub const DOH_LOG_LEVEL_ERROR: u32 = 0;
 | |
| /// The warning log level.
 | |
| pub const DOH_LOG_LEVEL_WARN: u32 = 1;
 | |
| /// The info log level.
 | |
| pub const DOH_LOG_LEVEL_INFO: u32 = 2;
 | |
| /// The debug log level.
 | |
| pub const DOH_LOG_LEVEL_DEBUG: u32 = 3;
 | |
| /// The trace log level.
 | |
| pub const DOH_LOG_LEVEL_TRACE: u32 = 4;
 | |
| 
 | |
| const DOH_PORT: u16 = 443;
 | |
| 
 | |
| fn level_from_u32(level: u32) -> Option<log::Level> {
 | |
|     use log::Level::*;
 | |
|     match level {
 | |
|         DOH_LOG_LEVEL_ERROR => Some(Error),
 | |
|         DOH_LOG_LEVEL_WARN => Some(Warn),
 | |
|         DOH_LOG_LEVEL_INFO => Some(Info),
 | |
|         DOH_LOG_LEVEL_DEBUG => Some(Debug),
 | |
|         DOH_LOG_LEVEL_TRACE => Some(Trace),
 | |
|         _ => None,
 | |
|     }
 | |
| }
 | |
| 
 | |
| /// Performs static initialization for android logger.
 | |
| /// If an invalid level is passed, defaults to logging errors only.
 | |
| /// If called more than once, it will have no effect on subsequent calls.
 | |
| #[no_mangle]
 | |
| pub extern "C" fn doh_init_logger(level: u32) {
 | |
|     let log_level = level_from_u32(level).unwrap_or(log::Level::Error);
 | |
|     android_logger::init_once(android_logger::Config::default().with_min_level(log_level));
 | |
| }
 | |
| 
 | |
| /// Set the log level.
 | |
| /// If an invalid level is passed, defaults to logging errors only.
 | |
| #[no_mangle]
 | |
| pub extern "C" fn doh_set_log_level(level: u32) {
 | |
|     let level_filter = level_from_u32(level)
 | |
|         .map(|level| level.to_level_filter())
 | |
|         .unwrap_or(log::LevelFilter::Error);
 | |
|     log::set_max_level(level_filter);
 | |
| }
 | |
| 
 | |
| /// Performs the initialization for the DoH engine.
 | |
| /// Creates and returns a DoH engine instance.
 | |
| #[no_mangle]
 | |
| pub extern "C" fn doh_dispatcher_new(
 | |
|     validation_fn: ValidationCallback,
 | |
|     tag_socket_fn: TagSocketCallback,
 | |
| ) -> *mut DohDispatcher {
 | |
|     match Dispatcher::new(
 | |
|         wrap_validation_callback(validation_fn),
 | |
|         wrap_tag_socket_callback(tag_socket_fn),
 | |
|     ) {
 | |
|         Ok(c) => Box::into_raw(Box::new(DohDispatcher(Mutex::new(c)))),
 | |
|         Err(e) => {
 | |
|             error!("doh_dispatcher_new: failed: {:?}", e);
 | |
|             ptr::null_mut()
 | |
|         }
 | |
|     }
 | |
| }
 | |
| 
 | |
| /// Deletes a DoH engine created by doh_dispatcher_new().
 | |
| /// # Safety
 | |
| /// `doh` must be a non-null pointer previously created by `doh_dispatcher_new()`
 | |
| /// and not yet deleted by `doh_dispatcher_delete()`.
 | |
| #[no_mangle]
 | |
| pub unsafe extern "C" fn doh_dispatcher_delete(doh: *mut DohDispatcher) {
 | |
|     Box::from_raw(doh).lock().exit_handler()
 | |
| }
 | |
| 
 | |
| /// Probes and stores the DoH server with the given configurations.
 | |
| /// Use the negative errno-style codes as the return value to represent the result.
 | |
| /// # Safety
 | |
| /// `doh` must be a non-null pointer previously created by `doh_dispatcher_new()`
 | |
| /// and not yet deleted by `doh_dispatcher_delete()`.
 | |
| /// `url`, `domain`, `ip_addr`, `cert_path` are null terminated strings.
 | |
| #[no_mangle]
 | |
| pub unsafe extern "C" fn doh_net_new(
 | |
|     doh: &DohDispatcher,
 | |
|     net_id: uint32_t,
 | |
|     url: *const c_char,
 | |
|     domain: *const c_char,
 | |
|     ip_addr: *const c_char,
 | |
|     sk_mark: libc::uint32_t,
 | |
|     cert_path: *const c_char,
 | |
|     flags: &FeatureFlags,
 | |
| ) -> int32_t {
 | |
|     let (url, domain, ip_addr, cert_path) = match (
 | |
|         std::ffi::CStr::from_ptr(url).to_str(),
 | |
|         std::ffi::CStr::from_ptr(domain).to_str(),
 | |
|         std::ffi::CStr::from_ptr(ip_addr).to_str(),
 | |
|         std::ffi::CStr::from_ptr(cert_path).to_str(),
 | |
|     ) {
 | |
|         (Ok(url), Ok(domain), Ok(ip_addr), Ok(cert_path)) => {
 | |
|             if domain.is_empty() {
 | |
|                 (url, None, ip_addr.to_string(), None)
 | |
|             } else if !cert_path.is_empty() {
 | |
|                 (url, Some(domain.to_string()), ip_addr.to_string(), Some(cert_path.to_string()))
 | |
|             } else {
 | |
|                 (
 | |
|                     url,
 | |
|                     Some(domain.to_string()),
 | |
|                     ip_addr.to_string(),
 | |
|                     Some(SYSTEM_CERT_PATH.to_string()),
 | |
|                 )
 | |
|             }
 | |
|         }
 | |
|         _ => {
 | |
|             error!("bad input"); // Should not happen
 | |
|             return -libc::EINVAL;
 | |
|         }
 | |
|     };
 | |
| 
 | |
|     let (url, ip_addr) = match (Url::parse(url), IpAddr::from_str(&ip_addr)) {
 | |
|         (Ok(url), Ok(ip_addr)) => (url, ip_addr),
 | |
|         _ => {
 | |
|             error!("bad ip or url"); // Should not happen
 | |
|             return -libc::EINVAL;
 | |
|         }
 | |
|     };
 | |
|     let cmd = Command::Probe {
 | |
|         info: ServerInfo {
 | |
|             net_id,
 | |
|             url,
 | |
|             peer_addr: SocketAddr::new(ip_addr, DOH_PORT),
 | |
|             domain,
 | |
|             sk_mark,
 | |
|             cert_path,
 | |
|             idle_timeout_ms: flags.idle_timeout_ms,
 | |
|             use_session_resumption: flags.use_session_resumption,
 | |
|         },
 | |
|         timeout: Duration::from_millis(flags.probe_timeout_ms),
 | |
|     };
 | |
|     if let Err(e) = doh.lock().send_cmd(cmd) {
 | |
|         error!("Failed to send the probe: {:?}", e);
 | |
|         return -libc::EPIPE;
 | |
|     }
 | |
|     0
 | |
| }
 | |
| 
 | |
| /// Sends a DNS query via the network associated to the given |net_id| and waits for the response.
 | |
| /// The return code should be either one of the public constant DOH_RESULT_* to indicate the error
 | |
| /// or the size of the answer.
 | |
| /// # Safety
 | |
| /// `doh` must be a non-null pointer previously created by `doh_dispatcher_new()`
 | |
| /// and not yet deleted by `doh_dispatcher_delete()`.
 | |
| /// `dns_query` must point to a buffer at least `dns_query_len` in size.
 | |
| /// `response` must point to a buffer at least `response_len` in size.
 | |
| #[no_mangle]
 | |
| pub unsafe extern "C" fn doh_query(
 | |
|     doh: &DohDispatcher,
 | |
|     net_id: uint32_t,
 | |
|     dns_query: *mut u8,
 | |
|     dns_query_len: size_t,
 | |
|     response: *mut u8,
 | |
|     response_len: size_t,
 | |
|     timeout_ms: uint64_t,
 | |
| ) -> ssize_t {
 | |
|     let q = slice::from_raw_parts_mut(dns_query, dns_query_len);
 | |
| 
 | |
|     let (resp_tx, resp_rx) = oneshot::channel();
 | |
|     let t = Duration::from_millis(timeout_ms);
 | |
|     if let Some(expired_time) = BootTime::now().checked_add(t) {
 | |
|         let cmd = Command::Query {
 | |
|             net_id,
 | |
|             base64_query: base64::encode_config(q, base64::URL_SAFE_NO_PAD),
 | |
|             expired_time,
 | |
|             resp: resp_tx,
 | |
|         };
 | |
| 
 | |
|         if let Err(e) = doh.lock().send_cmd(cmd) {
 | |
|             error!("Failed to send the query: {:?}", e);
 | |
|             return DOH_RESULT_CAN_NOT_SEND;
 | |
|         }
 | |
|     } else {
 | |
|         error!("Bad timeout parameter: {}", timeout_ms);
 | |
|         return DOH_RESULT_CAN_NOT_SEND;
 | |
|     }
 | |
| 
 | |
|     if let Ok(rt) = Builder::new_current_thread().enable_all().build() {
 | |
|         let local = task::LocalSet::new();
 | |
|         match local.block_on(&rt, async { timeout(t, resp_rx).await }) {
 | |
|             Ok(v) => match v {
 | |
|                 Ok(v) => match v {
 | |
|                     Response::Success { answer } => {
 | |
|                         if answer.len() > response_len || answer.len() > isize::MAX as usize {
 | |
|                             return DOH_RESULT_INTERNAL_ERROR;
 | |
|                         }
 | |
|                         let response = slice::from_raw_parts_mut(response, answer.len());
 | |
|                         response.copy_from_slice(&answer);
 | |
|                         answer.len() as ssize_t
 | |
|                     }
 | |
|                     rsp => {
 | |
|                         error!("Non-successful response: {:?}", rsp);
 | |
|                         DOH_RESULT_CAN_NOT_SEND
 | |
|                     }
 | |
|                 },
 | |
|                 Err(e) => {
 | |
|                     error!("no result {}", e);
 | |
|                     DOH_RESULT_CAN_NOT_SEND
 | |
|                 }
 | |
|             },
 | |
|             Err(e) => {
 | |
|                 error!("timeout: {}", e);
 | |
|                 DOH_RESULT_TIMEOUT
 | |
|             }
 | |
|         }
 | |
|     } else {
 | |
|         DOH_RESULT_CAN_NOT_SEND
 | |
|     }
 | |
| }
 | |
| 
 | |
| /// Clears the DoH servers associated with the given |netid|.
 | |
| /// # Safety
 | |
| /// `doh` must be a non-null pointer previously created by `doh_dispatcher_new()`
 | |
| /// and not yet deleted by `doh_dispatcher_delete()`.
 | |
| #[no_mangle]
 | |
| pub extern "C" fn doh_net_delete(doh: &DohDispatcher, net_id: uint32_t) {
 | |
|     if let Err(e) = doh.lock().send_cmd(Command::Clear { net_id }) {
 | |
|         error!("Failed to send the query: {:?}", e);
 | |
|     }
 | |
| }
 | |
| 
 | |
| #[cfg(test)]
 | |
| mod tests {
 | |
|     use super::*;
 | |
| 
 | |
|     const TEST_NET_ID: u32 = 50;
 | |
|     const LOOPBACK_ADDR: &str = "127.0.0.1:443";
 | |
|     const LOCALHOST_URL: &str = "https://mylocal.com/dns-query";
 | |
| 
 | |
|     extern "C" fn success_cb(
 | |
|         net_id: uint32_t,
 | |
|         success: bool,
 | |
|         ip_addr: *const c_char,
 | |
|         host: *const c_char,
 | |
|     ) {
 | |
|         assert!(success);
 | |
|         unsafe {
 | |
|             assert_validation_info(net_id, ip_addr, host);
 | |
|         }
 | |
|     }
 | |
| 
 | |
|     extern "C" fn fail_cb(
 | |
|         net_id: uint32_t,
 | |
|         success: bool,
 | |
|         ip_addr: *const c_char,
 | |
|         host: *const c_char,
 | |
|     ) {
 | |
|         assert!(!success);
 | |
|         unsafe {
 | |
|             assert_validation_info(net_id, ip_addr, host);
 | |
|         }
 | |
|     }
 | |
| 
 | |
|     // # Safety
 | |
|     // `ip_addr`, `host` are null terminated strings
 | |
|     unsafe fn assert_validation_info(
 | |
|         net_id: uint32_t,
 | |
|         ip_addr: *const c_char,
 | |
|         host: *const c_char,
 | |
|     ) {
 | |
|         assert_eq!(net_id, TEST_NET_ID);
 | |
|         let ip_addr = std::ffi::CStr::from_ptr(ip_addr).to_str().unwrap();
 | |
|         let expected_addr: SocketAddr = LOOPBACK_ADDR.parse().unwrap();
 | |
|         assert_eq!(ip_addr, expected_addr.ip().to_string());
 | |
|         let host = std::ffi::CStr::from_ptr(host).to_str().unwrap();
 | |
|         assert_eq!(host, "");
 | |
|     }
 | |
| 
 | |
|     #[tokio::test]
 | |
|     async fn wrap_validation_callback_converts_correctly() {
 | |
|         let info = ServerInfo {
 | |
|             net_id: TEST_NET_ID,
 | |
|             url: Url::parse(LOCALHOST_URL).unwrap(),
 | |
|             peer_addr: LOOPBACK_ADDR.parse().unwrap(),
 | |
|             domain: None,
 | |
|             sk_mark: 0,
 | |
|             cert_path: None,
 | |
|             idle_timeout_ms: 0,
 | |
|             use_session_resumption: true,
 | |
|         };
 | |
| 
 | |
|         wrap_validation_callback(success_cb)(&info, true).await;
 | |
|         wrap_validation_callback(fail_cb)(&info, false).await;
 | |
|     }
 | |
| 
 | |
|     extern "C" fn tag_socket_cb(raw_fd: RawFd) {
 | |
|         assert!(raw_fd > 0)
 | |
|     }
 | |
| 
 | |
|     #[tokio::test]
 | |
|     async fn wrap_tag_socket_callback_converts_correctly() {
 | |
|         let sock = std::net::UdpSocket::bind("127.0.0.1:0").unwrap();
 | |
|         wrap_tag_socket_callback(tag_socket_cb)(&sock).await;
 | |
|     }
 | |
| }
 |