1#![no_std]
11
12extern crate alloc;
13
14use alloc::rc::Rc;
15use alloc::vec;
16use alloc::vec::Vec;
17use core::cell::RefCell;
18use core::future::poll_fn;
19use core::marker::PhantomData;
20use core::pin::Pin;
21use core::task::{self, Poll};
22
23use thiserror::Error;
24
25use log::info;
26use smoltcp::{
27 iface::{Config, Context, Interface, PollResult, SocketHandle, SocketSet},
28 phy::Device,
29 socket::{AnySocket, dhcpv4, dns, tcp},
30 time::{Duration, Instant},
31 wire::{DnsQueryType, IpAddress, IpCidr, IpEndpoint, IpListenEndpoint, Ipv4Address, Ipv4Cidr},
32};
33
34use sel4_async_io::{Error as AsyncIOError, ErrorKind, ErrorType, Read, Write};
35
36pub(crate) const DEFAULT_KEEP_ALIVE_INTERVAL: u64 = 75000;
37pub(crate) const DEFAULT_TCP_SOCKET_BUFFER_SIZE: usize = 65535;
38
39#[derive(Clone)]
40pub struct ManagedInterface {
41 inner: Rc<RefCell<ManagedInterfaceShared>>,
42}
43
44struct ManagedInterfaceShared {
45 iface: Interface,
46 socket_set: SocketSet<'static>,
47 dns_socket_handle: SocketHandle,
48 dhcp_socket_handle: SocketHandle,
49 dhcp_overrides: DhcpOverrides,
50}
51
52#[derive(Default)]
53pub struct DhcpOverrides {
54 pub address: Option<Ipv4Cidr>,
55 pub router: Option<Option<Ipv4Address>>,
56 pub dns_servers: Option<Vec<Ipv4Address>>,
57}
58
59pub type TcpSocket = Socket<tcp::Socket<'static>>;
60
61pub struct Socket<T> {
62 handle: SocketHandle,
63 shared: ManagedInterface,
64 _phantom: PhantomData<T>,
65}
66
67impl<T> Drop for Socket<T> {
68 fn drop(&mut self) {
69 self.shared
70 .inner
71 .borrow_mut()
72 .socket_set
73 .remove(self.handle);
74 }
75}
76
77#[derive(Copy, Clone, Debug, PartialEq, Eq, Error)]
78pub enum TcpSocketError {
79 #[error("invalid state")]
80 InvalidState(tcp::State), #[error("recv error: {0}")]
82 RecvError(tcp::RecvError),
83 #[error("send error: {0}")]
84 SendError(tcp::SendError),
85 #[error("listen error: {0}")]
86 ListenError(tcp::ListenError),
87 #[error("connect error: {0}")]
88 ConnectError(tcp::ConnectError),
89 #[error("connection reset during connect")]
90 ConnectionResetDuringConnect,
91}
92
93impl AsyncIOError for TcpSocketError {
94 fn kind(&self) -> ErrorKind {
95 ErrorKind::Other
96 }
97}
98
99#[derive(Copy, Clone, Debug, PartialEq, Eq)]
100pub enum DnsError {
101 StartQueryError(dns::StartQueryError),
102 GetQueryResultError(dns::GetQueryResultError),
103}
104
105impl ManagedInterface {
106 pub fn new<D: Device + ?Sized>(
107 config: Config,
108 dhcp_overrides: DhcpOverrides,
109 device: &mut D,
110 instant: Instant,
111 ) -> Self {
112 let iface = Interface::new(config, device, instant);
113 let mut socket_set = SocketSet::new(vec![]);
114 let dns_socket_handle = socket_set.add(dns::Socket::new(&[], vec![]));
115 let dhcp_socket_handle = socket_set.add(dhcpv4::Socket::new());
116
117 let mut this = ManagedInterfaceShared {
118 iface,
119 socket_set,
120 dns_socket_handle,
121 dhcp_socket_handle,
122 dhcp_overrides,
123 };
124
125 this.apply_dhcp_overrides();
126
127 Self {
128 inner: Rc::new(RefCell::new(this)),
129 }
130 }
131
132 fn inner(&self) -> &Rc<RefCell<ManagedInterfaceShared>> {
133 &self.inner
134 }
135
136 pub fn new_tcp_socket(&self) -> TcpSocket {
137 self.new_tcp_socket_with_buffer_sizes(
138 DEFAULT_TCP_SOCKET_BUFFER_SIZE,
139 DEFAULT_TCP_SOCKET_BUFFER_SIZE,
140 )
141 }
142
143 pub fn new_tcp_socket_with_buffer_sizes(
144 &self,
145 rx_buffer_size: usize,
146 tx_buffer_size: usize,
147 ) -> TcpSocket {
148 let rx_buffer = tcp::SocketBuffer::new(vec![0; rx_buffer_size]);
149 let tx_buffer = tcp::SocketBuffer::new(vec![0; tx_buffer_size]);
150 self.new_socket(tcp::Socket::new(rx_buffer, tx_buffer))
151 }
152
153 pub fn new_socket<T: AnySocket<'static>>(&self, socket: T) -> Socket<T> {
154 let handle = self.inner().borrow_mut().socket_set.add(socket);
155 Socket {
156 handle,
157 shared: self.clone(),
158 _phantom: PhantomData,
159 }
160 }
161
162 pub fn poll_at(&self, timestamp: Instant) -> Option<Instant> {
163 self.inner().borrow_mut().poll_at(timestamp)
164 }
165
166 pub fn poll_delay(&self, timestamp: Instant) -> Option<Duration> {
167 self.inner().borrow_mut().poll_delay(timestamp)
168 }
169
170 pub fn poll<D: Device + ?Sized>(&self, timestamp: Instant, device: &mut D) -> bool {
171 self.inner().borrow_mut().poll(timestamp, device)
172 }
173
174 pub async fn dns_query(
175 &self,
176 name: &str,
177 query_type: DnsQueryType,
178 ) -> Result<Vec<IpAddress>, DnsError> {
179 let query_handle = {
180 let inner = &mut *self.inner().borrow_mut();
181 inner
182 .socket_set
183 .get_mut::<dns::Socket>(inner.dns_socket_handle)
184 .start_query(inner.iface.context(), name, query_type)
185 .map_err(DnsError::StartQueryError)?
186 };
187 poll_fn(|cx| {
188 let inner = &mut *self.inner().borrow_mut();
189 let socket = inner
190 .socket_set
191 .get_mut::<dns::Socket>(inner.dns_socket_handle);
192 match socket.get_query_result(query_handle) {
193 Err(dns::GetQueryResultError::Pending) => {
194 socket.register_query_waker(query_handle, cx.waker());
195 Poll::Pending
196 }
197 r => Poll::Ready(
198 r.map(|heapless_vec| heapless_vec.to_vec())
199 .map_err(DnsError::GetQueryResultError),
200 ),
201 }
202 })
203 .await
204 }
205}
206
207impl<T: AnySocket<'static>> Socket<T> {
208 pub fn with<R>(&self, f: impl FnOnce(&T) -> R) -> R {
209 let network = self.shared.inner().borrow();
210 let socket = network.socket_set.get(self.handle);
211 f(socket)
212 }
213
214 pub fn with_mut<R>(&mut self, f: impl FnOnce(&mut T) -> R) -> R {
215 let mut network = self.shared.inner().borrow_mut();
216 let socket = network.socket_set.get_mut(self.handle);
217 f(socket)
218 }
219
220 pub fn with_context_mut<R>(&mut self, f: impl FnOnce(&mut Context, &mut T) -> R) -> R {
221 let network = &mut *self.shared.inner().borrow_mut();
222 let context = network.iface.context();
223 let socket = network.socket_set.get_mut(self.handle);
224 f(context, socket)
225 }
226}
227
228impl Socket<tcp::Socket<'static>> {
229 pub async fn connect<T, U>(
230 &mut self,
231 remote_endpoint: T,
232 local_endpoint: U,
233 ) -> Result<(), TcpSocketError>
234 where
235 T: Into<IpEndpoint>,
236 U: Into<IpListenEndpoint>,
237 {
238 self.with_context_mut(|cx, socket| socket.connect(cx, remote_endpoint, local_endpoint))
239 .map_err(TcpSocketError::ConnectError)?;
240
241 poll_fn(|cx| {
242 self.with_mut(|socket| {
243 let state = socket.state();
244 match state {
245 tcp::State::Closed | tcp::State::TimeWait => {
246 Poll::Ready(Err(TcpSocketError::ConnectionResetDuringConnect))
247 }
248 tcp::State::Listen => unreachable!(), tcp::State::SynSent | tcp::State::SynReceived => {
250 socket.register_send_waker(cx.waker());
251 Poll::Pending
252 }
253 _ => Poll::Ready(Ok(())),
254 }
255 })
256 })
257 .await
258 }
259
260 pub async fn accept_with_keep_alive(
261 &mut self,
262 local_endpoint: impl Into<IpListenEndpoint>,
263 keep_alive_interval: Option<Duration>,
264 ) -> Result<(), TcpSocketError> {
265 self.with_mut(|socket| {
266 socket
267 .listen(local_endpoint)
268 .map_err(TcpSocketError::ListenError)
269 })?;
270
271 poll_fn(|cx| {
272 self.with_mut(|socket| match socket.state() {
273 tcp::State::Listen | tcp::State::SynSent | tcp::State::SynReceived => {
274 socket.register_recv_waker(cx.waker());
275 Poll::Pending
276 }
277 _ => Poll::Ready(Ok(())),
278 })
279 })
280 .await?;
281
282 self.with_mut(|socket| socket.set_keep_alive(keep_alive_interval));
283
284 Ok(())
285 }
286
287 pub async fn accept(
288 &mut self,
289 local_endpoint: impl Into<IpListenEndpoint>,
290 ) -> Result<(), TcpSocketError> {
291 self.accept_with_keep_alive(
292 local_endpoint,
293 Some(Duration::from_millis(DEFAULT_KEEP_ALIVE_INTERVAL)),
294 )
295 .await
296 }
297
298 pub fn close(&mut self) {
299 self.with_mut(|socket| socket.close())
300 }
301
302 pub fn abort(&mut self) {
303 self.with_mut(|socket| socket.abort())
304 }
305}
306
307impl ErrorType for Socket<tcp::Socket<'static>> {
308 type Error = TcpSocketError;
309}
310
311impl Read for Socket<tcp::Socket<'static>> {
312 fn poll_read(
313 mut self: Pin<&mut Self>,
314 cx: &mut task::Context<'_>,
315 buf: &mut [u8],
316 ) -> Poll<Result<usize, Self::Error>> {
317 self.with_mut(|socket| match socket.recv_slice(buf) {
318 Ok(0) if buf.is_empty() => Poll::Ready(Ok(0)),
319 Ok(0) => {
320 socket.register_recv_waker(cx.waker());
321 Poll::Pending
322 }
323 Ok(n) => Poll::Ready(Ok(n)),
324 Err(tcp::RecvError::Finished) => Poll::Ready(Ok(0)),
325 Err(err) => Poll::Ready(Err(TcpSocketError::RecvError(err))),
326 })
327 }
328}
329
330impl Write for Socket<tcp::Socket<'static>> {
331 fn poll_write(
332 mut self: Pin<&mut Self>,
333 cx: &mut task::Context<'_>,
334 buf: &[u8],
335 ) -> Poll<Result<usize, Self::Error>> {
336 self.with_mut(|socket| match socket.send_slice(buf) {
337 Ok(0) if buf.is_empty() => Poll::Ready(Ok(0)),
338 Ok(0) => {
339 socket.register_send_waker(cx.waker());
340 Poll::Pending
341 }
342 Ok(n) => Poll::Ready(Ok(n)),
343 Err(err) => Poll::Ready(Err(TcpSocketError::SendError(err))),
344 })
345 }
346
347 fn poll_flush(
348 mut self: Pin<&mut Self>,
349 cx: &mut task::Context<'_>,
350 ) -> Poll<Result<(), Self::Error>> {
351 self.with_mut(|socket| {
352 let waiting_close =
353 socket.state() == tcp::State::Closed && socket.remote_endpoint().is_some();
354 if socket.send_queue() > 0 || waiting_close {
355 socket.register_send_waker(cx.waker());
356 Poll::Pending
357 } else {
358 Poll::Ready(Ok(()))
359 }
360 })
361 }
362}
363
364impl ManagedInterfaceShared {
365 fn dhcp_socket_mut(&mut self) -> &mut dhcpv4::Socket<'static> {
366 self.socket_set.get_mut(self.dhcp_socket_handle)
367 }
368
369 fn dns_socket_mut(&mut self) -> &mut dns::Socket<'static> {
370 self.socket_set.get_mut(self.dns_socket_handle)
371 }
372
373 fn poll_at(&mut self, timestamp: Instant) -> Option<Instant> {
374 self.iface.poll_at(timestamp, &self.socket_set)
375 }
376
377 fn poll_delay(&mut self, timestamp: Instant) -> Option<Duration> {
378 self.iface.poll_delay(timestamp, &self.socket_set)
379 }
380
381 fn poll<D: Device + ?Sized>(&mut self, timestamp: Instant, device: &mut D) -> bool {
382 let activity = self.iface.poll(timestamp, device, &mut self.socket_set)
383 == PollResult::SocketStateChanged;
384 if activity {
385 self.poll_dhcp();
386 }
387 activity
388 }
389
390 fn poll_dhcp(&mut self) {
392 if let Some(event) = self.dhcp_socket_mut().poll() {
393 let event = free_dhcp_event(event);
394 match event {
395 dhcpv4::Event::Configured(config) => {
396 info!("DHCP config acquired");
397 if self.dhcp_overrides.address.is_none() {
398 self.set_address(config.address);
399 }
400 if self.dhcp_overrides.router.is_none() {
401 self.set_router(config.router);
402 }
403 if self.dhcp_overrides.dns_servers.is_none() {
404 self.set_dns_servers(&convert_dns_servers(&config.dns_servers));
405 }
406 }
407 dhcpv4::Event::Deconfigured => {
408 info!("DHCP config lost");
409 if self.dhcp_overrides.address.is_none() {
410 self.clear_address();
411 }
412 if self.dhcp_overrides.router.is_none() {
413 self.clear_router();
414 }
415 if self.dhcp_overrides.dns_servers.is_none() {
416 self.clear_dns_servers();
417 }
418 }
419 }
420 }
421 }
422
423 fn set_address(&mut self, address: Ipv4Cidr) {
424 let address = IpCidr::Ipv4(address);
425 info!("IP address: {address}");
426 self.iface.update_ip_addrs(|addrs| {
427 if let Some(dest) = addrs.iter_mut().next() {
428 *dest = address;
429 } else {
430 addrs.push(address).unwrap();
431 }
432 });
433 }
434
435 fn clear_address(&mut self) {
436 let cidr = Ipv4Cidr::new(Ipv4Address::UNSPECIFIED, 0);
437 self.iface.update_ip_addrs(|addrs| {
438 if let Some(dest) = addrs.iter_mut().next() {
439 *dest = IpCidr::Ipv4(cidr);
440 }
441 });
442 }
443
444 fn set_router(&mut self, router: Option<Ipv4Address>) {
445 if let Some(router) = router {
446 info!("Default gateway: {router}");
447 self.iface
448 .routes_mut()
449 .add_default_ipv4_route(router)
450 .unwrap();
451 } else {
452 info!("Default gateway: (none)");
453 self.iface.routes_mut().remove_default_ipv4_route();
454 }
455 }
456
457 fn clear_router(&mut self) {
458 self.iface.routes_mut().remove_default_ipv4_route();
459 }
460
461 fn set_dns_servers(&mut self, dns_servers: &[IpAddress]) {
462 for (i, s) in dns_servers.iter().enumerate() {
463 info!("DNS server {i}: {s}");
464 }
465 self.dns_socket_mut().update_servers(dns_servers);
466 }
467
468 fn clear_dns_servers(&mut self) {
469 self.dns_socket_mut().update_servers(&[]);
470 }
471
472 fn apply_dhcp_overrides(&mut self) {
473 if let Some(address) = self.dhcp_overrides.address {
474 self.set_address(address);
475 }
476 if let Some(router) = self.dhcp_overrides.router {
477 self.set_router(router);
478 }
479 if let Some(dns_servers) = self
480 .dhcp_overrides
481 .dns_servers
482 .as_deref()
483 .map(convert_dns_servers)
484 {
485 self.set_dns_servers(&dns_servers);
486 }
487 }
488}
489
490fn free_dhcp_event(event: dhcpv4::Event) -> dhcpv4::Event<'static> {
491 match event {
492 dhcpv4::Event::Deconfigured => dhcpv4::Event::Deconfigured,
493 dhcpv4::Event::Configured(config) => dhcpv4::Event::Configured(free_dhcp_config(config)),
494 }
495}
496
497fn free_dhcp_config(config: dhcpv4::Config) -> dhcpv4::Config<'static> {
498 dhcpv4::Config {
499 server: config.server,
500 address: config.address,
501 router: config.router,
502 dns_servers: config.dns_servers,
503 packet: None,
504 }
505}
506
507fn convert_dns_servers(dns_servers: &[Ipv4Address]) -> Vec<IpAddress> {
508 dns_servers
509 .iter()
510 .copied()
511 .map(From::from)
512 .collect::<Vec<_>>()
513}