1use super::{
2 protocol::VsockAddr, vsock::ConnectionInfo, DisconnectReason, SocketError, VirtIOSocket,
3 VsockEvent, VsockEventType, DEFAULT_RX_BUFFER_SIZE,
4};
5use crate::{transport::Transport, Hal, Result};
6use alloc::{boxed::Box, vec::Vec};
7use core::cmp::min;
8use core::convert::TryInto;
9use core::hint::spin_loop;
10use log::debug;
11use zerocopy::FromZeroes;
12
13const DEFAULT_PER_CONNECTION_BUFFER_CAPACITY: u32 = 1024;
14
15pub struct VsockConnectionManager<
47 H: Hal,
48 T: Transport,
49 const RX_BUFFER_SIZE: usize = DEFAULT_RX_BUFFER_SIZE,
50> {
51 driver: VirtIOSocket<H, T, RX_BUFFER_SIZE>,
52 per_connection_buffer_capacity: u32,
53 connections: Vec<Connection>,
54 listening_ports: Vec<u32>,
55}
56
57#[derive(Debug)]
58struct Connection {
59 info: ConnectionInfo,
60 buffer: RingBuffer,
61 peer_requested_shutdown: bool,
64}
65
66impl Connection {
67 fn new(peer: VsockAddr, local_port: u32, buffer_capacity: u32) -> Self {
68 let mut info = ConnectionInfo::new(peer, local_port);
69 info.buf_alloc = buffer_capacity;
70 Self {
71 info,
72 buffer: RingBuffer::new(buffer_capacity.try_into().unwrap()),
73 peer_requested_shutdown: false,
74 }
75 }
76}
77
78impl<H: Hal, T: Transport, const RX_BUFFER_SIZE: usize>
79 VsockConnectionManager<H, T, RX_BUFFER_SIZE>
80{
81 pub fn new(driver: VirtIOSocket<H, T, RX_BUFFER_SIZE>) -> Self {
83 Self::new_with_capacity(driver, DEFAULT_PER_CONNECTION_BUFFER_CAPACITY)
84 }
85
86 pub fn new_with_capacity(
89 driver: VirtIOSocket<H, T, RX_BUFFER_SIZE>,
90 per_connection_buffer_capacity: u32,
91 ) -> Self {
92 Self {
93 driver,
94 connections: Vec::new(),
95 listening_ports: Vec::new(),
96 per_connection_buffer_capacity,
97 }
98 }
99
100 pub fn guest_cid(&self) -> u64 {
102 self.driver.guest_cid()
103 }
104
105 pub fn listen(&mut self, port: u32) {
107 if !self.listening_ports.contains(&port) {
108 self.listening_ports.push(port);
109 }
110 }
111
112 pub fn unlisten(&mut self, port: u32) {
114 self.listening_ports.retain(|p| *p != port);
115 }
116
117 pub fn connect(&mut self, destination: VsockAddr, src_port: u32) -> Result {
123 if self.connections.iter().any(|connection| {
124 connection.info.dst == destination && connection.info.src_port == src_port
125 }) {
126 return Err(SocketError::ConnectionExists.into());
127 }
128
129 let new_connection =
130 Connection::new(destination, src_port, self.per_connection_buffer_capacity);
131
132 self.driver.connect(&new_connection.info)?;
133 debug!("Connection requested: {:?}", new_connection.info);
134 self.connections.push(new_connection);
135 Ok(())
136 }
137
138 pub fn send(&mut self, destination: VsockAddr, src_port: u32, buffer: &[u8]) -> Result {
140 let (_, connection) = get_connection(&mut self.connections, destination, src_port)?;
141
142 self.driver.send(buffer, &mut connection.info)
143 }
144
145 pub fn poll(&mut self) -> Result<Option<VsockEvent>> {
147 let guest_cid = self.driver.guest_cid();
148 let connections = &mut self.connections;
149 let per_connection_buffer_capacity = self.per_connection_buffer_capacity;
150
151 let result = self.driver.poll(|event, body| {
152 let connection = get_connection_for_event(connections, &event, guest_cid);
153
154 let connection = if let Some((_, connection)) = connection {
157 connection
158 } else if let VsockEventType::ConnectionRequest = event.event_type {
159 if connection.is_some() || event.destination.cid != guest_cid {
161 return Ok(None);
162 }
163 connections.push(Connection::new(
166 event.source,
167 event.destination.port,
168 per_connection_buffer_capacity,
169 ));
170 connections.last_mut().unwrap()
171 } else {
172 return Ok(None);
173 };
174
175 connection.info.update_for_event(&event);
177
178 if let VsockEventType::Received { length } = event.event_type {
179 if !connection.buffer.add(body) {
181 return Err(SocketError::OutputBufferTooShort(length).into());
182 }
183 }
184
185 Ok(Some(event))
186 })?;
187
188 let Some(event) = result else {
189 return Ok(None);
190 };
191
192 let (connection_index, connection) =
194 get_connection_for_event(connections, &event, guest_cid).unwrap();
195
196 match event.event_type {
197 VsockEventType::ConnectionRequest => {
198 if self.listening_ports.contains(&event.destination.port) {
199 self.driver.accept(&connection.info)?;
200 } else {
201 self.driver.force_close(&connection.info)?;
203 self.connections.swap_remove(connection_index);
204
205 return Ok(None);
207 }
208 }
209 VsockEventType::Connected => {}
210 VsockEventType::Disconnected { reason } => {
211 if connection.buffer.is_empty() {
213 if reason == DisconnectReason::Shutdown {
214 self.driver.force_close(&connection.info)?;
215 }
216 self.connections.swap_remove(connection_index);
217 } else {
218 connection.peer_requested_shutdown = true;
219 }
220 }
221 VsockEventType::Received { .. } => {
222 }
224 VsockEventType::CreditRequest => {
225 self.driver.credit_update(&connection.info)?;
227 return Ok(None);
229 }
230 VsockEventType::CreditUpdate => {}
231 }
232
233 Ok(Some(event))
234 }
235
236 pub fn recv(&mut self, peer: VsockAddr, src_port: u32, buffer: &mut [u8]) -> Result<usize> {
238 let (connection_index, connection) = get_connection(&mut self.connections, peer, src_port)?;
239
240 let bytes_read = connection.buffer.drain(buffer);
242
243 connection.info.done_forwarding(bytes_read);
244
245 if connection.peer_requested_shutdown && connection.buffer.is_empty() {
248 self.driver.force_close(&connection.info)?;
249 self.connections.swap_remove(connection_index);
250 }
251
252 Ok(bytes_read)
253 }
254
255 pub fn recv_buffer_available_bytes(&mut self, peer: VsockAddr, src_port: u32) -> Result<usize> {
260 let (_, connection) = get_connection(&mut self.connections, peer, src_port)?;
261 Ok(connection.buffer.used())
262 }
263
264 pub fn update_credit(&mut self, peer: VsockAddr, src_port: u32) -> Result {
266 let (_, connection) = get_connection(&mut self.connections, peer, src_port)?;
267 self.driver.credit_update(&connection.info)
268 }
269
270 pub fn wait_for_event(&mut self) -> Result<VsockEvent> {
272 loop {
273 if let Some(event) = self.poll()? {
274 return Ok(event);
275 } else {
276 spin_loop();
277 }
278 }
279 }
280
281 pub fn shutdown(&mut self, destination: VsockAddr, src_port: u32) -> Result {
288 let (_, connection) = get_connection(&mut self.connections, destination, src_port)?;
289
290 self.driver.shutdown(&connection.info)
291 }
292
293 pub fn force_close(&mut self, destination: VsockAddr, src_port: u32) -> Result {
295 let (index, connection) = get_connection(&mut self.connections, destination, src_port)?;
296
297 self.driver.force_close(&connection.info)?;
298
299 self.connections.swap_remove(index);
300 Ok(())
301 }
302}
303
304fn get_connection(
309 connections: &mut [Connection],
310 peer: VsockAddr,
311 local_port: u32,
312) -> core::result::Result<(usize, &mut Connection), SocketError> {
313 connections
314 .iter_mut()
315 .enumerate()
316 .find(|(_, connection)| {
317 connection.info.dst == peer && connection.info.src_port == local_port
318 })
319 .ok_or(SocketError::NotConnected)
320}
321
322fn get_connection_for_event<'a>(
324 connections: &'a mut [Connection],
325 event: &VsockEvent,
326 local_cid: u64,
327) -> Option<(usize, &'a mut Connection)> {
328 connections
329 .iter_mut()
330 .enumerate()
331 .find(|(_, connection)| event.matches_connection(&connection.info, local_cid))
332}
333
334#[derive(Debug)]
335struct RingBuffer {
336 buffer: Box<[u8]>,
337 used: usize,
339 start: usize,
341}
342
343impl RingBuffer {
344 pub fn new(capacity: usize) -> Self {
345 Self {
346 buffer: FromZeroes::new_box_slice_zeroed(capacity),
347 used: 0,
348 start: 0,
349 }
350 }
351
352 pub fn used(&self) -> usize {
354 self.used
355 }
356
357 pub fn is_empty(&self) -> bool {
359 self.used == 0
360 }
361
362 pub fn free(&self) -> usize {
364 self.buffer.len() - self.used
365 }
366
367 pub fn add(&mut self, bytes: &[u8]) -> bool {
371 if bytes.len() > self.free() {
372 return false;
373 }
374
375 let first_available = (self.start + self.used) % self.buffer.len();
377 let copy_length_before_wraparound = min(bytes.len(), self.buffer.len() - first_available);
380 self.buffer[first_available..first_available + copy_length_before_wraparound]
381 .copy_from_slice(&bytes[0..copy_length_before_wraparound]);
382 if let Some(bytes_after_wraparound) = bytes.get(copy_length_before_wraparound..) {
383 self.buffer[0..bytes_after_wraparound.len()].copy_from_slice(bytes_after_wraparound);
384 }
385 self.used += bytes.len();
386
387 true
388 }
389
390 pub fn drain(&mut self, out: &mut [u8]) -> usize {
393 let bytes_read = min(self.used, out.len());
394
395 let read_before_wraparound = min(bytes_read, self.buffer.len() - self.start);
397 let read_after_wraparound = bytes_read
399 .checked_sub(read_before_wraparound)
400 .unwrap_or_default();
401
402 out[0..read_before_wraparound]
403 .copy_from_slice(&self.buffer[self.start..self.start + read_before_wraparound]);
404 out[read_before_wraparound..bytes_read]
405 .copy_from_slice(&self.buffer[0..read_after_wraparound]);
406
407 self.used -= bytes_read;
408 self.start = (self.start + bytes_read) % self.buffer.len();
409
410 bytes_read
411 }
412}
413
414#[cfg(test)]
415mod tests {
416 use super::*;
417 use crate::{
418 device::socket::{
419 protocol::{
420 SocketType, StreamShutdown, VirtioVsockConfig, VirtioVsockHdr, VirtioVsockOp,
421 },
422 vsock::{VsockBufferStatus, QUEUE_SIZE, RX_QUEUE_IDX, TX_QUEUE_IDX},
423 },
424 hal::fake::FakeHal,
425 transport::{
426 fake::{FakeTransport, QueueStatus, State},
427 DeviceType,
428 },
429 volatile::ReadOnly,
430 };
431 use alloc::{sync::Arc, vec};
432 use core::{mem::size_of, ptr::NonNull};
433 use std::{sync::Mutex, thread};
434 use zerocopy::{AsBytes, FromBytes};
435
436 #[test]
437 fn send_recv() {
438 let host_cid = 2;
439 let guest_cid = 66;
440 let host_port = 1234;
441 let guest_port = 4321;
442 let host_address = VsockAddr {
443 cid: host_cid,
444 port: host_port,
445 };
446 let hello_from_guest = "Hello from guest";
447 let hello_from_host = "Hello from host";
448
449 let mut config_space = VirtioVsockConfig {
450 guest_cid_low: ReadOnly::new(66),
451 guest_cid_high: ReadOnly::new(0),
452 };
453 let state = Arc::new(Mutex::new(State {
454 queues: vec![
455 QueueStatus::default(),
456 QueueStatus::default(),
457 QueueStatus::default(),
458 ],
459 ..Default::default()
460 }));
461 let transport = FakeTransport {
462 device_type: DeviceType::Socket,
463 max_queue_size: 32,
464 device_features: 0,
465 config_space: NonNull::from(&mut config_space),
466 state: state.clone(),
467 };
468 let mut socket = VsockConnectionManager::new(
469 VirtIOSocket::<FakeHal, FakeTransport<VirtioVsockConfig>>::new(transport).unwrap(),
470 );
471
472 let handle = thread::spawn(move || {
474 State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
476 assert_eq!(
477 VirtioVsockHdr::read_from(
478 state
479 .lock()
480 .unwrap()
481 .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
482 .as_slice()
483 )
484 .unwrap(),
485 VirtioVsockHdr {
486 op: VirtioVsockOp::Request.into(),
487 src_cid: guest_cid.into(),
488 dst_cid: host_cid.into(),
489 src_port: guest_port.into(),
490 dst_port: host_port.into(),
491 len: 0.into(),
492 socket_type: SocketType::Stream.into(),
493 flags: 0.into(),
494 buf_alloc: 1024.into(),
495 fwd_cnt: 0.into(),
496 }
497 );
498
499 state.lock().unwrap().write_to_queue::<QUEUE_SIZE>(
501 RX_QUEUE_IDX,
502 VirtioVsockHdr {
503 op: VirtioVsockOp::Response.into(),
504 src_cid: host_cid.into(),
505 dst_cid: guest_cid.into(),
506 src_port: host_port.into(),
507 dst_port: guest_port.into(),
508 len: 0.into(),
509 socket_type: SocketType::Stream.into(),
510 flags: 0.into(),
511 buf_alloc: 50.into(),
512 fwd_cnt: 0.into(),
513 }
514 .as_bytes(),
515 );
516
517 State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
519 let request = state
520 .lock()
521 .unwrap()
522 .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX);
523 assert_eq!(
524 request.len(),
525 size_of::<VirtioVsockHdr>() + hello_from_guest.len()
526 );
527 assert_eq!(
528 VirtioVsockHdr::read_from_prefix(request.as_slice()).unwrap(),
529 VirtioVsockHdr {
530 op: VirtioVsockOp::Rw.into(),
531 src_cid: guest_cid.into(),
532 dst_cid: host_cid.into(),
533 src_port: guest_port.into(),
534 dst_port: host_port.into(),
535 len: (hello_from_guest.len() as u32).into(),
536 socket_type: SocketType::Stream.into(),
537 flags: 0.into(),
538 buf_alloc: 1024.into(),
539 fwd_cnt: 0.into(),
540 }
541 );
542 assert_eq!(
543 &request[size_of::<VirtioVsockHdr>()..],
544 hello_from_guest.as_bytes()
545 );
546
547 println!("Host sending");
548
549 let mut response = vec![0; size_of::<VirtioVsockHdr>() + hello_from_host.len()];
551 VirtioVsockHdr {
552 op: VirtioVsockOp::Rw.into(),
553 src_cid: host_cid.into(),
554 dst_cid: guest_cid.into(),
555 src_port: host_port.into(),
556 dst_port: guest_port.into(),
557 len: (hello_from_host.len() as u32).into(),
558 socket_type: SocketType::Stream.into(),
559 flags: 0.into(),
560 buf_alloc: 50.into(),
561 fwd_cnt: (hello_from_guest.len() as u32).into(),
562 }
563 .write_to_prefix(response.as_mut_slice());
564 response[size_of::<VirtioVsockHdr>()..].copy_from_slice(hello_from_host.as_bytes());
565 state
566 .lock()
567 .unwrap()
568 .write_to_queue::<QUEUE_SIZE>(RX_QUEUE_IDX, &response);
569
570 State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
572 assert_eq!(
573 VirtioVsockHdr::read_from(
574 state
575 .lock()
576 .unwrap()
577 .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
578 .as_slice()
579 )
580 .unwrap(),
581 VirtioVsockHdr {
582 op: VirtioVsockOp::Shutdown.into(),
583 src_cid: guest_cid.into(),
584 dst_cid: host_cid.into(),
585 src_port: guest_port.into(),
586 dst_port: host_port.into(),
587 len: 0.into(),
588 socket_type: SocketType::Stream.into(),
589 flags: (StreamShutdown::SEND | StreamShutdown::RECEIVE).into(),
590 buf_alloc: 1024.into(),
591 fwd_cnt: (hello_from_host.len() as u32).into(),
592 }
593 );
594 });
595
596 socket.connect(host_address, guest_port).unwrap();
597 assert_eq!(
598 socket.wait_for_event().unwrap(),
599 VsockEvent {
600 source: host_address,
601 destination: VsockAddr {
602 cid: guest_cid,
603 port: guest_port,
604 },
605 event_type: VsockEventType::Connected,
606 buffer_status: VsockBufferStatus {
607 buffer_allocation: 50,
608 forward_count: 0,
609 },
610 }
611 );
612 println!("Guest sending");
613 socket
614 .send(host_address, guest_port, "Hello from guest".as_bytes())
615 .unwrap();
616 println!("Guest waiting to receive.");
617 assert_eq!(
618 socket.wait_for_event().unwrap(),
619 VsockEvent {
620 source: host_address,
621 destination: VsockAddr {
622 cid: guest_cid,
623 port: guest_port,
624 },
625 event_type: VsockEventType::Received {
626 length: hello_from_host.len()
627 },
628 buffer_status: VsockBufferStatus {
629 buffer_allocation: 50,
630 forward_count: hello_from_guest.len() as u32,
631 },
632 }
633 );
634 println!("Guest getting received data.");
635 let mut buffer = [0u8; 64];
636 assert_eq!(
637 socket.recv(host_address, guest_port, &mut buffer).unwrap(),
638 hello_from_host.len()
639 );
640 assert_eq!(
641 &buffer[0..hello_from_host.len()],
642 hello_from_host.as_bytes()
643 );
644 socket.shutdown(host_address, guest_port).unwrap();
645
646 handle.join().unwrap();
647 }
648
649 #[test]
650 fn incoming_connection() {
651 let host_cid = 2;
652 let guest_cid = 66;
653 let host_port = 1234;
654 let guest_port = 4321;
655 let wrong_guest_port = 4444;
656 let host_address = VsockAddr {
657 cid: host_cid,
658 port: host_port,
659 };
660
661 let mut config_space = VirtioVsockConfig {
662 guest_cid_low: ReadOnly::new(66),
663 guest_cid_high: ReadOnly::new(0),
664 };
665 let state = Arc::new(Mutex::new(State {
666 queues: vec![
667 QueueStatus::default(),
668 QueueStatus::default(),
669 QueueStatus::default(),
670 ],
671 ..Default::default()
672 }));
673 let transport = FakeTransport {
674 device_type: DeviceType::Socket,
675 max_queue_size: 32,
676 device_features: 0,
677 config_space: NonNull::from(&mut config_space),
678 state: state.clone(),
679 };
680 let mut socket = VsockConnectionManager::new(
681 VirtIOSocket::<FakeHal, FakeTransport<VirtioVsockConfig>>::new(transport).unwrap(),
682 );
683
684 socket.listen(guest_port);
685
686 let handle = thread::spawn(move || {
688 println!("Host sending connection request to wrong port");
690 state.lock().unwrap().write_to_queue::<QUEUE_SIZE>(
691 RX_QUEUE_IDX,
692 VirtioVsockHdr {
693 op: VirtioVsockOp::Request.into(),
694 src_cid: host_cid.into(),
695 dst_cid: guest_cid.into(),
696 src_port: host_port.into(),
697 dst_port: wrong_guest_port.into(),
698 len: 0.into(),
699 socket_type: SocketType::Stream.into(),
700 flags: 0.into(),
701 buf_alloc: 50.into(),
702 fwd_cnt: 0.into(),
703 }
704 .as_bytes(),
705 );
706
707 println!("Host waiting for rejection");
709 State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
710 assert_eq!(
711 VirtioVsockHdr::read_from(
712 state
713 .lock()
714 .unwrap()
715 .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
716 .as_slice()
717 )
718 .unwrap(),
719 VirtioVsockHdr {
720 op: VirtioVsockOp::Rst.into(),
721 src_cid: guest_cid.into(),
722 dst_cid: host_cid.into(),
723 src_port: wrong_guest_port.into(),
724 dst_port: host_port.into(),
725 len: 0.into(),
726 socket_type: SocketType::Stream.into(),
727 flags: 0.into(),
728 buf_alloc: 1024.into(),
729 fwd_cnt: 0.into(),
730 }
731 );
732
733 println!("Host sending connection request to right port");
735 state.lock().unwrap().write_to_queue::<QUEUE_SIZE>(
736 RX_QUEUE_IDX,
737 VirtioVsockHdr {
738 op: VirtioVsockOp::Request.into(),
739 src_cid: host_cid.into(),
740 dst_cid: guest_cid.into(),
741 src_port: host_port.into(),
742 dst_port: guest_port.into(),
743 len: 0.into(),
744 socket_type: SocketType::Stream.into(),
745 flags: 0.into(),
746 buf_alloc: 50.into(),
747 fwd_cnt: 0.into(),
748 }
749 .as_bytes(),
750 );
751
752 println!("Host waiting for response");
754 State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
755 assert_eq!(
756 VirtioVsockHdr::read_from(
757 state
758 .lock()
759 .unwrap()
760 .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
761 .as_slice()
762 )
763 .unwrap(),
764 VirtioVsockHdr {
765 op: VirtioVsockOp::Response.into(),
766 src_cid: guest_cid.into(),
767 dst_cid: host_cid.into(),
768 src_port: guest_port.into(),
769 dst_port: host_port.into(),
770 len: 0.into(),
771 socket_type: SocketType::Stream.into(),
772 flags: 0.into(),
773 buf_alloc: 1024.into(),
774 fwd_cnt: 0.into(),
775 }
776 );
777
778 println!("Host finished");
779 });
780
781 println!("Guest expecting incoming connection.");
783 assert_eq!(
784 socket.wait_for_event().unwrap(),
785 VsockEvent {
786 source: host_address,
787 destination: VsockAddr {
788 cid: guest_cid,
789 port: guest_port,
790 },
791 event_type: VsockEventType::ConnectionRequest,
792 buffer_status: VsockBufferStatus {
793 buffer_allocation: 50,
794 forward_count: 0,
795 },
796 }
797 );
798
799 handle.join().unwrap();
800 }
801}