1#![no_std]
11
12extern crate alloc;
13
14use alloc::rc::Rc;
15use alloc::vec;
16use alloc::vec::Vec;
17use core::cell::RefCell;
18use core::future::poll_fn;
19use core::marker::PhantomData;
20use core::pin::Pin;
21use core::task::{self, Poll};
22
23use log::info;
24use smoltcp::{
25 iface::{Config, Context, Interface, SocketHandle, SocketSet},
26 phy::Device,
27 socket::{dhcpv4, dns, tcp, AnySocket},
28 time::{Duration, Instant},
29 wire::{DnsQueryType, IpAddress, IpCidr, IpEndpoint, IpListenEndpoint, Ipv4Address, Ipv4Cidr},
30};
31
32use sel4_async_io::{Error as AsyncIOError, ErrorKind, ErrorType, Read, Write};
33
34pub(crate) const DEFAULT_KEEP_ALIVE_INTERVAL: u64 = 75000;
35pub(crate) const DEFAULT_TCP_SOCKET_BUFFER_SIZE: usize = 65535;
36
37#[derive(Clone)]
38pub struct ManagedInterface {
39 inner: Rc<RefCell<ManagedInterfaceShared>>,
40}
41
42struct ManagedInterfaceShared {
43 iface: Interface,
44 socket_set: SocketSet<'static>,
45 dns_socket_handle: SocketHandle,
46 dhcp_socket_handle: SocketHandle,
47 dhcp_overrides: DhcpOverrides,
48}
49
50#[derive(Default)]
51pub struct DhcpOverrides {
52 pub address: Option<Ipv4Cidr>,
53 pub router: Option<Option<Ipv4Address>>,
54 pub dns_servers: Option<Vec<Ipv4Address>>,
55}
56
57pub type TcpSocket = Socket<tcp::Socket<'static>>;
58
59pub struct Socket<T> {
60 handle: SocketHandle,
61 shared: ManagedInterface,
62 _phantom: PhantomData<T>,
63}
64
65impl<T> Drop for Socket<T> {
66 fn drop(&mut self) {
67 self.shared
68 .inner
69 .borrow_mut()
70 .socket_set
71 .remove(self.handle);
72 }
73}
74
75#[derive(Copy, Clone, Debug, PartialEq, Eq)]
76pub enum TcpSocketError {
77 InvalidState(tcp::State), RecvError(tcp::RecvError),
79 SendError(tcp::SendError),
80 ListenError(tcp::ListenError),
81 ConnectError(tcp::ConnectError),
82 ConnectionResetDuringConnect,
83}
84
85impl AsyncIOError for TcpSocketError {
86 fn kind(&self) -> ErrorKind {
87 ErrorKind::Other
88 }
89}
90
91#[derive(Copy, Clone, Debug, PartialEq, Eq)]
92pub enum DnsError {
93 StartQueryError(dns::StartQueryError),
94 GetQueryResultError(dns::GetQueryResultError),
95}
96
97impl ManagedInterface {
98 pub fn new<D: Device + ?Sized>(
99 config: Config,
100 dhcp_overrides: DhcpOverrides,
101 device: &mut D,
102 instant: Instant,
103 ) -> Self {
104 let iface = Interface::new(config, device, instant);
105 let mut socket_set = SocketSet::new(vec![]);
106 let dns_socket_handle = socket_set.add(dns::Socket::new(&[], vec![]));
107 let dhcp_socket_handle = socket_set.add(dhcpv4::Socket::new());
108
109 let mut this = ManagedInterfaceShared {
110 iface,
111 socket_set,
112 dns_socket_handle,
113 dhcp_socket_handle,
114 dhcp_overrides,
115 };
116
117 this.apply_dhcp_overrides();
118
119 Self {
120 inner: Rc::new(RefCell::new(this)),
121 }
122 }
123
124 fn inner(&self) -> &Rc<RefCell<ManagedInterfaceShared>> {
125 &self.inner
126 }
127
128 pub fn new_tcp_socket(&self) -> TcpSocket {
129 self.new_tcp_socket_with_buffer_sizes(
130 DEFAULT_TCP_SOCKET_BUFFER_SIZE,
131 DEFAULT_TCP_SOCKET_BUFFER_SIZE,
132 )
133 }
134
135 pub fn new_tcp_socket_with_buffer_sizes(
136 &self,
137 rx_buffer_size: usize,
138 tx_buffer_size: usize,
139 ) -> TcpSocket {
140 let rx_buffer = tcp::SocketBuffer::new(vec![0; rx_buffer_size]);
141 let tx_buffer = tcp::SocketBuffer::new(vec![0; tx_buffer_size]);
142 self.new_socket(tcp::Socket::new(rx_buffer, tx_buffer))
143 }
144
145 pub fn new_socket<T: AnySocket<'static>>(&self, socket: T) -> Socket<T> {
146 let handle = self.inner().borrow_mut().socket_set.add(socket);
147 Socket {
148 handle,
149 shared: self.clone(),
150 _phantom: PhantomData,
151 }
152 }
153
154 pub fn poll_at(&self, timestamp: Instant) -> Option<Instant> {
155 self.inner().borrow_mut().poll_at(timestamp)
156 }
157
158 pub fn poll_delay(&self, timestamp: Instant) -> Option<Duration> {
159 self.inner().borrow_mut().poll_delay(timestamp)
160 }
161
162 pub fn poll<D: Device + ?Sized>(&self, timestamp: Instant, device: &mut D) -> bool {
163 self.inner().borrow_mut().poll(timestamp, device)
164 }
165
166 pub async fn dns_query(
167 &self,
168 name: &str,
169 query_type: DnsQueryType,
170 ) -> Result<Vec<IpAddress>, DnsError> {
171 let query_handle = {
172 let inner = &mut *self.inner().borrow_mut();
173 inner
174 .socket_set
175 .get_mut::<dns::Socket>(inner.dns_socket_handle)
176 .start_query(inner.iface.context(), name, query_type)
177 .map_err(DnsError::StartQueryError)?
178 };
179 poll_fn(|cx| {
180 let inner = &mut *self.inner().borrow_mut();
181 let socket = inner
182 .socket_set
183 .get_mut::<dns::Socket>(inner.dns_socket_handle);
184 match socket.get_query_result(query_handle) {
185 Err(dns::GetQueryResultError::Pending) => {
186 socket.register_query_waker(query_handle, cx.waker());
187 Poll::Pending
188 }
189 r => Poll::Ready(
190 r.map(|heapless_vec| heapless_vec.to_vec())
191 .map_err(DnsError::GetQueryResultError),
192 ),
193 }
194 })
195 .await
196 }
197}
198
199impl<T: AnySocket<'static>> Socket<T> {
200 pub fn with<R>(&self, f: impl FnOnce(&T) -> R) -> R {
201 let network = self.shared.inner().borrow();
202 let socket = network.socket_set.get(self.handle);
203 f(socket)
204 }
205
206 pub fn with_mut<R>(&mut self, f: impl FnOnce(&mut T) -> R) -> R {
207 let mut network = self.shared.inner().borrow_mut();
208 let socket = network.socket_set.get_mut(self.handle);
209 f(socket)
210 }
211
212 pub fn with_context_mut<R>(&mut self, f: impl FnOnce(&mut Context, &mut T) -> R) -> R {
213 let network = &mut *self.shared.inner().borrow_mut();
214 let context = network.iface.context();
215 let socket = network.socket_set.get_mut(self.handle);
216 f(context, socket)
217 }
218}
219
220impl Socket<tcp::Socket<'static>> {
221 pub async fn connect<T, U>(
222 &mut self,
223 remote_endpoint: T,
224 local_endpoint: U,
225 ) -> Result<(), TcpSocketError>
226 where
227 T: Into<IpEndpoint>,
228 U: Into<IpListenEndpoint>,
229 {
230 self.with_context_mut(|cx, socket| socket.connect(cx, remote_endpoint, local_endpoint))
231 .map_err(TcpSocketError::ConnectError)?;
232
233 poll_fn(|cx| {
234 self.with_mut(|socket| {
235 let state = socket.state();
236 match state {
237 tcp::State::Closed | tcp::State::TimeWait => {
238 Poll::Ready(Err(TcpSocketError::ConnectionResetDuringConnect))
239 }
240 tcp::State::Listen => unreachable!(), tcp::State::SynSent | tcp::State::SynReceived => {
242 socket.register_send_waker(cx.waker());
243 Poll::Pending
244 }
245 _ => Poll::Ready(Ok(())),
246 }
247 })
248 })
249 .await
250 }
251
252 pub async fn accept_with_keep_alive(
253 &mut self,
254 local_endpoint: impl Into<IpListenEndpoint>,
255 keep_alive_interval: Option<Duration>,
256 ) -> Result<(), TcpSocketError> {
257 self.with_mut(|socket| {
258 socket
259 .listen(local_endpoint)
260 .map_err(TcpSocketError::ListenError)
261 })?;
262
263 poll_fn(|cx| {
264 self.with_mut(|socket| match socket.state() {
265 tcp::State::Listen | tcp::State::SynSent | tcp::State::SynReceived => {
266 socket.register_recv_waker(cx.waker());
267 Poll::Pending
268 }
269 _ => Poll::Ready(Ok(())),
270 })
271 })
272 .await?;
273
274 self.with_mut(|socket| socket.set_keep_alive(keep_alive_interval));
275
276 Ok(())
277 }
278
279 pub async fn accept(
280 &mut self,
281 local_endpoint: impl Into<IpListenEndpoint>,
282 ) -> Result<(), TcpSocketError> {
283 self.accept_with_keep_alive(
284 local_endpoint,
285 Some(Duration::from_millis(DEFAULT_KEEP_ALIVE_INTERVAL)),
286 )
287 .await
288 }
289
290 pub fn close(&mut self) {
291 self.with_mut(|socket| socket.close())
292 }
293
294 pub fn abort(&mut self) {
295 self.with_mut(|socket| socket.abort())
296 }
297}
298
299impl ErrorType for Socket<tcp::Socket<'static>> {
300 type Error = TcpSocketError;
301}
302
303impl Read for Socket<tcp::Socket<'static>> {
304 fn poll_read(
305 mut self: Pin<&mut Self>,
306 cx: &mut task::Context<'_>,
307 buf: &mut [u8],
308 ) -> Poll<Result<usize, Self::Error>> {
309 self.with_mut(|socket| match socket.recv_slice(buf) {
310 Ok(0) if buf.is_empty() => Poll::Ready(Ok(0)),
311 Ok(0) => {
312 socket.register_recv_waker(cx.waker());
313 Poll::Pending
314 }
315 Ok(n) => Poll::Ready(Ok(n)),
316 Err(tcp::RecvError::Finished) => Poll::Ready(Ok(0)),
317 Err(err) => Poll::Ready(Err(TcpSocketError::RecvError(err))),
318 })
319 }
320}
321
322impl Write for Socket<tcp::Socket<'static>> {
323 fn poll_write(
324 mut self: Pin<&mut Self>,
325 cx: &mut task::Context<'_>,
326 buf: &[u8],
327 ) -> Poll<Result<usize, Self::Error>> {
328 self.with_mut(|socket| match socket.send_slice(buf) {
329 Ok(0) if buf.is_empty() => Poll::Ready(Ok(0)),
330 Ok(0) => {
331 socket.register_send_waker(cx.waker());
332 Poll::Pending
333 }
334 Ok(n) => Poll::Ready(Ok(n)),
335 Err(err) => Poll::Ready(Err(TcpSocketError::SendError(err))),
336 })
337 }
338
339 fn poll_flush(
340 mut self: Pin<&mut Self>,
341 cx: &mut task::Context<'_>,
342 ) -> Poll<Result<(), Self::Error>> {
343 self.with_mut(|socket| {
344 let waiting_close =
345 socket.state() == tcp::State::Closed && socket.remote_endpoint().is_some();
346 if socket.send_queue() > 0 || waiting_close {
347 socket.register_send_waker(cx.waker());
348 Poll::Pending
349 } else {
350 Poll::Ready(Ok(()))
351 }
352 })
353 }
354}
355
356impl ManagedInterfaceShared {
357 fn dhcp_socket_mut(&mut self) -> &mut dhcpv4::Socket<'static> {
358 self.socket_set.get_mut(self.dhcp_socket_handle)
359 }
360
361 fn dns_socket_mut(&mut self) -> &mut dns::Socket<'static> {
362 self.socket_set.get_mut(self.dns_socket_handle)
363 }
364
365 fn poll_at(&mut self, timestamp: Instant) -> Option<Instant> {
366 self.iface.poll_at(timestamp, &self.socket_set)
367 }
368
369 fn poll_delay(&mut self, timestamp: Instant) -> Option<Duration> {
370 self.iface.poll_delay(timestamp, &self.socket_set)
371 }
372
373 fn poll<D: Device + ?Sized>(&mut self, timestamp: Instant, device: &mut D) -> bool {
374 let activity = self.iface.poll(timestamp, device, &mut self.socket_set);
375 if activity {
376 self.poll_dhcp();
377 }
378 activity
379 }
380
381 fn poll_dhcp(&mut self) {
383 if let Some(event) = self.dhcp_socket_mut().poll() {
384 let event = free_dhcp_event(event);
385 match event {
386 dhcpv4::Event::Configured(config) => {
387 info!("DHCP config acquired");
388 if self.dhcp_overrides.address.is_none() {
389 self.set_address(config.address);
390 }
391 if self.dhcp_overrides.router.is_none() {
392 self.set_router(config.router);
393 }
394 if self.dhcp_overrides.dns_servers.is_none() {
395 self.set_dns_servers(&convert_dns_servers(&config.dns_servers));
396 }
397 }
398 dhcpv4::Event::Deconfigured => {
399 info!("DHCP config lost");
400 if self.dhcp_overrides.address.is_none() {
401 self.clear_address();
402 }
403 if self.dhcp_overrides.router.is_none() {
404 self.clear_router();
405 }
406 if self.dhcp_overrides.dns_servers.is_none() {
407 self.clear_dns_servers();
408 }
409 }
410 }
411 }
412 }
413
414 fn set_address(&mut self, address: Ipv4Cidr) {
415 let address = IpCidr::Ipv4(address);
416 info!("IP address: {}", address);
417 self.iface.update_ip_addrs(|addrs| {
418 if let Some(dest) = addrs.iter_mut().next() {
419 *dest = address;
420 } else {
421 addrs.push(address).unwrap();
422 }
423 });
424 }
425
426 fn clear_address(&mut self) {
427 let cidr = Ipv4Cidr::new(Ipv4Address::UNSPECIFIED, 0);
428 self.iface.update_ip_addrs(|addrs| {
429 if let Some(dest) = addrs.iter_mut().next() {
430 *dest = IpCidr::Ipv4(cidr);
431 }
432 });
433 }
434
435 fn set_router(&mut self, router: Option<Ipv4Address>) {
436 if let Some(router) = router {
437 info!("Default gateway: {}", router);
438 self.iface
439 .routes_mut()
440 .add_default_ipv4_route(router)
441 .unwrap();
442 } else {
443 info!("Default gateway: (none)");
444 self.iface.routes_mut().remove_default_ipv4_route();
445 }
446 }
447
448 fn clear_router(&mut self) {
449 self.iface.routes_mut().remove_default_ipv4_route();
450 }
451
452 fn set_dns_servers(&mut self, dns_servers: &[IpAddress]) {
453 for (i, s) in dns_servers.iter().enumerate() {
454 info!("DNS server {}: {}", i, s);
455 }
456 self.dns_socket_mut().update_servers(dns_servers);
457 }
458
459 fn clear_dns_servers(&mut self) {
460 self.dns_socket_mut().update_servers(&[]);
461 }
462
463 fn apply_dhcp_overrides(&mut self) {
464 if let Some(address) = self.dhcp_overrides.address {
465 self.set_address(address);
466 }
467 if let Some(router) = self.dhcp_overrides.router {
468 self.set_router(router);
469 }
470 if let Some(dns_servers) = self
471 .dhcp_overrides
472 .dns_servers
473 .as_deref()
474 .map(convert_dns_servers)
475 {
476 self.set_dns_servers(&dns_servers);
477 }
478 }
479}
480
481fn free_dhcp_event(event: dhcpv4::Event) -> dhcpv4::Event<'static> {
482 match event {
483 dhcpv4::Event::Deconfigured => dhcpv4::Event::Deconfigured,
484 dhcpv4::Event::Configured(config) => dhcpv4::Event::Configured(free_dhcp_config(config)),
485 }
486}
487
488fn free_dhcp_config(config: dhcpv4::Config) -> dhcpv4::Config<'static> {
489 dhcpv4::Config {
490 server: config.server,
491 address: config.address,
492 router: config.router,
493 dns_servers: config.dns_servers,
494 packet: None,
495 }
496}
497
498fn convert_dns_servers(dns_servers: &[Ipv4Address]) -> Vec<IpAddress> {
499 dns_servers
500 .iter()
501 .copied()
502 .map(From::from)
503 .collect::<Vec<_>>()
504}