sel4_async_network/
lib.rs

1//
2// Copyright 2023, Colias Group, LLC
3//
4// SPDX-License-Identifier: BSD-2-Clause
5//
6
7// Ideas for implementing operations on TCP sockets taken from:
8// https://github.com/embassy-rs/embassy/blob/main/embassy-net/src/tcp.rs
9
10#![no_std]
11
12extern crate alloc;
13
14use alloc::rc::Rc;
15use alloc::vec;
16use alloc::vec::Vec;
17use core::cell::RefCell;
18use core::future::poll_fn;
19use core::marker::PhantomData;
20use core::pin::Pin;
21use core::task::{self, Poll};
22
23use log::info;
24use smoltcp::{
25    iface::{Config, Context, Interface, SocketHandle, SocketSet},
26    phy::Device,
27    socket::{dhcpv4, dns, tcp, AnySocket},
28    time::{Duration, Instant},
29    wire::{DnsQueryType, IpAddress, IpCidr, IpEndpoint, IpListenEndpoint, Ipv4Address, Ipv4Cidr},
30};
31
32use sel4_async_io::{Error as AsyncIOError, ErrorKind, ErrorType, Read, Write};
33
34pub(crate) const DEFAULT_KEEP_ALIVE_INTERVAL: u64 = 75000;
35pub(crate) const DEFAULT_TCP_SOCKET_BUFFER_SIZE: usize = 65535;
36
37#[derive(Clone)]
38pub struct ManagedInterface {
39    inner: Rc<RefCell<ManagedInterfaceShared>>,
40}
41
42struct ManagedInterfaceShared {
43    iface: Interface,
44    socket_set: SocketSet<'static>,
45    dns_socket_handle: SocketHandle,
46    dhcp_socket_handle: SocketHandle,
47    dhcp_overrides: DhcpOverrides,
48}
49
50#[derive(Default)]
51pub struct DhcpOverrides {
52    pub address: Option<Ipv4Cidr>,
53    pub router: Option<Option<Ipv4Address>>,
54    pub dns_servers: Option<Vec<Ipv4Address>>,
55}
56
57pub type TcpSocket = Socket<tcp::Socket<'static>>;
58
59pub struct Socket<T> {
60    handle: SocketHandle,
61    shared: ManagedInterface,
62    _phantom: PhantomData<T>,
63}
64
65impl<T> Drop for Socket<T> {
66    fn drop(&mut self) {
67        self.shared
68            .inner
69            .borrow_mut()
70            .socket_set
71            .remove(self.handle);
72    }
73}
74
75#[derive(Copy, Clone, Debug, PartialEq, Eq)]
76pub enum TcpSocketError {
77    InvalidState(tcp::State), // TODO just use InvalidState variants of below errors?
78    RecvError(tcp::RecvError),
79    SendError(tcp::SendError),
80    ListenError(tcp::ListenError),
81    ConnectError(tcp::ConnectError),
82    ConnectionResetDuringConnect,
83}
84
85impl AsyncIOError for TcpSocketError {
86    fn kind(&self) -> ErrorKind {
87        ErrorKind::Other
88    }
89}
90
91#[derive(Copy, Clone, Debug, PartialEq, Eq)]
92pub enum DnsError {
93    StartQueryError(dns::StartQueryError),
94    GetQueryResultError(dns::GetQueryResultError),
95}
96
97impl ManagedInterface {
98    pub fn new<D: Device + ?Sized>(
99        config: Config,
100        dhcp_overrides: DhcpOverrides,
101        device: &mut D,
102        instant: Instant,
103    ) -> Self {
104        let iface = Interface::new(config, device, instant);
105        let mut socket_set = SocketSet::new(vec![]);
106        let dns_socket_handle = socket_set.add(dns::Socket::new(&[], vec![]));
107        let dhcp_socket_handle = socket_set.add(dhcpv4::Socket::new());
108
109        let mut this = ManagedInterfaceShared {
110            iface,
111            socket_set,
112            dns_socket_handle,
113            dhcp_socket_handle,
114            dhcp_overrides,
115        };
116
117        this.apply_dhcp_overrides();
118
119        Self {
120            inner: Rc::new(RefCell::new(this)),
121        }
122    }
123
124    fn inner(&self) -> &Rc<RefCell<ManagedInterfaceShared>> {
125        &self.inner
126    }
127
128    pub fn new_tcp_socket(&self) -> TcpSocket {
129        self.new_tcp_socket_with_buffer_sizes(
130            DEFAULT_TCP_SOCKET_BUFFER_SIZE,
131            DEFAULT_TCP_SOCKET_BUFFER_SIZE,
132        )
133    }
134
135    pub fn new_tcp_socket_with_buffer_sizes(
136        &self,
137        rx_buffer_size: usize,
138        tx_buffer_size: usize,
139    ) -> TcpSocket {
140        let rx_buffer = tcp::SocketBuffer::new(vec![0; rx_buffer_size]);
141        let tx_buffer = tcp::SocketBuffer::new(vec![0; tx_buffer_size]);
142        self.new_socket(tcp::Socket::new(rx_buffer, tx_buffer))
143    }
144
145    pub fn new_socket<T: AnySocket<'static>>(&self, socket: T) -> Socket<T> {
146        let handle = self.inner().borrow_mut().socket_set.add(socket);
147        Socket {
148            handle,
149            shared: self.clone(),
150            _phantom: PhantomData,
151        }
152    }
153
154    pub fn poll_at(&self, timestamp: Instant) -> Option<Instant> {
155        self.inner().borrow_mut().poll_at(timestamp)
156    }
157
158    pub fn poll_delay(&self, timestamp: Instant) -> Option<Duration> {
159        self.inner().borrow_mut().poll_delay(timestamp)
160    }
161
162    pub fn poll<D: Device + ?Sized>(&self, timestamp: Instant, device: &mut D) -> bool {
163        self.inner().borrow_mut().poll(timestamp, device)
164    }
165
166    pub async fn dns_query(
167        &self,
168        name: &str,
169        query_type: DnsQueryType,
170    ) -> Result<Vec<IpAddress>, DnsError> {
171        let query_handle = {
172            let inner = &mut *self.inner().borrow_mut();
173            inner
174                .socket_set
175                .get_mut::<dns::Socket>(inner.dns_socket_handle)
176                .start_query(inner.iface.context(), name, query_type)
177                .map_err(DnsError::StartQueryError)?
178        };
179        poll_fn(|cx| {
180            let inner = &mut *self.inner().borrow_mut();
181            let socket = inner
182                .socket_set
183                .get_mut::<dns::Socket>(inner.dns_socket_handle);
184            match socket.get_query_result(query_handle) {
185                Err(dns::GetQueryResultError::Pending) => {
186                    socket.register_query_waker(query_handle, cx.waker());
187                    Poll::Pending
188                }
189                r => Poll::Ready(
190                    r.map(|heapless_vec| heapless_vec.to_vec())
191                        .map_err(DnsError::GetQueryResultError),
192                ),
193            }
194        })
195        .await
196    }
197}
198
199impl<T: AnySocket<'static>> Socket<T> {
200    pub fn with<R>(&self, f: impl FnOnce(&T) -> R) -> R {
201        let network = self.shared.inner().borrow();
202        let socket = network.socket_set.get(self.handle);
203        f(socket)
204    }
205
206    pub fn with_mut<R>(&mut self, f: impl FnOnce(&mut T) -> R) -> R {
207        let mut network = self.shared.inner().borrow_mut();
208        let socket = network.socket_set.get_mut(self.handle);
209        f(socket)
210    }
211
212    pub fn with_context_mut<R>(&mut self, f: impl FnOnce(&mut Context, &mut T) -> R) -> R {
213        let network = &mut *self.shared.inner().borrow_mut();
214        let context = network.iface.context();
215        let socket = network.socket_set.get_mut(self.handle);
216        f(context, socket)
217    }
218}
219
220impl Socket<tcp::Socket<'static>> {
221    pub async fn connect<T, U>(
222        &mut self,
223        remote_endpoint: T,
224        local_endpoint: U,
225    ) -> Result<(), TcpSocketError>
226    where
227        T: Into<IpEndpoint>,
228        U: Into<IpListenEndpoint>,
229    {
230        self.with_context_mut(|cx, socket| socket.connect(cx, remote_endpoint, local_endpoint))
231            .map_err(TcpSocketError::ConnectError)?;
232
233        poll_fn(|cx| {
234            self.with_mut(|socket| {
235                let state = socket.state();
236                match state {
237                    tcp::State::Closed | tcp::State::TimeWait => {
238                        Poll::Ready(Err(TcpSocketError::ConnectionResetDuringConnect))
239                    }
240                    tcp::State::Listen => unreachable!(), // because future holds &mut self
241                    tcp::State::SynSent | tcp::State::SynReceived => {
242                        socket.register_send_waker(cx.waker());
243                        Poll::Pending
244                    }
245                    _ => Poll::Ready(Ok(())),
246                }
247            })
248        })
249        .await
250    }
251
252    pub async fn accept_with_keep_alive(
253        &mut self,
254        local_endpoint: impl Into<IpListenEndpoint>,
255        keep_alive_interval: Option<Duration>,
256    ) -> Result<(), TcpSocketError> {
257        self.with_mut(|socket| {
258            socket
259                .listen(local_endpoint)
260                .map_err(TcpSocketError::ListenError)
261        })?;
262
263        poll_fn(|cx| {
264            self.with_mut(|socket| match socket.state() {
265                tcp::State::Listen | tcp::State::SynSent | tcp::State::SynReceived => {
266                    socket.register_recv_waker(cx.waker());
267                    Poll::Pending
268                }
269                _ => Poll::Ready(Ok(())),
270            })
271        })
272        .await?;
273
274        self.with_mut(|socket| socket.set_keep_alive(keep_alive_interval));
275
276        Ok(())
277    }
278
279    pub async fn accept(
280        &mut self,
281        local_endpoint: impl Into<IpListenEndpoint>,
282    ) -> Result<(), TcpSocketError> {
283        self.accept_with_keep_alive(
284            local_endpoint,
285            Some(Duration::from_millis(DEFAULT_KEEP_ALIVE_INTERVAL)),
286        )
287        .await
288    }
289
290    pub fn close(&mut self) {
291        self.with_mut(|socket| socket.close())
292    }
293
294    pub fn abort(&mut self) {
295        self.with_mut(|socket| socket.abort())
296    }
297}
298
299impl ErrorType for Socket<tcp::Socket<'static>> {
300    type Error = TcpSocketError;
301}
302
303impl Read for Socket<tcp::Socket<'static>> {
304    fn poll_read(
305        mut self: Pin<&mut Self>,
306        cx: &mut task::Context<'_>,
307        buf: &mut [u8],
308    ) -> Poll<Result<usize, Self::Error>> {
309        self.with_mut(|socket| match socket.recv_slice(buf) {
310            Ok(0) if buf.is_empty() => Poll::Ready(Ok(0)),
311            Ok(0) => {
312                socket.register_recv_waker(cx.waker());
313                Poll::Pending
314            }
315            Ok(n) => Poll::Ready(Ok(n)),
316            Err(tcp::RecvError::Finished) => Poll::Ready(Ok(0)),
317            Err(err) => Poll::Ready(Err(TcpSocketError::RecvError(err))),
318        })
319    }
320}
321
322impl Write for Socket<tcp::Socket<'static>> {
323    fn poll_write(
324        mut self: Pin<&mut Self>,
325        cx: &mut task::Context<'_>,
326        buf: &[u8],
327    ) -> Poll<Result<usize, Self::Error>> {
328        self.with_mut(|socket| match socket.send_slice(buf) {
329            Ok(0) if buf.is_empty() => Poll::Ready(Ok(0)),
330            Ok(0) => {
331                socket.register_send_waker(cx.waker());
332                Poll::Pending
333            }
334            Ok(n) => Poll::Ready(Ok(n)),
335            Err(err) => Poll::Ready(Err(TcpSocketError::SendError(err))),
336        })
337    }
338
339    fn poll_flush(
340        mut self: Pin<&mut Self>,
341        cx: &mut task::Context<'_>,
342    ) -> Poll<Result<(), Self::Error>> {
343        self.with_mut(|socket| {
344            let waiting_close =
345                socket.state() == tcp::State::Closed && socket.remote_endpoint().is_some();
346            if socket.send_queue() > 0 || waiting_close {
347                socket.register_send_waker(cx.waker());
348                Poll::Pending
349            } else {
350                Poll::Ready(Ok(()))
351            }
352        })
353    }
354}
355
356impl ManagedInterfaceShared {
357    fn dhcp_socket_mut(&mut self) -> &mut dhcpv4::Socket<'static> {
358        self.socket_set.get_mut(self.dhcp_socket_handle)
359    }
360
361    fn dns_socket_mut(&mut self) -> &mut dns::Socket<'static> {
362        self.socket_set.get_mut(self.dns_socket_handle)
363    }
364
365    fn poll_at(&mut self, timestamp: Instant) -> Option<Instant> {
366        self.iface.poll_at(timestamp, &self.socket_set)
367    }
368
369    fn poll_delay(&mut self, timestamp: Instant) -> Option<Duration> {
370        self.iface.poll_delay(timestamp, &self.socket_set)
371    }
372
373    fn poll<D: Device + ?Sized>(&mut self, timestamp: Instant, device: &mut D) -> bool {
374        let activity = self.iface.poll(timestamp, device, &mut self.socket_set);
375        if activity {
376            self.poll_dhcp();
377        }
378        activity
379    }
380
381    // TODO should dhcp events instead just be monitored in a task?
382    fn poll_dhcp(&mut self) {
383        if let Some(event) = self.dhcp_socket_mut().poll() {
384            let event = free_dhcp_event(event);
385            match event {
386                dhcpv4::Event::Configured(config) => {
387                    info!("DHCP config acquired");
388                    if self.dhcp_overrides.address.is_none() {
389                        self.set_address(config.address);
390                    }
391                    if self.dhcp_overrides.router.is_none() {
392                        self.set_router(config.router);
393                    }
394                    if self.dhcp_overrides.dns_servers.is_none() {
395                        self.set_dns_servers(&convert_dns_servers(&config.dns_servers));
396                    }
397                }
398                dhcpv4::Event::Deconfigured => {
399                    info!("DHCP config lost");
400                    if self.dhcp_overrides.address.is_none() {
401                        self.clear_address();
402                    }
403                    if self.dhcp_overrides.router.is_none() {
404                        self.clear_router();
405                    }
406                    if self.dhcp_overrides.dns_servers.is_none() {
407                        self.clear_dns_servers();
408                    }
409                }
410            }
411        }
412    }
413
414    fn set_address(&mut self, address: Ipv4Cidr) {
415        let address = IpCidr::Ipv4(address);
416        info!("IP address: {}", address);
417        self.iface.update_ip_addrs(|addrs| {
418            if let Some(dest) = addrs.iter_mut().next() {
419                *dest = address;
420            } else {
421                addrs.push(address).unwrap();
422            }
423        });
424    }
425
426    fn clear_address(&mut self) {
427        let cidr = Ipv4Cidr::new(Ipv4Address::UNSPECIFIED, 0);
428        self.iface.update_ip_addrs(|addrs| {
429            if let Some(dest) = addrs.iter_mut().next() {
430                *dest = IpCidr::Ipv4(cidr);
431            }
432        });
433    }
434
435    fn set_router(&mut self, router: Option<Ipv4Address>) {
436        if let Some(router) = router {
437            info!("Default gateway: {}", router);
438            self.iface
439                .routes_mut()
440                .add_default_ipv4_route(router)
441                .unwrap();
442        } else {
443            info!("Default gateway: (none)");
444            self.iface.routes_mut().remove_default_ipv4_route();
445        }
446    }
447
448    fn clear_router(&mut self) {
449        self.iface.routes_mut().remove_default_ipv4_route();
450    }
451
452    fn set_dns_servers(&mut self, dns_servers: &[IpAddress]) {
453        for (i, s) in dns_servers.iter().enumerate() {
454            info!("DNS server {}: {}", i, s);
455        }
456        self.dns_socket_mut().update_servers(dns_servers);
457    }
458
459    fn clear_dns_servers(&mut self) {
460        self.dns_socket_mut().update_servers(&[]);
461    }
462
463    fn apply_dhcp_overrides(&mut self) {
464        if let Some(address) = self.dhcp_overrides.address {
465            self.set_address(address);
466        }
467        if let Some(router) = self.dhcp_overrides.router {
468            self.set_router(router);
469        }
470        if let Some(dns_servers) = self
471            .dhcp_overrides
472            .dns_servers
473            .as_deref()
474            .map(convert_dns_servers)
475        {
476            self.set_dns_servers(&dns_servers);
477        }
478    }
479}
480
481fn free_dhcp_event(event: dhcpv4::Event) -> dhcpv4::Event<'static> {
482    match event {
483        dhcpv4::Event::Deconfigured => dhcpv4::Event::Deconfigured,
484        dhcpv4::Event::Configured(config) => dhcpv4::Event::Configured(free_dhcp_config(config)),
485    }
486}
487
488fn free_dhcp_config(config: dhcpv4::Config) -> dhcpv4::Config<'static> {
489    dhcpv4::Config {
490        server: config.server,
491        address: config.address,
492        router: config.router,
493        dns_servers: config.dns_servers,
494        packet: None,
495    }
496}
497
498fn convert_dns_servers(dns_servers: &[Ipv4Address]) -> Vec<IpAddress> {
499    dns_servers
500        .iter()
501        .copied()
502        .map(From::from)
503        .collect::<Vec<_>>()
504}