virtio_drivers/device/socket/
vsock.rs

1//! Driver for VirtIO socket devices.
2#![deny(unsafe_op_in_unsafe_fn)]
3
4use super::error::SocketError;
5use super::protocol::{
6    Feature, StreamShutdown, VirtioVsockConfig, VirtioVsockHdr, VirtioVsockOp, VsockAddr,
7};
8use super::DEFAULT_RX_BUFFER_SIZE;
9use crate::hal::Hal;
10use crate::queue::{owning::OwningQueue, VirtQueue};
11use crate::transport::Transport;
12use crate::volatile::volread;
13use crate::Result;
14use core::mem::size_of;
15use log::debug;
16use zerocopy::{AsBytes, FromBytes};
17
18pub(crate) const RX_QUEUE_IDX: u16 = 0;
19pub(crate) const TX_QUEUE_IDX: u16 = 1;
20const EVENT_QUEUE_IDX: u16 = 2;
21
22pub(crate) const QUEUE_SIZE: usize = 8;
23const SUPPORTED_FEATURES: Feature = Feature::RING_EVENT_IDX.union(Feature::RING_INDIRECT_DESC);
24
25/// Information about a particular vsock connection.
26#[derive(Clone, Debug, Default, PartialEq, Eq)]
27pub struct ConnectionInfo {
28    /// The address of the peer.
29    pub dst: VsockAddr,
30    /// The local port number associated with the connection.
31    pub src_port: u32,
32    /// The last `buf_alloc` value the peer sent to us, indicating how much receive buffer space in
33    /// bytes it has allocated for packet bodies.
34    peer_buf_alloc: u32,
35    /// The last `fwd_cnt` value the peer sent to us, indicating how many bytes of packet bodies it
36    /// has finished processing.
37    peer_fwd_cnt: u32,
38    /// The number of bytes of packet bodies which we have sent to the peer.
39    tx_cnt: u32,
40    /// The number of bytes of buffer space we have allocated to receive packet bodies from the
41    /// peer.
42    pub buf_alloc: u32,
43    /// The number of bytes of packet bodies which we have received from the peer and handled.
44    fwd_cnt: u32,
45    /// Whether we have recently requested credit from the peer.
46    ///
47    /// This is set to true when we send a `VIRTIO_VSOCK_OP_CREDIT_REQUEST`, and false when we
48    /// receive a `VIRTIO_VSOCK_OP_CREDIT_UPDATE`.
49    has_pending_credit_request: bool,
50}
51
52impl ConnectionInfo {
53    /// Creates a new `ConnectionInfo` for the given peer address and local port, and default values
54    /// for everything else.
55    pub fn new(destination: VsockAddr, src_port: u32) -> Self {
56        Self {
57            dst: destination,
58            src_port,
59            ..Default::default()
60        }
61    }
62
63    /// Updates this connection info with the peer buffer allocation and forwarded count from the
64    /// given event.
65    pub fn update_for_event(&mut self, event: &VsockEvent) {
66        self.peer_buf_alloc = event.buffer_status.buffer_allocation;
67        self.peer_fwd_cnt = event.buffer_status.forward_count;
68
69        if let VsockEventType::CreditUpdate = event.event_type {
70            self.has_pending_credit_request = false;
71        }
72    }
73
74    /// Increases the forwarded count recorded for this connection by the given number of bytes.
75    ///
76    /// This should be called once received data has been passed to the client, so there is buffer
77    /// space available for more.
78    pub fn done_forwarding(&mut self, length: usize) {
79        self.fwd_cnt += length as u32;
80    }
81
82    /// Returns the number of bytes of RX buffer space the peer has available to receive packet body
83    /// data from us.
84    fn peer_free(&self) -> u32 {
85        self.peer_buf_alloc - (self.tx_cnt - self.peer_fwd_cnt)
86    }
87
88    fn new_header(&self, src_cid: u64) -> VirtioVsockHdr {
89        VirtioVsockHdr {
90            src_cid: src_cid.into(),
91            dst_cid: self.dst.cid.into(),
92            src_port: self.src_port.into(),
93            dst_port: self.dst.port.into(),
94            buf_alloc: self.buf_alloc.into(),
95            fwd_cnt: self.fwd_cnt.into(),
96            ..Default::default()
97        }
98    }
99}
100
101/// An event received from a VirtIO socket device.
102#[derive(Clone, Debug, Eq, PartialEq)]
103pub struct VsockEvent {
104    /// The source of the event, i.e. the peer who sent it.
105    pub source: VsockAddr,
106    /// The destination of the event, i.e. the CID and port on our side.
107    pub destination: VsockAddr,
108    /// The peer's buffer status for the connection.
109    pub buffer_status: VsockBufferStatus,
110    /// The type of event.
111    pub event_type: VsockEventType,
112}
113
114impl VsockEvent {
115    /// Returns whether the event matches the given connection.
116    pub fn matches_connection(&self, connection_info: &ConnectionInfo, guest_cid: u64) -> bool {
117        self.source == connection_info.dst
118            && self.destination.cid == guest_cid
119            && self.destination.port == connection_info.src_port
120    }
121
122    fn from_header(header: &VirtioVsockHdr) -> Result<Self> {
123        let op = header.op()?;
124        let buffer_status = VsockBufferStatus {
125            buffer_allocation: header.buf_alloc.into(),
126            forward_count: header.fwd_cnt.into(),
127        };
128        let source = header.source();
129        let destination = header.destination();
130
131        let event_type = match op {
132            VirtioVsockOp::Request => {
133                header.check_data_is_empty()?;
134                VsockEventType::ConnectionRequest
135            }
136            VirtioVsockOp::Response => {
137                header.check_data_is_empty()?;
138                VsockEventType::Connected
139            }
140            VirtioVsockOp::CreditUpdate => {
141                header.check_data_is_empty()?;
142                VsockEventType::CreditUpdate
143            }
144            VirtioVsockOp::Rst | VirtioVsockOp::Shutdown => {
145                header.check_data_is_empty()?;
146                debug!("Disconnected from the peer");
147                let reason = if op == VirtioVsockOp::Rst {
148                    DisconnectReason::Reset
149                } else {
150                    DisconnectReason::Shutdown
151                };
152                VsockEventType::Disconnected { reason }
153            }
154            VirtioVsockOp::Rw => VsockEventType::Received {
155                length: header.len() as usize,
156            },
157            VirtioVsockOp::CreditRequest => {
158                header.check_data_is_empty()?;
159                VsockEventType::CreditRequest
160            }
161            VirtioVsockOp::Invalid => return Err(SocketError::InvalidOperation.into()),
162        };
163
164        Ok(VsockEvent {
165            source,
166            destination,
167            buffer_status,
168            event_type,
169        })
170    }
171}
172
173#[derive(Clone, Debug, Eq, PartialEq)]
174pub struct VsockBufferStatus {
175    pub buffer_allocation: u32,
176    pub forward_count: u32,
177}
178
179/// The reason why a vsock connection was closed.
180#[derive(Copy, Clone, Debug, Eq, PartialEq)]
181pub enum DisconnectReason {
182    /// The peer has either closed the connection in response to our shutdown request, or forcibly
183    /// closed it of its own accord.
184    Reset,
185    /// The peer asked to shut down the connection.
186    Shutdown,
187}
188
189/// Details of the type of an event received from a VirtIO socket.
190#[derive(Clone, Debug, Eq, PartialEq)]
191pub enum VsockEventType {
192    /// The peer requests to establish a connection with us.
193    ConnectionRequest,
194    /// The connection was successfully established.
195    Connected,
196    /// The connection was closed.
197    Disconnected {
198        /// The reason for the disconnection.
199        reason: DisconnectReason,
200    },
201    /// Data was received on the connection.
202    Received {
203        /// The length of the data in bytes.
204        length: usize,
205    },
206    /// The peer requests us to send a credit update.
207    CreditRequest,
208    /// The peer just sent us a credit update with nothing else.
209    CreditUpdate,
210}
211
212/// Low-level driver for a VirtIO socket device.
213///
214/// You probably want to use [`VsockConnectionManager`](super::VsockConnectionManager) rather than
215/// using this directly.
216///
217/// `RX_BUFFER_SIZE` is the size in bytes of each buffer used in the RX virtqueue. This must be
218/// bigger than `size_of::<VirtioVsockHdr>()`.
219pub struct VirtIOSocket<H: Hal, T: Transport, const RX_BUFFER_SIZE: usize = DEFAULT_RX_BUFFER_SIZE>
220{
221    transport: T,
222    /// Virtqueue to receive packets.
223    rx: OwningQueue<H, QUEUE_SIZE, RX_BUFFER_SIZE>,
224    tx: VirtQueue<H, { QUEUE_SIZE }>,
225    /// Virtqueue to receive events from the device.
226    event: VirtQueue<H, { QUEUE_SIZE }>,
227    /// The guest_cid field contains the guest’s context ID, which uniquely identifies
228    /// the device for its lifetime. The upper 32 bits of the CID are reserved and zeroed.
229    guest_cid: u64,
230}
231
232impl<H: Hal, T: Transport, const RX_BUFFER_SIZE: usize> Drop
233    for VirtIOSocket<H, T, RX_BUFFER_SIZE>
234{
235    fn drop(&mut self) {
236        // Clear any pointers pointing to DMA regions, so the device doesn't try to access them
237        // after they have been freed.
238        self.transport.queue_unset(RX_QUEUE_IDX);
239        self.transport.queue_unset(TX_QUEUE_IDX);
240        self.transport.queue_unset(EVENT_QUEUE_IDX);
241    }
242}
243
244impl<H: Hal, T: Transport, const RX_BUFFER_SIZE: usize> VirtIOSocket<H, T, RX_BUFFER_SIZE> {
245    /// Create a new VirtIO Vsock driver.
246    pub fn new(mut transport: T) -> Result<Self> {
247        assert!(RX_BUFFER_SIZE > size_of::<VirtioVsockHdr>());
248
249        let negotiated_features = transport.begin_init(SUPPORTED_FEATURES);
250
251        let config = transport.config_space::<VirtioVsockConfig>()?;
252        debug!("config: {:?}", config);
253        // Safe because config is a valid pointer to the device configuration space.
254        let guest_cid = unsafe {
255            volread!(config, guest_cid_low) as u64 | (volread!(config, guest_cid_high) as u64) << 32
256        };
257        debug!("guest cid: {guest_cid:?}");
258
259        let rx = VirtQueue::new(
260            &mut transport,
261            RX_QUEUE_IDX,
262            negotiated_features.contains(Feature::RING_INDIRECT_DESC),
263            negotiated_features.contains(Feature::RING_EVENT_IDX),
264        )?;
265        let tx = VirtQueue::new(
266            &mut transport,
267            TX_QUEUE_IDX,
268            negotiated_features.contains(Feature::RING_INDIRECT_DESC),
269            negotiated_features.contains(Feature::RING_EVENT_IDX),
270        )?;
271        let event = VirtQueue::new(
272            &mut transport,
273            EVENT_QUEUE_IDX,
274            negotiated_features.contains(Feature::RING_INDIRECT_DESC),
275            negotiated_features.contains(Feature::RING_EVENT_IDX),
276        )?;
277
278        let rx = OwningQueue::new(rx)?;
279
280        transport.finish_init();
281        if rx.should_notify() {
282            transport.notify(RX_QUEUE_IDX);
283        }
284
285        Ok(Self {
286            transport,
287            rx,
288            tx,
289            event,
290            guest_cid,
291        })
292    }
293
294    /// Returns the CID which has been assigned to this guest.
295    pub fn guest_cid(&self) -> u64 {
296        self.guest_cid
297    }
298
299    /// Sends a request to connect to the given destination.
300    ///
301    /// This returns as soon as the request is sent; you should wait until `poll` returns a
302    /// `VsockEventType::Connected` event indicating that the peer has accepted the connection
303    /// before sending data.
304    pub fn connect(&mut self, connection_info: &ConnectionInfo) -> Result {
305        let header = VirtioVsockHdr {
306            op: VirtioVsockOp::Request.into(),
307            ..connection_info.new_header(self.guest_cid)
308        };
309        // Sends a header only packet to the TX queue to connect the device to the listening socket
310        // at the given destination.
311        self.send_packet_to_tx_queue(&header, &[])
312    }
313
314    /// Accepts the given connection from a peer.
315    pub fn accept(&mut self, connection_info: &ConnectionInfo) -> Result {
316        let header = VirtioVsockHdr {
317            op: VirtioVsockOp::Response.into(),
318            ..connection_info.new_header(self.guest_cid)
319        };
320        self.send_packet_to_tx_queue(&header, &[])
321    }
322
323    /// Requests the peer to send us a credit update for the given connection.
324    fn request_credit(&mut self, connection_info: &ConnectionInfo) -> Result {
325        let header = VirtioVsockHdr {
326            op: VirtioVsockOp::CreditRequest.into(),
327            ..connection_info.new_header(self.guest_cid)
328        };
329        self.send_packet_to_tx_queue(&header, &[])
330    }
331
332    /// Sends the buffer to the destination.
333    pub fn send(&mut self, buffer: &[u8], connection_info: &mut ConnectionInfo) -> Result {
334        self.check_peer_buffer_is_sufficient(connection_info, buffer.len())?;
335
336        let len = buffer.len() as u32;
337        let header = VirtioVsockHdr {
338            op: VirtioVsockOp::Rw.into(),
339            len: len.into(),
340            ..connection_info.new_header(self.guest_cid)
341        };
342        connection_info.tx_cnt += len;
343        self.send_packet_to_tx_queue(&header, buffer)
344    }
345
346    fn check_peer_buffer_is_sufficient(
347        &mut self,
348        connection_info: &mut ConnectionInfo,
349        buffer_len: usize,
350    ) -> Result {
351        if connection_info.peer_free() as usize >= buffer_len {
352            Ok(())
353        } else {
354            // Request an update of the cached peer credit, if we haven't already done so, and tell
355            // the caller to try again later.
356            if !connection_info.has_pending_credit_request {
357                self.request_credit(connection_info)?;
358                connection_info.has_pending_credit_request = true;
359            }
360            Err(SocketError::InsufficientBufferSpaceInPeer.into())
361        }
362    }
363
364    /// Tells the peer how much buffer space we have to receive data.
365    pub fn credit_update(&mut self, connection_info: &ConnectionInfo) -> Result {
366        let header = VirtioVsockHdr {
367            op: VirtioVsockOp::CreditUpdate.into(),
368            ..connection_info.new_header(self.guest_cid)
369        };
370        self.send_packet_to_tx_queue(&header, &[])
371    }
372
373    /// Polls the RX virtqueue for the next event, and calls the given handler function to handle
374    /// it.
375    pub fn poll(
376        &mut self,
377        handler: impl FnOnce(VsockEvent, &[u8]) -> Result<Option<VsockEvent>>,
378    ) -> Result<Option<VsockEvent>> {
379        self.rx.poll(&mut self.transport, |buffer| {
380            let (header, body) = read_header_and_body(buffer)?;
381            VsockEvent::from_header(&header).and_then(|event| handler(event, body))
382        })
383    }
384
385    /// Requests to shut down the connection cleanly, sending hints about whether we will send or
386    /// receive more data.
387    ///
388    /// This returns as soon as the request is sent; you should wait until `poll` returns a
389    /// `VsockEventType::Disconnected` event if you want to know that the peer has acknowledged the
390    /// shutdown.
391    pub fn shutdown_with_hints(
392        &mut self,
393        connection_info: &ConnectionInfo,
394        hints: StreamShutdown,
395    ) -> Result {
396        let header = VirtioVsockHdr {
397            op: VirtioVsockOp::Shutdown.into(),
398            flags: hints.into(),
399            ..connection_info.new_header(self.guest_cid)
400        };
401        self.send_packet_to_tx_queue(&header, &[])
402    }
403
404    /// Requests to shut down the connection cleanly, telling the peer that we won't send or receive
405    /// any more data.
406    ///
407    /// This returns as soon as the request is sent; you should wait until `poll` returns a
408    /// `VsockEventType::Disconnected` event if you want to know that the peer has acknowledged the
409    /// shutdown.
410    pub fn shutdown(&mut self, connection_info: &ConnectionInfo) -> Result {
411        self.shutdown_with_hints(
412            connection_info,
413            StreamShutdown::SEND | StreamShutdown::RECEIVE,
414        )
415    }
416
417    /// Forcibly closes the connection without waiting for the peer.
418    pub fn force_close(&mut self, connection_info: &ConnectionInfo) -> Result {
419        let header = VirtioVsockHdr {
420            op: VirtioVsockOp::Rst.into(),
421            ..connection_info.new_header(self.guest_cid)
422        };
423        self.send_packet_to_tx_queue(&header, &[])?;
424        Ok(())
425    }
426
427    fn send_packet_to_tx_queue(&mut self, header: &VirtioVsockHdr, buffer: &[u8]) -> Result {
428        let _len = if buffer.is_empty() {
429            self.tx
430                .add_notify_wait_pop(&[header.as_bytes()], &mut [], &mut self.transport)?
431        } else {
432            self.tx.add_notify_wait_pop(
433                &[header.as_bytes(), buffer],
434                &mut [],
435                &mut self.transport,
436            )?
437        };
438        Ok(())
439    }
440}
441
442fn read_header_and_body(buffer: &[u8]) -> Result<(VirtioVsockHdr, &[u8])> {
443    // This could fail if the device returns a buffer used length shorter than the header size.
444    let header = VirtioVsockHdr::read_from_prefix(buffer).ok_or(SocketError::BufferTooShort)?;
445    let body_length = header.len() as usize;
446
447    // This could fail if the device returns an unreasonably long body length.
448    let data_end = size_of::<VirtioVsockHdr>()
449        .checked_add(body_length)
450        .ok_or(SocketError::InvalidNumber)?;
451    // This could fail if the device returns a body length longer than buffer used length it
452    // returned.
453    let data = buffer
454        .get(size_of::<VirtioVsockHdr>()..data_end)
455        .ok_or(SocketError::BufferTooShort)?;
456    Ok((header, data))
457}
458
459#[cfg(test)]
460mod tests {
461    use super::*;
462    use crate::{
463        hal::fake::FakeHal,
464        transport::{
465            fake::{FakeTransport, QueueStatus, State},
466            DeviceType,
467        },
468        volatile::ReadOnly,
469    };
470    use alloc::{sync::Arc, vec};
471    use core::ptr::NonNull;
472    use std::sync::Mutex;
473
474    #[test]
475    fn config() {
476        let mut config_space = VirtioVsockConfig {
477            guest_cid_low: ReadOnly::new(66),
478            guest_cid_high: ReadOnly::new(0),
479        };
480        let state = Arc::new(Mutex::new(State {
481            queues: vec![
482                QueueStatus::default(),
483                QueueStatus::default(),
484                QueueStatus::default(),
485            ],
486            ..Default::default()
487        }));
488        let transport = FakeTransport {
489            device_type: DeviceType::Socket,
490            max_queue_size: 32,
491            device_features: 0,
492            config_space: NonNull::from(&mut config_space),
493            state: state.clone(),
494        };
495        let socket =
496            VirtIOSocket::<FakeHal, FakeTransport<VirtioVsockConfig>>::new(transport).unwrap();
497        assert_eq!(socket.guest_cid(), 0x00_0000_0042);
498    }
499}