1use super::{
2 protocol::VsockAddr, vsock::ConnectionInfo, DisconnectReason, SocketError, VirtIOSocket,
3 VsockEvent, VsockEventType, DEFAULT_RX_BUFFER_SIZE,
4};
5use crate::{transport::Transport, Hal, Result};
6use alloc::{boxed::Box, vec::Vec};
7use core::cmp::min;
8use core::convert::TryInto;
9use core::hint::spin_loop;
10use log::debug;
11use zerocopy::FromZeros;
12
13const DEFAULT_PER_CONNECTION_BUFFER_CAPACITY: u32 = 1024;
14
15pub struct VsockConnectionManager<
47 H: Hal,
48 T: Transport,
49 const RX_BUFFER_SIZE: usize = DEFAULT_RX_BUFFER_SIZE,
50> {
51 driver: VirtIOSocket<H, T, RX_BUFFER_SIZE>,
52 per_connection_buffer_capacity: u32,
53 connections: Vec<Connection>,
54 listening_ports: Vec<u32>,
55}
56
57#[derive(Debug)]
58struct Connection {
59 info: ConnectionInfo,
60 buffer: RingBuffer,
61 peer_requested_shutdown: bool,
64}
65
66impl Connection {
67 fn new(peer: VsockAddr, local_port: u32, buffer_capacity: u32) -> Self {
68 let mut info = ConnectionInfo::new(peer, local_port);
69 info.buf_alloc = buffer_capacity;
70 Self {
71 info,
72 buffer: RingBuffer::new(buffer_capacity.try_into().unwrap()),
73 peer_requested_shutdown: false,
74 }
75 }
76}
77
78impl<H: Hal, T: Transport, const RX_BUFFER_SIZE: usize>
79 VsockConnectionManager<H, T, RX_BUFFER_SIZE>
80{
81 pub fn new(driver: VirtIOSocket<H, T, RX_BUFFER_SIZE>) -> Self {
83 Self::new_with_capacity(driver, DEFAULT_PER_CONNECTION_BUFFER_CAPACITY)
84 }
85
86 pub fn new_with_capacity(
89 driver: VirtIOSocket<H, T, RX_BUFFER_SIZE>,
90 per_connection_buffer_capacity: u32,
91 ) -> Self {
92 Self {
93 driver,
94 connections: Vec::new(),
95 listening_ports: Vec::new(),
96 per_connection_buffer_capacity,
97 }
98 }
99
100 pub fn guest_cid(&self) -> u64 {
102 self.driver.guest_cid()
103 }
104
105 pub fn listen(&mut self, port: u32) {
107 if !self.listening_ports.contains(&port) {
108 self.listening_ports.push(port);
109 }
110 }
111
112 pub fn unlisten(&mut self, port: u32) {
114 self.listening_ports.retain(|p| *p != port);
115 }
116
117 pub fn connect(&mut self, destination: VsockAddr, src_port: u32) -> Result {
123 if self.connections.iter().any(|connection| {
124 connection.info.dst == destination && connection.info.src_port == src_port
125 }) {
126 return Err(SocketError::ConnectionExists.into());
127 }
128
129 let new_connection =
130 Connection::new(destination, src_port, self.per_connection_buffer_capacity);
131
132 self.driver.connect(&new_connection.info)?;
133 debug!("Connection requested: {:?}", new_connection.info);
134 self.connections.push(new_connection);
135 Ok(())
136 }
137
138 pub fn send(&mut self, destination: VsockAddr, src_port: u32, buffer: &[u8]) -> Result {
140 let (_, connection) = get_connection(&mut self.connections, destination, src_port)?;
141
142 self.driver.send(buffer, &mut connection.info)
143 }
144
145 pub fn poll(&mut self) -> Result<Option<VsockEvent>> {
147 let guest_cid = self.driver.guest_cid();
148 let connections = &mut self.connections;
149 let per_connection_buffer_capacity = self.per_connection_buffer_capacity;
150
151 let result = self.driver.poll(|event, body| {
152 let connection = get_connection_for_event(connections, &event, guest_cid);
153
154 let connection = if let Some((_, connection)) = connection {
157 connection
158 } else if let VsockEventType::ConnectionRequest = event.event_type {
159 if connection.is_some() || event.destination.cid != guest_cid {
161 return Ok(None);
162 }
163 connections.push(Connection::new(
166 event.source,
167 event.destination.port,
168 per_connection_buffer_capacity,
169 ));
170 connections.last_mut().unwrap()
171 } else {
172 return Ok(None);
173 };
174
175 connection.info.update_for_event(&event);
177
178 if let VsockEventType::Received { length } = event.event_type {
179 if !connection.buffer.add(body) {
181 return Err(SocketError::OutputBufferTooShort(length).into());
182 }
183 }
184
185 Ok(Some(event))
186 })?;
187
188 let Some(event) = result else {
189 return Ok(None);
190 };
191
192 let (connection_index, connection) =
194 get_connection_for_event(connections, &event, guest_cid).unwrap();
195
196 match event.event_type {
197 VsockEventType::ConnectionRequest => {
198 if self.listening_ports.contains(&event.destination.port) {
199 self.driver.accept(&connection.info)?;
200 } else {
201 self.driver.force_close(&connection.info)?;
203 self.connections.swap_remove(connection_index);
204
205 return Ok(None);
207 }
208 }
209 VsockEventType::Connected => {}
210 VsockEventType::Disconnected { reason } => {
211 if connection.buffer.is_empty() {
213 if reason == DisconnectReason::Shutdown {
214 self.driver.force_close(&connection.info)?;
215 }
216 self.connections.swap_remove(connection_index);
217 } else {
218 connection.peer_requested_shutdown = true;
219 }
220 }
221 VsockEventType::Received { .. } => {
222 }
224 VsockEventType::CreditRequest => {
225 self.driver.credit_update(&connection.info)?;
227 return Ok(None);
229 }
230 VsockEventType::CreditUpdate => {}
231 }
232
233 Ok(Some(event))
234 }
235
236 pub fn recv(&mut self, peer: VsockAddr, src_port: u32, buffer: &mut [u8]) -> Result<usize> {
238 let (connection_index, connection) = get_connection(&mut self.connections, peer, src_port)?;
239
240 let bytes_read = connection.buffer.drain(buffer);
242
243 connection.info.done_forwarding(bytes_read);
244
245 if connection.peer_requested_shutdown && connection.buffer.is_empty() {
248 self.driver.force_close(&connection.info)?;
249 self.connections.swap_remove(connection_index);
250 }
251
252 Ok(bytes_read)
253 }
254
255 pub fn recv_buffer_available_bytes(&mut self, peer: VsockAddr, src_port: u32) -> Result<usize> {
260 let (_, connection) = get_connection(&mut self.connections, peer, src_port)?;
261 Ok(connection.buffer.used())
262 }
263
264 pub fn update_credit(&mut self, peer: VsockAddr, src_port: u32) -> Result {
266 let (_, connection) = get_connection(&mut self.connections, peer, src_port)?;
267 self.driver.credit_update(&connection.info)
268 }
269
270 pub fn wait_for_event(&mut self) -> Result<VsockEvent> {
272 loop {
273 if let Some(event) = self.poll()? {
274 return Ok(event);
275 } else {
276 spin_loop();
277 }
278 }
279 }
280
281 pub fn shutdown(&mut self, destination: VsockAddr, src_port: u32) -> Result {
288 let (_, connection) = get_connection(&mut self.connections, destination, src_port)?;
289
290 self.driver.shutdown(&connection.info)
291 }
292
293 pub fn force_close(&mut self, destination: VsockAddr, src_port: u32) -> Result {
295 let (index, connection) = get_connection(&mut self.connections, destination, src_port)?;
296
297 self.driver.force_close(&connection.info)?;
298
299 self.connections.swap_remove(index);
300 Ok(())
301 }
302}
303
304fn get_connection(
309 connections: &mut [Connection],
310 peer: VsockAddr,
311 local_port: u32,
312) -> core::result::Result<(usize, &mut Connection), SocketError> {
313 connections
314 .iter_mut()
315 .enumerate()
316 .find(|(_, connection)| {
317 connection.info.dst == peer && connection.info.src_port == local_port
318 })
319 .ok_or(SocketError::NotConnected)
320}
321
322fn get_connection_for_event<'a>(
324 connections: &'a mut [Connection],
325 event: &VsockEvent,
326 local_cid: u64,
327) -> Option<(usize, &'a mut Connection)> {
328 connections
329 .iter_mut()
330 .enumerate()
331 .find(|(_, connection)| event.matches_connection(&connection.info, local_cid))
332}
333
334#[derive(Debug)]
335struct RingBuffer {
336 buffer: Box<[u8]>,
337 used: usize,
339 start: usize,
341}
342
343impl RingBuffer {
344 pub fn new(capacity: usize) -> Self {
345 Self {
346 buffer: FromZeros::new_box_zeroed_with_elems(capacity).unwrap(),
347 used: 0,
348 start: 0,
349 }
350 }
351
352 pub fn used(&self) -> usize {
354 self.used
355 }
356
357 pub fn is_empty(&self) -> bool {
359 self.used == 0
360 }
361
362 pub fn free(&self) -> usize {
364 self.buffer.len() - self.used
365 }
366
367 pub fn add(&mut self, bytes: &[u8]) -> bool {
371 if bytes.len() > self.free() {
372 return false;
373 }
374
375 let first_available = (self.start + self.used) % self.buffer.len();
377 let copy_length_before_wraparound = min(bytes.len(), self.buffer.len() - first_available);
380 self.buffer[first_available..first_available + copy_length_before_wraparound]
381 .copy_from_slice(&bytes[0..copy_length_before_wraparound]);
382 if let Some(bytes_after_wraparound) = bytes.get(copy_length_before_wraparound..) {
383 self.buffer[0..bytes_after_wraparound.len()].copy_from_slice(bytes_after_wraparound);
384 }
385 self.used += bytes.len();
386
387 true
388 }
389
390 pub fn drain(&mut self, out: &mut [u8]) -> usize {
393 let bytes_read = min(self.used, out.len());
394
395 let read_before_wraparound = min(bytes_read, self.buffer.len() - self.start);
397 let read_after_wraparound = bytes_read
399 .checked_sub(read_before_wraparound)
400 .unwrap_or_default();
401
402 out[0..read_before_wraparound]
403 .copy_from_slice(&self.buffer[self.start..self.start + read_before_wraparound]);
404 out[read_before_wraparound..bytes_read]
405 .copy_from_slice(&self.buffer[0..read_after_wraparound]);
406
407 self.used -= bytes_read;
408 self.start = (self.start + bytes_read) % self.buffer.len();
409
410 bytes_read
411 }
412}
413
414#[cfg(test)]
415mod tests {
416 use super::*;
417 use crate::{
418 config::ReadOnly,
419 device::socket::{
420 protocol::{
421 SocketType, StreamShutdown, VirtioVsockConfig, VirtioVsockHdr, VirtioVsockOp,
422 },
423 vsock::{VsockBufferStatus, QUEUE_SIZE, RX_QUEUE_IDX, TX_QUEUE_IDX},
424 },
425 hal::fake::FakeHal,
426 transport::{
427 fake::{FakeTransport, QueueStatus, State},
428 DeviceType,
429 },
430 };
431 use alloc::{sync::Arc, vec};
432 use core::mem::size_of;
433 use std::{sync::Mutex, thread};
434 use zerocopy::{FromBytes, IntoBytes};
435
436 #[test]
437 fn send_recv() {
438 let host_cid = 2;
439 let guest_cid = 66;
440 let host_port = 1234;
441 let guest_port = 4321;
442 let host_address = VsockAddr {
443 cid: host_cid,
444 port: host_port,
445 };
446 let hello_from_guest = "Hello from guest";
447 let hello_from_host = "Hello from host";
448
449 let config_space = VirtioVsockConfig {
450 guest_cid_low: ReadOnly::new(66),
451 guest_cid_high: ReadOnly::new(0),
452 };
453 let state = Arc::new(Mutex::new(State::new(
454 vec![
455 QueueStatus::default(),
456 QueueStatus::default(),
457 QueueStatus::default(),
458 ],
459 config_space,
460 )));
461 let transport = FakeTransport {
462 device_type: DeviceType::Socket,
463 max_queue_size: 32,
464 device_features: 0,
465 state: state.clone(),
466 };
467 let mut socket = VsockConnectionManager::new(
468 VirtIOSocket::<FakeHal, FakeTransport<VirtioVsockConfig>>::new(transport).unwrap(),
469 );
470
471 let handle = thread::spawn(move || {
473 State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
475 assert_eq!(
476 VirtioVsockHdr::read_from_bytes(
477 state
478 .lock()
479 .unwrap()
480 .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
481 .as_slice()
482 )
483 .unwrap(),
484 VirtioVsockHdr {
485 op: VirtioVsockOp::Request.into(),
486 src_cid: guest_cid.into(),
487 dst_cid: host_cid.into(),
488 src_port: guest_port.into(),
489 dst_port: host_port.into(),
490 len: 0.into(),
491 socket_type: SocketType::Stream.into(),
492 flags: 0.into(),
493 buf_alloc: 1024.into(),
494 fwd_cnt: 0.into(),
495 }
496 );
497
498 state.lock().unwrap().write_to_queue::<QUEUE_SIZE>(
500 RX_QUEUE_IDX,
501 VirtioVsockHdr {
502 op: VirtioVsockOp::Response.into(),
503 src_cid: host_cid.into(),
504 dst_cid: guest_cid.into(),
505 src_port: host_port.into(),
506 dst_port: guest_port.into(),
507 len: 0.into(),
508 socket_type: SocketType::Stream.into(),
509 flags: 0.into(),
510 buf_alloc: 50.into(),
511 fwd_cnt: 0.into(),
512 }
513 .as_bytes(),
514 );
515
516 State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
518 let request = state
519 .lock()
520 .unwrap()
521 .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX);
522 assert_eq!(
523 request.len(),
524 size_of::<VirtioVsockHdr>() + hello_from_guest.len()
525 );
526 assert_eq!(
527 VirtioVsockHdr::read_from_prefix(request.as_slice())
528 .unwrap()
529 .0,
530 VirtioVsockHdr {
531 op: VirtioVsockOp::Rw.into(),
532 src_cid: guest_cid.into(),
533 dst_cid: host_cid.into(),
534 src_port: guest_port.into(),
535 dst_port: host_port.into(),
536 len: (hello_from_guest.len() as u32).into(),
537 socket_type: SocketType::Stream.into(),
538 flags: 0.into(),
539 buf_alloc: 1024.into(),
540 fwd_cnt: 0.into(),
541 }
542 );
543 assert_eq!(
544 &request[size_of::<VirtioVsockHdr>()..],
545 hello_from_guest.as_bytes()
546 );
547
548 println!("Host sending");
549
550 let mut response = vec![0; size_of::<VirtioVsockHdr>() + hello_from_host.len()];
552 VirtioVsockHdr {
553 op: VirtioVsockOp::Rw.into(),
554 src_cid: host_cid.into(),
555 dst_cid: guest_cid.into(),
556 src_port: host_port.into(),
557 dst_port: guest_port.into(),
558 len: (hello_from_host.len() as u32).into(),
559 socket_type: SocketType::Stream.into(),
560 flags: 0.into(),
561 buf_alloc: 50.into(),
562 fwd_cnt: (hello_from_guest.len() as u32).into(),
563 }
564 .write_to_prefix(response.as_mut_slice())
565 .unwrap();
566 response[size_of::<VirtioVsockHdr>()..].copy_from_slice(hello_from_host.as_bytes());
567 state
568 .lock()
569 .unwrap()
570 .write_to_queue::<QUEUE_SIZE>(RX_QUEUE_IDX, &response);
571
572 State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
574 assert_eq!(
575 VirtioVsockHdr::read_from_bytes(
576 state
577 .lock()
578 .unwrap()
579 .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
580 .as_slice()
581 )
582 .unwrap(),
583 VirtioVsockHdr {
584 op: VirtioVsockOp::Shutdown.into(),
585 src_cid: guest_cid.into(),
586 dst_cid: host_cid.into(),
587 src_port: guest_port.into(),
588 dst_port: host_port.into(),
589 len: 0.into(),
590 socket_type: SocketType::Stream.into(),
591 flags: (StreamShutdown::SEND | StreamShutdown::RECEIVE).into(),
592 buf_alloc: 1024.into(),
593 fwd_cnt: (hello_from_host.len() as u32).into(),
594 }
595 );
596 });
597
598 socket.connect(host_address, guest_port).unwrap();
599 assert_eq!(
600 socket.wait_for_event().unwrap(),
601 VsockEvent {
602 source: host_address,
603 destination: VsockAddr {
604 cid: guest_cid,
605 port: guest_port,
606 },
607 event_type: VsockEventType::Connected,
608 buffer_status: VsockBufferStatus {
609 buffer_allocation: 50,
610 forward_count: 0,
611 },
612 }
613 );
614 println!("Guest sending");
615 socket
616 .send(host_address, guest_port, "Hello from guest".as_bytes())
617 .unwrap();
618 println!("Guest waiting to receive.");
619 assert_eq!(
620 socket.wait_for_event().unwrap(),
621 VsockEvent {
622 source: host_address,
623 destination: VsockAddr {
624 cid: guest_cid,
625 port: guest_port,
626 },
627 event_type: VsockEventType::Received {
628 length: hello_from_host.len()
629 },
630 buffer_status: VsockBufferStatus {
631 buffer_allocation: 50,
632 forward_count: hello_from_guest.len() as u32,
633 },
634 }
635 );
636 println!("Guest getting received data.");
637 let mut buffer = [0u8; 64];
638 assert_eq!(
639 socket.recv(host_address, guest_port, &mut buffer).unwrap(),
640 hello_from_host.len()
641 );
642 assert_eq!(
643 &buffer[0..hello_from_host.len()],
644 hello_from_host.as_bytes()
645 );
646 socket.shutdown(host_address, guest_port).unwrap();
647
648 handle.join().unwrap();
649 }
650
651 #[test]
652 fn incoming_connection() {
653 let host_cid = 2;
654 let guest_cid = 66;
655 let host_port = 1234;
656 let guest_port = 4321;
657 let wrong_guest_port = 4444;
658 let host_address = VsockAddr {
659 cid: host_cid,
660 port: host_port,
661 };
662
663 let config_space = VirtioVsockConfig {
664 guest_cid_low: ReadOnly::new(66),
665 guest_cid_high: ReadOnly::new(0),
666 };
667 let state = Arc::new(Mutex::new(State::new(
668 vec![
669 QueueStatus::default(),
670 QueueStatus::default(),
671 QueueStatus::default(),
672 ],
673 config_space,
674 )));
675 let transport = FakeTransport {
676 device_type: DeviceType::Socket,
677 max_queue_size: 32,
678 device_features: 0,
679 state: state.clone(),
680 };
681 let mut socket = VsockConnectionManager::new(
682 VirtIOSocket::<FakeHal, FakeTransport<VirtioVsockConfig>>::new(transport).unwrap(),
683 );
684
685 socket.listen(guest_port);
686
687 let handle = thread::spawn(move || {
689 println!("Host sending connection request to wrong port");
691 state.lock().unwrap().write_to_queue::<QUEUE_SIZE>(
692 RX_QUEUE_IDX,
693 VirtioVsockHdr {
694 op: VirtioVsockOp::Request.into(),
695 src_cid: host_cid.into(),
696 dst_cid: guest_cid.into(),
697 src_port: host_port.into(),
698 dst_port: wrong_guest_port.into(),
699 len: 0.into(),
700 socket_type: SocketType::Stream.into(),
701 flags: 0.into(),
702 buf_alloc: 50.into(),
703 fwd_cnt: 0.into(),
704 }
705 .as_bytes(),
706 );
707
708 println!("Host waiting for rejection");
710 State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
711 assert_eq!(
712 VirtioVsockHdr::read_from_bytes(
713 state
714 .lock()
715 .unwrap()
716 .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
717 .as_slice()
718 )
719 .unwrap(),
720 VirtioVsockHdr {
721 op: VirtioVsockOp::Rst.into(),
722 src_cid: guest_cid.into(),
723 dst_cid: host_cid.into(),
724 src_port: wrong_guest_port.into(),
725 dst_port: host_port.into(),
726 len: 0.into(),
727 socket_type: SocketType::Stream.into(),
728 flags: 0.into(),
729 buf_alloc: 1024.into(),
730 fwd_cnt: 0.into(),
731 }
732 );
733
734 println!("Host sending connection request to right port");
736 state.lock().unwrap().write_to_queue::<QUEUE_SIZE>(
737 RX_QUEUE_IDX,
738 VirtioVsockHdr {
739 op: VirtioVsockOp::Request.into(),
740 src_cid: host_cid.into(),
741 dst_cid: guest_cid.into(),
742 src_port: host_port.into(),
743 dst_port: guest_port.into(),
744 len: 0.into(),
745 socket_type: SocketType::Stream.into(),
746 flags: 0.into(),
747 buf_alloc: 50.into(),
748 fwd_cnt: 0.into(),
749 }
750 .as_bytes(),
751 );
752
753 println!("Host waiting for response");
755 State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
756 assert_eq!(
757 VirtioVsockHdr::read_from_bytes(
758 state
759 .lock()
760 .unwrap()
761 .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
762 .as_slice()
763 )
764 .unwrap(),
765 VirtioVsockHdr {
766 op: VirtioVsockOp::Response.into(),
767 src_cid: guest_cid.into(),
768 dst_cid: host_cid.into(),
769 src_port: guest_port.into(),
770 dst_port: host_port.into(),
771 len: 0.into(),
772 socket_type: SocketType::Stream.into(),
773 flags: 0.into(),
774 buf_alloc: 1024.into(),
775 fwd_cnt: 0.into(),
776 }
777 );
778
779 println!("Host finished");
780 });
781
782 println!("Guest expecting incoming connection.");
784 assert_eq!(
785 socket.wait_for_event().unwrap(),
786 VsockEvent {
787 source: host_address,
788 destination: VsockAddr {
789 cid: guest_cid,
790 port: guest_port,
791 },
792 event_type: VsockEventType::ConnectionRequest,
793 buffer_status: VsockBufferStatus {
794 buffer_allocation: 50,
795 forward_count: 0,
796 },
797 }
798 );
799
800 handle.join().unwrap();
801 }
802}