1use core::cmp::min;
2#[cfg(feature = "async")]
3use core::task::Waker;
4
5use heapless::Vec;
6use managed::ManagedSlice;
7
8use crate::config::{DNS_MAX_NAME_SIZE, DNS_MAX_RESULT_COUNT, DNS_MAX_SERVER_COUNT};
9use crate::socket::{Context, PollAt};
10use crate::time::{Duration, Instant};
11use crate::wire::dns::{Flags, Opcode, Packet, Question, Rcode, Record, RecordData, Repr, Type};
12use crate::wire::{self, IpAddress, IpProtocol, IpRepr, UdpRepr};
13
14#[cfg(feature = "async")]
15use super::WakerRegistration;
16
17const DNS_PORT: u16 = 53;
18const MDNS_DNS_PORT: u16 = 5353;
19const RETRANSMIT_DELAY: Duration = Duration::from_millis(1_000);
20const MAX_RETRANSMIT_DELAY: Duration = Duration::from_millis(10_000);
21const RETRANSMIT_TIMEOUT: Duration = Duration::from_millis(10_000); #[cfg(feature = "proto-ipv6")]
24#[allow(unused)]
25const MDNS_IPV6_ADDR: IpAddress = IpAddress::Ipv6(crate::wire::Ipv6Address::new(
26 0xff02, 0, 0, 0, 0, 0, 0, 0xfb,
27));
28
29#[cfg(feature = "proto-ipv4")]
30#[allow(unused)]
31const MDNS_IPV4_ADDR: IpAddress = IpAddress::Ipv4(crate::wire::Ipv4Address::new(224, 0, 0, 251));
32
33#[derive(Debug, PartialEq, Eq, Clone, Copy)]
35#[cfg_attr(feature = "defmt", derive(defmt::Format))]
36pub enum StartQueryError {
37 NoFreeSlot,
38 InvalidName,
39 NameTooLong,
40}
41
42impl core::fmt::Display for StartQueryError {
43 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
44 match self {
45 StartQueryError::NoFreeSlot => write!(f, "No free slot"),
46 StartQueryError::InvalidName => write!(f, "Invalid name"),
47 StartQueryError::NameTooLong => write!(f, "Name too long"),
48 }
49 }
50}
51
52#[cfg(feature = "std")]
53impl std::error::Error for StartQueryError {}
54
55#[derive(Debug, PartialEq, Eq, Clone, Copy)]
57#[cfg_attr(feature = "defmt", derive(defmt::Format))]
58pub enum GetQueryResultError {
59 Pending,
61 Failed,
63}
64
65impl core::fmt::Display for GetQueryResultError {
66 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
67 match self {
68 GetQueryResultError::Pending => write!(f, "Query is not done yet"),
69 GetQueryResultError::Failed => write!(f, "Query failed"),
70 }
71 }
72}
73
74#[cfg(feature = "std")]
75impl std::error::Error for GetQueryResultError {}
76
77#[derive(Debug)]
82pub struct DnsQuery {
83 state: State,
84
85 #[cfg(feature = "async")]
86 waker: WakerRegistration,
87}
88
89impl DnsQuery {
90 fn set_state(&mut self, state: State) {
91 self.state = state;
92 #[cfg(feature = "async")]
93 self.waker.wake();
94 }
95}
96
97#[derive(Debug)]
98#[allow(clippy::large_enum_variant)]
99enum State {
100 Pending(PendingQuery),
101 Completed(CompletedQuery),
102 Failure,
103}
104
105#[derive(Debug)]
106struct PendingQuery {
107 name: Vec<u8, DNS_MAX_NAME_SIZE>,
108 type_: Type,
109
110 port: u16, txid: u16, timeout_at: Option<Instant>,
114 retransmit_at: Instant,
115 delay: Duration,
116
117 server_idx: usize,
118 mdns: MulticastDns,
119}
120
121#[derive(Debug)]
122pub enum MulticastDns {
123 Disabled,
124 #[cfg(feature = "socket-mdns")]
125 Enabled,
126}
127
128#[derive(Debug)]
129struct CompletedQuery {
130 addresses: Vec<IpAddress, DNS_MAX_RESULT_COUNT>,
131}
132
133#[derive(Clone, Copy)]
135pub struct QueryHandle(usize);
136
137#[derive(Debug)]
142pub struct Socket<'a> {
143 servers: Vec<IpAddress, DNS_MAX_SERVER_COUNT>,
144 queries: ManagedSlice<'a, Option<DnsQuery>>,
145
146 hop_limit: Option<u8>,
148}
149
150impl<'a> Socket<'a> {
151 pub fn new<Q>(servers: &[IpAddress], queries: Q) -> Socket<'a>
155 where
156 Q: Into<ManagedSlice<'a, Option<DnsQuery>>>,
157 {
158 let truncated_servers = &servers[..min(servers.len(), DNS_MAX_SERVER_COUNT)];
159
160 Socket {
161 servers: Vec::from_slice(truncated_servers).unwrap(),
162 queries: queries.into(),
163 hop_limit: None,
164 }
165 }
166
167 pub fn update_servers(&mut self, servers: &[IpAddress]) {
171 if servers.len() > DNS_MAX_SERVER_COUNT {
172 net_trace!("Max DNS Servers exceeded. Increase MAX_SERVER_COUNT");
173 self.servers = Vec::from_slice(&servers[..DNS_MAX_SERVER_COUNT]).unwrap();
174 } else {
175 self.servers = Vec::from_slice(servers).unwrap();
176 }
177 }
178
179 pub fn hop_limit(&self) -> Option<u8> {
183 self.hop_limit
184 }
185
186 pub fn set_hop_limit(&mut self, hop_limit: Option<u8>) {
198 if let Some(0) = hop_limit {
200 panic!("the time-to-live value of a packet must not be zero")
201 }
202
203 self.hop_limit = hop_limit
204 }
205
206 fn find_free_query(&mut self) -> Option<QueryHandle> {
207 for (i, q) in self.queries.iter().enumerate() {
208 if q.is_none() {
209 return Some(QueryHandle(i));
210 }
211 }
212
213 match &mut self.queries {
214 ManagedSlice::Borrowed(_) => None,
215 #[cfg(feature = "alloc")]
216 ManagedSlice::Owned(queries) => {
217 queries.push(None);
218 let index = queries.len() - 1;
219 Some(QueryHandle(index))
220 }
221 }
222 }
223
224 pub fn start_query(
230 &mut self,
231 cx: &mut Context,
232 name: &str,
233 query_type: Type,
234 ) -> Result<QueryHandle, StartQueryError> {
235 let mut name = name.as_bytes();
236
237 if name.is_empty() {
238 net_trace!("invalid name: zero length");
239 return Err(StartQueryError::InvalidName);
240 }
241
242 if name[name.len() - 1] == b'.' {
244 name = &name[..name.len() - 1];
245 }
246
247 let mut raw_name: Vec<u8, DNS_MAX_NAME_SIZE> = Vec::new();
248
249 let mut mdns = MulticastDns::Disabled;
250 #[cfg(feature = "socket-mdns")]
251 if name.split(|&c| c == b'.').last().unwrap() == b"local" {
252 net_trace!("Starting a mDNS query");
253 mdns = MulticastDns::Enabled;
254 }
255
256 for s in name.split(|&c| c == b'.') {
257 if s.len() > 63 {
258 net_trace!("invalid name: too long label");
259 return Err(StartQueryError::InvalidName);
260 }
261 if s.is_empty() {
262 net_trace!("invalid name: zero length label");
263 return Err(StartQueryError::InvalidName);
264 }
265
266 raw_name
268 .push(s.len() as u8)
269 .map_err(|_| StartQueryError::NameTooLong)?;
270 raw_name
271 .extend_from_slice(s)
272 .map_err(|_| StartQueryError::NameTooLong)?;
273 }
274
275 raw_name
277 .push(0x00)
278 .map_err(|_| StartQueryError::NameTooLong)?;
279
280 self.start_query_raw(cx, &raw_name, query_type, mdns)
281 }
282
283 pub fn start_query_raw(
288 &mut self,
289 cx: &mut Context,
290 raw_name: &[u8],
291 query_type: Type,
292 mdns: MulticastDns,
293 ) -> Result<QueryHandle, StartQueryError> {
294 let handle = self.find_free_query().ok_or(StartQueryError::NoFreeSlot)?;
295
296 self.queries[handle.0] = Some(DnsQuery {
297 state: State::Pending(PendingQuery {
298 name: Vec::from_slice(raw_name).map_err(|_| StartQueryError::NameTooLong)?,
299 type_: query_type,
300 txid: cx.rand().rand_u16(),
301 port: cx.rand().rand_source_port(),
302 delay: RETRANSMIT_DELAY,
303 timeout_at: None,
304 retransmit_at: Instant::ZERO,
305 server_idx: 0,
306 mdns,
307 }),
308 #[cfg(feature = "async")]
309 waker: WakerRegistration::new(),
310 });
311 Ok(handle)
312 }
313
314 pub fn get_query_result(
321 &mut self,
322 handle: QueryHandle,
323 ) -> Result<Vec<IpAddress, DNS_MAX_RESULT_COUNT>, GetQueryResultError> {
324 let slot = &mut self.queries[handle.0];
325 let q = slot.as_mut().unwrap();
326 match &mut q.state {
327 State::Pending(_) => Err(GetQueryResultError::Pending),
329 State::Completed(q) => {
331 let res = q.addresses.clone();
332 *slot = None; Ok(res)
334 }
335 State::Failure => {
336 *slot = None; Err(GetQueryResultError::Failed)
338 }
339 }
340 }
341
342 pub fn cancel_query(&mut self, handle: QueryHandle) {
348 let slot = &mut self.queries[handle.0];
349 if slot.is_none() {
350 panic!("Canceling query in a free slot.")
351 }
352 *slot = None; }
354
355 #[cfg(feature = "async")]
363 pub fn register_query_waker(&mut self, handle: QueryHandle, waker: &Waker) {
364 self.queries[handle.0]
365 .as_mut()
366 .unwrap()
367 .waker
368 .register(waker);
369 }
370
371 pub(crate) fn accepts(&self, ip_repr: &IpRepr, udp_repr: &UdpRepr) -> bool {
372 (udp_repr.src_port == DNS_PORT
373 && self
374 .servers
375 .iter()
376 .any(|server| *server == ip_repr.src_addr()))
377 || (udp_repr.src_port == MDNS_DNS_PORT)
378 }
379
380 pub(crate) fn process(
381 &mut self,
382 _cx: &mut Context,
383 ip_repr: &IpRepr,
384 udp_repr: &UdpRepr,
385 payload: &[u8],
386 ) {
387 debug_assert!(self.accepts(ip_repr, udp_repr));
388
389 let size = payload.len();
390
391 net_trace!(
392 "receiving {} octets from {:?}:{}",
393 size,
394 ip_repr.src_addr(),
395 udp_repr.dst_port
396 );
397
398 let p = match Packet::new_checked(payload) {
399 Ok(x) => x,
400 Err(_) => {
401 net_trace!("dns packet malformed");
402 return;
403 }
404 };
405 if p.opcode() != Opcode::Query {
406 net_trace!("unwanted opcode {:?}", p.opcode());
407 return;
408 }
409
410 if !p.flags().contains(Flags::RESPONSE) {
411 net_trace!("packet doesn't have response bit set");
412 return;
413 }
414
415 if p.question_count() != 1 {
416 net_trace!("bad question count {:?}", p.question_count());
417 return;
418 }
419
420 for q in self.queries.iter_mut().flatten() {
422 if let State::Pending(pq) = &mut q.state {
423 if udp_repr.dst_port != pq.port || p.transaction_id() != pq.txid {
424 continue;
425 }
426
427 if p.rcode() == Rcode::NXDomain {
428 net_trace!("rcode NXDomain");
429 q.set_state(State::Failure);
430 continue;
431 }
432
433 let payload = p.payload();
434 let (mut payload, question) = match Question::parse(payload) {
435 Ok(x) => x,
436 Err(_) => {
437 net_trace!("question malformed");
438 return;
439 }
440 };
441
442 if question.type_ != pq.type_ {
443 net_trace!("question type mismatch");
444 return;
445 }
446
447 match eq_names(p.parse_name(question.name), p.parse_name(&pq.name)) {
448 Ok(true) => {}
449 Ok(false) => {
450 net_trace!("question name mismatch");
451 return;
452 }
453 Err(_) => {
454 net_trace!("dns question name malformed");
455 return;
456 }
457 }
458
459 let mut addresses = Vec::new();
460
461 for _ in 0..p.answer_record_count() {
462 let (payload2, r) = match Record::parse(payload) {
463 Ok(x) => x,
464 Err(_) => {
465 net_trace!("dns answer record malformed");
466 return;
467 }
468 };
469 payload = payload2;
470
471 match eq_names(p.parse_name(r.name), p.parse_name(&pq.name)) {
472 Ok(true) => {}
473 Ok(false) => {
474 net_trace!("answer name mismatch: {:?}", r);
475 continue;
476 }
477 Err(_) => {
478 net_trace!("dns answer record name malformed");
479 return;
480 }
481 }
482
483 match r.data {
484 #[cfg(feature = "proto-ipv4")]
485 RecordData::A(addr) => {
486 net_trace!("A: {:?}", addr);
487 if addresses.push(addr.into()).is_err() {
488 net_trace!("too many addresses in response, ignoring {:?}", addr);
489 }
490 }
491 #[cfg(feature = "proto-ipv6")]
492 RecordData::Aaaa(addr) => {
493 net_trace!("AAAA: {:?}", addr);
494 if addresses.push(addr.into()).is_err() {
495 net_trace!("too many addresses in response, ignoring {:?}", addr);
496 }
497 }
498 RecordData::Cname(name) => {
499 net_trace!("CNAME: {:?}", name);
500
501 if copy_name(&mut pq.name, p.parse_name(name)).is_err() {
509 net_trace!("dns answer cname malformed");
510 return;
511 }
512 }
513 RecordData::Other(type_, data) => {
514 net_trace!("unknown: {:?} {:?}", type_, data)
515 }
516 }
517 }
518
519 q.set_state(if addresses.is_empty() {
520 State::Failure
521 } else {
522 State::Completed(CompletedQuery { addresses })
523 });
524
525 return;
527 }
528 }
529
530 net_trace!("no query matched");
532 }
533
534 pub(crate) fn dispatch<F, E>(&mut self, cx: &mut Context, emit: F) -> Result<(), E>
535 where
536 F: FnOnce(&mut Context, (IpRepr, UdpRepr, &[u8])) -> Result<(), E>,
537 {
538 let hop_limit = self.hop_limit.unwrap_or(64);
539
540 for q in self.queries.iter_mut().flatten() {
541 if let State::Pending(pq) = &mut q.state {
542 let servers = match pq.mdns {
546 #[cfg(feature = "socket-mdns")]
547 MulticastDns::Enabled => &[
548 #[cfg(feature = "proto-ipv6")]
549 MDNS_IPV6_ADDR,
550 #[cfg(feature = "proto-ipv4")]
551 MDNS_IPV4_ADDR,
552 ],
553 MulticastDns::Disabled => self.servers.as_slice(),
554 };
555
556 let timeout = if let Some(timeout) = pq.timeout_at {
557 timeout
558 } else {
559 let v = cx.now() + RETRANSMIT_TIMEOUT;
560 pq.timeout_at = Some(v);
561 v
562 };
563
564 if timeout < cx.now() {
566 pq.timeout_at = Some(cx.now() + RETRANSMIT_TIMEOUT);
568 pq.retransmit_at = Instant::ZERO;
569 pq.delay = RETRANSMIT_DELAY;
570
571 pq.server_idx += 1;
573 }
574 if pq.server_idx >= servers.len() {
576 net_trace!("already tried all servers.");
577 q.set_state(State::Failure);
578 continue;
579 }
580
581 if servers[pq.server_idx].is_unspecified() {
583 net_trace!("invalid unspecified DNS server addr.");
584 q.set_state(State::Failure);
585 continue;
586 }
587
588 if pq.retransmit_at > cx.now() {
589 continue;
591 }
592
593 let repr = Repr {
594 transaction_id: pq.txid,
595 flags: Flags::RECURSION_DESIRED,
596 opcode: Opcode::Query,
597 question: Question {
598 name: &pq.name,
599 type_: pq.type_,
600 },
601 };
602
603 let mut payload = [0u8; 512];
604 let payload = &mut payload[..repr.buffer_len()];
605 repr.emit(&mut Packet::new_unchecked(payload));
606
607 let dst_port = match pq.mdns {
608 #[cfg(feature = "socket-mdns")]
609 MulticastDns::Enabled => MDNS_DNS_PORT,
610 MulticastDns::Disabled => DNS_PORT,
611 };
612
613 let udp_repr = UdpRepr {
614 src_port: pq.port,
615 dst_port,
616 };
617
618 let dst_addr = servers[pq.server_idx];
619 let src_addr = match cx.get_source_address(&dst_addr) {
620 Some(src_addr) => src_addr,
621 None => {
622 net_trace!("no source address for destination {}", dst_addr);
623 q.set_state(State::Failure);
624 continue;
625 }
626 };
627
628 let ip_repr = IpRepr::new(
629 src_addr,
630 dst_addr,
631 IpProtocol::Udp,
632 udp_repr.header_len() + payload.len(),
633 hop_limit,
634 );
635
636 net_trace!(
637 "sending {} octets to {} from port {}",
638 payload.len(),
639 ip_repr.dst_addr(),
640 udp_repr.src_port
641 );
642
643 emit(cx, (ip_repr, udp_repr, payload))?;
644
645 pq.retransmit_at = cx.now() + pq.delay;
646 pq.delay = MAX_RETRANSMIT_DELAY.min(pq.delay * 2);
647
648 return Ok(());
649 }
650 }
651
652 Ok(())
654 }
655
656 pub(crate) fn poll_at(&self, _cx: &Context) -> PollAt {
657 self.queries
658 .iter()
659 .flatten()
660 .filter_map(|q| match &q.state {
661 State::Pending(pq) => Some(PollAt::Time(pq.retransmit_at)),
662 State::Completed(_) => None,
663 State::Failure => None,
664 })
665 .min()
666 .unwrap_or(PollAt::Ingress)
667 }
668}
669
670fn eq_names<'a>(
671 mut a: impl Iterator<Item = wire::Result<&'a [u8]>>,
672 mut b: impl Iterator<Item = wire::Result<&'a [u8]>>,
673) -> wire::Result<bool> {
674 loop {
675 match (a.next(), b.next()) {
676 (Some(Err(e)), _) => return Err(e),
678 (_, Some(Err(e))) => return Err(e),
679
680 (None, None) => return Ok(true),
682
683 (None, _) => return Ok(false),
685 (_, None) => return Ok(false),
686
687 (Some(Ok(la)), Some(Ok(lb))) => {
689 if la != lb {
690 return Ok(false);
691 }
692 }
693 }
694 }
695}
696
697fn copy_name<'a, const N: usize>(
698 dest: &mut Vec<u8, N>,
699 name: impl Iterator<Item = wire::Result<&'a [u8]>>,
700) -> Result<(), wire::Error> {
701 dest.truncate(0);
702
703 for label in name {
704 let label = label?;
705 dest.push(label.len() as u8).map_err(|_| wire::Error)?;
706 dest.extend_from_slice(label).map_err(|_| wire::Error)?;
707 }
708
709 dest.push(0).map_err(|_| wire::Error)?;
711
712 Ok(())
713}