sel4_async_network/
lib.rs

1//
2// Copyright 2023, Colias Group, LLC
3//
4// SPDX-License-Identifier: BSD-2-Clause
5//
6
7// Ideas for implementing operations on TCP sockets taken from:
8// https://github.com/embassy-rs/embassy/blob/main/embassy-net/src/tcp.rs
9
10#![no_std]
11
12extern crate alloc;
13
14use alloc::rc::Rc;
15use alloc::vec;
16use alloc::vec::Vec;
17use core::cell::RefCell;
18use core::future::poll_fn;
19use core::marker::PhantomData;
20use core::pin::Pin;
21use core::task::{self, Poll};
22
23use thiserror::Error;
24
25use log::info;
26use smoltcp::{
27    iface::{Config, Context, Interface, PollResult, SocketHandle, SocketSet},
28    phy::Device,
29    socket::{AnySocket, dhcpv4, dns, tcp},
30    time::{Duration, Instant},
31    wire::{DnsQueryType, IpAddress, IpCidr, IpEndpoint, IpListenEndpoint, Ipv4Address, Ipv4Cidr},
32};
33
34use sel4_async_io::{Error as AsyncIOError, ErrorKind, ErrorType, Read, Write};
35
36pub(crate) const DEFAULT_KEEP_ALIVE_INTERVAL: u64 = 75000;
37pub(crate) const DEFAULT_TCP_SOCKET_BUFFER_SIZE: usize = 65535;
38
39#[derive(Clone)]
40pub struct ManagedInterface {
41    inner: Rc<RefCell<ManagedInterfaceShared>>,
42}
43
44struct ManagedInterfaceShared {
45    iface: Interface,
46    socket_set: SocketSet<'static>,
47    dns_socket_handle: SocketHandle,
48    dhcp_socket_handle: SocketHandle,
49    dhcp_overrides: DhcpOverrides,
50}
51
52#[derive(Default)]
53pub struct DhcpOverrides {
54    pub address: Option<Ipv4Cidr>,
55    pub router: Option<Option<Ipv4Address>>,
56    pub dns_servers: Option<Vec<Ipv4Address>>,
57}
58
59pub type TcpSocket = Socket<tcp::Socket<'static>>;
60
61pub struct Socket<T> {
62    handle: SocketHandle,
63    shared: ManagedInterface,
64    _phantom: PhantomData<T>,
65}
66
67impl<T> Drop for Socket<T> {
68    fn drop(&mut self) {
69        self.shared
70            .inner
71            .borrow_mut()
72            .socket_set
73            .remove(self.handle);
74    }
75}
76
77#[derive(Copy, Clone, Debug, PartialEq, Eq, Error)]
78pub enum TcpSocketError {
79    #[error("invalid state")]
80    InvalidState(tcp::State), // TODO just use InvalidState variants of below errors?
81    #[error("recv error: {0}")]
82    RecvError(tcp::RecvError),
83    #[error("send error: {0}")]
84    SendError(tcp::SendError),
85    #[error("listen error: {0}")]
86    ListenError(tcp::ListenError),
87    #[error("connect error: {0}")]
88    ConnectError(tcp::ConnectError),
89    #[error("connection reset during connect")]
90    ConnectionResetDuringConnect,
91}
92
93impl AsyncIOError for TcpSocketError {
94    fn kind(&self) -> ErrorKind {
95        ErrorKind::Other
96    }
97}
98
99#[derive(Copy, Clone, Debug, PartialEq, Eq)]
100pub enum DnsError {
101    StartQueryError(dns::StartQueryError),
102    GetQueryResultError(dns::GetQueryResultError),
103}
104
105impl ManagedInterface {
106    pub fn new<D: Device + ?Sized>(
107        config: Config,
108        dhcp_overrides: DhcpOverrides,
109        device: &mut D,
110        instant: Instant,
111    ) -> Self {
112        let iface = Interface::new(config, device, instant);
113        let mut socket_set = SocketSet::new(vec![]);
114        let dns_socket_handle = socket_set.add(dns::Socket::new(&[], vec![]));
115        let dhcp_socket_handle = socket_set.add(dhcpv4::Socket::new());
116
117        let mut this = ManagedInterfaceShared {
118            iface,
119            socket_set,
120            dns_socket_handle,
121            dhcp_socket_handle,
122            dhcp_overrides,
123        };
124
125        this.apply_dhcp_overrides();
126
127        Self {
128            inner: Rc::new(RefCell::new(this)),
129        }
130    }
131
132    fn inner(&self) -> &Rc<RefCell<ManagedInterfaceShared>> {
133        &self.inner
134    }
135
136    pub fn new_tcp_socket(&self) -> TcpSocket {
137        self.new_tcp_socket_with_buffer_sizes(
138            DEFAULT_TCP_SOCKET_BUFFER_SIZE,
139            DEFAULT_TCP_SOCKET_BUFFER_SIZE,
140        )
141    }
142
143    pub fn new_tcp_socket_with_buffer_sizes(
144        &self,
145        rx_buffer_size: usize,
146        tx_buffer_size: usize,
147    ) -> TcpSocket {
148        let rx_buffer = tcp::SocketBuffer::new(vec![0; rx_buffer_size]);
149        let tx_buffer = tcp::SocketBuffer::new(vec![0; tx_buffer_size]);
150        self.new_socket(tcp::Socket::new(rx_buffer, tx_buffer))
151    }
152
153    pub fn new_socket<T: AnySocket<'static>>(&self, socket: T) -> Socket<T> {
154        let handle = self.inner().borrow_mut().socket_set.add(socket);
155        Socket {
156            handle,
157            shared: self.clone(),
158            _phantom: PhantomData,
159        }
160    }
161
162    pub fn poll_at(&self, timestamp: Instant) -> Option<Instant> {
163        self.inner().borrow_mut().poll_at(timestamp)
164    }
165
166    pub fn poll_delay(&self, timestamp: Instant) -> Option<Duration> {
167        self.inner().borrow_mut().poll_delay(timestamp)
168    }
169
170    pub fn poll<D: Device + ?Sized>(&self, timestamp: Instant, device: &mut D) -> bool {
171        self.inner().borrow_mut().poll(timestamp, device)
172    }
173
174    pub async fn dns_query(
175        &self,
176        name: &str,
177        query_type: DnsQueryType,
178    ) -> Result<Vec<IpAddress>, DnsError> {
179        let query_handle = {
180            let inner = &mut *self.inner().borrow_mut();
181            inner
182                .socket_set
183                .get_mut::<dns::Socket>(inner.dns_socket_handle)
184                .start_query(inner.iface.context(), name, query_type)
185                .map_err(DnsError::StartQueryError)?
186        };
187        poll_fn(|cx| {
188            let inner = &mut *self.inner().borrow_mut();
189            let socket = inner
190                .socket_set
191                .get_mut::<dns::Socket>(inner.dns_socket_handle);
192            match socket.get_query_result(query_handle) {
193                Err(dns::GetQueryResultError::Pending) => {
194                    socket.register_query_waker(query_handle, cx.waker());
195                    Poll::Pending
196                }
197                r => Poll::Ready(
198                    r.map(|heapless_vec| heapless_vec.to_vec())
199                        .map_err(DnsError::GetQueryResultError),
200                ),
201            }
202        })
203        .await
204    }
205}
206
207impl<T: AnySocket<'static>> Socket<T> {
208    pub fn with<R>(&self, f: impl FnOnce(&T) -> R) -> R {
209        let network = self.shared.inner().borrow();
210        let socket = network.socket_set.get(self.handle);
211        f(socket)
212    }
213
214    pub fn with_mut<R>(&mut self, f: impl FnOnce(&mut T) -> R) -> R {
215        let mut network = self.shared.inner().borrow_mut();
216        let socket = network.socket_set.get_mut(self.handle);
217        f(socket)
218    }
219
220    pub fn with_context_mut<R>(&mut self, f: impl FnOnce(&mut Context, &mut T) -> R) -> R {
221        let network = &mut *self.shared.inner().borrow_mut();
222        let context = network.iface.context();
223        let socket = network.socket_set.get_mut(self.handle);
224        f(context, socket)
225    }
226}
227
228impl Socket<tcp::Socket<'static>> {
229    pub async fn connect<T, U>(
230        &mut self,
231        remote_endpoint: T,
232        local_endpoint: U,
233    ) -> Result<(), TcpSocketError>
234    where
235        T: Into<IpEndpoint>,
236        U: Into<IpListenEndpoint>,
237    {
238        self.with_context_mut(|cx, socket| socket.connect(cx, remote_endpoint, local_endpoint))
239            .map_err(TcpSocketError::ConnectError)?;
240
241        poll_fn(|cx| {
242            self.with_mut(|socket| {
243                let state = socket.state();
244                match state {
245                    tcp::State::Closed | tcp::State::TimeWait => {
246                        Poll::Ready(Err(TcpSocketError::ConnectionResetDuringConnect))
247                    }
248                    tcp::State::Listen => unreachable!(), // because future holds &mut self
249                    tcp::State::SynSent | tcp::State::SynReceived => {
250                        socket.register_send_waker(cx.waker());
251                        Poll::Pending
252                    }
253                    _ => Poll::Ready(Ok(())),
254                }
255            })
256        })
257        .await
258    }
259
260    pub async fn accept_with_keep_alive(
261        &mut self,
262        local_endpoint: impl Into<IpListenEndpoint>,
263        keep_alive_interval: Option<Duration>,
264    ) -> Result<(), TcpSocketError> {
265        self.with_mut(|socket| {
266            socket
267                .listen(local_endpoint)
268                .map_err(TcpSocketError::ListenError)
269        })?;
270
271        poll_fn(|cx| {
272            self.with_mut(|socket| match socket.state() {
273                tcp::State::Listen | tcp::State::SynSent | tcp::State::SynReceived => {
274                    socket.register_recv_waker(cx.waker());
275                    Poll::Pending
276                }
277                _ => Poll::Ready(Ok(())),
278            })
279        })
280        .await?;
281
282        self.with_mut(|socket| socket.set_keep_alive(keep_alive_interval));
283
284        Ok(())
285    }
286
287    pub async fn accept(
288        &mut self,
289        local_endpoint: impl Into<IpListenEndpoint>,
290    ) -> Result<(), TcpSocketError> {
291        self.accept_with_keep_alive(
292            local_endpoint,
293            Some(Duration::from_millis(DEFAULT_KEEP_ALIVE_INTERVAL)),
294        )
295        .await
296    }
297
298    pub fn close(&mut self) {
299        self.with_mut(|socket| socket.close())
300    }
301
302    pub fn abort(&mut self) {
303        self.with_mut(|socket| socket.abort())
304    }
305}
306
307impl ErrorType for Socket<tcp::Socket<'static>> {
308    type Error = TcpSocketError;
309}
310
311impl Read for Socket<tcp::Socket<'static>> {
312    fn poll_read(
313        mut self: Pin<&mut Self>,
314        cx: &mut task::Context<'_>,
315        buf: &mut [u8],
316    ) -> Poll<Result<usize, Self::Error>> {
317        self.with_mut(|socket| match socket.recv_slice(buf) {
318            Ok(0) if buf.is_empty() => Poll::Ready(Ok(0)),
319            Ok(0) => {
320                socket.register_recv_waker(cx.waker());
321                Poll::Pending
322            }
323            Ok(n) => Poll::Ready(Ok(n)),
324            Err(tcp::RecvError::Finished) => Poll::Ready(Ok(0)),
325            Err(err) => Poll::Ready(Err(TcpSocketError::RecvError(err))),
326        })
327    }
328}
329
330impl Write for Socket<tcp::Socket<'static>> {
331    fn poll_write(
332        mut self: Pin<&mut Self>,
333        cx: &mut task::Context<'_>,
334        buf: &[u8],
335    ) -> Poll<Result<usize, Self::Error>> {
336        self.with_mut(|socket| match socket.send_slice(buf) {
337            Ok(0) if buf.is_empty() => Poll::Ready(Ok(0)),
338            Ok(0) => {
339                socket.register_send_waker(cx.waker());
340                Poll::Pending
341            }
342            Ok(n) => Poll::Ready(Ok(n)),
343            Err(err) => Poll::Ready(Err(TcpSocketError::SendError(err))),
344        })
345    }
346
347    fn poll_flush(
348        mut self: Pin<&mut Self>,
349        cx: &mut task::Context<'_>,
350    ) -> Poll<Result<(), Self::Error>> {
351        self.with_mut(|socket| {
352            let waiting_close =
353                socket.state() == tcp::State::Closed && socket.remote_endpoint().is_some();
354            if socket.send_queue() > 0 || waiting_close {
355                socket.register_send_waker(cx.waker());
356                Poll::Pending
357            } else {
358                Poll::Ready(Ok(()))
359            }
360        })
361    }
362}
363
364impl ManagedInterfaceShared {
365    fn dhcp_socket_mut(&mut self) -> &mut dhcpv4::Socket<'static> {
366        self.socket_set.get_mut(self.dhcp_socket_handle)
367    }
368
369    fn dns_socket_mut(&mut self) -> &mut dns::Socket<'static> {
370        self.socket_set.get_mut(self.dns_socket_handle)
371    }
372
373    fn poll_at(&mut self, timestamp: Instant) -> Option<Instant> {
374        self.iface.poll_at(timestamp, &self.socket_set)
375    }
376
377    fn poll_delay(&mut self, timestamp: Instant) -> Option<Duration> {
378        self.iface.poll_delay(timestamp, &self.socket_set)
379    }
380
381    fn poll<D: Device + ?Sized>(&mut self, timestamp: Instant, device: &mut D) -> bool {
382        let activity = self.iface.poll(timestamp, device, &mut self.socket_set)
383            == PollResult::SocketStateChanged;
384        if activity {
385            self.poll_dhcp();
386        }
387        activity
388    }
389
390    // TODO should dhcp events instead just be monitored in a task?
391    fn poll_dhcp(&mut self) {
392        if let Some(event) = self.dhcp_socket_mut().poll() {
393            let event = free_dhcp_event(event);
394            match event {
395                dhcpv4::Event::Configured(config) => {
396                    info!("DHCP config acquired");
397                    if self.dhcp_overrides.address.is_none() {
398                        self.set_address(config.address);
399                    }
400                    if self.dhcp_overrides.router.is_none() {
401                        self.set_router(config.router);
402                    }
403                    if self.dhcp_overrides.dns_servers.is_none() {
404                        self.set_dns_servers(&convert_dns_servers(&config.dns_servers));
405                    }
406                }
407                dhcpv4::Event::Deconfigured => {
408                    info!("DHCP config lost");
409                    if self.dhcp_overrides.address.is_none() {
410                        self.clear_address();
411                    }
412                    if self.dhcp_overrides.router.is_none() {
413                        self.clear_router();
414                    }
415                    if self.dhcp_overrides.dns_servers.is_none() {
416                        self.clear_dns_servers();
417                    }
418                }
419            }
420        }
421    }
422
423    fn set_address(&mut self, address: Ipv4Cidr) {
424        let address = IpCidr::Ipv4(address);
425        info!("IP address: {address}");
426        self.iface.update_ip_addrs(|addrs| {
427            if let Some(dest) = addrs.iter_mut().next() {
428                *dest = address;
429            } else {
430                addrs.push(address).unwrap();
431            }
432        });
433    }
434
435    fn clear_address(&mut self) {
436        let cidr = Ipv4Cidr::new(Ipv4Address::UNSPECIFIED, 0);
437        self.iface.update_ip_addrs(|addrs| {
438            if let Some(dest) = addrs.iter_mut().next() {
439                *dest = IpCidr::Ipv4(cidr);
440            }
441        });
442    }
443
444    fn set_router(&mut self, router: Option<Ipv4Address>) {
445        if let Some(router) = router {
446            info!("Default gateway: {router}");
447            self.iface
448                .routes_mut()
449                .add_default_ipv4_route(router)
450                .unwrap();
451        } else {
452            info!("Default gateway: (none)");
453            self.iface.routes_mut().remove_default_ipv4_route();
454        }
455    }
456
457    fn clear_router(&mut self) {
458        self.iface.routes_mut().remove_default_ipv4_route();
459    }
460
461    fn set_dns_servers(&mut self, dns_servers: &[IpAddress]) {
462        for (i, s) in dns_servers.iter().enumerate() {
463            info!("DNS server {i}: {s}");
464        }
465        self.dns_socket_mut().update_servers(dns_servers);
466    }
467
468    fn clear_dns_servers(&mut self) {
469        self.dns_socket_mut().update_servers(&[]);
470    }
471
472    fn apply_dhcp_overrides(&mut self) {
473        if let Some(address) = self.dhcp_overrides.address {
474            self.set_address(address);
475        }
476        if let Some(router) = self.dhcp_overrides.router {
477            self.set_router(router);
478        }
479        if let Some(dns_servers) = self
480            .dhcp_overrides
481            .dns_servers
482            .as_deref()
483            .map(convert_dns_servers)
484        {
485            self.set_dns_servers(&dns_servers);
486        }
487    }
488}
489
490fn free_dhcp_event(event: dhcpv4::Event) -> dhcpv4::Event<'static> {
491    match event {
492        dhcpv4::Event::Deconfigured => dhcpv4::Event::Deconfigured,
493        dhcpv4::Event::Configured(config) => dhcpv4::Event::Configured(free_dhcp_config(config)),
494    }
495}
496
497fn free_dhcp_config(config: dhcpv4::Config) -> dhcpv4::Config<'static> {
498    dhcpv4::Config {
499        server: config.server,
500        address: config.address,
501        router: config.router,
502        dns_servers: config.dns_servers,
503        packet: None,
504    }
505}
506
507fn convert_dns_servers(dns_servers: &[Ipv4Address]) -> Vec<IpAddress> {
508    dns_servers
509        .iter()
510        .copied()
511        .map(From::from)
512        .collect::<Vec<_>>()
513}