smoltcp/wire/
udp.rs

1use byteorder::{ByteOrder, NetworkEndian};
2use core::fmt;
3
4use super::{Error, Result};
5use crate::phy::ChecksumCapabilities;
6use crate::wire::ip::checksum;
7use crate::wire::{IpAddress, IpProtocol};
8
9/// A read/write wrapper around an User Datagram Protocol packet buffer.
10#[derive(Debug, PartialEq, Eq, Clone)]
11pub struct Packet<T: AsRef<[u8]>> {
12    buffer: T,
13}
14
15mod field {
16    #![allow(non_snake_case)]
17
18    use crate::wire::field::*;
19
20    pub const SRC_PORT: Field = 0..2;
21    pub const DST_PORT: Field = 2..4;
22    pub const LENGTH: Field = 4..6;
23    pub const CHECKSUM: Field = 6..8;
24
25    pub const fn PAYLOAD(length: u16) -> Field {
26        CHECKSUM.end..(length as usize)
27    }
28}
29
30pub const HEADER_LEN: usize = field::CHECKSUM.end;
31
32#[allow(clippy::len_without_is_empty)]
33impl<T: AsRef<[u8]>> Packet<T> {
34    /// Imbue a raw octet buffer with UDP packet structure.
35    pub const fn new_unchecked(buffer: T) -> Packet<T> {
36        Packet { buffer }
37    }
38
39    /// Shorthand for a combination of [new_unchecked] and [check_len].
40    ///
41    /// [new_unchecked]: #method.new_unchecked
42    /// [check_len]: #method.check_len
43    pub fn new_checked(buffer: T) -> Result<Packet<T>> {
44        let packet = Self::new_unchecked(buffer);
45        packet.check_len()?;
46        Ok(packet)
47    }
48
49    /// Ensure that no accessor method will panic if called.
50    /// Returns `Err(Error)` if the buffer is too short.
51    /// Returns `Err(Error)` if the length field has a value smaller
52    /// than the header length.
53    ///
54    /// The result of this check is invalidated by calling [set_len].
55    ///
56    /// [set_len]: #method.set_len
57    pub fn check_len(&self) -> Result<()> {
58        let buffer_len = self.buffer.as_ref().len();
59        if buffer_len < HEADER_LEN {
60            Err(Error)
61        } else {
62            let field_len = self.len() as usize;
63            if buffer_len < field_len || field_len < HEADER_LEN {
64                Err(Error)
65            } else {
66                Ok(())
67            }
68        }
69    }
70
71    /// Consume the packet, returning the underlying buffer.
72    pub fn into_inner(self) -> T {
73        self.buffer
74    }
75
76    /// Return the source port field.
77    #[inline]
78    pub fn src_port(&self) -> u16 {
79        let data = self.buffer.as_ref();
80        NetworkEndian::read_u16(&data[field::SRC_PORT])
81    }
82
83    /// Return the destination port field.
84    #[inline]
85    pub fn dst_port(&self) -> u16 {
86        let data = self.buffer.as_ref();
87        NetworkEndian::read_u16(&data[field::DST_PORT])
88    }
89
90    /// Return the length field.
91    #[inline]
92    pub fn len(&self) -> u16 {
93        let data = self.buffer.as_ref();
94        NetworkEndian::read_u16(&data[field::LENGTH])
95    }
96
97    /// Return the checksum field.
98    #[inline]
99    pub fn checksum(&self) -> u16 {
100        let data = self.buffer.as_ref();
101        NetworkEndian::read_u16(&data[field::CHECKSUM])
102    }
103
104    /// Validate the packet checksum.
105    ///
106    /// # Panics
107    /// This function panics unless `src_addr` and `dst_addr` belong to the same family,
108    /// and that family is IPv4 or IPv6.
109    ///
110    /// # Fuzzing
111    /// This function always returns `true` when fuzzing.
112    pub fn verify_checksum(&self, src_addr: &IpAddress, dst_addr: &IpAddress) -> bool {
113        if cfg!(fuzzing) {
114            return true;
115        }
116
117        // From the RFC:
118        // > An all zero transmitted checksum value means that the transmitter
119        // > generated no checksum (for debugging or for higher level protocols
120        // > that don't care).
121        if self.checksum() == 0 {
122            return true;
123        }
124
125        let data = self.buffer.as_ref();
126        checksum::combine(&[
127            checksum::pseudo_header(src_addr, dst_addr, IpProtocol::Udp, self.len() as u32),
128            checksum::data(&data[..self.len() as usize]),
129        ]) == !0
130    }
131}
132
133impl<'a, T: AsRef<[u8]> + ?Sized> Packet<&'a T> {
134    /// Return a pointer to the payload.
135    #[inline]
136    pub fn payload(&self) -> &'a [u8] {
137        let length = self.len();
138        let data = self.buffer.as_ref();
139        &data[field::PAYLOAD(length)]
140    }
141}
142
143impl<T: AsRef<[u8]> + AsMut<[u8]>> Packet<T> {
144    /// Set the source port field.
145    #[inline]
146    pub fn set_src_port(&mut self, value: u16) {
147        let data = self.buffer.as_mut();
148        NetworkEndian::write_u16(&mut data[field::SRC_PORT], value)
149    }
150
151    /// Set the destination port field.
152    #[inline]
153    pub fn set_dst_port(&mut self, value: u16) {
154        let data = self.buffer.as_mut();
155        NetworkEndian::write_u16(&mut data[field::DST_PORT], value)
156    }
157
158    /// Set the length field.
159    #[inline]
160    pub fn set_len(&mut self, value: u16) {
161        let data = self.buffer.as_mut();
162        NetworkEndian::write_u16(&mut data[field::LENGTH], value)
163    }
164
165    /// Set the checksum field.
166    #[inline]
167    pub fn set_checksum(&mut self, value: u16) {
168        let data = self.buffer.as_mut();
169        NetworkEndian::write_u16(&mut data[field::CHECKSUM], value)
170    }
171
172    /// Compute and fill in the header checksum.
173    ///
174    /// # Panics
175    /// This function panics unless `src_addr` and `dst_addr` belong to the same family,
176    /// and that family is IPv4 or IPv6.
177    pub fn fill_checksum(&mut self, src_addr: &IpAddress, dst_addr: &IpAddress) {
178        self.set_checksum(0);
179        let checksum = {
180            let data = self.buffer.as_ref();
181            !checksum::combine(&[
182                checksum::pseudo_header(src_addr, dst_addr, IpProtocol::Udp, self.len() as u32),
183                checksum::data(&data[..self.len() as usize]),
184            ])
185        };
186        // UDP checksum value of 0 means no checksum; if the checksum really is zero,
187        // use all-ones, which indicates that the remote end must verify the checksum.
188        // Arithmetically, RFC 1071 checksums of all-zeroes and all-ones behave identically,
189        // so no action is necessary on the remote end.
190        self.set_checksum(if checksum == 0 { 0xffff } else { checksum })
191    }
192
193    /// Return a mutable pointer to the payload.
194    #[inline]
195    pub fn payload_mut(&mut self) -> &mut [u8] {
196        let length = self.len();
197        let data = self.buffer.as_mut();
198        &mut data[field::PAYLOAD(length)]
199    }
200}
201
202impl<T: AsRef<[u8]>> AsRef<[u8]> for Packet<T> {
203    fn as_ref(&self) -> &[u8] {
204        self.buffer.as_ref()
205    }
206}
207
208/// A high-level representation of an User Datagram Protocol packet.
209#[derive(Debug, PartialEq, Eq, Clone, Copy)]
210pub struct Repr {
211    pub src_port: u16,
212    pub dst_port: u16,
213}
214
215impl Repr {
216    /// Parse an User Datagram Protocol packet and return a high-level representation.
217    pub fn parse<T>(
218        packet: &Packet<&T>,
219        src_addr: &IpAddress,
220        dst_addr: &IpAddress,
221        checksum_caps: &ChecksumCapabilities,
222    ) -> Result<Repr>
223    where
224        T: AsRef<[u8]> + ?Sized,
225    {
226        // Destination port cannot be omitted (but source port can be).
227        if packet.dst_port() == 0 {
228            return Err(Error);
229        }
230        // Valid checksum is expected...
231        if checksum_caps.udp.rx() && !packet.verify_checksum(src_addr, dst_addr) {
232            match (src_addr, dst_addr) {
233                // ... except on UDP-over-IPv4, where it can be omitted.
234                #[cfg(feature = "proto-ipv4")]
235                (&IpAddress::Ipv4(_), &IpAddress::Ipv4(_)) if packet.checksum() == 0 => (),
236                _ => return Err(Error),
237            }
238        }
239
240        Ok(Repr {
241            src_port: packet.src_port(),
242            dst_port: packet.dst_port(),
243        })
244    }
245
246    /// Return the length of the packet header that will be emitted from this high-level representation.
247    pub const fn header_len(&self) -> usize {
248        HEADER_LEN
249    }
250
251    /// Emit a high-level representation into an User Datagram Protocol packet.
252    ///
253    /// This never calculates the checksum, and is intended for internal-use only,
254    /// not for packets that are going to be actually sent over the network. For
255    /// example, when decompressing 6lowpan.
256    pub(crate) fn emit_header<T: ?Sized>(&self, packet: &mut Packet<&mut T>, payload_len: usize)
257    where
258        T: AsRef<[u8]> + AsMut<[u8]>,
259    {
260        packet.set_src_port(self.src_port);
261        packet.set_dst_port(self.dst_port);
262        packet.set_len((HEADER_LEN + payload_len) as u16);
263        packet.set_checksum(0);
264    }
265
266    /// Emit a high-level representation into an User Datagram Protocol packet.
267    pub fn emit<T: ?Sized>(
268        &self,
269        packet: &mut Packet<&mut T>,
270        src_addr: &IpAddress,
271        dst_addr: &IpAddress,
272        payload_len: usize,
273        emit_payload: impl FnOnce(&mut [u8]),
274        checksum_caps: &ChecksumCapabilities,
275    ) where
276        T: AsRef<[u8]> + AsMut<[u8]>,
277    {
278        packet.set_src_port(self.src_port);
279        packet.set_dst_port(self.dst_port);
280        packet.set_len((HEADER_LEN + payload_len) as u16);
281        emit_payload(packet.payload_mut());
282
283        if checksum_caps.udp.tx() {
284            packet.fill_checksum(src_addr, dst_addr)
285        } else {
286            // make sure we get a consistently zeroed checksum,
287            // since implementations might rely on it
288            packet.set_checksum(0);
289        }
290    }
291}
292
293impl<'a, T: AsRef<[u8]> + ?Sized> fmt::Display for Packet<&'a T> {
294    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
295        // Cannot use Repr::parse because we don't have the IP addresses.
296        write!(
297            f,
298            "UDP src={} dst={} len={}",
299            self.src_port(),
300            self.dst_port(),
301            self.payload().len()
302        )
303    }
304}
305
306#[cfg(feature = "defmt")]
307impl<'a, T: AsRef<[u8]> + ?Sized> defmt::Format for Packet<&'a T> {
308    fn format(&self, fmt: defmt::Formatter) {
309        // Cannot use Repr::parse because we don't have the IP addresses.
310        defmt::write!(
311            fmt,
312            "UDP src={} dst={} len={}",
313            self.src_port(),
314            self.dst_port(),
315            self.payload().len()
316        );
317    }
318}
319
320impl fmt::Display for Repr {
321    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
322        write!(f, "UDP src={} dst={}", self.src_port, self.dst_port)
323    }
324}
325
326#[cfg(feature = "defmt")]
327impl defmt::Format for Repr {
328    fn format(&self, fmt: defmt::Formatter) {
329        defmt::write!(fmt, "UDP src={} dst={}", self.src_port, self.dst_port);
330    }
331}
332
333use crate::wire::pretty_print::{PrettyIndent, PrettyPrint};
334
335impl<T: AsRef<[u8]>> PrettyPrint for Packet<T> {
336    fn pretty_print(
337        buffer: &dyn AsRef<[u8]>,
338        f: &mut fmt::Formatter,
339        indent: &mut PrettyIndent,
340    ) -> fmt::Result {
341        match Packet::new_checked(buffer) {
342            Err(err) => write!(f, "{indent}({err})"),
343            Ok(packet) => write!(f, "{indent}{packet}"),
344        }
345    }
346}
347
348#[cfg(test)]
349mod test {
350    use super::*;
351    #[cfg(feature = "proto-ipv4")]
352    use crate::wire::Ipv4Address;
353
354    #[cfg(feature = "proto-ipv4")]
355    const SRC_ADDR: Ipv4Address = Ipv4Address([192, 168, 1, 1]);
356    #[cfg(feature = "proto-ipv4")]
357    const DST_ADDR: Ipv4Address = Ipv4Address([192, 168, 1, 2]);
358
359    #[cfg(feature = "proto-ipv4")]
360    static PACKET_BYTES: [u8; 12] = [
361        0xbf, 0x00, 0x00, 0x35, 0x00, 0x0c, 0x12, 0x4d, 0xaa, 0x00, 0x00, 0xff,
362    ];
363
364    #[cfg(feature = "proto-ipv4")]
365    static NO_CHECKSUM_PACKET: [u8; 12] = [
366        0xbf, 0x00, 0x00, 0x35, 0x00, 0x0c, 0x00, 0x00, 0xaa, 0x00, 0x00, 0xff,
367    ];
368
369    #[cfg(feature = "proto-ipv4")]
370    static PAYLOAD_BYTES: [u8; 4] = [0xaa, 0x00, 0x00, 0xff];
371
372    #[test]
373    #[cfg(feature = "proto-ipv4")]
374    fn test_deconstruct() {
375        let packet = Packet::new_unchecked(&PACKET_BYTES[..]);
376        assert_eq!(packet.src_port(), 48896);
377        assert_eq!(packet.dst_port(), 53);
378        assert_eq!(packet.len(), 12);
379        assert_eq!(packet.checksum(), 0x124d);
380        assert_eq!(packet.payload(), &PAYLOAD_BYTES[..]);
381        assert!(packet.verify_checksum(&SRC_ADDR.into(), &DST_ADDR.into()));
382    }
383
384    #[test]
385    #[cfg(feature = "proto-ipv4")]
386    fn test_construct() {
387        let mut bytes = vec![0xa5; 12];
388        let mut packet = Packet::new_unchecked(&mut bytes);
389        packet.set_src_port(48896);
390        packet.set_dst_port(53);
391        packet.set_len(12);
392        packet.set_checksum(0xffff);
393        packet.payload_mut().copy_from_slice(&PAYLOAD_BYTES[..]);
394        packet.fill_checksum(&SRC_ADDR.into(), &DST_ADDR.into());
395        assert_eq!(&*packet.into_inner(), &PACKET_BYTES[..]);
396    }
397
398    #[test]
399    fn test_impossible_len() {
400        let mut bytes = vec![0; 12];
401        let mut packet = Packet::new_unchecked(&mut bytes);
402        packet.set_len(4);
403        assert_eq!(packet.check_len(), Err(Error));
404    }
405
406    #[test]
407    #[cfg(feature = "proto-ipv4")]
408    fn test_zero_checksum() {
409        let mut bytes = vec![0; 8];
410        let mut packet = Packet::new_unchecked(&mut bytes);
411        packet.set_src_port(1);
412        packet.set_dst_port(31881);
413        packet.set_len(8);
414        packet.fill_checksum(&SRC_ADDR.into(), &DST_ADDR.into());
415        assert_eq!(packet.checksum(), 0xffff);
416    }
417
418    #[test]
419    #[cfg(feature = "proto-ipv4")]
420    fn test_no_checksum() {
421        let mut bytes = vec![0; 8];
422        let mut packet = Packet::new_unchecked(&mut bytes);
423        packet.set_src_port(1);
424        packet.set_dst_port(31881);
425        packet.set_len(8);
426        packet.set_checksum(0);
427        assert!(packet.verify_checksum(&SRC_ADDR.into(), &DST_ADDR.into()));
428    }
429
430    #[cfg(feature = "proto-ipv4")]
431    fn packet_repr() -> Repr {
432        Repr {
433            src_port: 48896,
434            dst_port: 53,
435        }
436    }
437
438    #[test]
439    #[cfg(feature = "proto-ipv4")]
440    fn test_parse() {
441        let packet = Packet::new_unchecked(&PACKET_BYTES[..]);
442        let repr = Repr::parse(
443            &packet,
444            &SRC_ADDR.into(),
445            &DST_ADDR.into(),
446            &ChecksumCapabilities::default(),
447        )
448        .unwrap();
449        assert_eq!(repr, packet_repr());
450    }
451
452    #[test]
453    #[cfg(feature = "proto-ipv4")]
454    fn test_emit() {
455        let repr = packet_repr();
456        let mut bytes = vec![0xa5; repr.header_len() + PAYLOAD_BYTES.len()];
457        let mut packet = Packet::new_unchecked(&mut bytes);
458        repr.emit(
459            &mut packet,
460            &SRC_ADDR.into(),
461            &DST_ADDR.into(),
462            PAYLOAD_BYTES.len(),
463            |payload| payload.copy_from_slice(&PAYLOAD_BYTES),
464            &ChecksumCapabilities::default(),
465        );
466        assert_eq!(&*packet.into_inner(), &PACKET_BYTES[..]);
467    }
468
469    #[test]
470    #[cfg(feature = "proto-ipv4")]
471    fn test_checksum_omitted() {
472        let packet = Packet::new_unchecked(&NO_CHECKSUM_PACKET[..]);
473        let repr = Repr::parse(
474            &packet,
475            &SRC_ADDR.into(),
476            &DST_ADDR.into(),
477            &ChecksumCapabilities::default(),
478        )
479        .unwrap();
480        assert_eq!(repr, packet_repr());
481    }
482}