smoltcp/socket/
dns.rs

1use core::cmp::min;
2#[cfg(feature = "async")]
3use core::task::Waker;
4
5use heapless::Vec;
6use managed::ManagedSlice;
7
8use crate::config::{DNS_MAX_NAME_SIZE, DNS_MAX_RESULT_COUNT, DNS_MAX_SERVER_COUNT};
9use crate::socket::{Context, PollAt};
10use crate::time::{Duration, Instant};
11use crate::wire::dns::{Flags, Opcode, Packet, Question, Rcode, Record, RecordData, Repr, Type};
12use crate::wire::{self, IpAddress, IpProtocol, IpRepr, UdpRepr};
13
14#[cfg(feature = "async")]
15use super::WakerRegistration;
16
17const DNS_PORT: u16 = 53;
18const MDNS_DNS_PORT: u16 = 5353;
19const RETRANSMIT_DELAY: Duration = Duration::from_millis(1_000);
20const MAX_RETRANSMIT_DELAY: Duration = Duration::from_millis(10_000);
21const RETRANSMIT_TIMEOUT: Duration = Duration::from_millis(10_000); // Should generally be 2-10 secs
22
23#[cfg(feature = "proto-ipv6")]
24#[allow(unused)]
25const MDNS_IPV6_ADDR: IpAddress = IpAddress::Ipv6(crate::wire::Ipv6Address::new(
26    0xff02, 0, 0, 0, 0, 0, 0, 0xfb,
27));
28
29#[cfg(feature = "proto-ipv4")]
30#[allow(unused)]
31const MDNS_IPV4_ADDR: IpAddress = IpAddress::Ipv4(crate::wire::Ipv4Address::new(224, 0, 0, 251));
32
33/// Error returned by [`Socket::start_query`]
34#[derive(Debug, PartialEq, Eq, Clone, Copy)]
35#[cfg_attr(feature = "defmt", derive(defmt::Format))]
36pub enum StartQueryError {
37    NoFreeSlot,
38    InvalidName,
39    NameTooLong,
40}
41
42impl core::fmt::Display for StartQueryError {
43    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
44        match self {
45            StartQueryError::NoFreeSlot => write!(f, "No free slot"),
46            StartQueryError::InvalidName => write!(f, "Invalid name"),
47            StartQueryError::NameTooLong => write!(f, "Name too long"),
48        }
49    }
50}
51
52#[cfg(feature = "std")]
53impl std::error::Error for StartQueryError {}
54
55/// Error returned by [`Socket::get_query_result`]
56#[derive(Debug, PartialEq, Eq, Clone, Copy)]
57#[cfg_attr(feature = "defmt", derive(defmt::Format))]
58pub enum GetQueryResultError {
59    /// Query is not done yet.
60    Pending,
61    /// Query failed.
62    Failed,
63}
64
65impl core::fmt::Display for GetQueryResultError {
66    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
67        match self {
68            GetQueryResultError::Pending => write!(f, "Query is not done yet"),
69            GetQueryResultError::Failed => write!(f, "Query failed"),
70        }
71    }
72}
73
74#[cfg(feature = "std")]
75impl std::error::Error for GetQueryResultError {}
76
77/// State for an in-progress DNS query.
78///
79/// The only reason this struct is public is to allow the socket state
80/// to be allocated externally.
81#[derive(Debug)]
82pub struct DnsQuery {
83    state: State,
84
85    #[cfg(feature = "async")]
86    waker: WakerRegistration,
87}
88
89impl DnsQuery {
90    fn set_state(&mut self, state: State) {
91        self.state = state;
92        #[cfg(feature = "async")]
93        self.waker.wake();
94    }
95}
96
97#[derive(Debug)]
98#[allow(clippy::large_enum_variant)]
99enum State {
100    Pending(PendingQuery),
101    Completed(CompletedQuery),
102    Failure,
103}
104
105#[derive(Debug)]
106struct PendingQuery {
107    name: Vec<u8, DNS_MAX_NAME_SIZE>,
108    type_: Type,
109
110    port: u16, // UDP port (src for request, dst for response)
111    txid: u16, // transaction ID
112
113    timeout_at: Option<Instant>,
114    retransmit_at: Instant,
115    delay: Duration,
116
117    server_idx: usize,
118    mdns: MulticastDns,
119}
120
121#[derive(Debug)]
122pub enum MulticastDns {
123    Disabled,
124    #[cfg(feature = "socket-mdns")]
125    Enabled,
126}
127
128#[derive(Debug)]
129struct CompletedQuery {
130    addresses: Vec<IpAddress, DNS_MAX_RESULT_COUNT>,
131}
132
133/// A handle to an in-progress DNS query.
134#[derive(Clone, Copy)]
135pub struct QueryHandle(usize);
136
137/// A Domain Name System socket.
138///
139/// A UDP socket is bound to a specific endpoint, and owns transmit and receive
140/// packet buffers.
141#[derive(Debug)]
142pub struct Socket<'a> {
143    servers: Vec<IpAddress, DNS_MAX_SERVER_COUNT>,
144    queries: ManagedSlice<'a, Option<DnsQuery>>,
145
146    /// The time-to-live (IPv4) or hop limit (IPv6) value used in outgoing packets.
147    hop_limit: Option<u8>,
148}
149
150impl<'a> Socket<'a> {
151    /// Create a DNS socket.
152    ///
153    /// Truncates the server list if `servers.len() > MAX_SERVER_COUNT`
154    pub fn new<Q>(servers: &[IpAddress], queries: Q) -> Socket<'a>
155    where
156        Q: Into<ManagedSlice<'a, Option<DnsQuery>>>,
157    {
158        let truncated_servers = &servers[..min(servers.len(), DNS_MAX_SERVER_COUNT)];
159
160        Socket {
161            servers: Vec::from_slice(truncated_servers).unwrap(),
162            queries: queries.into(),
163            hop_limit: None,
164        }
165    }
166
167    /// Update the list of DNS servers, will replace all existing servers
168    ///
169    /// Truncates the server list if `servers.len() > MAX_SERVER_COUNT`
170    pub fn update_servers(&mut self, servers: &[IpAddress]) {
171        if servers.len() > DNS_MAX_SERVER_COUNT {
172            net_trace!("Max DNS Servers exceeded. Increase MAX_SERVER_COUNT");
173            self.servers = Vec::from_slice(&servers[..DNS_MAX_SERVER_COUNT]).unwrap();
174        } else {
175            self.servers = Vec::from_slice(servers).unwrap();
176        }
177    }
178
179    /// Return the time-to-live (IPv4) or hop limit (IPv6) value used in outgoing packets.
180    ///
181    /// See also the [set_hop_limit](#method.set_hop_limit) method
182    pub fn hop_limit(&self) -> Option<u8> {
183        self.hop_limit
184    }
185
186    /// Set the time-to-live (IPv4) or hop limit (IPv6) value used in outgoing packets.
187    ///
188    /// A socket without an explicitly set hop limit value uses the default [IANA recommended]
189    /// value (64).
190    ///
191    /// # Panics
192    ///
193    /// This function panics if a hop limit value of 0 is given. See [RFC 1122 § 3.2.1.7].
194    ///
195    /// [IANA recommended]: https://www.iana.org/assignments/ip-parameters/ip-parameters.xhtml
196    /// [RFC 1122 § 3.2.1.7]: https://tools.ietf.org/html/rfc1122#section-3.2.1.7
197    pub fn set_hop_limit(&mut self, hop_limit: Option<u8>) {
198        // A host MUST NOT send a datagram with a hop limit value of 0
199        if let Some(0) = hop_limit {
200            panic!("the time-to-live value of a packet must not be zero")
201        }
202
203        self.hop_limit = hop_limit
204    }
205
206    fn find_free_query(&mut self) -> Option<QueryHandle> {
207        for (i, q) in self.queries.iter().enumerate() {
208            if q.is_none() {
209                return Some(QueryHandle(i));
210            }
211        }
212
213        match &mut self.queries {
214            ManagedSlice::Borrowed(_) => None,
215            #[cfg(feature = "alloc")]
216            ManagedSlice::Owned(queries) => {
217                queries.push(None);
218                let index = queries.len() - 1;
219                Some(QueryHandle(index))
220            }
221        }
222    }
223
224    /// Start a query.
225    ///
226    /// `name` is specified in human-friendly format, such as `"rust-lang.org"`.
227    /// It accepts names both with and without trailing dot, and they're treated
228    /// the same (there's no support for DNS search path).
229    pub fn start_query(
230        &mut self,
231        cx: &mut Context,
232        name: &str,
233        query_type: Type,
234    ) -> Result<QueryHandle, StartQueryError> {
235        let mut name = name.as_bytes();
236
237        if name.is_empty() {
238            net_trace!("invalid name: zero length");
239            return Err(StartQueryError::InvalidName);
240        }
241
242        // Remove trailing dot, if any
243        if name[name.len() - 1] == b'.' {
244            name = &name[..name.len() - 1];
245        }
246
247        let mut raw_name: Vec<u8, DNS_MAX_NAME_SIZE> = Vec::new();
248
249        let mut mdns = MulticastDns::Disabled;
250        #[cfg(feature = "socket-mdns")]
251        if name.split(|&c| c == b'.').last().unwrap() == b"local" {
252            net_trace!("Starting a mDNS query");
253            mdns = MulticastDns::Enabled;
254        }
255
256        for s in name.split(|&c| c == b'.') {
257            if s.len() > 63 {
258                net_trace!("invalid name: too long label");
259                return Err(StartQueryError::InvalidName);
260            }
261            if s.is_empty() {
262                net_trace!("invalid name: zero length label");
263                return Err(StartQueryError::InvalidName);
264            }
265
266            // Push label
267            raw_name
268                .push(s.len() as u8)
269                .map_err(|_| StartQueryError::NameTooLong)?;
270            raw_name
271                .extend_from_slice(s)
272                .map_err(|_| StartQueryError::NameTooLong)?;
273        }
274
275        // Push terminator.
276        raw_name
277            .push(0x00)
278            .map_err(|_| StartQueryError::NameTooLong)?;
279
280        self.start_query_raw(cx, &raw_name, query_type, mdns)
281    }
282
283    /// Start a query with a raw (wire-format) DNS name.
284    /// `b"\x09rust-lang\x03org\x00"`
285    ///
286    /// You probably want to use [`start_query`] instead.
287    pub fn start_query_raw(
288        &mut self,
289        cx: &mut Context,
290        raw_name: &[u8],
291        query_type: Type,
292        mdns: MulticastDns,
293    ) -> Result<QueryHandle, StartQueryError> {
294        let handle = self.find_free_query().ok_or(StartQueryError::NoFreeSlot)?;
295
296        self.queries[handle.0] = Some(DnsQuery {
297            state: State::Pending(PendingQuery {
298                name: Vec::from_slice(raw_name).map_err(|_| StartQueryError::NameTooLong)?,
299                type_: query_type,
300                txid: cx.rand().rand_u16(),
301                port: cx.rand().rand_source_port(),
302                delay: RETRANSMIT_DELAY,
303                timeout_at: None,
304                retransmit_at: Instant::ZERO,
305                server_idx: 0,
306                mdns,
307            }),
308            #[cfg(feature = "async")]
309            waker: WakerRegistration::new(),
310        });
311        Ok(handle)
312    }
313
314    /// Get the result of a query.
315    ///
316    /// If the query is completed, the query slot is automatically freed.
317    ///
318    /// # Panics
319    /// Panics if the QueryHandle corresponds to a free slot.
320    pub fn get_query_result(
321        &mut self,
322        handle: QueryHandle,
323    ) -> Result<Vec<IpAddress, DNS_MAX_RESULT_COUNT>, GetQueryResultError> {
324        let slot = &mut self.queries[handle.0];
325        let q = slot.as_mut().unwrap();
326        match &mut q.state {
327            // Query is not done yet.
328            State::Pending(_) => Err(GetQueryResultError::Pending),
329            // Query is done
330            State::Completed(q) => {
331                let res = q.addresses.clone();
332                *slot = None; // Free up the slot for recycling.
333                Ok(res)
334            }
335            State::Failure => {
336                *slot = None; // Free up the slot for recycling.
337                Err(GetQueryResultError::Failed)
338            }
339        }
340    }
341
342    /// Cancels a query, freeing the slot.
343    ///
344    /// # Panics
345    ///
346    /// Panics if the QueryHandle corresponds to an already free slot.
347    pub fn cancel_query(&mut self, handle: QueryHandle) {
348        let slot = &mut self.queries[handle.0];
349        if slot.is_none() {
350            panic!("Canceling query in a free slot.")
351        }
352        *slot = None; // Free up the slot for recycling.
353    }
354
355    /// Assign a waker to a query slot
356    ///
357    /// The waker will be woken when the query completes, either successfully or failed.
358    ///
359    /// # Panics
360    ///
361    /// Panics if the QueryHandle corresponds to an already free slot.
362    #[cfg(feature = "async")]
363    pub fn register_query_waker(&mut self, handle: QueryHandle, waker: &Waker) {
364        self.queries[handle.0]
365            .as_mut()
366            .unwrap()
367            .waker
368            .register(waker);
369    }
370
371    pub(crate) fn accepts(&self, ip_repr: &IpRepr, udp_repr: &UdpRepr) -> bool {
372        (udp_repr.src_port == DNS_PORT
373            && self
374                .servers
375                .iter()
376                .any(|server| *server == ip_repr.src_addr()))
377            || (udp_repr.src_port == MDNS_DNS_PORT)
378    }
379
380    pub(crate) fn process(
381        &mut self,
382        _cx: &mut Context,
383        ip_repr: &IpRepr,
384        udp_repr: &UdpRepr,
385        payload: &[u8],
386    ) {
387        debug_assert!(self.accepts(ip_repr, udp_repr));
388
389        let size = payload.len();
390
391        net_trace!(
392            "receiving {} octets from {:?}:{}",
393            size,
394            ip_repr.src_addr(),
395            udp_repr.dst_port
396        );
397
398        let p = match Packet::new_checked(payload) {
399            Ok(x) => x,
400            Err(_) => {
401                net_trace!("dns packet malformed");
402                return;
403            }
404        };
405        if p.opcode() != Opcode::Query {
406            net_trace!("unwanted opcode {:?}", p.opcode());
407            return;
408        }
409
410        if !p.flags().contains(Flags::RESPONSE) {
411            net_trace!("packet doesn't have response bit set");
412            return;
413        }
414
415        if p.question_count() != 1 {
416            net_trace!("bad question count {:?}", p.question_count());
417            return;
418        }
419
420        // Find pending query
421        for q in self.queries.iter_mut().flatten() {
422            if let State::Pending(pq) = &mut q.state {
423                if udp_repr.dst_port != pq.port || p.transaction_id() != pq.txid {
424                    continue;
425                }
426
427                if p.rcode() == Rcode::NXDomain {
428                    net_trace!("rcode NXDomain");
429                    q.set_state(State::Failure);
430                    continue;
431                }
432
433                let payload = p.payload();
434                let (mut payload, question) = match Question::parse(payload) {
435                    Ok(x) => x,
436                    Err(_) => {
437                        net_trace!("question malformed");
438                        return;
439                    }
440                };
441
442                if question.type_ != pq.type_ {
443                    net_trace!("question type mismatch");
444                    return;
445                }
446
447                match eq_names(p.parse_name(question.name), p.parse_name(&pq.name)) {
448                    Ok(true) => {}
449                    Ok(false) => {
450                        net_trace!("question name mismatch");
451                        return;
452                    }
453                    Err(_) => {
454                        net_trace!("dns question name malformed");
455                        return;
456                    }
457                }
458
459                let mut addresses = Vec::new();
460
461                for _ in 0..p.answer_record_count() {
462                    let (payload2, r) = match Record::parse(payload) {
463                        Ok(x) => x,
464                        Err(_) => {
465                            net_trace!("dns answer record malformed");
466                            return;
467                        }
468                    };
469                    payload = payload2;
470
471                    match eq_names(p.parse_name(r.name), p.parse_name(&pq.name)) {
472                        Ok(true) => {}
473                        Ok(false) => {
474                            net_trace!("answer name mismatch: {:?}", r);
475                            continue;
476                        }
477                        Err(_) => {
478                            net_trace!("dns answer record name malformed");
479                            return;
480                        }
481                    }
482
483                    match r.data {
484                        #[cfg(feature = "proto-ipv4")]
485                        RecordData::A(addr) => {
486                            net_trace!("A: {:?}", addr);
487                            if addresses.push(addr.into()).is_err() {
488                                net_trace!("too many addresses in response, ignoring {:?}", addr);
489                            }
490                        }
491                        #[cfg(feature = "proto-ipv6")]
492                        RecordData::Aaaa(addr) => {
493                            net_trace!("AAAA: {:?}", addr);
494                            if addresses.push(addr.into()).is_err() {
495                                net_trace!("too many addresses in response, ignoring {:?}", addr);
496                            }
497                        }
498                        RecordData::Cname(name) => {
499                            net_trace!("CNAME: {:?}", name);
500
501                            // When faced with a CNAME, recursive resolvers are supposed to
502                            // resolve the CNAME and append the results for it.
503                            //
504                            // We update the query with the new name, so that we pick up the A/AAAA
505                            // records for the CNAME when we parse them later.
506                            // I believe it's mandatory the CNAME results MUST come *after* in the
507                            // packet, so it's enough to do one linear pass over it.
508                            if copy_name(&mut pq.name, p.parse_name(name)).is_err() {
509                                net_trace!("dns answer cname malformed");
510                                return;
511                            }
512                        }
513                        RecordData::Other(type_, data) => {
514                            net_trace!("unknown: {:?} {:?}", type_, data)
515                        }
516                    }
517                }
518
519                q.set_state(if addresses.is_empty() {
520                    State::Failure
521                } else {
522                    State::Completed(CompletedQuery { addresses })
523                });
524
525                // If we get here, packet matched the current query, stop processing.
526                return;
527            }
528        }
529
530        // If we get here, packet matched with no query.
531        net_trace!("no query matched");
532    }
533
534    pub(crate) fn dispatch<F, E>(&mut self, cx: &mut Context, emit: F) -> Result<(), E>
535    where
536        F: FnOnce(&mut Context, (IpRepr, UdpRepr, &[u8])) -> Result<(), E>,
537    {
538        let hop_limit = self.hop_limit.unwrap_or(64);
539
540        for q in self.queries.iter_mut().flatten() {
541            if let State::Pending(pq) = &mut q.state {
542                // As per RFC 6762 any DNS query ending in .local. MUST be sent as mdns
543                // so we internally overwrite the servers for any of those queries
544                // in this function.
545                let servers = match pq.mdns {
546                    #[cfg(feature = "socket-mdns")]
547                    MulticastDns::Enabled => &[
548                        #[cfg(feature = "proto-ipv6")]
549                        MDNS_IPV6_ADDR,
550                        #[cfg(feature = "proto-ipv4")]
551                        MDNS_IPV4_ADDR,
552                    ],
553                    MulticastDns::Disabled => self.servers.as_slice(),
554                };
555
556                let timeout = if let Some(timeout) = pq.timeout_at {
557                    timeout
558                } else {
559                    let v = cx.now() + RETRANSMIT_TIMEOUT;
560                    pq.timeout_at = Some(v);
561                    v
562                };
563
564                // Check timeout
565                if timeout < cx.now() {
566                    // DNS timeout
567                    pq.timeout_at = Some(cx.now() + RETRANSMIT_TIMEOUT);
568                    pq.retransmit_at = Instant::ZERO;
569                    pq.delay = RETRANSMIT_DELAY;
570
571                    // Try next server. We check below whether we've tried all servers.
572                    pq.server_idx += 1;
573                }
574                // Check if we've run out of servers to try.
575                if pq.server_idx >= servers.len() {
576                    net_trace!("already tried all servers.");
577                    q.set_state(State::Failure);
578                    continue;
579                }
580
581                // Check so the IP address is valid
582                if servers[pq.server_idx].is_unspecified() {
583                    net_trace!("invalid unspecified DNS server addr.");
584                    q.set_state(State::Failure);
585                    continue;
586                }
587
588                if pq.retransmit_at > cx.now() {
589                    // query is waiting for retransmit
590                    continue;
591                }
592
593                let repr = Repr {
594                    transaction_id: pq.txid,
595                    flags: Flags::RECURSION_DESIRED,
596                    opcode: Opcode::Query,
597                    question: Question {
598                        name: &pq.name,
599                        type_: pq.type_,
600                    },
601                };
602
603                let mut payload = [0u8; 512];
604                let payload = &mut payload[..repr.buffer_len()];
605                repr.emit(&mut Packet::new_unchecked(payload));
606
607                let dst_port = match pq.mdns {
608                    #[cfg(feature = "socket-mdns")]
609                    MulticastDns::Enabled => MDNS_DNS_PORT,
610                    MulticastDns::Disabled => DNS_PORT,
611                };
612
613                let udp_repr = UdpRepr {
614                    src_port: pq.port,
615                    dst_port,
616                };
617
618                let dst_addr = servers[pq.server_idx];
619                let src_addr = match cx.get_source_address(&dst_addr) {
620                    Some(src_addr) => src_addr,
621                    None => {
622                        net_trace!("no source address for destination {}", dst_addr);
623                        q.set_state(State::Failure);
624                        continue;
625                    }
626                };
627
628                let ip_repr = IpRepr::new(
629                    src_addr,
630                    dst_addr,
631                    IpProtocol::Udp,
632                    udp_repr.header_len() + payload.len(),
633                    hop_limit,
634                );
635
636                net_trace!(
637                    "sending {} octets to {} from port {}",
638                    payload.len(),
639                    ip_repr.dst_addr(),
640                    udp_repr.src_port
641                );
642
643                emit(cx, (ip_repr, udp_repr, payload))?;
644
645                pq.retransmit_at = cx.now() + pq.delay;
646                pq.delay = MAX_RETRANSMIT_DELAY.min(pq.delay * 2);
647
648                return Ok(());
649            }
650        }
651
652        // Nothing to dispatch
653        Ok(())
654    }
655
656    pub(crate) fn poll_at(&self, _cx: &Context) -> PollAt {
657        self.queries
658            .iter()
659            .flatten()
660            .filter_map(|q| match &q.state {
661                State::Pending(pq) => Some(PollAt::Time(pq.retransmit_at)),
662                State::Completed(_) => None,
663                State::Failure => None,
664            })
665            .min()
666            .unwrap_or(PollAt::Ingress)
667    }
668}
669
670fn eq_names<'a>(
671    mut a: impl Iterator<Item = wire::Result<&'a [u8]>>,
672    mut b: impl Iterator<Item = wire::Result<&'a [u8]>>,
673) -> wire::Result<bool> {
674    loop {
675        match (a.next(), b.next()) {
676            // Handle errors
677            (Some(Err(e)), _) => return Err(e),
678            (_, Some(Err(e))) => return Err(e),
679
680            // Both finished -> equal
681            (None, None) => return Ok(true),
682
683            // One finished before the other -> not equal
684            (None, _) => return Ok(false),
685            (_, None) => return Ok(false),
686
687            // Got two labels, check if they're equal
688            (Some(Ok(la)), Some(Ok(lb))) => {
689                if la != lb {
690                    return Ok(false);
691                }
692            }
693        }
694    }
695}
696
697fn copy_name<'a, const N: usize>(
698    dest: &mut Vec<u8, N>,
699    name: impl Iterator<Item = wire::Result<&'a [u8]>>,
700) -> Result<(), wire::Error> {
701    dest.truncate(0);
702
703    for label in name {
704        let label = label?;
705        dest.push(label.len() as u8).map_err(|_| wire::Error)?;
706        dest.extend_from_slice(label).map_err(|_| wire::Error)?;
707    }
708
709    // Write terminator 0x00
710    dest.push(0).map_err(|_| wire::Error)?;
711
712    Ok(())
713}