use byteorder::{ByteOrder, NetworkEndian};
use core::{cmp, fmt};
use super::{Error, Result};
use crate::phy::ChecksumCapabilities;
use crate::wire::ip::checksum;
use crate::wire::{Ipv4Packet, Ipv4Repr};
enum_with_unknown! {
pub enum Message(u8) {
EchoReply = 0,
DstUnreachable = 3,
Redirect = 5,
EchoRequest = 8,
RouterAdvert = 9,
RouterSolicit = 10,
TimeExceeded = 11,
ParamProblem = 12,
Timestamp = 13,
TimestampReply = 14
}
}
impl fmt::Display for Message {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
Message::EchoReply => write!(f, "echo reply"),
Message::DstUnreachable => write!(f, "destination unreachable"),
Message::Redirect => write!(f, "message redirect"),
Message::EchoRequest => write!(f, "echo request"),
Message::RouterAdvert => write!(f, "router advertisement"),
Message::RouterSolicit => write!(f, "router solicitation"),
Message::TimeExceeded => write!(f, "time exceeded"),
Message::ParamProblem => write!(f, "parameter problem"),
Message::Timestamp => write!(f, "timestamp"),
Message::TimestampReply => write!(f, "timestamp reply"),
Message::Unknown(id) => write!(f, "{id}"),
}
}
}
enum_with_unknown! {
pub enum DstUnreachable(u8) {
NetUnreachable = 0,
HostUnreachable = 1,
ProtoUnreachable = 2,
PortUnreachable = 3,
FragRequired = 4,
SrcRouteFailed = 5,
DstNetUnknown = 6,
DstHostUnknown = 7,
SrcHostIsolated = 8,
NetProhibited = 9,
HostProhibited = 10,
NetUnreachToS = 11,
HostUnreachToS = 12,
CommProhibited = 13,
HostPrecedViol = 14,
PrecedCutoff = 15
}
}
impl fmt::Display for DstUnreachable {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
DstUnreachable::NetUnreachable => write!(f, "destination network unreachable"),
DstUnreachable::HostUnreachable => write!(f, "destination host unreachable"),
DstUnreachable::ProtoUnreachable => write!(f, "destination protocol unreachable"),
DstUnreachable::PortUnreachable => write!(f, "destination port unreachable"),
DstUnreachable::FragRequired => write!(f, "fragmentation required, and DF flag set"),
DstUnreachable::SrcRouteFailed => write!(f, "source route failed"),
DstUnreachable::DstNetUnknown => write!(f, "destination network unknown"),
DstUnreachable::DstHostUnknown => write!(f, "destination host unknown"),
DstUnreachable::SrcHostIsolated => write!(f, "source host isolated"),
DstUnreachable::NetProhibited => write!(f, "network administratively prohibited"),
DstUnreachable::HostProhibited => write!(f, "host administratively prohibited"),
DstUnreachable::NetUnreachToS => write!(f, "network unreachable for ToS"),
DstUnreachable::HostUnreachToS => write!(f, "host unreachable for ToS"),
DstUnreachable::CommProhibited => {
write!(f, "communication administratively prohibited")
}
DstUnreachable::HostPrecedViol => write!(f, "host precedence violation"),
DstUnreachable::PrecedCutoff => write!(f, "precedence cutoff in effect"),
DstUnreachable::Unknown(id) => write!(f, "{id}"),
}
}
}
enum_with_unknown! {
pub enum Redirect(u8) {
Net = 0,
Host = 1,
NetToS = 2,
HostToS = 3
}
}
enum_with_unknown! {
pub enum TimeExceeded(u8) {
TtlExpired = 0,
FragExpired = 1
}
}
impl fmt::Display for TimeExceeded {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
TimeExceeded::TtlExpired => write!(f, "time-to-live exceeded in transit"),
TimeExceeded::FragExpired => write!(f, "fragment reassembly time exceeded"),
TimeExceeded::Unknown(id) => write!(f, "{id}"),
}
}
}
enum_with_unknown! {
pub enum ParamProblem(u8) {
AtPointer = 0,
MissingOption = 1,
BadLength = 2
}
}
#[derive(Debug, PartialEq, Eq, Clone)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct Packet<T: AsRef<[u8]>> {
buffer: T,
}
mod field {
use crate::wire::field::*;
pub const TYPE: usize = 0;
pub const CODE: usize = 1;
pub const CHECKSUM: Field = 2..4;
pub const UNUSED: Field = 4..8;
pub const ECHO_IDENT: Field = 4..6;
pub const ECHO_SEQNO: Field = 6..8;
pub const HEADER_END: usize = 8;
}
impl<T: AsRef<[u8]>> Packet<T> {
pub const fn new_unchecked(buffer: T) -> Packet<T> {
Packet { buffer }
}
pub fn new_checked(buffer: T) -> Result<Packet<T>> {
let packet = Self::new_unchecked(buffer);
packet.check_len()?;
Ok(packet)
}
pub fn check_len(&self) -> Result<()> {
let len = self.buffer.as_ref().len();
if len < field::HEADER_END {
Err(Error)
} else {
Ok(())
}
}
pub fn into_inner(self) -> T {
self.buffer
}
#[inline]
pub fn msg_type(&self) -> Message {
let data = self.buffer.as_ref();
Message::from(data[field::TYPE])
}
#[inline]
pub fn msg_code(&self) -> u8 {
let data = self.buffer.as_ref();
data[field::CODE]
}
#[inline]
pub fn checksum(&self) -> u16 {
let data = self.buffer.as_ref();
NetworkEndian::read_u16(&data[field::CHECKSUM])
}
#[inline]
pub fn echo_ident(&self) -> u16 {
let data = self.buffer.as_ref();
NetworkEndian::read_u16(&data[field::ECHO_IDENT])
}
#[inline]
pub fn echo_seq_no(&self) -> u16 {
let data = self.buffer.as_ref();
NetworkEndian::read_u16(&data[field::ECHO_SEQNO])
}
pub fn header_len(&self) -> usize {
match self.msg_type() {
Message::EchoRequest => field::ECHO_SEQNO.end,
Message::EchoReply => field::ECHO_SEQNO.end,
Message::DstUnreachable => field::UNUSED.end,
_ => field::UNUSED.end, }
}
pub fn verify_checksum(&self) -> bool {
if cfg!(fuzzing) {
return true;
}
let data = self.buffer.as_ref();
checksum::data(data) == !0
}
}
impl<'a, T: AsRef<[u8]> + ?Sized> Packet<&'a T> {
#[inline]
pub fn data(&self) -> &'a [u8] {
let data = self.buffer.as_ref();
&data[self.header_len()..]
}
}
impl<T: AsRef<[u8]> + AsMut<[u8]>> Packet<T> {
#[inline]
pub fn set_msg_type(&mut self, value: Message) {
let data = self.buffer.as_mut();
data[field::TYPE] = value.into()
}
#[inline]
pub fn set_msg_code(&mut self, value: u8) {
let data = self.buffer.as_mut();
data[field::CODE] = value
}
#[inline]
pub fn set_checksum(&mut self, value: u16) {
let data = self.buffer.as_mut();
NetworkEndian::write_u16(&mut data[field::CHECKSUM], value)
}
#[inline]
pub fn set_echo_ident(&mut self, value: u16) {
let data = self.buffer.as_mut();
NetworkEndian::write_u16(&mut data[field::ECHO_IDENT], value)
}
#[inline]
pub fn set_echo_seq_no(&mut self, value: u16) {
let data = self.buffer.as_mut();
NetworkEndian::write_u16(&mut data[field::ECHO_SEQNO], value)
}
pub fn fill_checksum(&mut self) {
self.set_checksum(0);
let checksum = {
let data = self.buffer.as_ref();
!checksum::data(data)
};
self.set_checksum(checksum)
}
}
impl<'a, T: AsRef<[u8]> + AsMut<[u8]> + ?Sized> Packet<&'a mut T> {
#[inline]
pub fn data_mut(&mut self) -> &mut [u8] {
let range = self.header_len()..;
let data = self.buffer.as_mut();
&mut data[range]
}
}
impl<T: AsRef<[u8]>> AsRef<[u8]> for Packet<T> {
fn as_ref(&self) -> &[u8] {
self.buffer.as_ref()
}
}
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
#[non_exhaustive]
pub enum Repr<'a> {
EchoRequest {
ident: u16,
seq_no: u16,
data: &'a [u8],
},
EchoReply {
ident: u16,
seq_no: u16,
data: &'a [u8],
},
DstUnreachable {
reason: DstUnreachable,
header: Ipv4Repr,
data: &'a [u8],
},
TimeExceeded {
reason: TimeExceeded,
header: Ipv4Repr,
data: &'a [u8],
},
}
impl<'a> Repr<'a> {
pub fn parse<T>(
packet: &Packet<&'a T>,
checksum_caps: &ChecksumCapabilities,
) -> Result<Repr<'a>>
where
T: AsRef<[u8]> + ?Sized,
{
if checksum_caps.icmpv4.rx() && !packet.verify_checksum() {
return Err(Error);
}
match (packet.msg_type(), packet.msg_code()) {
(Message::EchoRequest, 0) => Ok(Repr::EchoRequest {
ident: packet.echo_ident(),
seq_no: packet.echo_seq_no(),
data: packet.data(),
}),
(Message::EchoReply, 0) => Ok(Repr::EchoReply {
ident: packet.echo_ident(),
seq_no: packet.echo_seq_no(),
data: packet.data(),
}),
(Message::DstUnreachable, code) => {
let ip_packet = Ipv4Packet::new_checked(packet.data())?;
let payload = &packet.data()[ip_packet.header_len() as usize..];
if payload.len() < 8 {
return Err(Error);
}
Ok(Repr::DstUnreachable {
reason: DstUnreachable::from(code),
header: Ipv4Repr {
src_addr: ip_packet.src_addr(),
dst_addr: ip_packet.dst_addr(),
next_header: ip_packet.next_header(),
payload_len: payload.len(),
hop_limit: ip_packet.hop_limit(),
},
data: payload,
})
}
(Message::TimeExceeded, code) => {
let ip_packet = Ipv4Packet::new_checked(packet.data())?;
let payload = &packet.data()[ip_packet.header_len() as usize..];
if payload.len() < 8 {
return Err(Error);
}
Ok(Repr::TimeExceeded {
reason: TimeExceeded::from(code),
header: Ipv4Repr {
src_addr: ip_packet.src_addr(),
dst_addr: ip_packet.dst_addr(),
next_header: ip_packet.next_header(),
payload_len: payload.len(),
hop_limit: ip_packet.hop_limit(),
},
data: payload,
})
}
_ => Err(Error),
}
}
pub const fn buffer_len(&self) -> usize {
match self {
&Repr::EchoRequest { data, .. } | &Repr::EchoReply { data, .. } => {
field::ECHO_SEQNO.end + data.len()
}
&Repr::DstUnreachable { header, data, .. }
| &Repr::TimeExceeded { header, data, .. } => {
field::UNUSED.end + header.buffer_len() + data.len()
}
}
}
pub fn emit<T>(&self, packet: &mut Packet<&mut T>, checksum_caps: &ChecksumCapabilities)
where
T: AsRef<[u8]> + AsMut<[u8]> + ?Sized,
{
packet.set_msg_code(0);
match *self {
Repr::EchoRequest {
ident,
seq_no,
data,
} => {
packet.set_msg_type(Message::EchoRequest);
packet.set_msg_code(0);
packet.set_echo_ident(ident);
packet.set_echo_seq_no(seq_no);
let data_len = cmp::min(packet.data_mut().len(), data.len());
packet.data_mut()[..data_len].copy_from_slice(&data[..data_len])
}
Repr::EchoReply {
ident,
seq_no,
data,
} => {
packet.set_msg_type(Message::EchoReply);
packet.set_msg_code(0);
packet.set_echo_ident(ident);
packet.set_echo_seq_no(seq_no);
let data_len = cmp::min(packet.data_mut().len(), data.len());
packet.data_mut()[..data_len].copy_from_slice(&data[..data_len])
}
Repr::DstUnreachable {
reason,
header,
data,
} => {
packet.set_msg_type(Message::DstUnreachable);
packet.set_msg_code(reason.into());
let mut ip_packet = Ipv4Packet::new_unchecked(packet.data_mut());
header.emit(&mut ip_packet, checksum_caps);
let payload = &mut ip_packet.into_inner()[header.buffer_len()..];
payload.copy_from_slice(data)
}
Repr::TimeExceeded {
reason,
header,
data,
} => {
packet.set_msg_type(Message::TimeExceeded);
packet.set_msg_code(reason.into());
let mut ip_packet = Ipv4Packet::new_unchecked(packet.data_mut());
header.emit(&mut ip_packet, checksum_caps);
let payload = &mut ip_packet.into_inner()[header.buffer_len()..];
payload.copy_from_slice(data)
}
}
if checksum_caps.icmpv4.tx() {
packet.fill_checksum()
} else {
packet.set_checksum(0);
}
}
}
impl<'a, T: AsRef<[u8]> + ?Sized> fmt::Display for Packet<&'a T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match Repr::parse(self, &ChecksumCapabilities::default()) {
Ok(repr) => write!(f, "{repr}"),
Err(err) => {
write!(f, "ICMPv4 ({err})")?;
write!(f, " type={:?}", self.msg_type())?;
match self.msg_type() {
Message::DstUnreachable => {
write!(f, " code={:?}", DstUnreachable::from(self.msg_code()))
}
Message::TimeExceeded => {
write!(f, " code={:?}", TimeExceeded::from(self.msg_code()))
}
_ => write!(f, " code={}", self.msg_code()),
}
}
}
}
}
impl<'a> fmt::Display for Repr<'a> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
Repr::EchoRequest {
ident,
seq_no,
data,
} => write!(
f,
"ICMPv4 echo request id={} seq={} len={}",
ident,
seq_no,
data.len()
),
Repr::EchoReply {
ident,
seq_no,
data,
} => write!(
f,
"ICMPv4 echo reply id={} seq={} len={}",
ident,
seq_no,
data.len()
),
Repr::DstUnreachable { reason, .. } => {
write!(f, "ICMPv4 destination unreachable ({reason})")
}
Repr::TimeExceeded { reason, .. } => {
write!(f, "ICMPv4 time exceeded ({reason})")
}
}
}
}
use crate::wire::pretty_print::{PrettyIndent, PrettyPrint};
impl<T: AsRef<[u8]>> PrettyPrint for Packet<T> {
fn pretty_print(
buffer: &dyn AsRef<[u8]>,
f: &mut fmt::Formatter,
indent: &mut PrettyIndent,
) -> fmt::Result {
let packet = match Packet::new_checked(buffer) {
Err(err) => return write!(f, "{indent}({err})"),
Ok(packet) => packet,
};
write!(f, "{indent}{packet}")?;
match packet.msg_type() {
Message::DstUnreachable | Message::TimeExceeded => {
indent.increase(f)?;
super::Ipv4Packet::<&[u8]>::pretty_print(&packet.data(), f, indent)
}
_ => Ok(()),
}
}
}
#[cfg(test)]
mod test {
use super::*;
static ECHO_PACKET_BYTES: [u8; 12] = [
0x08, 0x00, 0x8e, 0xfe, 0x12, 0x34, 0xab, 0xcd, 0xaa, 0x00, 0x00, 0xff,
];
static ECHO_DATA_BYTES: [u8; 4] = [0xaa, 0x00, 0x00, 0xff];
#[test]
fn test_echo_deconstruct() {
let packet = Packet::new_unchecked(&ECHO_PACKET_BYTES[..]);
assert_eq!(packet.msg_type(), Message::EchoRequest);
assert_eq!(packet.msg_code(), 0);
assert_eq!(packet.checksum(), 0x8efe);
assert_eq!(packet.echo_ident(), 0x1234);
assert_eq!(packet.echo_seq_no(), 0xabcd);
assert_eq!(packet.data(), &ECHO_DATA_BYTES[..]);
assert!(packet.verify_checksum());
}
#[test]
fn test_echo_construct() {
let mut bytes = vec![0xa5; 12];
let mut packet = Packet::new_unchecked(&mut bytes);
packet.set_msg_type(Message::EchoRequest);
packet.set_msg_code(0);
packet.set_echo_ident(0x1234);
packet.set_echo_seq_no(0xabcd);
packet.data_mut().copy_from_slice(&ECHO_DATA_BYTES[..]);
packet.fill_checksum();
assert_eq!(&packet.into_inner()[..], &ECHO_PACKET_BYTES[..]);
}
fn echo_packet_repr() -> Repr<'static> {
Repr::EchoRequest {
ident: 0x1234,
seq_no: 0xabcd,
data: &ECHO_DATA_BYTES,
}
}
#[test]
fn test_echo_parse() {
let packet = Packet::new_unchecked(&ECHO_PACKET_BYTES[..]);
let repr = Repr::parse(&packet, &ChecksumCapabilities::default()).unwrap();
assert_eq!(repr, echo_packet_repr());
}
#[test]
fn test_echo_emit() {
let repr = echo_packet_repr();
let mut bytes = vec![0xa5; repr.buffer_len()];
let mut packet = Packet::new_unchecked(&mut bytes);
repr.emit(&mut packet, &ChecksumCapabilities::default());
assert_eq!(&packet.into_inner()[..], &ECHO_PACKET_BYTES[..]);
}
#[test]
fn test_check_len() {
let bytes = [0x0b, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00];
assert_eq!(Packet::new_checked(&[]), Err(Error));
assert_eq!(Packet::new_checked(&bytes[..4]), Err(Error));
assert!(Packet::new_checked(&bytes[..]).is_ok());
}
}