diff --git a/Cargo.lock b/Cargo.lock index 866f8e719..c542522c8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -484,6 +484,12 @@ version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d77f7ec81a6d05a3abb01ab6eb7590f6083d08449fe5a1c8b1e620283546ccb7" +[[package]] +name = "hex-literal" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6fe2267d4ed49bc07b63801559be28c718ea06c4738b7a03c94df7386d2cde46" + [[package]] name = "home" version = "0.5.5" @@ -1441,6 +1447,7 @@ dependencies = [ "derive_more", "dns-lookup", "etcetera", + "hex-literal", "humantime", "indexmap", "itertools", diff --git a/Cargo.toml b/Cargo.toml index 89c5fda38..23732fe3d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -74,6 +74,7 @@ windows-sys = { version = "0.48.0", features = [ ] } [dev-dependencies] +hex-literal = "0.4.1" rand = "0.8.5" test-case = "3.2.1" diff --git a/src/tracing/net/ipv4.rs b/src/tracing/net/ipv4.rs index 4ae744fe6..21b1e32a3 100644 --- a/src/tracing/net/ipv4.rs +++ b/src/tracing/net/ipv4.rs @@ -8,12 +8,14 @@ use crate::tracing::packet::checksum::{icmp_ipv4_checksum, udp_ipv4_checksum}; use crate::tracing::packet::icmpv4::destination_unreachable::DestinationUnreachablePacket; use crate::tracing::packet::icmpv4::echo_reply::EchoReplyPacket; use crate::tracing::packet::icmpv4::echo_request::EchoRequestPacket; +use crate::tracing::packet::icmpv4::extension::extension_header::ExtensionHeader; +use crate::tracing::packet::icmpv4::extension::extension_structure::ExtensionStructure; use crate::tracing::packet::icmpv4::time_exceeded::TimeExceededPacket; use crate::tracing::packet::icmpv4::{IcmpCode, IcmpPacket, IcmpType}; use crate::tracing::packet::ipv4::Ipv4Packet; use crate::tracing::packet::tcp::TcpPacket; use crate::tracing::packet::udp::UdpPacket; -use crate::tracing::packet::IpProtocol; +use crate::tracing::packet::{fmt_payload, IpProtocol}; use crate::tracing::probe::{ ProbeResponse, ProbeResponseData, ProbeResponseSeq, ProbeResponseSeqIcmp, ProbeResponseSeqTcp, ProbeResponseSeqUdp, @@ -332,14 +334,42 @@ fn extract_probe_resp( Ok(match icmp_v4.get_icmp_type() { IcmpType::TimeExceeded => { let packet = TimeExceededPacket::new_view(icmp_v4.packet()).req()?; - let resp_seq = extract_probe_resp_seq(packet.payload(), protocol)?; + let payload = packet.payload(); + let extension = packet.extension(); + if let Some(ext) = extension { + let extensions = ExtensionStructure::new_view(ext).req()?; + let ext_header = ExtensionHeader::new_view(extensions.header()).req()?; + println!( + "extension header: version={}, checksum={}", + ext_header.get_version(), + ext_header.get_checksum() + ); + for obj in extensions.iter() { + println!("extension object: length={:?}, class_num={:?}, class_subtype={:?}, payload={}", obj.get_length(), obj.get_class_num(), obj.get_class_subtype(), fmt_payload(obj.payload())); + } + } + let resp_seq = extract_probe_resp_seq(payload, protocol)?; Some(ProbeResponse::TimeExceeded(ProbeResponseData::new( recv, src, resp_seq, ))) } IcmpType::DestinationUnreachable => { let packet = DestinationUnreachablePacket::new_view(icmp_v4.packet()).req()?; - let resp_seq = extract_probe_resp_seq(packet.payload(), protocol)?; + let payload = packet.payload(); + let extension = packet.extension(); + if let Some(ext) = extension { + let extensions = ExtensionStructure::new_view(ext).req()?; + let ext_header = ExtensionHeader::new_view(extensions.header()).req()?; + println!( + "extension header: version={}, checksum={}", + ext_header.get_version(), + ext_header.get_checksum() + ); + for obj in extensions.iter() { + println!("extension object: length={:?}, class_num={:?}, class_subtype={:?}, payload={}", obj.get_length(), obj.get_class_num(), obj.get_class_subtype(), fmt_payload(obj.payload())); + } + } + let resp_seq = extract_probe_resp_seq(payload, protocol)?; Some(ProbeResponse::DestinationUnreachable( ProbeResponseData::new(recv, src, resp_seq), )) diff --git a/src/tracing/packet.rs b/src/tracing/packet.rs index ab9f56160..1177dfdf4 100644 --- a/src/tracing/packet.rs +++ b/src/tracing/packet.rs @@ -21,7 +21,8 @@ pub mod udp; /// `TCP` packets. pub mod tcp; -fn fmt_payload(bytes: &[u8]) -> String { +#[must_use] +pub fn fmt_payload(bytes: &[u8]) -> String { use itertools::Itertools as _; format!("{:02x}", bytes.iter().format(" ")) } diff --git a/src/tracing/packet/icmpv4.rs b/src/tracing/packet/icmpv4.rs index 577e41ec0..4cd2f7264 100644 --- a/src/tracing/packet/icmpv4.rs +++ b/src/tracing/packet/icmpv4.rs @@ -631,12 +631,14 @@ pub mod echo_reply { pub mod time_exceeded { use crate::tracing::packet::buffer::Buffer; use crate::tracing::packet::fmt_payload; + use crate::tracing::packet::icmpv4::extension::split; use crate::tracing::packet::icmpv4::{IcmpCode, IcmpType}; use std::fmt::{Debug, Formatter}; const TYPE_OFFSET: usize = 0; const CODE_OFFSET: usize = 1; const CHECKSUM_OFFSET: usize = 2; + const LENGTH_OFFSET: usize = 5; /// Represents an ICMP `TimeExceeded` packet. /// @@ -689,6 +691,11 @@ pub mod time_exceeded { u16::from_be_bytes(self.buf.get_bytes(CHECKSUM_OFFSET)) } + #[must_use] + pub fn get_length(&self) -> u8 { + self.buf.read(LENGTH_OFFSET) + } + pub fn set_icmp_type(&mut self, val: IcmpType) { *self.buf.write(TYPE_OFFSET) = val.id(); } @@ -701,6 +708,10 @@ pub mod time_exceeded { self.buf.set_bytes(CHECKSUM_OFFSET, val.to_be_bytes()); } + pub fn set_length(&mut self, val: u8) { + *self.buf.write(LENGTH_OFFSET) = val; + } + pub fn set_payload(&mut self, vals: &[u8]) { let current_offset = Self::minimum_packet_size(); self.buf.as_slice_mut()[current_offset..current_offset + vals.len()] @@ -714,7 +725,20 @@ pub mod time_exceeded { #[must_use] pub fn payload(&self) -> &[u8] { - &self.buf.as_slice()[Self::minimum_packet_size()..] + let (payload, _) = self.split_payload_extension(); + payload + } + + #[must_use] + pub fn extension(&self) -> Option<&[u8]> { + let (_, extension) = self.split_payload_extension(); + extension + } + + fn split_payload_extension(&self) -> (&[u8], Option<&[u8]>) { + let rfc4884_length = self.get_length(); + let icmp_payload = &self.buf.as_slice()[Self::minimum_packet_size()..]; + split(rfc4884_length, icmp_payload) } } @@ -724,6 +748,7 @@ pub mod time_exceeded { .field("icmp_type", &self.get_icmp_type()) .field("icmp_code", &self.get_icmp_code()) .field("checksum", &self.get_checksum()) + .field("length", &self.get_length()) .field("payload", &fmt_payload(self.payload())) .finish() } @@ -799,13 +824,15 @@ pub mod time_exceeded { pub mod destination_unreachable { use crate::tracing::packet::buffer::Buffer; use crate::tracing::packet::fmt_payload; + use crate::tracing::packet::icmpv4::extension::split; use crate::tracing::packet::icmpv4::{IcmpCode, IcmpType}; use std::fmt::{Debug, Formatter}; const TYPE_OFFSET: usize = 0; const CODE_OFFSET: usize = 1; const CHECKSUM_OFFSET: usize = 2; - const UNUSED_OFFSET: usize = 4; + const LENGTH_OFFSET: usize = 5; + // const UNUSED_OFFSET: usize = 4; // TODO const NEXT_HOP_MTU_OFFSET: usize = 6; /// Represents an ICMP `DestinationUnreachable` packet. @@ -859,9 +886,15 @@ pub mod destination_unreachable { u16::from_be_bytes(self.buf.get_bytes(CHECKSUM_OFFSET)) } + // TODO + // #[must_use] + // pub fn get_unused(&self) -> u16 { + // u16::from_be_bytes(self.buf.get_bytes(UNUSED_OFFSET)) + // } + #[must_use] - pub fn get_unused(&self) -> u16 { - u16::from_be_bytes(self.buf.get_bytes(UNUSED_OFFSET)) + pub fn get_length(&self) -> u8 { + self.buf.read(LENGTH_OFFSET) } #[must_use] @@ -881,8 +914,13 @@ pub mod destination_unreachable { self.buf.set_bytes(CHECKSUM_OFFSET, val.to_be_bytes()); } - pub fn set_unused(&mut self, val: u16) { - self.buf.set_bytes(UNUSED_OFFSET, val.to_be_bytes()); + // TODO + // pub fn set_unused(&mut self, val: u16) { + // self.buf.set_bytes(UNUSED_OFFSET, val.to_be_bytes()); + // } + + pub fn set_length(&mut self, val: u8) { + *self.buf.write(LENGTH_OFFSET) = val; } pub fn set_next_hop_mtu(&mut self, val: u16) { @@ -902,7 +940,20 @@ pub mod destination_unreachable { #[must_use] pub fn payload(&self) -> &[u8] { - &self.buf.as_slice()[Self::minimum_packet_size()..] + let (payload, _) = self.split_payload_extension(); + payload + } + + #[must_use] + pub fn extension(&self) -> Option<&[u8]> { + let (_, extension) = self.split_payload_extension(); + extension + } + + fn split_payload_extension(&self) -> (&[u8], Option<&[u8]>) { + let rfc4884_length = self.get_length(); + let icmp_payload = &self.buf.as_slice()[Self::minimum_packet_size()..]; + split(rfc4884_length, icmp_payload) } } @@ -912,7 +963,7 @@ pub mod destination_unreachable { .field("icmp_type", &self.get_icmp_type()) .field("icmp_code", &self.get_icmp_code()) .field("checksum", &self.get_checksum()) - .field("unused", &self.get_unused()) + .field("length", &self.get_length()) .field("next_hop_mtu", &self.get_next_hop_mtu()) .field("payload", &fmt_payload(self.payload())) .finish() @@ -985,3 +1036,644 @@ pub mod destination_unreachable { } } } + +pub mod extension { + + pub mod extension_structure { + use crate::tracing::packet::buffer::Buffer; + use crate::tracing::packet::icmpv4::extension::extension_object::ExtensionObject; + + /// Represents an ICMP `ExtensionStructure` pseudo object. + /// + /// The internal representation is held in network byte order (big-endian) and all accessor + /// methods take and return data in host byte order, converting as necessary for the given + /// architecture. + pub struct ExtensionStructure<'a> { + buf: Buffer<'a>, + } + + impl<'a> ExtensionStructure<'a> { + pub fn new(packet: &'a mut [u8]) -> Option> { + if packet.len() >= Self::minimum_packet_size() { + Some(Self { + buf: Buffer::Mutable(packet), + }) + } else { + None + } + } + + #[must_use] + pub fn new_view(packet: &'a [u8]) -> Option> { + if packet.len() >= Self::minimum_packet_size() { + Some(Self { + buf: Buffer::Immutable(packet), + }) + } else { + None + } + } + + #[must_use] + pub const fn minimum_packet_size() -> usize { + 4 + } + + #[must_use] + pub fn packet(&self) -> &[u8] { + self.buf.as_slice() + } + + // TODO return Option here or &[u8]? + #[must_use] + pub fn header(&self) -> &[u8] { + &self.buf.as_slice()[..Self::minimum_packet_size()] + } + + /// An iterator of Extension Objects contained within this `ExtensionStructure`. + #[must_use] + pub fn iter(&self) -> ExtensionObjectIter<'_> { + ExtensionObjectIter::new(&self.buf) + } + } + + pub struct ExtensionObjectIter<'a> { + buf: &'a Buffer<'a>, + offset: usize, + } + + impl<'a> ExtensionObjectIter<'a> { + #[must_use] + pub fn new(buf: &'a Buffer<'_>) -> Self { + Self { + buf, + offset: ExtensionStructure::minimum_packet_size(), + } + } + } + + impl<'a> Iterator for ExtensionObjectIter<'a> { + type Item = ExtensionObject<'a>; // TODO or return &[u8]? + + fn next(&mut self) -> Option { + if self.offset >= self.buf.as_slice().len() { + None + } else { + // TODO check for edge cases here + ExtensionObject::new_view(&self.buf.as_slice()[self.offset..]).map(|obj| { + self.offset += usize::from(obj.get_length()); + obj + }) + } + } + } + + #[cfg(test)] + mod tests { + use super::*; + use crate::tracing::packet::icmpv4::extension::extension_header::ExtensionHeader; + use crate::tracing::packet::icmpv4::extension::extension_object::{ + ClassNum, ClassSubType, + }; + + #[test] + fn test_header() { + let buf = [ + 0x20, 0x00, 0x99, 0x3a, 0x00, 0x08, 0x01, 0x01, 0x04, 0xbb, 0x41, 0x01, + ]; + let extensions = ExtensionStructure::new_view(&buf).unwrap(); + let header = ExtensionHeader::new_view(extensions.header()).unwrap(); + assert_eq!(2, header.get_version()); + assert_eq!(0x993A, header.get_checksum()); + } + + #[test] + fn test_object_iterator() { + let buf = [ + 0x20, 0x00, 0x99, 0x3a, 0x00, 0x08, 0x01, 0x01, 0x04, 0xbb, 0x41, 0x01, + ]; + let extensions = ExtensionStructure::new_view(&buf).unwrap(); + let mut object_iter = extensions.iter(); + let object = object_iter.next().unwrap(); + assert_eq!(8, object.get_length()); + assert_eq!( + ClassNum::MultiProtocolLabelSwitchingLabelStack, + object.get_class_num() + ); + assert_eq!(ClassSubType(1), object.get_class_subtype()); + assert_eq!([0x04, 0xbb, 0x41, 0x01], object.payload()); + // assert_eq!(None, object_iter.next()); + } + } + } + + pub mod extension_header { + use crate::tracing::packet::buffer::Buffer; + use std::fmt::{Debug, Formatter}; + + const VERSION_OFFSET: usize = 0; + const CHECKSUM_OFFSET: usize = 2; + + /// Represents an ICMP `ExtensionHeader`. + /// + /// The internal representation is held in network byte order (big-endian) and all accessor + /// methods take and return data in host byte order, converting as necessary for the given + /// architecture. + pub struct ExtensionHeader<'a> { + buf: Buffer<'a>, + } + + impl<'a> ExtensionHeader<'a> { + pub fn new(packet: &'a mut [u8]) -> Option> { + if packet.len() >= Self::minimum_packet_size() { + Some(Self { + buf: Buffer::Mutable(packet), + }) + } else { + None + } + } + + #[must_use] + pub fn new_view(packet: &'a [u8]) -> Option> { + if packet.len() >= Self::minimum_packet_size() { + Some(Self { + buf: Buffer::Immutable(packet), + }) + } else { + None + } + } + + #[must_use] + pub const fn minimum_packet_size() -> usize { + 4 + } + + #[must_use] + pub fn get_version(&self) -> u8 { + (self.buf.read(VERSION_OFFSET) & 0xf0) >> 4 + } + + #[must_use] + pub fn get_checksum(&self) -> u16 { + u16::from_be_bytes(self.buf.get_bytes(CHECKSUM_OFFSET)) + } + + pub fn set_version(&mut self, val: u8) { + *self.buf.write(VERSION_OFFSET) = + (self.buf.read(VERSION_OFFSET) & 0xf) | ((val & 0xf) << 4); + } + + pub fn set_checksum(&mut self, val: u16) { + self.buf.set_bytes(CHECKSUM_OFFSET, val.to_be_bytes()); + } + + #[must_use] + pub fn packet(&self) -> &[u8] { + self.buf.as_slice() + } + } + + impl Debug for ExtensionHeader<'_> { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ExtensionHeader") + .field("version", &self.get_version()) + .field("checksum", &self.get_checksum()) + // .field("payload", &fmt_payload(self.payload())) + .finish() + } + } + + #[cfg(test)] + mod tests { + use super::*; + + #[test] + fn test_version() { + let mut buf = [0_u8; ExtensionHeader::minimum_packet_size()]; + let mut extension = ExtensionHeader::new(&mut buf).unwrap(); + extension.set_version(0); + assert_eq!(0, extension.get_version()); + assert_eq!([0x00], extension.packet()[0..1]); + extension.set_version(2); + assert_eq!(2, extension.get_version()); + assert_eq!([0x20], extension.packet()[0..1]); + extension.set_version(15); + assert_eq!(15, extension.get_version()); + assert_eq!([0xF0], extension.packet()[0..1]); + } + + #[test] + fn test_checksum() { + let mut buf = [0_u8; ExtensionHeader::minimum_packet_size()]; + let mut extension = ExtensionHeader::new(&mut buf).unwrap(); + extension.set_checksum(0); + assert_eq!(0, extension.get_checksum()); + assert_eq!([0x00, 0x00], extension.packet()[2..=3]); + extension.set_checksum(1999); + assert_eq!(1999, extension.get_checksum()); + assert_eq!([0x07, 0xCF], extension.packet()[2..=3]); + extension.set_checksum(39226); + assert_eq!(39226, extension.get_checksum()); + assert_eq!([0x99, 0x3A], extension.packet()[2..=3]); + extension.set_checksum(u16::MAX); + assert_eq!(u16::MAX, extension.get_checksum()); + assert_eq!([0xFF, 0xFF], extension.packet()[2..=3]); + } + + #[test] + fn test_extension_header_view() { + let buf = [ + 0x20, 0x00, 0x99, 0x3a, 0x00, 0x08, 0x01, 0x01, 0x04, 0xbb, 0x41, 0x01, + ]; + let extension = ExtensionHeader::new_view(&buf).unwrap(); + assert_eq!(2, extension.get_version()); + assert_eq!(0x993A, extension.get_checksum()); + } + } + } + + pub mod extension_object { + use crate::tracing::packet::buffer::Buffer; + use crate::tracing::packet::fmt_payload; + use std::fmt::{Debug, Formatter}; + + /// The ICMP Extension Object Class Num. + #[derive(Debug, Copy, Clone, Ord, PartialOrd, Eq, PartialEq)] + pub enum ClassNum { + MultiProtocolLabelSwitchingLabelStack, + InterfaceInformationObject, + InterfaceIdentificationObject, + ExtendedInformation, + Other(u8), + } + + impl ClassNum { + #[must_use] + pub fn id(&self) -> u8 { + match self { + Self::MultiProtocolLabelSwitchingLabelStack => 1, + Self::InterfaceInformationObject => 2, + Self::InterfaceIdentificationObject => 3, + Self::ExtendedInformation => 4, + Self::Other(id) => *id, + } + } + } + + impl From for ClassNum { + fn from(val: u8) -> Self { + match val { + 1 => Self::MultiProtocolLabelSwitchingLabelStack, + 2 => Self::InterfaceInformationObject, + 3 => Self::InterfaceIdentificationObject, + 4 => Self::ExtendedInformation, + id => Self::Other(id), + } + } + } + + /// The ICMP Extension Object Class Sub-type. + #[derive(Debug, Copy, Clone, Ord, PartialOrd, Eq, PartialEq)] + pub struct ClassSubType(pub u8); + + impl From for ClassSubType { + fn from(val: u8) -> Self { + Self(val) + } + } + + const LENGTH_OFFSET: usize = 0; + const CLASS_NUM_OFFSET: usize = 2; + const CLASS_SUBTYPE_OFFSET: usize = 3; + + /// Represents an ICMP `ExtensionObject`. + /// + /// The internal representation is held in network byte order (big-endian) and all accessor + /// methods take and return data in host byte order, converting as necessary for the given + /// architecture. + pub struct ExtensionObject<'a> { + buf: Buffer<'a>, + } + + impl<'a> ExtensionObject<'a> { + pub fn new(packet: &'a mut [u8]) -> Option> { + if packet.len() >= Self::minimum_packet_size() { + Some(Self { + buf: Buffer::Mutable(packet), + }) + } else { + None + } + } + + #[must_use] + pub fn new_view(packet: &'a [u8]) -> Option> { + if packet.len() >= Self::minimum_packet_size() { + Some(Self { + buf: Buffer::Immutable(packet), + }) + } else { + None + } + } + + #[must_use] + pub const fn minimum_packet_size() -> usize { + 4 + } + + pub fn set_length(&mut self, val: u16) { + self.buf.set_bytes(LENGTH_OFFSET, val.to_be_bytes()); + } + + pub fn set_class_num(&mut self, val: ClassNum) { + *self.buf.write(CLASS_NUM_OFFSET) = val.id(); + } + + pub fn set_class_subtype(&mut self, val: ClassSubType) { + *self.buf.write(CLASS_SUBTYPE_OFFSET) = val.0; + } + + pub fn set_payload(&mut self, vals: &[u8]) { + let current_offset = Self::minimum_packet_size(); + self.buf.as_slice_mut()[current_offset..current_offset + vals.len()] + .copy_from_slice(vals); + } + + #[must_use] + pub fn get_length(&self) -> u16 { + u16::from_be_bytes(self.buf.get_bytes(LENGTH_OFFSET)) + } + + #[must_use] + pub fn get_class_num(&self) -> ClassNum { + ClassNum::from(self.buf.read(CLASS_NUM_OFFSET)) + } + + #[must_use] + pub fn get_class_subtype(&self) -> ClassSubType { + ClassSubType::from(self.buf.read(CLASS_SUBTYPE_OFFSET)) + } + + #[must_use] + pub fn packet(&self) -> &[u8] { + self.buf.as_slice() + } + + // TODO should use the length here to get the payload for this object only + #[must_use] + pub fn payload(&self) -> &[u8] { + &self.buf.as_slice()[Self::minimum_packet_size()..] + } + } + + impl Debug for ExtensionObject<'_> { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ExtensionObject") + .field("length", &self.get_length()) + .field("class_num", &self.get_class_num()) + .field("class_subtype", &self.get_class_subtype()) + .field("payload", &fmt_payload(self.payload())) + .finish() + } + } + + #[cfg(test)] + mod tests { + use super::*; + + #[test] + fn test_length() { + let mut buf = [0_u8; ExtensionObject::minimum_packet_size()]; + let mut extension = ExtensionObject::new(&mut buf).unwrap(); + extension.set_length(0); + assert_eq!(0, extension.get_length()); + assert_eq!([0x00, 0x00], extension.packet()[0..=1]); + extension.set_length(8); + assert_eq!(8, extension.get_length()); + assert_eq!([0x00, 0x08], extension.packet()[0..=1]); + extension.set_length(u16::MAX); + assert_eq!(u16::MAX, extension.get_length()); + assert_eq!([0xFF, 0xFF], extension.packet()[0..=1]); + } + + #[test] + fn test_class_num() { + let mut buf = [0_u8; ExtensionObject::minimum_packet_size()]; + let mut extension = ExtensionObject::new(&mut buf).unwrap(); + extension.set_class_num(ClassNum::MultiProtocolLabelSwitchingLabelStack); + assert_eq!( + ClassNum::MultiProtocolLabelSwitchingLabelStack, + extension.get_class_num() + ); + assert_eq!([0x01], extension.packet()[2..3]); + extension.set_class_num(ClassNum::InterfaceInformationObject); + assert_eq!( + ClassNum::InterfaceInformationObject, + extension.get_class_num() + ); + assert_eq!([0x02], extension.packet()[2..3]); + extension.set_class_num(ClassNum::InterfaceIdentificationObject); + assert_eq!( + ClassNum::InterfaceIdentificationObject, + extension.get_class_num() + ); + assert_eq!([0x03], extension.packet()[2..3]); + extension.set_class_num(ClassNum::ExtendedInformation); + assert_eq!(ClassNum::ExtendedInformation, extension.get_class_num()); + assert_eq!([0x04], extension.packet()[2..3]); + extension.set_class_num(ClassNum::Other(255)); + assert_eq!(ClassNum::Other(255), extension.get_class_num()); + assert_eq!([0xFF], extension.packet()[2..3]); + } + + #[test] + fn test_class_subtype() { + let mut buf = [0_u8; ExtensionObject::minimum_packet_size()]; + let mut extension = ExtensionObject::new(&mut buf).unwrap(); + extension.set_class_subtype(ClassSubType(0)); + assert_eq!(ClassSubType(0), extension.get_class_subtype()); + assert_eq!([0x00], extension.packet()[3..4]); + extension.set_class_subtype(ClassSubType(1)); + assert_eq!(ClassSubType(1), extension.get_class_subtype()); + assert_eq!([0x01], extension.packet()[3..4]); + extension.set_class_subtype(ClassSubType(255)); + assert_eq!(ClassSubType(255), extension.get_class_subtype()); + assert_eq!([0xff], extension.packet()[3..4]); + } + + #[test] + fn test_extension_header_view() { + let buf = [0x00, 0x08, 0x01, 0x01, 0x04, 0xbb, 0x41, 0x01]; + let object = ExtensionObject::new_view(&buf).unwrap(); + assert_eq!(8, object.get_length()); + assert_eq!( + ClassNum::MultiProtocolLabelSwitchingLabelStack, + object.get_class_num() + ); + assert_eq!(ClassSubType(1), object.get_class_subtype()); + assert_eq!([0x04, 0xbb, 0x41, 0x01], object.payload()); + } + } + } + + const ICMP_ORIG_DATAGRAM_MIN_LENGTH: usize = 128; + + /// Separate an ICMP payload from ICMP extensions as defined in rfc4884. + /// + /// Applies to `TimeExceeded` and `DestinationUnreachable` ICMP messages only. + #[must_use] + pub fn split(rfc4884_length: u8, icmp_payload: &[u8]) -> (&[u8], Option<&[u8]>) { + let orig_datagram_length = usize::from(rfc4884_length * 4); + + // TODO what to do if the claimed orig_datagram_length is bigger than the actual payload? + // we could truncate or we can err or we could return empty? + if orig_datagram_length > icmp_payload.len() { + return (&[], None); + } + + if orig_datagram_length > 0 { + // compliant message case + if icmp_payload.len() > orig_datagram_length { + // extension case (untested): the icmp_payload is longer than the orig_datagram and so whatever remains must be an extension + let extension_len = icmp_payload.len() - orig_datagram_length; + let extension = + &icmp_payload[orig_datagram_length..orig_datagram_length + extension_len]; + ( + &icmp_payload[..orig_datagram_length - extension_len], + Some(extension), + ) + } else { + (&icmp_payload[..orig_datagram_length], None) + } + // "Specifically, when a TRACEROUTE application operating in non- + // compliant mode receives a sufficiently long ICMP message that does + // not specify a length attribute, it will parse for a valid extension + // header at a fixed location, assuming a 128-octet "original datagram" + // field." + // TODO - have to include length of the extension header here? MTR does + } else if orig_datagram_length == 0 && icmp_payload.len() > ICMP_ORIG_DATAGRAM_MIN_LENGTH { + // extension present, non-compliant message + let extension_len = icmp_payload.len() - ICMP_ORIG_DATAGRAM_MIN_LENGTH; + let extension = &icmp_payload + [ICMP_ORIG_DATAGRAM_MIN_LENGTH..ICMP_ORIG_DATAGRAM_MIN_LENGTH + extension_len]; + ( + &icmp_payload[..icmp_payload.len() - extension_len], + Some(extension), + ) + } else { + // no extension present + (icmp_payload, None) + } + } + + #[cfg(test)] + mod tests { + use crate::tracing::packet::icmpv4::echo_request::EchoRequestPacket; + use crate::tracing::packet::icmpv4::extension::extension_header::ExtensionHeader; + use crate::tracing::packet::icmpv4::extension::extension_object::{ClassNum, ClassSubType}; + use crate::tracing::packet::icmpv4::extension::extension_structure::ExtensionStructure; + use crate::tracing::packet::icmpv4::time_exceeded::TimeExceededPacket; + use crate::tracing::packet::icmpv4::{IcmpCode, IcmpType}; + use crate::tracing::packet::ipv4::Ipv4Packet; + use std::net::Ipv4Addr; + + // This ICMP TimeExceeded packet does not have a `length` field and is therefore rfc4884 non-complaint and has a + // single `MPLS` extension object. + #[test] + fn test_split_extension_ipv4_time_exceeded_non_compliant_mpls() { + let buf = hex_literal::hex!( + " + 0b 00 f4 ff 00 00 00 00 45 00 00 54 cc 1c 40 00 + 01 01 b5 f4 c0 a8 01 15 5d b8 d8 22 08 00 0f e3 + 65 da 82 42 00 00 00 00 00 00 00 00 00 00 00 00 + 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 + 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 + 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 + 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 + 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 + 00 00 00 00 00 00 00 00 20 00 99 3a 00 08 01 01 + 04 bb 41 01 + " + ); + let time_exceeded_packet = TimeExceededPacket::new_view(&buf).unwrap(); + assert_eq!(IcmpType::TimeExceeded, time_exceeded_packet.get_icmp_type()); + assert_eq!(IcmpCode(0), time_exceeded_packet.get_icmp_code()); + assert_eq!(62719, time_exceeded_packet.get_checksum()); + assert_eq!(0, time_exceeded_packet.get_length()); + assert_eq!(&buf[8..136], time_exceeded_packet.payload()); + assert_eq!(Some(&buf[136..]), time_exceeded_packet.extension()); + + let nested_ipv4 = Ipv4Packet::new_view(&buf[8..136]).unwrap(); + assert_eq!(Ipv4Addr::from([192, 168, 1, 21]), nested_ipv4.get_source()); + assert_eq!( + Ipv4Addr::from([93, 184, 216, 34]), + nested_ipv4.get_destination() + ); + assert_eq!(&buf[28..136], nested_ipv4.payload()); + + let nested_echo = EchoRequestPacket::new_view(nested_ipv4.payload()).unwrap(); + assert_eq!(IcmpCode(0), nested_echo.get_icmp_code()); + assert_eq!(IcmpType::EchoRequest, nested_echo.get_icmp_type()); + assert_eq!(0x0FE3, nested_echo.get_checksum()); + assert_eq!(26074, nested_echo.get_identifier()); + assert_eq!(33346, nested_echo.get_sequence()); + assert_eq!(&buf[36..136], nested_echo.payload()); + + let extensions = + ExtensionStructure::new_view(time_exceeded_packet.extension().unwrap()).unwrap(); + + let extension_header = ExtensionHeader::new_view(extensions.header()).unwrap(); + assert_eq!(2, extension_header.get_version()); + assert_eq!(0x993A, extension_header.get_checksum()); + + let extension_object = extensions.iter().next().unwrap(); + assert_eq!(8, extension_object.get_length()); + assert_eq!( + ClassNum::MultiProtocolLabelSwitchingLabelStack, + extension_object.get_class_num() + ); + assert_eq!(ClassSubType(1), extension_object.get_class_subtype()); + assert_eq!([0x04, 0xbb, 0x41, 0x01], extension_object.payload()); + } + + // This ICMP TimeExceeded packet has a rfc4884 complaint `length` field and does not have any ICMP extensions. + #[test] + fn test_split_extension_ipv4_time_exceeded_compliant_no_extension() { + let buf = hex_literal::hex!( + " + 0b 00 f4 ee 00 11 00 00 45 00 00 54 a2 ee 40 00 + 01 01 df 22 c0 a8 01 15 5d b8 d8 22 08 00 0f e1 + 65 da 82 44 00 00 00 00 00 00 00 00 00 00 00 00 + 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 + 00 00 00 00 00 00 00 00 00 00 00 00 + " + ); + let time_exceeded_packet = TimeExceededPacket::new_view(&buf).unwrap(); + assert_eq!(IcmpType::TimeExceeded, time_exceeded_packet.get_icmp_type()); + assert_eq!(IcmpCode(0), time_exceeded_packet.get_icmp_code()); + assert_eq!(62702, time_exceeded_packet.get_checksum()); + assert_eq!(17, time_exceeded_packet.get_length()); + assert_eq!(&buf[8..76], time_exceeded_packet.payload()); + assert_eq!(None, time_exceeded_packet.extension()); + + let nested_ipv4 = Ipv4Packet::new_view(&buf[8..76]).unwrap(); + assert_eq!(Ipv4Addr::from([192, 168, 1, 21]), nested_ipv4.get_source()); + assert_eq!( + Ipv4Addr::from([93, 184, 216, 34]), + nested_ipv4.get_destination() + ); + assert_eq!(&buf[28..76], nested_ipv4.payload()); + + let nested_echo = EchoRequestPacket::new_view(nested_ipv4.payload()).unwrap(); + assert_eq!(IcmpCode(0), nested_echo.get_icmp_code()); + assert_eq!(IcmpType::EchoRequest, nested_echo.get_icmp_type()); + assert_eq!(0x0FE1, nested_echo.get_checksum()); + assert_eq!(26074, nested_echo.get_identifier()); + assert_eq!(33348, nested_echo.get_sequence()); + assert_eq!(&buf[36..76], nested_echo.payload()); + } + } +}