Skip to main content

smoltcp/wire/
udp.rs

1use byteorder::{ByteOrder, NetworkEndian};
2use core::fmt;
3
4use super::{Error, Result};
5use crate::phy::ChecksumCapabilities;
6use crate::wire::ip::checksum;
7use crate::wire::{IpAddress, IpProtocol};
8
9/// A read/write wrapper around an User Datagram Protocol packet buffer.
10#[derive(Debug, PartialEq, Eq, Clone)]
11pub struct Packet<T: AsRef<[u8]>> {
12    buffer: T,
13}
14
15mod field {
16    #![allow(non_snake_case)]
17
18    use crate::wire::field::*;
19
20    pub const SRC_PORT: Field = 0..2;
21    pub const DST_PORT: Field = 2..4;
22    pub const LENGTH: Field = 4..6;
23    pub const CHECKSUM: Field = 6..8;
24
25    pub const fn PAYLOAD(length: u16) -> Field {
26        CHECKSUM.end..(length as usize)
27    }
28}
29
30pub const HEADER_LEN: usize = field::CHECKSUM.end;
31
32#[allow(clippy::len_without_is_empty)]
33impl<T: AsRef<[u8]>> Packet<T> {
34    /// Imbue a raw octet buffer with UDP packet structure.
35    pub const fn new_unchecked(buffer: T) -> Packet<T> {
36        Packet { buffer }
37    }
38
39    /// Shorthand for a combination of [new_unchecked] and [check_len].
40    ///
41    /// [new_unchecked]: #method.new_unchecked
42    /// [check_len]: #method.check_len
43    pub fn new_checked(buffer: T) -> Result<Packet<T>> {
44        let packet = Self::new_unchecked(buffer);
45        packet.check_len()?;
46        Ok(packet)
47    }
48
49    /// Ensure that no accessor method will panic if called.
50    /// Returns `Err(Error)` if the buffer is too short.
51    /// Returns `Err(Error)` if the length field has a value smaller
52    /// than the header length.
53    ///
54    /// The result of this check is invalidated by calling [set_len].
55    ///
56    /// [set_len]: #method.set_len
57    pub fn check_len(&self) -> Result<()> {
58        let buffer_len = self.buffer.as_ref().len();
59        if buffer_len < HEADER_LEN {
60            Err(Error)
61        } else {
62            let field_len = self.len() as usize;
63            if buffer_len < field_len || field_len < HEADER_LEN {
64                Err(Error)
65            } else {
66                Ok(())
67            }
68        }
69    }
70
71    /// Consume the packet, returning the underlying buffer.
72    pub fn into_inner(self) -> T {
73        self.buffer
74    }
75
76    /// Return the source port field.
77    #[inline]
78    pub fn src_port(&self) -> u16 {
79        let data = self.buffer.as_ref();
80        NetworkEndian::read_u16(&data[field::SRC_PORT])
81    }
82
83    /// Return the destination port field.
84    #[inline]
85    pub fn dst_port(&self) -> u16 {
86        let data = self.buffer.as_ref();
87        NetworkEndian::read_u16(&data[field::DST_PORT])
88    }
89
90    /// Return the length field.
91    #[inline]
92    pub fn len(&self) -> u16 {
93        let data = self.buffer.as_ref();
94        NetworkEndian::read_u16(&data[field::LENGTH])
95    }
96
97    /// Return the checksum field.
98    #[inline]
99    pub fn checksum(&self) -> u16 {
100        let data = self.buffer.as_ref();
101        NetworkEndian::read_u16(&data[field::CHECKSUM])
102    }
103
104    /// Validate the partial packet checksum.
105    ///
106    /// # Panics
107    /// This function panics unless `src_addr` and `dst_addr` belong to the same family,
108    /// and that family is IPv4 or IPv6.
109    ///
110    /// # Fuzzing
111    /// This function always returns `true` when fuzzing.
112    pub fn verify_partial_checksum(&self, src_addr: &IpAddress, dst_addr: &IpAddress) -> bool {
113        if cfg!(fuzzing) {
114            return true;
115        }
116
117        checksum::pseudo_header(src_addr, dst_addr, IpProtocol::Udp, self.len() as u32)
118            == self.checksum()
119    }
120
121    /// Validate the packet checksum.
122    ///
123    /// # Panics
124    /// This function panics unless `src_addr` and `dst_addr` belong to the same family,
125    /// and that family is IPv4 or IPv6.
126    ///
127    /// # Fuzzing
128    /// This function always returns `true` when fuzzing.
129    pub fn verify_checksum(&self, src_addr: &IpAddress, dst_addr: &IpAddress) -> bool {
130        if cfg!(fuzzing) {
131            return true;
132        }
133
134        // From the RFC:
135        // > An all zero transmitted checksum value means that the transmitter
136        // > generated no checksum (for debugging or for higher level protocols
137        // > that don't care).
138        if self.checksum() == 0 {
139            return true;
140        }
141
142        let data = self.buffer.as_ref();
143        checksum::combine(&[
144            checksum::pseudo_header(src_addr, dst_addr, IpProtocol::Udp, self.len() as u32),
145            checksum::data(&data[..self.len() as usize]),
146        ]) == !0
147    }
148}
149
150impl<'a, T: AsRef<[u8]> + ?Sized> Packet<&'a T> {
151    /// Return a pointer to the payload.
152    #[inline]
153    pub fn payload(&self) -> &'a [u8] {
154        let length = self.len();
155        let data = self.buffer.as_ref();
156        &data[field::PAYLOAD(length)]
157    }
158}
159
160impl<T: AsRef<[u8]> + AsMut<[u8]>> Packet<T> {
161    /// Set the source port field.
162    #[inline]
163    pub fn set_src_port(&mut self, value: u16) {
164        let data = self.buffer.as_mut();
165        NetworkEndian::write_u16(&mut data[field::SRC_PORT], value)
166    }
167
168    /// Set the destination port field.
169    #[inline]
170    pub fn set_dst_port(&mut self, value: u16) {
171        let data = self.buffer.as_mut();
172        NetworkEndian::write_u16(&mut data[field::DST_PORT], value)
173    }
174
175    /// Set the length field.
176    #[inline]
177    pub fn set_len(&mut self, value: u16) {
178        let data = self.buffer.as_mut();
179        NetworkEndian::write_u16(&mut data[field::LENGTH], value)
180    }
181
182    /// Set the checksum field.
183    #[inline]
184    pub fn set_checksum(&mut self, value: u16) {
185        let data = self.buffer.as_mut();
186        NetworkEndian::write_u16(&mut data[field::CHECKSUM], value)
187    }
188
189    /// Compute and fill in the header checksum.
190    ///
191    /// # Panics
192    /// This function panics unless `src_addr` and `dst_addr` belong to the same family,
193    /// and that family is IPv4 or IPv6.
194    pub fn fill_checksum(&mut self, src_addr: &IpAddress, dst_addr: &IpAddress) {
195        self.set_checksum(0);
196        let checksum = {
197            let data = self.buffer.as_ref();
198            !checksum::combine(&[
199                checksum::pseudo_header(src_addr, dst_addr, IpProtocol::Udp, self.len() as u32),
200                checksum::data(&data[..self.len() as usize]),
201            ])
202        };
203        // UDP checksum value of 0 means no checksum; if the checksum really is zero,
204        // use all-ones, which indicates that the remote end must verify the checksum.
205        // Arithmetically, RFC 1071 checksums of all-zeroes and all-ones behave identically,
206        // so no action is necessary on the remote end.
207        self.set_checksum(if checksum == 0 { 0xffff } else { checksum })
208    }
209
210    /// Return a mutable pointer to the payload.
211    #[inline]
212    pub fn payload_mut(&mut self) -> &mut [u8] {
213        let length = self.len();
214        let data = self.buffer.as_mut();
215        &mut data[field::PAYLOAD(length)]
216    }
217}
218
219impl<T: AsRef<[u8]>> AsRef<[u8]> for Packet<T> {
220    fn as_ref(&self) -> &[u8] {
221        self.buffer.as_ref()
222    }
223}
224
225/// A high-level representation of an User Datagram Protocol packet.
226#[derive(Debug, PartialEq, Eq, Clone, Copy)]
227pub struct Repr {
228    pub src_port: u16,
229    pub dst_port: u16,
230}
231
232impl Repr {
233    /// Parse an User Datagram Protocol packet and return a high-level representation.
234    pub fn parse<T>(
235        packet: &Packet<&T>,
236        src_addr: &IpAddress,
237        dst_addr: &IpAddress,
238        checksum_caps: &ChecksumCapabilities,
239    ) -> Result<Repr>
240    where
241        T: AsRef<[u8]> + ?Sized,
242    {
243        packet.check_len()?;
244
245        // Destination port cannot be omitted (but source port can be).
246        if packet.dst_port() == 0 {
247            return Err(Error);
248        }
249        // Valid checksum is expected...
250        if checksum_caps.udp.rx() && !packet.verify_checksum(src_addr, dst_addr) {
251            match (src_addr, dst_addr) {
252                // ... except on UDP-over-IPv4, where it can be omitted.
253                #[cfg(feature = "proto-ipv4")]
254                (&IpAddress::Ipv4(_), &IpAddress::Ipv4(_)) if packet.checksum() == 0 => (),
255                _ => return Err(Error),
256            }
257        }
258
259        Ok(Repr {
260            src_port: packet.src_port(),
261            dst_port: packet.dst_port(),
262        })
263    }
264
265    /// Return the length of the packet header that will be emitted from this high-level representation.
266    pub const fn header_len(&self) -> usize {
267        HEADER_LEN
268    }
269
270    /// Emit a high-level representation into an User Datagram Protocol packet.
271    ///
272    /// This never calculates the checksum, and is intended for internal-use only,
273    /// not for packets that are going to be actually sent over the network. For
274    /// example, when decompressing 6lowpan.
275    pub(crate) fn emit_header<T>(&self, packet: &mut Packet<&mut T>, payload_len: usize)
276    where
277        T: AsRef<[u8]> + AsMut<[u8]> + ?Sized,
278    {
279        packet.set_src_port(self.src_port);
280        packet.set_dst_port(self.dst_port);
281        packet.set_len((HEADER_LEN + payload_len) as u16);
282        packet.set_checksum(0);
283    }
284
285    /// Emit a high-level representation into an User Datagram Protocol packet.
286    pub fn emit<T>(
287        &self,
288        packet: &mut Packet<&mut T>,
289        src_addr: &IpAddress,
290        dst_addr: &IpAddress,
291        payload_len: usize,
292        emit_payload: impl FnOnce(&mut [u8]),
293        checksum_caps: &ChecksumCapabilities,
294    ) where
295        T: AsRef<[u8]> + AsMut<[u8]> + ?Sized,
296    {
297        packet.set_src_port(self.src_port);
298        packet.set_dst_port(self.dst_port);
299        packet.set_len((HEADER_LEN + payload_len) as u16);
300        emit_payload(packet.payload_mut());
301
302        if checksum_caps.udp.tx() {
303            packet.fill_checksum(src_addr, dst_addr)
304        } else {
305            // make sure we get a consistently zeroed checksum,
306            // since implementations might rely on it
307            packet.set_checksum(0);
308        }
309    }
310}
311
312impl<T: AsRef<[u8]> + ?Sized> fmt::Display for Packet<&T> {
313    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
314        // Cannot use Repr::parse because we don't have the IP addresses.
315        write!(
316            f,
317            "UDP src={} dst={} len={}",
318            self.src_port(),
319            self.dst_port(),
320            self.payload().len()
321        )
322    }
323}
324
325#[cfg(feature = "defmt")]
326impl<'a, T: AsRef<[u8]> + ?Sized> defmt::Format for Packet<&'a T> {
327    fn format(&self, fmt: defmt::Formatter) {
328        // Cannot use Repr::parse because we don't have the IP addresses.
329        defmt::write!(
330            fmt,
331            "UDP src={} dst={} len={}",
332            self.src_port(),
333            self.dst_port(),
334            self.payload().len()
335        );
336    }
337}
338
339impl fmt::Display for Repr {
340    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
341        write!(f, "UDP src={} dst={}", self.src_port, self.dst_port)
342    }
343}
344
345#[cfg(feature = "defmt")]
346impl defmt::Format for Repr {
347    fn format(&self, fmt: defmt::Formatter) {
348        defmt::write!(fmt, "UDP src={} dst={}", self.src_port, self.dst_port);
349    }
350}
351
352use crate::wire::pretty_print::{PrettyIndent, PrettyPrint};
353
354impl<T: AsRef<[u8]>> PrettyPrint for Packet<T> {
355    fn pretty_print(
356        buffer: &dyn AsRef<[u8]>,
357        f: &mut fmt::Formatter,
358        indent: &mut PrettyIndent,
359    ) -> fmt::Result {
360        match Packet::new_checked(buffer) {
361            Err(err) => write!(f, "{indent}({err})"),
362            Ok(packet) => write!(f, "{indent}{packet}"),
363        }
364    }
365}
366
367#[cfg(test)]
368mod test {
369    use super::*;
370    #[cfg(feature = "proto-ipv4")]
371    use crate::wire::Ipv4Address;
372
373    #[cfg(feature = "proto-ipv4")]
374    const SRC_ADDR: Ipv4Address = Ipv4Address::new(192, 168, 1, 1);
375    #[cfg(feature = "proto-ipv4")]
376    const DST_ADDR: Ipv4Address = Ipv4Address::new(192, 168, 1, 2);
377
378    #[cfg(feature = "proto-ipv4")]
379    static PACKET_BYTES: [u8; 12] = [
380        0xbf, 0x00, 0x00, 0x35, 0x00, 0x0c, 0x12, 0x4d, 0xaa, 0x00, 0x00, 0xff,
381    ];
382
383    #[cfg(feature = "proto-ipv4")]
384    static NO_CHECKSUM_PACKET: [u8; 12] = [
385        0xbf, 0x00, 0x00, 0x35, 0x00, 0x0c, 0x00, 0x00, 0xaa, 0x00, 0x00, 0xff,
386    ];
387
388    #[cfg(feature = "proto-ipv4")]
389    static PAYLOAD_BYTES: [u8; 4] = [0xaa, 0x00, 0x00, 0xff];
390
391    #[test]
392    #[cfg(feature = "proto-ipv4")]
393    fn test_deconstruct() {
394        let packet = Packet::new_unchecked(&PACKET_BYTES[..]);
395        assert_eq!(packet.src_port(), 48896);
396        assert_eq!(packet.dst_port(), 53);
397        assert_eq!(packet.len(), 12);
398        assert_eq!(packet.checksum(), 0x124d);
399        assert_eq!(packet.payload(), &PAYLOAD_BYTES[..]);
400        assert!(packet.verify_checksum(&SRC_ADDR.into(), &DST_ADDR.into()));
401    }
402
403    #[test]
404    #[cfg(feature = "proto-ipv4")]
405    fn test_construct() {
406        let mut bytes = vec![0xa5; 12];
407        let mut packet = Packet::new_unchecked(&mut bytes);
408        packet.set_src_port(48896);
409        packet.set_dst_port(53);
410        packet.set_len(12);
411        packet.set_checksum(0xffff);
412        packet.payload_mut().copy_from_slice(&PAYLOAD_BYTES[..]);
413        packet.fill_checksum(&SRC_ADDR.into(), &DST_ADDR.into());
414        assert_eq!(&*packet.into_inner(), &PACKET_BYTES[..]);
415    }
416
417    #[test]
418    fn test_impossible_len() {
419        let mut bytes = vec![0; 12];
420        let mut packet = Packet::new_unchecked(&mut bytes);
421        packet.set_len(4);
422        assert_eq!(packet.check_len(), Err(Error));
423    }
424
425    #[test]
426    #[cfg(feature = "proto-ipv4")]
427    fn test_zero_checksum() {
428        let mut bytes = vec![0; 8];
429        let mut packet = Packet::new_unchecked(&mut bytes);
430        packet.set_src_port(1);
431        packet.set_dst_port(31881);
432        packet.set_len(8);
433        packet.fill_checksum(&SRC_ADDR.into(), &DST_ADDR.into());
434        assert_eq!(packet.checksum(), 0xffff);
435    }
436
437    #[test]
438    #[cfg(feature = "proto-ipv4")]
439    fn test_no_checksum() {
440        let mut bytes = vec![0; 8];
441        let mut packet = Packet::new_unchecked(&mut bytes);
442        packet.set_src_port(1);
443        packet.set_dst_port(31881);
444        packet.set_len(8);
445        packet.set_checksum(0);
446        assert!(packet.verify_checksum(&SRC_ADDR.into(), &DST_ADDR.into()));
447    }
448
449    #[cfg(feature = "proto-ipv4")]
450    fn packet_repr() -> Repr {
451        Repr {
452            src_port: 48896,
453            dst_port: 53,
454        }
455    }
456
457    #[test]
458    #[cfg(feature = "proto-ipv4")]
459    fn test_parse() {
460        let packet = Packet::new_unchecked(&PACKET_BYTES[..]);
461        let repr = Repr::parse(
462            &packet,
463            &SRC_ADDR.into(),
464            &DST_ADDR.into(),
465            &ChecksumCapabilities::default(),
466        )
467        .unwrap();
468        assert_eq!(repr, packet_repr());
469    }
470
471    #[test]
472    #[cfg(feature = "proto-ipv4")]
473    fn test_emit() {
474        let repr = packet_repr();
475        let mut bytes = vec![0xa5; repr.header_len() + PAYLOAD_BYTES.len()];
476        let mut packet = Packet::new_unchecked(&mut bytes);
477        repr.emit(
478            &mut packet,
479            &SRC_ADDR.into(),
480            &DST_ADDR.into(),
481            PAYLOAD_BYTES.len(),
482            |payload| payload.copy_from_slice(&PAYLOAD_BYTES),
483            &ChecksumCapabilities::default(),
484        );
485        assert_eq!(&*packet.into_inner(), &PACKET_BYTES[..]);
486    }
487
488    #[test]
489    #[cfg(feature = "proto-ipv4")]
490    fn test_checksum_omitted() {
491        let packet = Packet::new_unchecked(&NO_CHECKSUM_PACKET[..]);
492        let repr = Repr::parse(
493            &packet,
494            &SRC_ADDR.into(),
495            &DST_ADDR.into(),
496            &ChecksumCapabilities::default(),
497        )
498        .unwrap();
499        assert_eq!(repr, packet_repr());
500    }
501}