1#[cfg(feature = "async")]
2use core::task::Waker;
3
4use heapless::Vec;
5use managed::ManagedSlice;
6
7use crate::config::{DNS_MAX_NAME_SIZE, DNS_MAX_RESULT_COUNT, DNS_MAX_SERVER_COUNT};
8use crate::socket::{Context, PollAt};
9use crate::time::{Duration, Instant};
10use crate::wire::dns::{Flags, Opcode, Packet, Question, Rcode, Record, RecordData, Repr, Type};
11use crate::wire::{self, IpAddress, IpProtocol, IpRepr, UdpRepr};
12
13#[cfg(feature = "async")]
14use super::WakerRegistration;
15
16const DNS_PORT: u16 = 53;
17const MDNS_DNS_PORT: u16 = 5353;
18const RETRANSMIT_DELAY: Duration = Duration::from_millis(1_000);
19const MAX_RETRANSMIT_DELAY: Duration = Duration::from_millis(10_000);
20const RETRANSMIT_TIMEOUT: Duration = Duration::from_millis(10_000); #[cfg(feature = "proto-ipv6")]
23const MDNS_IPV6_ADDR: IpAddress = IpAddress::Ipv6(crate::wire::Ipv6Address([
24 0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xfb,
25]));
26
27#[cfg(feature = "proto-ipv4")]
28const MDNS_IPV4_ADDR: IpAddress = IpAddress::Ipv4(crate::wire::Ipv4Address([224, 0, 0, 251]));
29
30#[derive(Debug, PartialEq, Eq, Clone, Copy)]
32#[cfg_attr(feature = "defmt", derive(defmt::Format))]
33pub enum StartQueryError {
34 NoFreeSlot,
35 InvalidName,
36 NameTooLong,
37}
38
39impl core::fmt::Display for StartQueryError {
40 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
41 match self {
42 StartQueryError::NoFreeSlot => write!(f, "No free slot"),
43 StartQueryError::InvalidName => write!(f, "Invalid name"),
44 StartQueryError::NameTooLong => write!(f, "Name too long"),
45 }
46 }
47}
48
49#[cfg(feature = "std")]
50impl std::error::Error for StartQueryError {}
51
52#[derive(Debug, PartialEq, Eq, Clone, Copy)]
54#[cfg_attr(feature = "defmt", derive(defmt::Format))]
55pub enum GetQueryResultError {
56 Pending,
58 Failed,
60}
61
62impl core::fmt::Display for GetQueryResultError {
63 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
64 match self {
65 GetQueryResultError::Pending => write!(f, "Query is not done yet"),
66 GetQueryResultError::Failed => write!(f, "Query failed"),
67 }
68 }
69}
70
71#[cfg(feature = "std")]
72impl std::error::Error for GetQueryResultError {}
73
74#[derive(Debug)]
79pub struct DnsQuery {
80 state: State,
81
82 #[cfg(feature = "async")]
83 waker: WakerRegistration,
84}
85
86impl DnsQuery {
87 fn set_state(&mut self, state: State) {
88 self.state = state;
89 #[cfg(feature = "async")]
90 self.waker.wake();
91 }
92}
93
94#[derive(Debug)]
95#[allow(clippy::large_enum_variant)]
96enum State {
97 Pending(PendingQuery),
98 Completed(CompletedQuery),
99 Failure,
100}
101
102#[derive(Debug)]
103struct PendingQuery {
104 name: Vec<u8, DNS_MAX_NAME_SIZE>,
105 type_: Type,
106
107 port: u16, txid: u16, timeout_at: Option<Instant>,
111 retransmit_at: Instant,
112 delay: Duration,
113
114 server_idx: usize,
115 mdns: MulticastDns,
116}
117
118#[derive(Debug)]
119pub enum MulticastDns {
120 Disabled,
121 #[cfg(feature = "socket-mdns")]
122 Enabled,
123}
124
125#[derive(Debug)]
126struct CompletedQuery {
127 addresses: Vec<IpAddress, DNS_MAX_RESULT_COUNT>,
128}
129
130#[derive(Clone, Copy)]
132pub struct QueryHandle(usize);
133
134#[derive(Debug)]
139pub struct Socket<'a> {
140 servers: Vec<IpAddress, DNS_MAX_SERVER_COUNT>,
141 queries: ManagedSlice<'a, Option<DnsQuery>>,
142
143 hop_limit: Option<u8>,
145}
146
147impl<'a> Socket<'a> {
148 pub fn new<Q>(servers: &[IpAddress], queries: Q) -> Socket<'a>
154 where
155 Q: Into<ManagedSlice<'a, Option<DnsQuery>>>,
156 {
157 Socket {
158 servers: Vec::from_slice(servers).unwrap(),
159 queries: queries.into(),
160 hop_limit: None,
161 }
162 }
163
164 pub fn update_servers(&mut self, servers: &[IpAddress]) {
170 self.servers = Vec::from_slice(servers).unwrap();
171 }
172
173 pub fn hop_limit(&self) -> Option<u8> {
177 self.hop_limit
178 }
179
180 pub fn set_hop_limit(&mut self, hop_limit: Option<u8>) {
192 if let Some(0) = hop_limit {
194 panic!("the time-to-live value of a packet must not be zero")
195 }
196
197 self.hop_limit = hop_limit
198 }
199
200 fn find_free_query(&mut self) -> Option<QueryHandle> {
201 for (i, q) in self.queries.iter().enumerate() {
202 if q.is_none() {
203 return Some(QueryHandle(i));
204 }
205 }
206
207 match &mut self.queries {
208 ManagedSlice::Borrowed(_) => None,
209 #[cfg(feature = "alloc")]
210 ManagedSlice::Owned(queries) => {
211 queries.push(None);
212 let index = queries.len() - 1;
213 Some(QueryHandle(index))
214 }
215 }
216 }
217
218 pub fn start_query(
224 &mut self,
225 cx: &mut Context,
226 name: &str,
227 query_type: Type,
228 ) -> Result<QueryHandle, StartQueryError> {
229 let mut name = name.as_bytes();
230
231 if name.is_empty() {
232 net_trace!("invalid name: zero length");
233 return Err(StartQueryError::InvalidName);
234 }
235
236 if name[name.len() - 1] == b'.' {
238 name = &name[..name.len() - 1];
239 }
240
241 let mut raw_name: Vec<u8, DNS_MAX_NAME_SIZE> = Vec::new();
242
243 let mut mdns = MulticastDns::Disabled;
244 #[cfg(feature = "socket-mdns")]
245 if name.split(|&c| c == b'.').last().unwrap() == b"local" {
246 net_trace!("Starting a mDNS query");
247 mdns = MulticastDns::Enabled;
248 }
249
250 for s in name.split(|&c| c == b'.') {
251 if s.len() > 63 {
252 net_trace!("invalid name: too long label");
253 return Err(StartQueryError::InvalidName);
254 }
255 if s.is_empty() {
256 net_trace!("invalid name: zero length label");
257 return Err(StartQueryError::InvalidName);
258 }
259
260 raw_name
262 .push(s.len() as u8)
263 .map_err(|_| StartQueryError::NameTooLong)?;
264 raw_name
265 .extend_from_slice(s)
266 .map_err(|_| StartQueryError::NameTooLong)?;
267 }
268
269 raw_name
271 .push(0x00)
272 .map_err(|_| StartQueryError::NameTooLong)?;
273
274 self.start_query_raw(cx, &raw_name, query_type, mdns)
275 }
276
277 pub fn start_query_raw(
282 &mut self,
283 cx: &mut Context,
284 raw_name: &[u8],
285 query_type: Type,
286 mdns: MulticastDns,
287 ) -> Result<QueryHandle, StartQueryError> {
288 let handle = self.find_free_query().ok_or(StartQueryError::NoFreeSlot)?;
289
290 self.queries[handle.0] = Some(DnsQuery {
291 state: State::Pending(PendingQuery {
292 name: Vec::from_slice(raw_name).map_err(|_| StartQueryError::NameTooLong)?,
293 type_: query_type,
294 txid: cx.rand().rand_u16(),
295 port: cx.rand().rand_source_port(),
296 delay: RETRANSMIT_DELAY,
297 timeout_at: None,
298 retransmit_at: Instant::ZERO,
299 server_idx: 0,
300 mdns,
301 }),
302 #[cfg(feature = "async")]
303 waker: WakerRegistration::new(),
304 });
305 Ok(handle)
306 }
307
308 pub fn get_query_result(
315 &mut self,
316 handle: QueryHandle,
317 ) -> Result<Vec<IpAddress, DNS_MAX_RESULT_COUNT>, GetQueryResultError> {
318 let slot = &mut self.queries[handle.0];
319 let q = slot.as_mut().unwrap();
320 match &mut q.state {
321 State::Pending(_) => Err(GetQueryResultError::Pending),
323 State::Completed(q) => {
325 let res = q.addresses.clone();
326 *slot = None; Ok(res)
328 }
329 State::Failure => {
330 *slot = None; Err(GetQueryResultError::Failed)
332 }
333 }
334 }
335
336 pub fn cancel_query(&mut self, handle: QueryHandle) {
342 let slot = &mut self.queries[handle.0];
343 if slot.is_none() {
344 panic!("Canceling query in a free slot.")
345 }
346 *slot = None; }
348
349 #[cfg(feature = "async")]
357 pub fn register_query_waker(&mut self, handle: QueryHandle, waker: &Waker) {
358 self.queries[handle.0]
359 .as_mut()
360 .unwrap()
361 .waker
362 .register(waker);
363 }
364
365 pub(crate) fn accepts(&self, ip_repr: &IpRepr, udp_repr: &UdpRepr) -> bool {
366 (udp_repr.src_port == DNS_PORT
367 && self
368 .servers
369 .iter()
370 .any(|server| *server == ip_repr.src_addr()))
371 || (udp_repr.src_port == MDNS_DNS_PORT)
372 }
373
374 pub(crate) fn process(
375 &mut self,
376 _cx: &mut Context,
377 ip_repr: &IpRepr,
378 udp_repr: &UdpRepr,
379 payload: &[u8],
380 ) {
381 debug_assert!(self.accepts(ip_repr, udp_repr));
382
383 let size = payload.len();
384
385 net_trace!(
386 "receiving {} octets from {:?}:{}",
387 size,
388 ip_repr.src_addr(),
389 udp_repr.dst_port
390 );
391
392 let p = match Packet::new_checked(payload) {
393 Ok(x) => x,
394 Err(_) => {
395 net_trace!("dns packet malformed");
396 return;
397 }
398 };
399 if p.opcode() != Opcode::Query {
400 net_trace!("unwanted opcode {:?}", p.opcode());
401 return;
402 }
403
404 if !p.flags().contains(Flags::RESPONSE) {
405 net_trace!("packet doesn't have response bit set");
406 return;
407 }
408
409 if p.question_count() != 1 {
410 net_trace!("bad question count {:?}", p.question_count());
411 return;
412 }
413
414 for q in self.queries.iter_mut().flatten() {
416 if let State::Pending(pq) = &mut q.state {
417 if udp_repr.dst_port != pq.port || p.transaction_id() != pq.txid {
418 continue;
419 }
420
421 if p.rcode() == Rcode::NXDomain {
422 net_trace!("rcode NXDomain");
423 q.set_state(State::Failure);
424 continue;
425 }
426
427 let payload = p.payload();
428 let (mut payload, question) = match Question::parse(payload) {
429 Ok(x) => x,
430 Err(_) => {
431 net_trace!("question malformed");
432 return;
433 }
434 };
435
436 if question.type_ != pq.type_ {
437 net_trace!("question type mismatch");
438 return;
439 }
440
441 match eq_names(p.parse_name(question.name), p.parse_name(&pq.name)) {
442 Ok(true) => {}
443 Ok(false) => {
444 net_trace!("question name mismatch");
445 return;
446 }
447 Err(_) => {
448 net_trace!("dns question name malformed");
449 return;
450 }
451 }
452
453 let mut addresses = Vec::new();
454
455 for _ in 0..p.answer_record_count() {
456 let (payload2, r) = match Record::parse(payload) {
457 Ok(x) => x,
458 Err(_) => {
459 net_trace!("dns answer record malformed");
460 return;
461 }
462 };
463 payload = payload2;
464
465 match eq_names(p.parse_name(r.name), p.parse_name(&pq.name)) {
466 Ok(true) => {}
467 Ok(false) => {
468 net_trace!("answer name mismatch: {:?}", r);
469 continue;
470 }
471 Err(_) => {
472 net_trace!("dns answer record name malformed");
473 return;
474 }
475 }
476
477 match r.data {
478 #[cfg(feature = "proto-ipv4")]
479 RecordData::A(addr) => {
480 net_trace!("A: {:?}", addr);
481 if addresses.push(addr.into()).is_err() {
482 net_trace!("too many addresses in response, ignoring {:?}", addr);
483 }
484 }
485 #[cfg(feature = "proto-ipv6")]
486 RecordData::Aaaa(addr) => {
487 net_trace!("AAAA: {:?}", addr);
488 if addresses.push(addr.into()).is_err() {
489 net_trace!("too many addresses in response, ignoring {:?}", addr);
490 }
491 }
492 RecordData::Cname(name) => {
493 net_trace!("CNAME: {:?}", name);
494
495 if copy_name(&mut pq.name, p.parse_name(name)).is_err() {
503 net_trace!("dns answer cname malformed");
504 return;
505 }
506 }
507 RecordData::Other(type_, data) => {
508 net_trace!("unknown: {:?} {:?}", type_, data)
509 }
510 }
511 }
512
513 q.set_state(if addresses.is_empty() {
514 State::Failure
515 } else {
516 State::Completed(CompletedQuery { addresses })
517 });
518
519 return;
521 }
522 }
523
524 net_trace!("no query matched");
526 }
527
528 pub(crate) fn dispatch<F, E>(&mut self, cx: &mut Context, emit: F) -> Result<(), E>
529 where
530 F: FnOnce(&mut Context, (IpRepr, UdpRepr, &[u8])) -> Result<(), E>,
531 {
532 let hop_limit = self.hop_limit.unwrap_or(64);
533
534 for q in self.queries.iter_mut().flatten() {
535 if let State::Pending(pq) = &mut q.state {
536 let servers = match pq.mdns {
540 #[cfg(feature = "socket-mdns")]
541 MulticastDns::Enabled => &[
542 #[cfg(feature = "proto-ipv6")]
543 MDNS_IPV6_ADDR,
544 #[cfg(feature = "proto-ipv4")]
545 MDNS_IPV4_ADDR,
546 ],
547 MulticastDns::Disabled => self.servers.as_slice(),
548 };
549
550 let timeout = if let Some(timeout) = pq.timeout_at {
551 timeout
552 } else {
553 let v = cx.now() + RETRANSMIT_TIMEOUT;
554 pq.timeout_at = Some(v);
555 v
556 };
557
558 if timeout < cx.now() {
560 pq.timeout_at = Some(cx.now() + RETRANSMIT_TIMEOUT);
562 pq.retransmit_at = Instant::ZERO;
563 pq.delay = RETRANSMIT_DELAY;
564
565 pq.server_idx += 1;
567 }
568 if pq.server_idx >= servers.len() {
570 net_trace!("already tried all servers.");
571 q.set_state(State::Failure);
572 continue;
573 }
574
575 if servers[pq.server_idx].is_unspecified() {
577 net_trace!("invalid unspecified DNS server addr.");
578 q.set_state(State::Failure);
579 continue;
580 }
581
582 if pq.retransmit_at > cx.now() {
583 continue;
585 }
586
587 let repr = Repr {
588 transaction_id: pq.txid,
589 flags: Flags::RECURSION_DESIRED,
590 opcode: Opcode::Query,
591 question: Question {
592 name: &pq.name,
593 type_: pq.type_,
594 },
595 };
596
597 let mut payload = [0u8; 512];
598 let payload = &mut payload[..repr.buffer_len()];
599 repr.emit(&mut Packet::new_unchecked(payload));
600
601 let dst_port = match pq.mdns {
602 #[cfg(feature = "socket-mdns")]
603 MulticastDns::Enabled => MDNS_DNS_PORT,
604 MulticastDns::Disabled => DNS_PORT,
605 };
606
607 let udp_repr = UdpRepr {
608 src_port: pq.port,
609 dst_port,
610 };
611
612 let dst_addr = servers[pq.server_idx];
613 let src_addr = cx.get_source_address(&dst_addr).unwrap(); let ip_repr = IpRepr::new(
615 src_addr,
616 dst_addr,
617 IpProtocol::Udp,
618 udp_repr.header_len() + payload.len(),
619 hop_limit,
620 );
621
622 net_trace!(
623 "sending {} octets to {} from port {}",
624 payload.len(),
625 ip_repr.dst_addr(),
626 udp_repr.src_port
627 );
628
629 emit(cx, (ip_repr, udp_repr, payload))?;
630
631 pq.retransmit_at = cx.now() + pq.delay;
632 pq.delay = MAX_RETRANSMIT_DELAY.min(pq.delay * 2);
633
634 return Ok(());
635 }
636 }
637
638 Ok(())
640 }
641
642 pub(crate) fn poll_at(&self, _cx: &Context) -> PollAt {
643 self.queries
644 .iter()
645 .flatten()
646 .filter_map(|q| match &q.state {
647 State::Pending(pq) => Some(PollAt::Time(pq.retransmit_at)),
648 State::Completed(_) => None,
649 State::Failure => None,
650 })
651 .min()
652 .unwrap_or(PollAt::Ingress)
653 }
654}
655
656fn eq_names<'a>(
657 mut a: impl Iterator<Item = wire::Result<&'a [u8]>>,
658 mut b: impl Iterator<Item = wire::Result<&'a [u8]>>,
659) -> wire::Result<bool> {
660 loop {
661 match (a.next(), b.next()) {
662 (Some(Err(e)), _) => return Err(e),
664 (_, Some(Err(e))) => return Err(e),
665
666 (None, None) => return Ok(true),
668
669 (None, _) => return Ok(false),
671 (_, None) => return Ok(false),
672
673 (Some(Ok(la)), Some(Ok(lb))) => {
675 if la != lb {
676 return Ok(false);
677 }
678 }
679 }
680 }
681}
682
683fn copy_name<'a, const N: usize>(
684 dest: &mut Vec<u8, N>,
685 name: impl Iterator<Item = wire::Result<&'a [u8]>>,
686) -> Result<(), wire::Error> {
687 dest.truncate(0);
688
689 for label in name {
690 let label = label?;
691 dest.push(label.len() as u8).map_err(|_| wire::Error)?;
692 dest.extend_from_slice(label).map_err(|_| wire::Error)?;
693 }
694
695 dest.push(0).map_err(|_| wire::Error)?;
697
698 Ok(())
699}