Skip to content

Commit

Permalink
Add sanity checks in DNS message decoding (#169)
Browse files Browse the repository at this point in the history
And: 
* fixed a few new cargo clippy warnings.
* defined DnsIncoming::HEADER_LEN const.
  • Loading branch information
keepsimple1 authored Feb 3, 2024
1 parent b503309 commit e5bda10
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 7 deletions.
74 changes: 70 additions & 4 deletions src/dns_parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -841,6 +841,8 @@ pub(crate) struct DnsIncoming {
}

impl DnsIncoming {
const HEADER_LEN: usize = 12;

pub(crate) fn new(data: Vec<u8>) -> Result<Self> {
let mut incoming = Self {
offset: 0,
Expand Down Expand Up @@ -870,7 +872,7 @@ impl DnsIncoming {
}

fn read_header(&mut self) -> Result<()> {
if self.data.len() < 12 {
if self.data.len() < Self::HEADER_LEN {
return Err(Error::Msg(format!(
"DNS incoming: header is too short: {} bytes",
self.data.len()
Expand All @@ -885,7 +887,7 @@ impl DnsIncoming {
self.num_authorities = u16_from_be_slice(&data[8..10]);
self.num_additionals = u16_from_be_slice(&data[10..12]);

self.offset = 12;
self.offset = Self::HEADER_LEN;

debug!(
"read_header: id {}, {} questions {} answers {} authorities {} additionals",
Expand Down Expand Up @@ -954,6 +956,16 @@ impl DnsIncoming {
for _ in 0..n {
let name = self.read_name()?;
let slice = &self.data[self.offset..];

// Muse have at least TYPE, CLASS, TTL, RDLENGTH fields: 10 bytes.
if slice.len() < 10 {
return Err(Error::Msg(format!(
"read_others: RR '{}' is too short after name: {} bytes",
&name,
slice.len()
)));
}

let ty = u16_from_be_slice(&slice[..2]);
let class = u16_from_be_slice(&slice[2..4]);
let ttl = u32_from_be_slice(&slice[4..8]);
Expand Down Expand Up @@ -1113,7 +1125,18 @@ impl DnsIncoming {
0x00 => {
// regular utf8 string with length
offset += 1;
name += str::from_utf8(&data[offset..(offset + length as usize)])
let ending = offset + length as usize;

// Never read beyond the whole data length.
if ending > data.len() {
return Err(Error::Msg(format!(
"read_name: ending {} exceeds data length {}",
ending,
data.len()
)));
}

name += str::from_utf8(&data[offset..ending])
.map_err(|e| Error::Msg(format!("read_name: from_utf8: {}", e)))?;
name += ".";
offset += length as usize;
Expand Down Expand Up @@ -1176,7 +1199,10 @@ fn get_expiration_time(created: u64, ttl: u32, percent: u32) -> u64 {

#[cfg(test)]
mod tests {
use super::{DnsIncoming, DnsOutgoing, FLAGS_QR_QUERY, TYPE_PTR};
use super::{
DnsIncoming, DnsOutgoing, DnsSrv, CLASS_IN, CLASS_UNIQUE, FLAGS_QR_QUERY,
FLAGS_QR_RESPONSE, TYPE_PTR,
};

#[test]
fn test_read_name_invalid_length() {
Expand All @@ -1186,7 +1212,9 @@ mod tests {
let data = out.to_packet_data();

// construct invalid data.
let max_len = data.len() as u8;
let mut data_with_invalid_name_length = data.clone();
let mut data_with_larger_name_length = data.clone();
let name_length_offset = 12;

// 0x9 is the length of `name`
Expand All @@ -1203,5 +1231,43 @@ mod tests {
if let Err(e) = invalid {
println!("error: {}", e);
}

// Another error case: `length`` is larger than the actual string length.
data_with_larger_name_length[name_length_offset] = max_len + 1;
let invalid = DnsIncoming::new(data_with_larger_name_length);
assert!(invalid.is_err());
if let Err(e) = invalid {
println!("error: {}", e);
}
}

/// Tests DnsIncoming::read_others()
#[test]
fn test_rr_too_short_after_name() {
let name = "test_rr_too_short._udp.local.";
let mut response = DnsOutgoing::new(FLAGS_QR_RESPONSE);
response.add_additional_answer(Box::new(DnsSrv::new(
name,
CLASS_IN | CLASS_UNIQUE,
1,
1,
1,
9000,
"instance1".to_string(),
)));
let data = response.to_packet_data();
let mut data_too_short = data.clone();

// verify the original data is good.
let incoming = DnsIncoming::new(data);
assert!(incoming.is_ok());

// verify that truncated data will cause an error.
data_too_short.truncate(DnsIncoming::HEADER_LEN + name.len() + 2);
let invalid = DnsIncoming::new(data_too_short);
assert!(invalid.is_err());
if let Err(e) = invalid {
println!("error: {}", e);
}
}
}
8 changes: 5 additions & 3 deletions src/service_daemon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1417,6 +1417,8 @@ impl Zeroconf {
return false;
}

buf.truncate(sz); // reduce potential processing errors

match DnsIncoming::new(buf) {
Ok(msg) => {
if msg.is_query() {
Expand Down Expand Up @@ -1532,7 +1534,7 @@ impl Zeroconf {

// resolve SRV record
if let Some(records) = self.cache.srv.get(fullname) {
if let Some(answer) = records.get(0) {
if let Some(answer) = records.first() {
if let Some(dns_srv) = answer.any().downcast_ref::<DnsSrv>() {
info.set_hostname(dns_srv.host.clone());
info.set_port(dns_srv.port);
Expand All @@ -1542,7 +1544,7 @@ impl Zeroconf {

// resolve TXT record
if let Some(records) = self.cache.txt.get(fullname) {
if let Some(record) = records.get(0) {
if let Some(record) = records.first() {
if let Some(dns_txt) = record.any().downcast_ref::<DnsTxt>() {
info.set_properties_from_txt(&dns_txt.text);
}
Expand Down Expand Up @@ -2001,7 +2003,7 @@ impl DnsCache {
self.srv
.iter()
.filter_map(|(instance, srv_list)| {
if let Some(item) = srv_list.get(0) {
if let Some(item) = srv_list.first() {
if let Some(dns_srv) = item.any().downcast_ref::<DnsSrv>() {
if dns_srv.host == host {
return Some(instance.clone());
Expand Down

0 comments on commit e5bda10

Please sign in to comment.