Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
473 changes: 473 additions & 0 deletions vm/devices/net/net_consomme/consomme/src/dns_resolver/dns_tcp.rs

Large diffs are not rendered by default.

122 changes: 105 additions & 17 deletions vm/devices/net/net_consomme/consomme/src/dns_resolver/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,35 @@ use mesh_channel_core::Receiver;
use mesh_channel_core::Sender;
use smoltcp::wire::EthernetAddress;
use smoltcp::wire::IpAddress;
use std::sync::Arc;
use std::task::Context;
use std::task::Poll;

use crate::DropReason;

pub mod dns_tcp;

#[cfg(unix)]
mod unix;

#[cfg(windows)]
mod windows;

#[cfg(unix)]
type PlatformDnsBackend = unix::UnixDnsResolverBackend;

#[cfg(windows)]
type PlatformDnsBackend = windows::WindowsDnsResolverBackend;

static DNS_HEADER_SIZE: usize = 12;

/// Transport protocol for a DNS query.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DnsTransport {
Udp,
Tcp,
}

#[derive(Debug, Clone)]
pub struct DnsFlow {
pub src_addr: IpAddress,
Expand All @@ -27,6 +43,11 @@ pub struct DnsFlow {
pub dst_port: u16,
pub gateway_mac: EthernetAddress,
pub client_mac: EthernetAddress,
// Used by the glibc and Windows DNS backends. The musl resolver
// implementation handles TCP internally, so this field is not
// used in the musl backend.
#[allow(dead_code)]
pub transport: DnsTransport,
}

#[derive(Debug, Clone)]
Expand All @@ -42,16 +63,26 @@ pub struct DnsResponse {
pub response_data: Vec<u8>,
}

/// Backend trait for resolving DNS queries.
///
/// Both `dns_query` in [`DnsRequest`] and `response_data` in [`DnsResponse`]
/// carry **raw DNS message bytes** with no transport-layer framing (e.g. no
/// TCP 2-byte length prefix). Transport framing is the responsibility of the
/// caller (see [`dns_tcp::DnsTcpHandler`]).
pub(crate) trait DnsBackend: Send + Sync {
fn query(&self, request: &DnsRequest<'_>, response_sender: Sender<DnsResponse>);
}

#[derive(Inspect)]
pub struct DnsResolver {
pub struct DnsResolver<B: DnsBackend = PlatformDnsBackend> {
#[inspect(skip)]
backend: Box<dyn DnsBackend>,
backend: Arc<B>,
/// Channel receiver for UDP DNS responses. Each call to
/// [`Self::submit_udp_query`] sends the response back through this
/// channel so that [`Self::poll_udp_response`] can retrieve it.
/// The TCP path uses its own per-connection channel instead.
#[inspect(skip)]
receiver: Receiver<DnsResponse>,
udp_receiver: Receiver<DnsResponse>,
pending_requests: usize,
max_pending_requests: usize,
}
Expand All @@ -68,10 +99,10 @@ impl DnsResolver {
pub fn new(max_pending_requests: usize) -> Result<Self, std::io::Error> {
use crate::dns_resolver::windows::WindowsDnsResolverBackend;

let receiver = Receiver::new();
let udp_receiver = Receiver::new();
Ok(Self {
backend: Box::new(WindowsDnsResolverBackend::new()?),
receiver,
backend: Arc::new(WindowsDnsResolverBackend::new()?),
udp_receiver,
pending_requests: 0,
max_pending_requests,
})
Expand All @@ -85,43 +116,100 @@ impl DnsResolver {
pub fn new(max_pending_requests: usize) -> Result<Self, std::io::Error> {
use crate::dns_resolver::unix::UnixDnsResolverBackend;

let receiver = Receiver::new();
let udp_receiver = Receiver::new();
Ok(Self {
backend: Box::new(UnixDnsResolverBackend::new()?),
receiver,
backend: Arc::new(UnixDnsResolverBackend::new()?),
udp_receiver,
pending_requests: 0,
max_pending_requests,
})
}
}

pub fn handle_dns(&mut self, request: &DnsRequest<'_>) -> Result<(), DropReason> {
if request.dns_query.len() <= DNS_HEADER_SIZE {
return Err(DropReason::Packet(smoltcp::wire::Error));
}

impl<B: DnsBackend> DnsResolver<B> {
// ── Shared ───────────────────────────────────────────────────────

/// Submit a DNS query to the backend with a caller-supplied response
/// sender. Returns `true` if accepted, `false` if the pending-request
/// limit has been reached.
fn submit_query(
&mut self,
request: &DnsRequest<'_>,
response_sender: Sender<DnsResponse>,
) -> bool {
if self.pending_requests < self.max_pending_requests {
self.pending_requests += 1;
self.backend.query(request, self.receiver.sender());
self.backend.query(request, response_sender);
true
} else {
tracelimit::warn_ratelimited!(
current = self.pending_requests,
max = self.max_pending_requests,
"DNS request limit reached"
);
false
}
}

/// Validate and submit a DNS query received over UDP.
///
/// The response will be delivered through [`Self::poll_udp_response`].
pub fn submit_udp_query(&mut self, request: &DnsRequest<'_>) -> Result<(), DropReason> {
if request.dns_query.len() <= DNS_HEADER_SIZE {
return Err(DropReason::Packet(smoltcp::wire::Error));
}

let sender = self.udp_receiver.sender();
self.submit_query(request, sender);
Ok(())
}

pub fn poll_response(&mut self, cx: &mut Context<'_>) -> Poll<Option<DnsResponse>> {
match self.receiver.poll_recv(cx) {
/// Poll for the next completed UDP DNS response.
///
/// This drains `self.udp_receiver`; it must **not** be used for TCP
/// responses (the TCP path has its own per-connection channel).
pub fn poll_udp_response(&mut self, cx: &mut Context<'_>) -> Poll<Option<DnsResponse>> {
match self.udp_receiver.poll_recv(cx) {
Poll::Ready(Ok(response)) => {
self.pending_requests -= 1;
Poll::Ready(Some(response))
}
Poll::Ready(Err(_)) | Poll::Pending => Poll::Pending,
}
}

/// Submit a DNS query with a caller-supplied response sender.
///
/// Returns `true` if the query was accepted, or `false` if the
/// pending-request limit has been reached.
///
/// The TCP handler calls this with its own [`Sender`] so responses
/// arrive on the per-connection channel rather than `udp_receiver`.
pub fn submit_tcp_query(
&mut self,
request: &DnsRequest<'_>,
response_sender: Sender<DnsResponse>,
) -> bool {
self.submit_query(request, response_sender)
}

/// Decrement the pending-request counter after a TCP response has
/// been consumed by [`dns_tcp::DnsTcpHandler`].
pub fn complete_tcp_query(&mut self) {
self.pending_requests = self.pending_requests.saturating_sub(1);
}

/// Create a resolver with a test backend (for unit tests only).
#[cfg(test)]
pub(crate) fn new_for_test(backend: Arc<B>) -> Self {
let udp_receiver = Receiver::new();
Self {
backend,
udp_receiver,
pending_requests: 0,
max_pending_requests: DEFAULT_MAX_PENDING_DNS_REQUESTS,
}
}
}

/// Internal DNS request structure used by backend implementations.
Expand Down
146 changes: 111 additions & 35 deletions vm/devices/net/net_consomme/consomme/src/dns_resolver/unix/glibc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,15 @@ use super::DnsRequestInternal;
use super::DnsResponse;
use super::build_servfail_response;
use libc::c_int;
use libc::c_ulong;
use zerocopy::FromZeros;
use zerocopy::Immutable;
use zerocopy::IntoBytes;
use zerocopy::KnownLayout;

/// RES_USEVC option flag - use TCP (virtual circuit) instead of UDP.
/// From glibc resolv/resolv.h: https://sourceware.org/git/?p=glibc.git;a=blob_plain;f=resolv/resolv.h;hb=HEAD
const RES_USEVC: c_ulong = 0x00000040;

/// Size of the `res_state` structure for different platforms.
/// These values were derived from including resolv.h and using sizeof(struct __res_state).
Expand All @@ -18,16 +27,44 @@ const RES_STATE_SIZE: usize = 552;
#[cfg(target_os = "linux")]
const RES_STATE_SIZE: usize = 568;

/// The prefix of the glibc `struct __res_state` that we need to access.
/// This matches the layout defined in glibc resolv/bits/types/res_state.h:
/// See: https://sourceware.org/git/?p=glibc.git;a=blob_plain;f=resolv/bits/types/res_state.h;hb=HEAD
/// See: https://github.com/apple-oss-distributions/libresolv/blob/main/resolv.h
///
/// ```c
/// struct __res_state {
/// int retrans; /* retransmission time interval */
/// int retry; /* number of times to retransmit */
/// unsigned long options; /* option flags */
/// ...
/// }
/// ```
#[repr(C)]
#[derive(IntoBytes, Immutable, KnownLayout, FromZeros)]
struct ResStatePrefix {
retrans: c_int,
retry: c_int,
options: c_ulong,
}

/// Wrapper around the glibc/macOS resolver state structure.
#[repr(C)]
#[derive(IntoBytes, Immutable, KnownLayout, FromZeros)]
pub struct ResState {
_data: [u8; RES_STATE_SIZE],
prefix: ResStatePrefix,
_rest: [u8; RES_STATE_SIZE - size_of::<ResStatePrefix>()],
}

impl ResState {
pub fn zeroed() -> Self {
Self {
_data: [0u8; RES_STATE_SIZE],
}
/// Set the options field in the resolver state.
pub fn set_options(&mut self, options: c_ulong) {
self.prefix.options = options;
}

/// Get the options field from the resolver state.
pub fn options(&self) -> c_ulong {
self.prefix.options
}
}

Expand Down Expand Up @@ -59,7 +96,7 @@ unsafe extern "C" {
/// Handle a DNS query using reentrant resolver functions (macOS and GNU libc).
pub fn handle_dns_query(request: DnsRequestInternal) {
let mut answer = vec![0u8; 4096];
let mut state = ResState::zeroed();
let mut state = ResState::new_zeroed();

// SAFETY: res_ninit initializes the resolver state by reading /etc/resolv.conf.
// The state is properly sized and aligned.
Expand All @@ -74,6 +111,10 @@ pub fn handle_dns_query(request: DnsRequestInternal) {
return;
}

// Set RES_USEVC to force TCP for DNS queries.
if request.flow.transport == crate::dns_resolver::DnsTransport::Tcp {
state.set_options(state.options() | RES_USEVC);
}
// SAFETY: res_nsend is called with valid state, query buffer and answer buffer.
// All buffers are properly sized and aligned. The state was initialized above.
let answer_len = unsafe {
Expand Down Expand Up @@ -110,17 +151,9 @@ pub fn handle_dns_query(request: DnsRequestInternal) {
mod tests {
use super::*;

#[test]
fn test_res_ninit_and_res_nsend_callable() {
// Test that the reentrant resolver functions are callable
let mut state = ResState::zeroed();

// SAFETY: res_ninit initializes the resolver state
let init_result = unsafe { res_ninit(&mut state) };
assert_eq!(init_result, 0, "res_ninit() should succeed");

// Example DNS query buffer for google.com A record
let dns_query: Vec<u8> = vec![
/// Example DNS query buffer for google.com A record.
fn sample_dns_query() -> Vec<u8> {
vec![
0x12, 0x34, // Transaction ID
0x01, 0x00, // Flags: standard query
0x00, 0x01, // Questions: 1
Expand All @@ -131,23 +164,66 @@ mod tests {
0x00, // null terminator
0x00, 0x01, // Type: A
0x00, 0x01, // Class: IN
];

let mut answer = vec![0u8; 4096];

// SAFETY: res_nsend is called with valid state, query buffer and answer buffer.
let _answer_len = unsafe {
res_nsend(
&mut state,
dns_query.as_ptr(),
dns_query.len() as c_int,
answer.as_mut_ptr(),
answer.len() as c_int,
)
};

// Clean up
// SAFETY: res_nclose frees resources associated with the resolver state.
unsafe { res_nclose(&mut state) };
]
}

/// RAII wrapper for ResState that ensures proper cleanup.
struct InitializedResState {
state: ResState,
}

impl InitializedResState {
fn new() -> Self {
let mut state = ResState::new_zeroed();
// SAFETY: res_ninit initializes the resolver state
let result = unsafe { res_ninit(&mut state) };
assert_eq!(result, 0, "res_ninit() should succeed");
Self { state }
}

/// Send a DNS query and return the response length.
fn send_query(&mut self, query: &[u8]) -> c_int {
let mut answer = vec![0u8; 4096];
// SAFETY: res_nsend is called with valid state, query buffer and answer buffer.
unsafe {
res_nsend(
&mut self.state,
query.as_ptr(),
query.len() as c_int,
answer.as_mut_ptr(),
answer.len() as c_int,
)
}
}
}

impl Drop for InitializedResState {
fn drop(&mut self) {
// SAFETY: res_nclose frees resources associated with the resolver state.
unsafe { res_nclose(&mut self.state) };
}
}

#[test]
fn test_res_ninit_and_res_nsend_callable() {
let mut state = InitializedResState::new();
let _answer_len = state.send_query(&sample_dns_query());
}

#[test]
fn test_res_usevc_flag_for_tcp() {
let mut state = InitializedResState::new();

// Verify we can read and modify the options field
let original_options = state.state.options();
state.state.set_options(original_options | RES_USEVC);
assert_ne!(
state.state.options() & RES_USEVC,
0,
"RES_USEVC flag should be set"
);

// With RES_USEVC set, this should use TCP instead of UDP.
let _answer_len = state.send_query(&sample_dns_query());
}
}
Loading