1#![deny(unsafe_op_in_unsafe_fn)]
3
4use super::error::SocketError;
5use super::protocol::{
6 Feature, StreamShutdown, VirtioVsockConfig, VirtioVsockHdr, VirtioVsockOp, VsockAddr,
7};
8use super::DEFAULT_RX_BUFFER_SIZE;
9use crate::config::read_config;
10use crate::hal::Hal;
11use crate::queue::{owning::OwningQueue, VirtQueue};
12use crate::transport::Transport;
13use crate::Result;
14use core::mem::size_of;
15use log::debug;
16use zerocopy::{FromBytes, IntoBytes};
17
18pub(crate) const RX_QUEUE_IDX: u16 = 0;
19pub(crate) const TX_QUEUE_IDX: u16 = 1;
20const EVENT_QUEUE_IDX: u16 = 2;
21
22pub(crate) const QUEUE_SIZE: usize = 8;
23const SUPPORTED_FEATURES: Feature = Feature::RING_EVENT_IDX
24 .union(Feature::RING_INDIRECT_DESC)
25 .union(Feature::VERSION_1);
26
27#[derive(Clone, Debug, Default, PartialEq, Eq)]
29pub struct ConnectionInfo {
30 pub dst: VsockAddr,
32 pub src_port: u32,
34 peer_buf_alloc: u32,
37 peer_fwd_cnt: u32,
40 tx_cnt: u32,
42 pub buf_alloc: u32,
45 fwd_cnt: u32,
47 has_pending_credit_request: bool,
52}
53
54impl ConnectionInfo {
55 pub fn new(destination: VsockAddr, src_port: u32) -> Self {
58 Self {
59 dst: destination,
60 src_port,
61 ..Default::default()
62 }
63 }
64
65 pub fn update_for_event(&mut self, event: &VsockEvent) {
68 self.peer_buf_alloc = event.buffer_status.buffer_allocation;
69 self.peer_fwd_cnt = event.buffer_status.forward_count;
70
71 if let VsockEventType::CreditUpdate = event.event_type {
72 self.has_pending_credit_request = false;
73 }
74 }
75
76 pub fn done_forwarding(&mut self, length: usize) {
81 self.fwd_cnt += length as u32;
82 }
83
84 fn peer_free(&self) -> u32 {
87 self.peer_buf_alloc - (self.tx_cnt - self.peer_fwd_cnt)
88 }
89
90 fn new_header(&self, src_cid: u64) -> VirtioVsockHdr {
91 VirtioVsockHdr {
92 src_cid: src_cid.into(),
93 dst_cid: self.dst.cid.into(),
94 src_port: self.src_port.into(),
95 dst_port: self.dst.port.into(),
96 buf_alloc: self.buf_alloc.into(),
97 fwd_cnt: self.fwd_cnt.into(),
98 ..Default::default()
99 }
100 }
101}
102
103#[derive(Clone, Debug, Eq, PartialEq)]
105pub struct VsockEvent {
106 pub source: VsockAddr,
108 pub destination: VsockAddr,
110 pub buffer_status: VsockBufferStatus,
112 pub event_type: VsockEventType,
114}
115
116impl VsockEvent {
117 pub fn matches_connection(&self, connection_info: &ConnectionInfo, guest_cid: u64) -> bool {
119 self.source == connection_info.dst
120 && self.destination.cid == guest_cid
121 && self.destination.port == connection_info.src_port
122 }
123
124 fn from_header(header: &VirtioVsockHdr) -> Result<Self> {
125 let op = header.op()?;
126 let buffer_status = VsockBufferStatus {
127 buffer_allocation: header.buf_alloc.into(),
128 forward_count: header.fwd_cnt.into(),
129 };
130 let source = header.source();
131 let destination = header.destination();
132
133 let event_type = match op {
134 VirtioVsockOp::Request => {
135 header.check_data_is_empty()?;
136 VsockEventType::ConnectionRequest
137 }
138 VirtioVsockOp::Response => {
139 header.check_data_is_empty()?;
140 VsockEventType::Connected
141 }
142 VirtioVsockOp::CreditUpdate => {
143 header.check_data_is_empty()?;
144 VsockEventType::CreditUpdate
145 }
146 VirtioVsockOp::Rst | VirtioVsockOp::Shutdown => {
147 header.check_data_is_empty()?;
148 debug!("Disconnected from the peer");
149 let reason = if op == VirtioVsockOp::Rst {
150 DisconnectReason::Reset
151 } else {
152 DisconnectReason::Shutdown
153 };
154 VsockEventType::Disconnected { reason }
155 }
156 VirtioVsockOp::Rw => VsockEventType::Received {
157 length: header.len() as usize,
158 },
159 VirtioVsockOp::CreditRequest => {
160 header.check_data_is_empty()?;
161 VsockEventType::CreditRequest
162 }
163 VirtioVsockOp::Invalid => return Err(SocketError::InvalidOperation.into()),
164 };
165
166 Ok(VsockEvent {
167 source,
168 destination,
169 buffer_status,
170 event_type,
171 })
172 }
173}
174
175#[derive(Clone, Debug, Eq, PartialEq)]
176pub struct VsockBufferStatus {
177 pub buffer_allocation: u32,
178 pub forward_count: u32,
179}
180
181#[derive(Copy, Clone, Debug, Eq, PartialEq)]
183pub enum DisconnectReason {
184 Reset,
187 Shutdown,
189}
190
191#[derive(Clone, Debug, Eq, PartialEq)]
193pub enum VsockEventType {
194 ConnectionRequest,
196 Connected,
198 Disconnected {
200 reason: DisconnectReason,
202 },
203 Received {
205 length: usize,
207 },
208 CreditRequest,
210 CreditUpdate,
212}
213
214pub struct VirtIOSocket<H: Hal, T: Transport, const RX_BUFFER_SIZE: usize = DEFAULT_RX_BUFFER_SIZE>
222{
223 transport: T,
224 rx: OwningQueue<H, QUEUE_SIZE, RX_BUFFER_SIZE>,
226 tx: VirtQueue<H, { QUEUE_SIZE }>,
227 event: VirtQueue<H, { QUEUE_SIZE }>,
229 guest_cid: u64,
232}
233
234impl<H: Hal, T: Transport, const RX_BUFFER_SIZE: usize> Drop
235 for VirtIOSocket<H, T, RX_BUFFER_SIZE>
236{
237 fn drop(&mut self) {
238 self.transport.queue_unset(RX_QUEUE_IDX);
241 self.transport.queue_unset(TX_QUEUE_IDX);
242 self.transport.queue_unset(EVENT_QUEUE_IDX);
243 }
244}
245
246impl<H: Hal, T: Transport, const RX_BUFFER_SIZE: usize> VirtIOSocket<H, T, RX_BUFFER_SIZE> {
247 pub fn new(mut transport: T) -> Result<Self> {
249 assert!(RX_BUFFER_SIZE > size_of::<VirtioVsockHdr>());
250
251 let negotiated_features = transport.begin_init(SUPPORTED_FEATURES);
252
253 let guest_cid = transport.read_consistent(|| {
254 Ok(
255 (read_config!(transport, VirtioVsockConfig, guest_cid_low)? as u64)
256 | ((read_config!(transport, VirtioVsockConfig, guest_cid_high)? as u64) << 32),
257 )
258 })?;
259 debug!("guest cid: {guest_cid:?}");
260
261 let rx = VirtQueue::new(
262 &mut transport,
263 RX_QUEUE_IDX,
264 negotiated_features.contains(Feature::RING_INDIRECT_DESC),
265 negotiated_features.contains(Feature::RING_EVENT_IDX),
266 )?;
267 let tx = VirtQueue::new(
268 &mut transport,
269 TX_QUEUE_IDX,
270 negotiated_features.contains(Feature::RING_INDIRECT_DESC),
271 negotiated_features.contains(Feature::RING_EVENT_IDX),
272 )?;
273 let event = VirtQueue::new(
274 &mut transport,
275 EVENT_QUEUE_IDX,
276 negotiated_features.contains(Feature::RING_INDIRECT_DESC),
277 negotiated_features.contains(Feature::RING_EVENT_IDX),
278 )?;
279
280 let rx = OwningQueue::new(rx)?;
281
282 transport.finish_init();
283 if rx.should_notify() {
284 transport.notify(RX_QUEUE_IDX);
285 }
286
287 Ok(Self {
288 transport,
289 rx,
290 tx,
291 event,
292 guest_cid,
293 })
294 }
295
296 pub fn guest_cid(&self) -> u64 {
298 self.guest_cid
299 }
300
301 pub fn connect(&mut self, connection_info: &ConnectionInfo) -> Result {
307 let header = VirtioVsockHdr {
308 op: VirtioVsockOp::Request.into(),
309 ..connection_info.new_header(self.guest_cid)
310 };
311 self.send_packet_to_tx_queue(&header, &[])
314 }
315
316 pub fn accept(&mut self, connection_info: &ConnectionInfo) -> Result {
318 let header = VirtioVsockHdr {
319 op: VirtioVsockOp::Response.into(),
320 ..connection_info.new_header(self.guest_cid)
321 };
322 self.send_packet_to_tx_queue(&header, &[])
323 }
324
325 fn request_credit(&mut self, connection_info: &ConnectionInfo) -> Result {
327 let header = VirtioVsockHdr {
328 op: VirtioVsockOp::CreditRequest.into(),
329 ..connection_info.new_header(self.guest_cid)
330 };
331 self.send_packet_to_tx_queue(&header, &[])
332 }
333
334 pub fn send(&mut self, buffer: &[u8], connection_info: &mut ConnectionInfo) -> Result {
336 self.check_peer_buffer_is_sufficient(connection_info, buffer.len())?;
337
338 let len = buffer.len() as u32;
339 let header = VirtioVsockHdr {
340 op: VirtioVsockOp::Rw.into(),
341 len: len.into(),
342 ..connection_info.new_header(self.guest_cid)
343 };
344 connection_info.tx_cnt += len;
345 self.send_packet_to_tx_queue(&header, buffer)
346 }
347
348 fn check_peer_buffer_is_sufficient(
349 &mut self,
350 connection_info: &mut ConnectionInfo,
351 buffer_len: usize,
352 ) -> Result {
353 if connection_info.peer_free() as usize >= buffer_len {
354 Ok(())
355 } else {
356 if !connection_info.has_pending_credit_request {
359 self.request_credit(connection_info)?;
360 connection_info.has_pending_credit_request = true;
361 }
362 Err(SocketError::InsufficientBufferSpaceInPeer.into())
363 }
364 }
365
366 pub fn credit_update(&mut self, connection_info: &ConnectionInfo) -> Result {
368 let header = VirtioVsockHdr {
369 op: VirtioVsockOp::CreditUpdate.into(),
370 ..connection_info.new_header(self.guest_cid)
371 };
372 self.send_packet_to_tx_queue(&header, &[])
373 }
374
375 pub fn poll(
378 &mut self,
379 handler: impl FnOnce(VsockEvent, &[u8]) -> Result<Option<VsockEvent>>,
380 ) -> Result<Option<VsockEvent>> {
381 self.rx.poll(&mut self.transport, |buffer| {
382 let (header, body) = read_header_and_body(buffer)?;
383 VsockEvent::from_header(&header).and_then(|event| handler(event, body))
384 })
385 }
386
387 pub fn shutdown_with_hints(
394 &mut self,
395 connection_info: &ConnectionInfo,
396 hints: StreamShutdown,
397 ) -> Result {
398 let header = VirtioVsockHdr {
399 op: VirtioVsockOp::Shutdown.into(),
400 flags: hints.into(),
401 ..connection_info.new_header(self.guest_cid)
402 };
403 self.send_packet_to_tx_queue(&header, &[])
404 }
405
406 pub fn shutdown(&mut self, connection_info: &ConnectionInfo) -> Result {
413 self.shutdown_with_hints(
414 connection_info,
415 StreamShutdown::SEND | StreamShutdown::RECEIVE,
416 )
417 }
418
419 pub fn force_close(&mut self, connection_info: &ConnectionInfo) -> Result {
421 let header = VirtioVsockHdr {
422 op: VirtioVsockOp::Rst.into(),
423 ..connection_info.new_header(self.guest_cid)
424 };
425 self.send_packet_to_tx_queue(&header, &[])?;
426 Ok(())
427 }
428
429 fn send_packet_to_tx_queue(&mut self, header: &VirtioVsockHdr, buffer: &[u8]) -> Result {
430 let _len = if buffer.is_empty() {
431 self.tx
432 .add_notify_wait_pop(&[header.as_bytes()], &mut [], &mut self.transport)?
433 } else {
434 self.tx.add_notify_wait_pop(
435 &[header.as_bytes(), buffer],
436 &mut [],
437 &mut self.transport,
438 )?
439 };
440 Ok(())
441 }
442}
443
444fn read_header_and_body(buffer: &[u8]) -> Result<(VirtioVsockHdr, &[u8])> {
445 let header = VirtioVsockHdr::read_from_prefix(buffer)
447 .map_err(|_| SocketError::BufferTooShort)?
448 .0;
449 let body_length = header.len() as usize;
450
451 let data_end = size_of::<VirtioVsockHdr>()
453 .checked_add(body_length)
454 .ok_or(SocketError::InvalidNumber)?;
455 let data = buffer
458 .get(size_of::<VirtioVsockHdr>()..data_end)
459 .ok_or(SocketError::BufferTooShort)?;
460 Ok((header, data))
461}
462
463#[cfg(test)]
464mod tests {
465 use super::*;
466 use crate::{
467 config::ReadOnly,
468 hal::fake::FakeHal,
469 transport::{
470 fake::{FakeTransport, QueueStatus, State},
471 DeviceType,
472 },
473 };
474 use alloc::{sync::Arc, vec};
475 use std::sync::Mutex;
476
477 #[test]
478 fn config() {
479 let config_space = VirtioVsockConfig {
480 guest_cid_low: ReadOnly::new(66),
481 guest_cid_high: ReadOnly::new(0),
482 };
483 let state = Arc::new(Mutex::new(State::new(
484 vec![
485 QueueStatus::default(),
486 QueueStatus::default(),
487 QueueStatus::default(),
488 ],
489 config_space,
490 )));
491 let transport = FakeTransport {
492 device_type: DeviceType::Socket,
493 max_queue_size: 32,
494 device_features: 0,
495 state: state.clone(),
496 };
497 let socket =
498 VirtIOSocket::<FakeHal, FakeTransport<VirtioVsockConfig>>::new(transport).unwrap();
499 assert_eq!(socket.guest_cid(), 0x00_0000_0042);
500 }
501}