smoltcp/socket/
dns.rs

1#[cfg(feature = "async")]
2use core::task::Waker;
3
4use heapless::Vec;
5use managed::ManagedSlice;
6
7use crate::config::{DNS_MAX_NAME_SIZE, DNS_MAX_RESULT_COUNT, DNS_MAX_SERVER_COUNT};
8use crate::socket::{Context, PollAt};
9use crate::time::{Duration, Instant};
10use crate::wire::dns::{Flags, Opcode, Packet, Question, Rcode, Record, RecordData, Repr, Type};
11use crate::wire::{self, IpAddress, IpProtocol, IpRepr, UdpRepr};
12
13#[cfg(feature = "async")]
14use super::WakerRegistration;
15
16const DNS_PORT: u16 = 53;
17const MDNS_DNS_PORT: u16 = 5353;
18const RETRANSMIT_DELAY: Duration = Duration::from_millis(1_000);
19const MAX_RETRANSMIT_DELAY: Duration = Duration::from_millis(10_000);
20const RETRANSMIT_TIMEOUT: Duration = Duration::from_millis(10_000); // Should generally be 2-10 secs
21
22#[cfg(feature = "proto-ipv6")]
23const MDNS_IPV6_ADDR: IpAddress = IpAddress::Ipv6(crate::wire::Ipv6Address([
24    0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xfb,
25]));
26
27#[cfg(feature = "proto-ipv4")]
28const MDNS_IPV4_ADDR: IpAddress = IpAddress::Ipv4(crate::wire::Ipv4Address([224, 0, 0, 251]));
29
30/// Error returned by [`Socket::start_query`]
31#[derive(Debug, PartialEq, Eq, Clone, Copy)]
32#[cfg_attr(feature = "defmt", derive(defmt::Format))]
33pub enum StartQueryError {
34    NoFreeSlot,
35    InvalidName,
36    NameTooLong,
37}
38
39impl core::fmt::Display for StartQueryError {
40    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
41        match self {
42            StartQueryError::NoFreeSlot => write!(f, "No free slot"),
43            StartQueryError::InvalidName => write!(f, "Invalid name"),
44            StartQueryError::NameTooLong => write!(f, "Name too long"),
45        }
46    }
47}
48
49#[cfg(feature = "std")]
50impl std::error::Error for StartQueryError {}
51
52/// Error returned by [`Socket::get_query_result`]
53#[derive(Debug, PartialEq, Eq, Clone, Copy)]
54#[cfg_attr(feature = "defmt", derive(defmt::Format))]
55pub enum GetQueryResultError {
56    /// Query is not done yet.
57    Pending,
58    /// Query failed.
59    Failed,
60}
61
62impl core::fmt::Display for GetQueryResultError {
63    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
64        match self {
65            GetQueryResultError::Pending => write!(f, "Query is not done yet"),
66            GetQueryResultError::Failed => write!(f, "Query failed"),
67        }
68    }
69}
70
71#[cfg(feature = "std")]
72impl std::error::Error for GetQueryResultError {}
73
74/// State for an in-progress DNS query.
75///
76/// The only reason this struct is public is to allow the socket state
77/// to be allocated externally.
78#[derive(Debug)]
79pub struct DnsQuery {
80    state: State,
81
82    #[cfg(feature = "async")]
83    waker: WakerRegistration,
84}
85
86impl DnsQuery {
87    fn set_state(&mut self, state: State) {
88        self.state = state;
89        #[cfg(feature = "async")]
90        self.waker.wake();
91    }
92}
93
94#[derive(Debug)]
95#[allow(clippy::large_enum_variant)]
96enum State {
97    Pending(PendingQuery),
98    Completed(CompletedQuery),
99    Failure,
100}
101
102#[derive(Debug)]
103struct PendingQuery {
104    name: Vec<u8, DNS_MAX_NAME_SIZE>,
105    type_: Type,
106
107    port: u16, // UDP port (src for request, dst for response)
108    txid: u16, // transaction ID
109
110    timeout_at: Option<Instant>,
111    retransmit_at: Instant,
112    delay: Duration,
113
114    server_idx: usize,
115    mdns: MulticastDns,
116}
117
118#[derive(Debug)]
119pub enum MulticastDns {
120    Disabled,
121    #[cfg(feature = "socket-mdns")]
122    Enabled,
123}
124
125#[derive(Debug)]
126struct CompletedQuery {
127    addresses: Vec<IpAddress, DNS_MAX_RESULT_COUNT>,
128}
129
130/// A handle to an in-progress DNS query.
131#[derive(Clone, Copy)]
132pub struct QueryHandle(usize);
133
134/// A Domain Name System socket.
135///
136/// A UDP socket is bound to a specific endpoint, and owns transmit and receive
137/// packet buffers.
138#[derive(Debug)]
139pub struct Socket<'a> {
140    servers: Vec<IpAddress, DNS_MAX_SERVER_COUNT>,
141    queries: ManagedSlice<'a, Option<DnsQuery>>,
142
143    /// The time-to-live (IPv4) or hop limit (IPv6) value used in outgoing packets.
144    hop_limit: Option<u8>,
145}
146
147impl<'a> Socket<'a> {
148    /// Create a DNS socket.
149    ///
150    /// # Panics
151    ///
152    /// Panics if `servers.len() > MAX_SERVER_COUNT`
153    pub fn new<Q>(servers: &[IpAddress], queries: Q) -> Socket<'a>
154    where
155        Q: Into<ManagedSlice<'a, Option<DnsQuery>>>,
156    {
157        Socket {
158            servers: Vec::from_slice(servers).unwrap(),
159            queries: queries.into(),
160            hop_limit: None,
161        }
162    }
163
164    /// Update the list of DNS servers, will replace all existing servers
165    ///
166    /// # Panics
167    ///
168    /// Panics if `servers.len() > MAX_SERVER_COUNT`
169    pub fn update_servers(&mut self, servers: &[IpAddress]) {
170        self.servers = Vec::from_slice(servers).unwrap();
171    }
172
173    /// Return the time-to-live (IPv4) or hop limit (IPv6) value used in outgoing packets.
174    ///
175    /// See also the [set_hop_limit](#method.set_hop_limit) method
176    pub fn hop_limit(&self) -> Option<u8> {
177        self.hop_limit
178    }
179
180    /// Set the time-to-live (IPv4) or hop limit (IPv6) value used in outgoing packets.
181    ///
182    /// A socket without an explicitly set hop limit value uses the default [IANA recommended]
183    /// value (64).
184    ///
185    /// # Panics
186    ///
187    /// This function panics if a hop limit value of 0 is given. See [RFC 1122 § 3.2.1.7].
188    ///
189    /// [IANA recommended]: https://www.iana.org/assignments/ip-parameters/ip-parameters.xhtml
190    /// [RFC 1122 § 3.2.1.7]: https://tools.ietf.org/html/rfc1122#section-3.2.1.7
191    pub fn set_hop_limit(&mut self, hop_limit: Option<u8>) {
192        // A host MUST NOT send a datagram with a hop limit value of 0
193        if let Some(0) = hop_limit {
194            panic!("the time-to-live value of a packet must not be zero")
195        }
196
197        self.hop_limit = hop_limit
198    }
199
200    fn find_free_query(&mut self) -> Option<QueryHandle> {
201        for (i, q) in self.queries.iter().enumerate() {
202            if q.is_none() {
203                return Some(QueryHandle(i));
204            }
205        }
206
207        match &mut self.queries {
208            ManagedSlice::Borrowed(_) => None,
209            #[cfg(feature = "alloc")]
210            ManagedSlice::Owned(queries) => {
211                queries.push(None);
212                let index = queries.len() - 1;
213                Some(QueryHandle(index))
214            }
215        }
216    }
217
218    /// Start a query.
219    ///
220    /// `name` is specified in human-friendly format, such as `"rust-lang.org"`.
221    /// It accepts names both with and without trailing dot, and they're treated
222    /// the same (there's no support for DNS search path).
223    pub fn start_query(
224        &mut self,
225        cx: &mut Context,
226        name: &str,
227        query_type: Type,
228    ) -> Result<QueryHandle, StartQueryError> {
229        let mut name = name.as_bytes();
230
231        if name.is_empty() {
232            net_trace!("invalid name: zero length");
233            return Err(StartQueryError::InvalidName);
234        }
235
236        // Remove trailing dot, if any
237        if name[name.len() - 1] == b'.' {
238            name = &name[..name.len() - 1];
239        }
240
241        let mut raw_name: Vec<u8, DNS_MAX_NAME_SIZE> = Vec::new();
242
243        let mut mdns = MulticastDns::Disabled;
244        #[cfg(feature = "socket-mdns")]
245        if name.split(|&c| c == b'.').last().unwrap() == b"local" {
246            net_trace!("Starting a mDNS query");
247            mdns = MulticastDns::Enabled;
248        }
249
250        for s in name.split(|&c| c == b'.') {
251            if s.len() > 63 {
252                net_trace!("invalid name: too long label");
253                return Err(StartQueryError::InvalidName);
254            }
255            if s.is_empty() {
256                net_trace!("invalid name: zero length label");
257                return Err(StartQueryError::InvalidName);
258            }
259
260            // Push label
261            raw_name
262                .push(s.len() as u8)
263                .map_err(|_| StartQueryError::NameTooLong)?;
264            raw_name
265                .extend_from_slice(s)
266                .map_err(|_| StartQueryError::NameTooLong)?;
267        }
268
269        // Push terminator.
270        raw_name
271            .push(0x00)
272            .map_err(|_| StartQueryError::NameTooLong)?;
273
274        self.start_query_raw(cx, &raw_name, query_type, mdns)
275    }
276
277    /// Start a query with a raw (wire-format) DNS name.
278    /// `b"\x09rust-lang\x03org\x00"`
279    ///
280    /// You probably want to use [`start_query`] instead.
281    pub fn start_query_raw(
282        &mut self,
283        cx: &mut Context,
284        raw_name: &[u8],
285        query_type: Type,
286        mdns: MulticastDns,
287    ) -> Result<QueryHandle, StartQueryError> {
288        let handle = self.find_free_query().ok_or(StartQueryError::NoFreeSlot)?;
289
290        self.queries[handle.0] = Some(DnsQuery {
291            state: State::Pending(PendingQuery {
292                name: Vec::from_slice(raw_name).map_err(|_| StartQueryError::NameTooLong)?,
293                type_: query_type,
294                txid: cx.rand().rand_u16(),
295                port: cx.rand().rand_source_port(),
296                delay: RETRANSMIT_DELAY,
297                timeout_at: None,
298                retransmit_at: Instant::ZERO,
299                server_idx: 0,
300                mdns,
301            }),
302            #[cfg(feature = "async")]
303            waker: WakerRegistration::new(),
304        });
305        Ok(handle)
306    }
307
308    /// Get the result of a query.
309    ///
310    /// If the query is completed, the query slot is automatically freed.
311    ///
312    /// # Panics
313    /// Panics if the QueryHandle corresponds to a free slot.
314    pub fn get_query_result(
315        &mut self,
316        handle: QueryHandle,
317    ) -> Result<Vec<IpAddress, DNS_MAX_RESULT_COUNT>, GetQueryResultError> {
318        let slot = &mut self.queries[handle.0];
319        let q = slot.as_mut().unwrap();
320        match &mut q.state {
321            // Query is not done yet.
322            State::Pending(_) => Err(GetQueryResultError::Pending),
323            // Query is done
324            State::Completed(q) => {
325                let res = q.addresses.clone();
326                *slot = None; // Free up the slot for recycling.
327                Ok(res)
328            }
329            State::Failure => {
330                *slot = None; // Free up the slot for recycling.
331                Err(GetQueryResultError::Failed)
332            }
333        }
334    }
335
336    /// Cancels a query, freeing the slot.
337    ///
338    /// # Panics
339    ///
340    /// Panics if the QueryHandle corresponds to an already free slot.
341    pub fn cancel_query(&mut self, handle: QueryHandle) {
342        let slot = &mut self.queries[handle.0];
343        if slot.is_none() {
344            panic!("Canceling query in a free slot.")
345        }
346        *slot = None; // Free up the slot for recycling.
347    }
348
349    /// Assign a waker to a query slot
350    ///
351    /// The waker will be woken when the query completes, either successfully or failed.
352    ///
353    /// # Panics
354    ///
355    /// Panics if the QueryHandle corresponds to an already free slot.
356    #[cfg(feature = "async")]
357    pub fn register_query_waker(&mut self, handle: QueryHandle, waker: &Waker) {
358        self.queries[handle.0]
359            .as_mut()
360            .unwrap()
361            .waker
362            .register(waker);
363    }
364
365    pub(crate) fn accepts(&self, ip_repr: &IpRepr, udp_repr: &UdpRepr) -> bool {
366        (udp_repr.src_port == DNS_PORT
367            && self
368                .servers
369                .iter()
370                .any(|server| *server == ip_repr.src_addr()))
371            || (udp_repr.src_port == MDNS_DNS_PORT)
372    }
373
374    pub(crate) fn process(
375        &mut self,
376        _cx: &mut Context,
377        ip_repr: &IpRepr,
378        udp_repr: &UdpRepr,
379        payload: &[u8],
380    ) {
381        debug_assert!(self.accepts(ip_repr, udp_repr));
382
383        let size = payload.len();
384
385        net_trace!(
386            "receiving {} octets from {:?}:{}",
387            size,
388            ip_repr.src_addr(),
389            udp_repr.dst_port
390        );
391
392        let p = match Packet::new_checked(payload) {
393            Ok(x) => x,
394            Err(_) => {
395                net_trace!("dns packet malformed");
396                return;
397            }
398        };
399        if p.opcode() != Opcode::Query {
400            net_trace!("unwanted opcode {:?}", p.opcode());
401            return;
402        }
403
404        if !p.flags().contains(Flags::RESPONSE) {
405            net_trace!("packet doesn't have response bit set");
406            return;
407        }
408
409        if p.question_count() != 1 {
410            net_trace!("bad question count {:?}", p.question_count());
411            return;
412        }
413
414        // Find pending query
415        for q in self.queries.iter_mut().flatten() {
416            if let State::Pending(pq) = &mut q.state {
417                if udp_repr.dst_port != pq.port || p.transaction_id() != pq.txid {
418                    continue;
419                }
420
421                if p.rcode() == Rcode::NXDomain {
422                    net_trace!("rcode NXDomain");
423                    q.set_state(State::Failure);
424                    continue;
425                }
426
427                let payload = p.payload();
428                let (mut payload, question) = match Question::parse(payload) {
429                    Ok(x) => x,
430                    Err(_) => {
431                        net_trace!("question malformed");
432                        return;
433                    }
434                };
435
436                if question.type_ != pq.type_ {
437                    net_trace!("question type mismatch");
438                    return;
439                }
440
441                match eq_names(p.parse_name(question.name), p.parse_name(&pq.name)) {
442                    Ok(true) => {}
443                    Ok(false) => {
444                        net_trace!("question name mismatch");
445                        return;
446                    }
447                    Err(_) => {
448                        net_trace!("dns question name malformed");
449                        return;
450                    }
451                }
452
453                let mut addresses = Vec::new();
454
455                for _ in 0..p.answer_record_count() {
456                    let (payload2, r) = match Record::parse(payload) {
457                        Ok(x) => x,
458                        Err(_) => {
459                            net_trace!("dns answer record malformed");
460                            return;
461                        }
462                    };
463                    payload = payload2;
464
465                    match eq_names(p.parse_name(r.name), p.parse_name(&pq.name)) {
466                        Ok(true) => {}
467                        Ok(false) => {
468                            net_trace!("answer name mismatch: {:?}", r);
469                            continue;
470                        }
471                        Err(_) => {
472                            net_trace!("dns answer record name malformed");
473                            return;
474                        }
475                    }
476
477                    match r.data {
478                        #[cfg(feature = "proto-ipv4")]
479                        RecordData::A(addr) => {
480                            net_trace!("A: {:?}", addr);
481                            if addresses.push(addr.into()).is_err() {
482                                net_trace!("too many addresses in response, ignoring {:?}", addr);
483                            }
484                        }
485                        #[cfg(feature = "proto-ipv6")]
486                        RecordData::Aaaa(addr) => {
487                            net_trace!("AAAA: {:?}", addr);
488                            if addresses.push(addr.into()).is_err() {
489                                net_trace!("too many addresses in response, ignoring {:?}", addr);
490                            }
491                        }
492                        RecordData::Cname(name) => {
493                            net_trace!("CNAME: {:?}", name);
494
495                            // When faced with a CNAME, recursive resolvers are supposed to
496                            // resolve the CNAME and append the results for it.
497                            //
498                            // We update the query with the new name, so that we pick up the A/AAAA
499                            // records for the CNAME when we parse them later.
500                            // I believe it's mandatory the CNAME results MUST come *after* in the
501                            // packet, so it's enough to do one linear pass over it.
502                            if copy_name(&mut pq.name, p.parse_name(name)).is_err() {
503                                net_trace!("dns answer cname malformed");
504                                return;
505                            }
506                        }
507                        RecordData::Other(type_, data) => {
508                            net_trace!("unknown: {:?} {:?}", type_, data)
509                        }
510                    }
511                }
512
513                q.set_state(if addresses.is_empty() {
514                    State::Failure
515                } else {
516                    State::Completed(CompletedQuery { addresses })
517                });
518
519                // If we get here, packet matched the current query, stop processing.
520                return;
521            }
522        }
523
524        // If we get here, packet matched with no query.
525        net_trace!("no query matched");
526    }
527
528    pub(crate) fn dispatch<F, E>(&mut self, cx: &mut Context, emit: F) -> Result<(), E>
529    where
530        F: FnOnce(&mut Context, (IpRepr, UdpRepr, &[u8])) -> Result<(), E>,
531    {
532        let hop_limit = self.hop_limit.unwrap_or(64);
533
534        for q in self.queries.iter_mut().flatten() {
535            if let State::Pending(pq) = &mut q.state {
536                // As per RFC 6762 any DNS query ending in .local. MUST be sent as mdns
537                // so we internally overwrite the servers for any of those queries
538                // in this function.
539                let servers = match pq.mdns {
540                    #[cfg(feature = "socket-mdns")]
541                    MulticastDns::Enabled => &[
542                        #[cfg(feature = "proto-ipv6")]
543                        MDNS_IPV6_ADDR,
544                        #[cfg(feature = "proto-ipv4")]
545                        MDNS_IPV4_ADDR,
546                    ],
547                    MulticastDns::Disabled => self.servers.as_slice(),
548                };
549
550                let timeout = if let Some(timeout) = pq.timeout_at {
551                    timeout
552                } else {
553                    let v = cx.now() + RETRANSMIT_TIMEOUT;
554                    pq.timeout_at = Some(v);
555                    v
556                };
557
558                // Check timeout
559                if timeout < cx.now() {
560                    // DNS timeout
561                    pq.timeout_at = Some(cx.now() + RETRANSMIT_TIMEOUT);
562                    pq.retransmit_at = Instant::ZERO;
563                    pq.delay = RETRANSMIT_DELAY;
564
565                    // Try next server. We check below whether we've tried all servers.
566                    pq.server_idx += 1;
567                }
568                // Check if we've run out of servers to try.
569                if pq.server_idx >= servers.len() {
570                    net_trace!("already tried all servers.");
571                    q.set_state(State::Failure);
572                    continue;
573                }
574
575                // Check so the IP address is valid
576                if servers[pq.server_idx].is_unspecified() {
577                    net_trace!("invalid unspecified DNS server addr.");
578                    q.set_state(State::Failure);
579                    continue;
580                }
581
582                if pq.retransmit_at > cx.now() {
583                    // query is waiting for retransmit
584                    continue;
585                }
586
587                let repr = Repr {
588                    transaction_id: pq.txid,
589                    flags: Flags::RECURSION_DESIRED,
590                    opcode: Opcode::Query,
591                    question: Question {
592                        name: &pq.name,
593                        type_: pq.type_,
594                    },
595                };
596
597                let mut payload = [0u8; 512];
598                let payload = &mut payload[..repr.buffer_len()];
599                repr.emit(&mut Packet::new_unchecked(payload));
600
601                let dst_port = match pq.mdns {
602                    #[cfg(feature = "socket-mdns")]
603                    MulticastDns::Enabled => MDNS_DNS_PORT,
604                    MulticastDns::Disabled => DNS_PORT,
605                };
606
607                let udp_repr = UdpRepr {
608                    src_port: pq.port,
609                    dst_port,
610                };
611
612                let dst_addr = servers[pq.server_idx];
613                let src_addr = cx.get_source_address(&dst_addr).unwrap(); // TODO remove unwrap
614                let ip_repr = IpRepr::new(
615                    src_addr,
616                    dst_addr,
617                    IpProtocol::Udp,
618                    udp_repr.header_len() + payload.len(),
619                    hop_limit,
620                );
621
622                net_trace!(
623                    "sending {} octets to {} from port {}",
624                    payload.len(),
625                    ip_repr.dst_addr(),
626                    udp_repr.src_port
627                );
628
629                emit(cx, (ip_repr, udp_repr, payload))?;
630
631                pq.retransmit_at = cx.now() + pq.delay;
632                pq.delay = MAX_RETRANSMIT_DELAY.min(pq.delay * 2);
633
634                return Ok(());
635            }
636        }
637
638        // Nothing to dispatch
639        Ok(())
640    }
641
642    pub(crate) fn poll_at(&self, _cx: &Context) -> PollAt {
643        self.queries
644            .iter()
645            .flatten()
646            .filter_map(|q| match &q.state {
647                State::Pending(pq) => Some(PollAt::Time(pq.retransmit_at)),
648                State::Completed(_) => None,
649                State::Failure => None,
650            })
651            .min()
652            .unwrap_or(PollAt::Ingress)
653    }
654}
655
656fn eq_names<'a>(
657    mut a: impl Iterator<Item = wire::Result<&'a [u8]>>,
658    mut b: impl Iterator<Item = wire::Result<&'a [u8]>>,
659) -> wire::Result<bool> {
660    loop {
661        match (a.next(), b.next()) {
662            // Handle errors
663            (Some(Err(e)), _) => return Err(e),
664            (_, Some(Err(e))) => return Err(e),
665
666            // Both finished -> equal
667            (None, None) => return Ok(true),
668
669            // One finished before the other -> not equal
670            (None, _) => return Ok(false),
671            (_, None) => return Ok(false),
672
673            // Got two labels, check if they're equal
674            (Some(Ok(la)), Some(Ok(lb))) => {
675                if la != lb {
676                    return Ok(false);
677                }
678            }
679        }
680    }
681}
682
683fn copy_name<'a, const N: usize>(
684    dest: &mut Vec<u8, N>,
685    name: impl Iterator<Item = wire::Result<&'a [u8]>>,
686) -> Result<(), wire::Error> {
687    dest.truncate(0);
688
689    for label in name {
690        let label = label?;
691        dest.push(label.len() as u8).map_err(|_| wire::Error)?;
692        dest.extend_from_slice(label).map_err(|_| wire::Error)?;
693    }
694
695    // Write terminator 0x00
696    dest.push(0).map_err(|_| wire::Error)?;
697
698    Ok(())
699}