android13/packages/modules/DnsResolver/doh/ffi.rs

403 lines
13 KiB
Rust

/*
* Copyright (C) 2021 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
//! C API for the DoH backend for the Android DnsResolver module.
use crate::boot_time::{timeout, BootTime, Duration};
use crate::dispatcher::{Command, Dispatcher, Response, ServerInfo};
use crate::network::{SocketTagger, ValidationReporter};
use futures::FutureExt;
use libc::{c_char, int32_t, size_t, ssize_t, uint32_t, uint64_t};
use log::{error, warn};
use std::ffi::CString;
use std::net::{IpAddr, SocketAddr};
use std::ops::DerefMut;
use std::os::unix::io::RawFd;
use std::str::FromStr;
use std::sync::{Arc, Mutex};
use std::{ptr, slice};
use tokio::runtime::Builder;
use tokio::sync::oneshot;
use tokio::task;
use url::Url;
pub type ValidationCallback =
extern "C" fn(net_id: uint32_t, success: bool, ip_addr: *const c_char, host: *const c_char);
pub type TagSocketCallback = extern "C" fn(sock: RawFd);
#[repr(C)]
pub struct FeatureFlags {
probe_timeout_ms: uint64_t,
idle_timeout_ms: uint64_t,
use_session_resumption: bool,
}
fn wrap_validation_callback(validation_fn: ValidationCallback) -> ValidationReporter {
Arc::new(move |info: &ServerInfo, success: bool| {
async move {
let (ip_addr, domain) = match (
CString::new(info.peer_addr.ip().to_string()),
CString::new(info.domain.clone().unwrap_or_default()),
) {
(Ok(ip_addr), Ok(domain)) => (ip_addr, domain),
_ => {
error!("validation_callback bad input");
return;
}
};
let netd_id = info.net_id;
task::spawn_blocking(move || {
validation_fn(netd_id, success, ip_addr.as_ptr(), domain.as_ptr())
})
.await
.unwrap_or_else(|e| warn!("Validation function task failed: {}", e))
}
.boxed()
})
}
fn wrap_tag_socket_callback(tag_socket_fn: TagSocketCallback) -> SocketTagger {
use std::os::unix::io::AsRawFd;
Arc::new(move |udp_socket: &std::net::UdpSocket| {
let fd = udp_socket.as_raw_fd();
async move {
task::spawn_blocking(move || {
tag_socket_fn(fd);
})
.await
.unwrap_or_else(|e| warn!("Socket tag function task failed: {}", e))
}
.boxed()
})
}
pub struct DohDispatcher(Mutex<Dispatcher>);
impl DohDispatcher {
fn lock(&self) -> impl DerefMut<Target = Dispatcher> + '_ {
self.0.lock().unwrap()
}
}
const SYSTEM_CERT_PATH: &str = "/system/etc/security/cacerts";
/// The return code of doh_query means that there is no answer.
pub const DOH_RESULT_INTERNAL_ERROR: ssize_t = -1;
/// The return code of doh_query means that query can't be sent.
pub const DOH_RESULT_CAN_NOT_SEND: ssize_t = -2;
/// The return code of doh_query to indicate that the query timed out.
pub const DOH_RESULT_TIMEOUT: ssize_t = -255;
/// The error log level.
pub const DOH_LOG_LEVEL_ERROR: u32 = 0;
/// The warning log level.
pub const DOH_LOG_LEVEL_WARN: u32 = 1;
/// The info log level.
pub const DOH_LOG_LEVEL_INFO: u32 = 2;
/// The debug log level.
pub const DOH_LOG_LEVEL_DEBUG: u32 = 3;
/// The trace log level.
pub const DOH_LOG_LEVEL_TRACE: u32 = 4;
const DOH_PORT: u16 = 443;
fn level_from_u32(level: u32) -> Option<log::Level> {
use log::Level::*;
match level {
DOH_LOG_LEVEL_ERROR => Some(Error),
DOH_LOG_LEVEL_WARN => Some(Warn),
DOH_LOG_LEVEL_INFO => Some(Info),
DOH_LOG_LEVEL_DEBUG => Some(Debug),
DOH_LOG_LEVEL_TRACE => Some(Trace),
_ => None,
}
}
/// Performs static initialization for android logger.
/// If an invalid level is passed, defaults to logging errors only.
/// If called more than once, it will have no effect on subsequent calls.
#[no_mangle]
pub extern "C" fn doh_init_logger(level: u32) {
let log_level = level_from_u32(level).unwrap_or(log::Level::Error);
android_logger::init_once(android_logger::Config::default().with_min_level(log_level));
}
/// Set the log level.
/// If an invalid level is passed, defaults to logging errors only.
#[no_mangle]
pub extern "C" fn doh_set_log_level(level: u32) {
let level_filter = level_from_u32(level)
.map(|level| level.to_level_filter())
.unwrap_or(log::LevelFilter::Error);
log::set_max_level(level_filter);
}
/// Performs the initialization for the DoH engine.
/// Creates and returns a DoH engine instance.
#[no_mangle]
pub extern "C" fn doh_dispatcher_new(
validation_fn: ValidationCallback,
tag_socket_fn: TagSocketCallback,
) -> *mut DohDispatcher {
match Dispatcher::new(
wrap_validation_callback(validation_fn),
wrap_tag_socket_callback(tag_socket_fn),
) {
Ok(c) => Box::into_raw(Box::new(DohDispatcher(Mutex::new(c)))),
Err(e) => {
error!("doh_dispatcher_new: failed: {:?}", e);
ptr::null_mut()
}
}
}
/// Deletes a DoH engine created by doh_dispatcher_new().
/// # Safety
/// `doh` must be a non-null pointer previously created by `doh_dispatcher_new()`
/// and not yet deleted by `doh_dispatcher_delete()`.
#[no_mangle]
pub unsafe extern "C" fn doh_dispatcher_delete(doh: *mut DohDispatcher) {
Box::from_raw(doh).lock().exit_handler()
}
/// Probes and stores the DoH server with the given configurations.
/// Use the negative errno-style codes as the return value to represent the result.
/// # Safety
/// `doh` must be a non-null pointer previously created by `doh_dispatcher_new()`
/// and not yet deleted by `doh_dispatcher_delete()`.
/// `url`, `domain`, `ip_addr`, `cert_path` are null terminated strings.
#[no_mangle]
pub unsafe extern "C" fn doh_net_new(
doh: &DohDispatcher,
net_id: uint32_t,
url: *const c_char,
domain: *const c_char,
ip_addr: *const c_char,
sk_mark: libc::uint32_t,
cert_path: *const c_char,
flags: &FeatureFlags,
) -> int32_t {
let (url, domain, ip_addr, cert_path) = match (
std::ffi::CStr::from_ptr(url).to_str(),
std::ffi::CStr::from_ptr(domain).to_str(),
std::ffi::CStr::from_ptr(ip_addr).to_str(),
std::ffi::CStr::from_ptr(cert_path).to_str(),
) {
(Ok(url), Ok(domain), Ok(ip_addr), Ok(cert_path)) => {
if domain.is_empty() {
(url, None, ip_addr.to_string(), None)
} else if !cert_path.is_empty() {
(url, Some(domain.to_string()), ip_addr.to_string(), Some(cert_path.to_string()))
} else {
(
url,
Some(domain.to_string()),
ip_addr.to_string(),
Some(SYSTEM_CERT_PATH.to_string()),
)
}
}
_ => {
error!("bad input"); // Should not happen
return -libc::EINVAL;
}
};
let (url, ip_addr) = match (Url::parse(url), IpAddr::from_str(&ip_addr)) {
(Ok(url), Ok(ip_addr)) => (url, ip_addr),
_ => {
error!("bad ip or url"); // Should not happen
return -libc::EINVAL;
}
};
let cmd = Command::Probe {
info: ServerInfo {
net_id,
url,
peer_addr: SocketAddr::new(ip_addr, DOH_PORT),
domain,
sk_mark,
cert_path,
idle_timeout_ms: flags.idle_timeout_ms,
use_session_resumption: flags.use_session_resumption,
},
timeout: Duration::from_millis(flags.probe_timeout_ms),
};
if let Err(e) = doh.lock().send_cmd(cmd) {
error!("Failed to send the probe: {:?}", e);
return -libc::EPIPE;
}
0
}
/// Sends a DNS query via the network associated to the given |net_id| and waits for the response.
/// The return code should be either one of the public constant DOH_RESULT_* to indicate the error
/// or the size of the answer.
/// # Safety
/// `doh` must be a non-null pointer previously created by `doh_dispatcher_new()`
/// and not yet deleted by `doh_dispatcher_delete()`.
/// `dns_query` must point to a buffer at least `dns_query_len` in size.
/// `response` must point to a buffer at least `response_len` in size.
#[no_mangle]
pub unsafe extern "C" fn doh_query(
doh: &DohDispatcher,
net_id: uint32_t,
dns_query: *mut u8,
dns_query_len: size_t,
response: *mut u8,
response_len: size_t,
timeout_ms: uint64_t,
) -> ssize_t {
let q = slice::from_raw_parts_mut(dns_query, dns_query_len);
let (resp_tx, resp_rx) = oneshot::channel();
let t = Duration::from_millis(timeout_ms);
if let Some(expired_time) = BootTime::now().checked_add(t) {
let cmd = Command::Query {
net_id,
base64_query: base64::encode_config(q, base64::URL_SAFE_NO_PAD),
expired_time,
resp: resp_tx,
};
if let Err(e) = doh.lock().send_cmd(cmd) {
error!("Failed to send the query: {:?}", e);
return DOH_RESULT_CAN_NOT_SEND;
}
} else {
error!("Bad timeout parameter: {}", timeout_ms);
return DOH_RESULT_CAN_NOT_SEND;
}
if let Ok(rt) = Builder::new_current_thread().enable_all().build() {
let local = task::LocalSet::new();
match local.block_on(&rt, async { timeout(t, resp_rx).await }) {
Ok(v) => match v {
Ok(v) => match v {
Response::Success { answer } => {
if answer.len() > response_len || answer.len() > isize::MAX as usize {
return DOH_RESULT_INTERNAL_ERROR;
}
let response = slice::from_raw_parts_mut(response, answer.len());
response.copy_from_slice(&answer);
answer.len() as ssize_t
}
rsp => {
error!("Non-successful response: {:?}", rsp);
DOH_RESULT_CAN_NOT_SEND
}
},
Err(e) => {
error!("no result {}", e);
DOH_RESULT_CAN_NOT_SEND
}
},
Err(e) => {
error!("timeout: {}", e);
DOH_RESULT_TIMEOUT
}
}
} else {
DOH_RESULT_CAN_NOT_SEND
}
}
/// Clears the DoH servers associated with the given |netid|.
/// # Safety
/// `doh` must be a non-null pointer previously created by `doh_dispatcher_new()`
/// and not yet deleted by `doh_dispatcher_delete()`.
#[no_mangle]
pub extern "C" fn doh_net_delete(doh: &DohDispatcher, net_id: uint32_t) {
if let Err(e) = doh.lock().send_cmd(Command::Clear { net_id }) {
error!("Failed to send the query: {:?}", e);
}
}
#[cfg(test)]
mod tests {
use super::*;
const TEST_NET_ID: u32 = 50;
const LOOPBACK_ADDR: &str = "127.0.0.1:443";
const LOCALHOST_URL: &str = "https://mylocal.com/dns-query";
extern "C" fn success_cb(
net_id: uint32_t,
success: bool,
ip_addr: *const c_char,
host: *const c_char,
) {
assert!(success);
unsafe {
assert_validation_info(net_id, ip_addr, host);
}
}
extern "C" fn fail_cb(
net_id: uint32_t,
success: bool,
ip_addr: *const c_char,
host: *const c_char,
) {
assert!(!success);
unsafe {
assert_validation_info(net_id, ip_addr, host);
}
}
// # Safety
// `ip_addr`, `host` are null terminated strings
unsafe fn assert_validation_info(
net_id: uint32_t,
ip_addr: *const c_char,
host: *const c_char,
) {
assert_eq!(net_id, TEST_NET_ID);
let ip_addr = std::ffi::CStr::from_ptr(ip_addr).to_str().unwrap();
let expected_addr: SocketAddr = LOOPBACK_ADDR.parse().unwrap();
assert_eq!(ip_addr, expected_addr.ip().to_string());
let host = std::ffi::CStr::from_ptr(host).to_str().unwrap();
assert_eq!(host, "");
}
#[tokio::test]
async fn wrap_validation_callback_converts_correctly() {
let info = ServerInfo {
net_id: TEST_NET_ID,
url: Url::parse(LOCALHOST_URL).unwrap(),
peer_addr: LOOPBACK_ADDR.parse().unwrap(),
domain: None,
sk_mark: 0,
cert_path: None,
idle_timeout_ms: 0,
use_session_resumption: true,
};
wrap_validation_callback(success_cb)(&info, true).await;
wrap_validation_callback(fail_cb)(&info, false).await;
}
extern "C" fn tag_socket_cb(raw_fd: RawFd) {
assert!(raw_fd > 0)
}
#[tokio::test]
async fn wrap_tag_socket_callback_converts_correctly() {
let sock = std::net::UdpSocket::bind("127.0.0.1:0").unwrap();
wrap_tag_socket_callback(tag_socket_cb)(&sock).await;
}
}