virtio_drivers/device/socket/
vsock.rs

1//! Driver for VirtIO socket devices.
2#![deny(unsafe_op_in_unsafe_fn)]
3
4use super::error::SocketError;
5use super::protocol::{
6    Feature, StreamShutdown, VirtioVsockConfig, VirtioVsockHdr, VirtioVsockOp, VsockAddr,
7};
8use super::DEFAULT_RX_BUFFER_SIZE;
9use crate::config::read_config;
10use crate::hal::Hal;
11use crate::queue::{owning::OwningQueue, VirtQueue};
12use crate::transport::Transport;
13use crate::Result;
14use core::mem::size_of;
15use log::debug;
16use zerocopy::{FromBytes, IntoBytes};
17
18pub(crate) const RX_QUEUE_IDX: u16 = 0;
19pub(crate) const TX_QUEUE_IDX: u16 = 1;
20const EVENT_QUEUE_IDX: u16 = 2;
21
22pub(crate) const QUEUE_SIZE: usize = 8;
23const SUPPORTED_FEATURES: Feature = Feature::RING_EVENT_IDX
24    .union(Feature::RING_INDIRECT_DESC)
25    .union(Feature::VERSION_1);
26
27/// Information about a particular vsock connection.
28#[derive(Clone, Debug, Default, PartialEq, Eq)]
29pub struct ConnectionInfo {
30    /// The address of the peer.
31    pub dst: VsockAddr,
32    /// The local port number associated with the connection.
33    pub src_port: u32,
34    /// The last `buf_alloc` value the peer sent to us, indicating how much receive buffer space in
35    /// bytes it has allocated for packet bodies.
36    peer_buf_alloc: u32,
37    /// The last `fwd_cnt` value the peer sent to us, indicating how many bytes of packet bodies it
38    /// has finished processing.
39    peer_fwd_cnt: u32,
40    /// The number of bytes of packet bodies which we have sent to the peer.
41    tx_cnt: u32,
42    /// The number of bytes of buffer space we have allocated to receive packet bodies from the
43    /// peer.
44    pub buf_alloc: u32,
45    /// The number of bytes of packet bodies which we have received from the peer and handled.
46    fwd_cnt: u32,
47    /// Whether we have recently requested credit from the peer.
48    ///
49    /// This is set to true when we send a `VIRTIO_VSOCK_OP_CREDIT_REQUEST`, and false when we
50    /// receive a `VIRTIO_VSOCK_OP_CREDIT_UPDATE`.
51    has_pending_credit_request: bool,
52}
53
54impl ConnectionInfo {
55    /// Creates a new `ConnectionInfo` for the given peer address and local port, and default values
56    /// for everything else.
57    pub fn new(destination: VsockAddr, src_port: u32) -> Self {
58        Self {
59            dst: destination,
60            src_port,
61            ..Default::default()
62        }
63    }
64
65    /// Updates this connection info with the peer buffer allocation and forwarded count from the
66    /// given event.
67    pub fn update_for_event(&mut self, event: &VsockEvent) {
68        self.peer_buf_alloc = event.buffer_status.buffer_allocation;
69        self.peer_fwd_cnt = event.buffer_status.forward_count;
70
71        if let VsockEventType::CreditUpdate = event.event_type {
72            self.has_pending_credit_request = false;
73        }
74    }
75
76    /// Increases the forwarded count recorded for this connection by the given number of bytes.
77    ///
78    /// This should be called once received data has been passed to the client, so there is buffer
79    /// space available for more.
80    pub fn done_forwarding(&mut self, length: usize) {
81        self.fwd_cnt += length as u32;
82    }
83
84    /// Returns the number of bytes of RX buffer space the peer has available to receive packet body
85    /// data from us.
86    fn peer_free(&self) -> u32 {
87        self.peer_buf_alloc - (self.tx_cnt - self.peer_fwd_cnt)
88    }
89
90    fn new_header(&self, src_cid: u64) -> VirtioVsockHdr {
91        VirtioVsockHdr {
92            src_cid: src_cid.into(),
93            dst_cid: self.dst.cid.into(),
94            src_port: self.src_port.into(),
95            dst_port: self.dst.port.into(),
96            buf_alloc: self.buf_alloc.into(),
97            fwd_cnt: self.fwd_cnt.into(),
98            ..Default::default()
99        }
100    }
101}
102
103/// An event received from a VirtIO socket device.
104#[derive(Clone, Debug, Eq, PartialEq)]
105pub struct VsockEvent {
106    /// The source of the event, i.e. the peer who sent it.
107    pub source: VsockAddr,
108    /// The destination of the event, i.e. the CID and port on our side.
109    pub destination: VsockAddr,
110    /// The peer's buffer status for the connection.
111    pub buffer_status: VsockBufferStatus,
112    /// The type of event.
113    pub event_type: VsockEventType,
114}
115
116impl VsockEvent {
117    /// Returns whether the event matches the given connection.
118    pub fn matches_connection(&self, connection_info: &ConnectionInfo, guest_cid: u64) -> bool {
119        self.source == connection_info.dst
120            && self.destination.cid == guest_cid
121            && self.destination.port == connection_info.src_port
122    }
123
124    fn from_header(header: &VirtioVsockHdr) -> Result<Self> {
125        let op = header.op()?;
126        let buffer_status = VsockBufferStatus {
127            buffer_allocation: header.buf_alloc.into(),
128            forward_count: header.fwd_cnt.into(),
129        };
130        let source = header.source();
131        let destination = header.destination();
132
133        let event_type = match op {
134            VirtioVsockOp::Request => {
135                header.check_data_is_empty()?;
136                VsockEventType::ConnectionRequest
137            }
138            VirtioVsockOp::Response => {
139                header.check_data_is_empty()?;
140                VsockEventType::Connected
141            }
142            VirtioVsockOp::CreditUpdate => {
143                header.check_data_is_empty()?;
144                VsockEventType::CreditUpdate
145            }
146            VirtioVsockOp::Rst | VirtioVsockOp::Shutdown => {
147                header.check_data_is_empty()?;
148                debug!("Disconnected from the peer");
149                let reason = if op == VirtioVsockOp::Rst {
150                    DisconnectReason::Reset
151                } else {
152                    DisconnectReason::Shutdown
153                };
154                VsockEventType::Disconnected { reason }
155            }
156            VirtioVsockOp::Rw => VsockEventType::Received {
157                length: header.len() as usize,
158            },
159            VirtioVsockOp::CreditRequest => {
160                header.check_data_is_empty()?;
161                VsockEventType::CreditRequest
162            }
163            VirtioVsockOp::Invalid => return Err(SocketError::InvalidOperation.into()),
164        };
165
166        Ok(VsockEvent {
167            source,
168            destination,
169            buffer_status,
170            event_type,
171        })
172    }
173}
174
175#[derive(Clone, Debug, Eq, PartialEq)]
176pub struct VsockBufferStatus {
177    pub buffer_allocation: u32,
178    pub forward_count: u32,
179}
180
181/// The reason why a vsock connection was closed.
182#[derive(Copy, Clone, Debug, Eq, PartialEq)]
183pub enum DisconnectReason {
184    /// The peer has either closed the connection in response to our shutdown request, or forcibly
185    /// closed it of its own accord.
186    Reset,
187    /// The peer asked to shut down the connection.
188    Shutdown,
189}
190
191/// Details of the type of an event received from a VirtIO socket.
192#[derive(Clone, Debug, Eq, PartialEq)]
193pub enum VsockEventType {
194    /// The peer requests to establish a connection with us.
195    ConnectionRequest,
196    /// The connection was successfully established.
197    Connected,
198    /// The connection was closed.
199    Disconnected {
200        /// The reason for the disconnection.
201        reason: DisconnectReason,
202    },
203    /// Data was received on the connection.
204    Received {
205        /// The length of the data in bytes.
206        length: usize,
207    },
208    /// The peer requests us to send a credit update.
209    CreditRequest,
210    /// The peer just sent us a credit update with nothing else.
211    CreditUpdate,
212}
213
214/// Low-level driver for a VirtIO socket device.
215///
216/// You probably want to use [`VsockConnectionManager`](super::VsockConnectionManager) rather than
217/// using this directly.
218///
219/// `RX_BUFFER_SIZE` is the size in bytes of each buffer used in the RX virtqueue. This must be
220/// bigger than `size_of::<VirtioVsockHdr>()`.
221pub struct VirtIOSocket<H: Hal, T: Transport, const RX_BUFFER_SIZE: usize = DEFAULT_RX_BUFFER_SIZE>
222{
223    transport: T,
224    /// Virtqueue to receive packets.
225    rx: OwningQueue<H, QUEUE_SIZE, RX_BUFFER_SIZE>,
226    tx: VirtQueue<H, { QUEUE_SIZE }>,
227    /// Virtqueue to receive events from the device.
228    event: VirtQueue<H, { QUEUE_SIZE }>,
229    /// The guest_cid field contains the guest’s context ID, which uniquely identifies
230    /// the device for its lifetime. The upper 32 bits of the CID are reserved and zeroed.
231    guest_cid: u64,
232}
233
234impl<H: Hal, T: Transport, const RX_BUFFER_SIZE: usize> Drop
235    for VirtIOSocket<H, T, RX_BUFFER_SIZE>
236{
237    fn drop(&mut self) {
238        // Clear any pointers pointing to DMA regions, so the device doesn't try to access them
239        // after they have been freed.
240        self.transport.queue_unset(RX_QUEUE_IDX);
241        self.transport.queue_unset(TX_QUEUE_IDX);
242        self.transport.queue_unset(EVENT_QUEUE_IDX);
243    }
244}
245
246impl<H: Hal, T: Transport, const RX_BUFFER_SIZE: usize> VirtIOSocket<H, T, RX_BUFFER_SIZE> {
247    /// Create a new VirtIO Vsock driver.
248    pub fn new(mut transport: T) -> Result<Self> {
249        assert!(RX_BUFFER_SIZE > size_of::<VirtioVsockHdr>());
250
251        let negotiated_features = transport.begin_init(SUPPORTED_FEATURES);
252
253        let guest_cid = transport.read_consistent(|| {
254            Ok(
255                (read_config!(transport, VirtioVsockConfig, guest_cid_low)? as u64)
256                    | ((read_config!(transport, VirtioVsockConfig, guest_cid_high)? as u64) << 32),
257            )
258        })?;
259        debug!("guest cid: {guest_cid:?}");
260
261        let rx = VirtQueue::new(
262            &mut transport,
263            RX_QUEUE_IDX,
264            negotiated_features.contains(Feature::RING_INDIRECT_DESC),
265            negotiated_features.contains(Feature::RING_EVENT_IDX),
266        )?;
267        let tx = VirtQueue::new(
268            &mut transport,
269            TX_QUEUE_IDX,
270            negotiated_features.contains(Feature::RING_INDIRECT_DESC),
271            negotiated_features.contains(Feature::RING_EVENT_IDX),
272        )?;
273        let event = VirtQueue::new(
274            &mut transport,
275            EVENT_QUEUE_IDX,
276            negotiated_features.contains(Feature::RING_INDIRECT_DESC),
277            negotiated_features.contains(Feature::RING_EVENT_IDX),
278        )?;
279
280        let rx = OwningQueue::new(rx)?;
281
282        transport.finish_init();
283        if rx.should_notify() {
284            transport.notify(RX_QUEUE_IDX);
285        }
286
287        Ok(Self {
288            transport,
289            rx,
290            tx,
291            event,
292            guest_cid,
293        })
294    }
295
296    /// Returns the CID which has been assigned to this guest.
297    pub fn guest_cid(&self) -> u64 {
298        self.guest_cid
299    }
300
301    /// Sends a request to connect to the given destination.
302    ///
303    /// This returns as soon as the request is sent; you should wait until `poll` returns a
304    /// `VsockEventType::Connected` event indicating that the peer has accepted the connection
305    /// before sending data.
306    pub fn connect(&mut self, connection_info: &ConnectionInfo) -> Result {
307        let header = VirtioVsockHdr {
308            op: VirtioVsockOp::Request.into(),
309            ..connection_info.new_header(self.guest_cid)
310        };
311        // Sends a header only packet to the TX queue to connect the device to the listening socket
312        // at the given destination.
313        self.send_packet_to_tx_queue(&header, &[])
314    }
315
316    /// Accepts the given connection from a peer.
317    pub fn accept(&mut self, connection_info: &ConnectionInfo) -> Result {
318        let header = VirtioVsockHdr {
319            op: VirtioVsockOp::Response.into(),
320            ..connection_info.new_header(self.guest_cid)
321        };
322        self.send_packet_to_tx_queue(&header, &[])
323    }
324
325    /// Requests the peer to send us a credit update for the given connection.
326    fn request_credit(&mut self, connection_info: &ConnectionInfo) -> Result {
327        let header = VirtioVsockHdr {
328            op: VirtioVsockOp::CreditRequest.into(),
329            ..connection_info.new_header(self.guest_cid)
330        };
331        self.send_packet_to_tx_queue(&header, &[])
332    }
333
334    /// Sends the buffer to the destination.
335    pub fn send(&mut self, buffer: &[u8], connection_info: &mut ConnectionInfo) -> Result {
336        self.check_peer_buffer_is_sufficient(connection_info, buffer.len())?;
337
338        let len = buffer.len() as u32;
339        let header = VirtioVsockHdr {
340            op: VirtioVsockOp::Rw.into(),
341            len: len.into(),
342            ..connection_info.new_header(self.guest_cid)
343        };
344        connection_info.tx_cnt += len;
345        self.send_packet_to_tx_queue(&header, buffer)
346    }
347
348    fn check_peer_buffer_is_sufficient(
349        &mut self,
350        connection_info: &mut ConnectionInfo,
351        buffer_len: usize,
352    ) -> Result {
353        if connection_info.peer_free() as usize >= buffer_len {
354            Ok(())
355        } else {
356            // Request an update of the cached peer credit, if we haven't already done so, and tell
357            // the caller to try again later.
358            if !connection_info.has_pending_credit_request {
359                self.request_credit(connection_info)?;
360                connection_info.has_pending_credit_request = true;
361            }
362            Err(SocketError::InsufficientBufferSpaceInPeer.into())
363        }
364    }
365
366    /// Tells the peer how much buffer space we have to receive data.
367    pub fn credit_update(&mut self, connection_info: &ConnectionInfo) -> Result {
368        let header = VirtioVsockHdr {
369            op: VirtioVsockOp::CreditUpdate.into(),
370            ..connection_info.new_header(self.guest_cid)
371        };
372        self.send_packet_to_tx_queue(&header, &[])
373    }
374
375    /// Polls the RX virtqueue for the next event, and calls the given handler function to handle
376    /// it.
377    pub fn poll(
378        &mut self,
379        handler: impl FnOnce(VsockEvent, &[u8]) -> Result<Option<VsockEvent>>,
380    ) -> Result<Option<VsockEvent>> {
381        self.rx.poll(&mut self.transport, |buffer| {
382            let (header, body) = read_header_and_body(buffer)?;
383            VsockEvent::from_header(&header).and_then(|event| handler(event, body))
384        })
385    }
386
387    /// Requests to shut down the connection cleanly, sending hints about whether we will send or
388    /// receive more data.
389    ///
390    /// This returns as soon as the request is sent; you should wait until `poll` returns a
391    /// `VsockEventType::Disconnected` event if you want to know that the peer has acknowledged the
392    /// shutdown.
393    pub fn shutdown_with_hints(
394        &mut self,
395        connection_info: &ConnectionInfo,
396        hints: StreamShutdown,
397    ) -> Result {
398        let header = VirtioVsockHdr {
399            op: VirtioVsockOp::Shutdown.into(),
400            flags: hints.into(),
401            ..connection_info.new_header(self.guest_cid)
402        };
403        self.send_packet_to_tx_queue(&header, &[])
404    }
405
406    /// Requests to shut down the connection cleanly, telling the peer that we won't send or receive
407    /// any more data.
408    ///
409    /// This returns as soon as the request is sent; you should wait until `poll` returns a
410    /// `VsockEventType::Disconnected` event if you want to know that the peer has acknowledged the
411    /// shutdown.
412    pub fn shutdown(&mut self, connection_info: &ConnectionInfo) -> Result {
413        self.shutdown_with_hints(
414            connection_info,
415            StreamShutdown::SEND | StreamShutdown::RECEIVE,
416        )
417    }
418
419    /// Forcibly closes the connection without waiting for the peer.
420    pub fn force_close(&mut self, connection_info: &ConnectionInfo) -> Result {
421        let header = VirtioVsockHdr {
422            op: VirtioVsockOp::Rst.into(),
423            ..connection_info.new_header(self.guest_cid)
424        };
425        self.send_packet_to_tx_queue(&header, &[])?;
426        Ok(())
427    }
428
429    fn send_packet_to_tx_queue(&mut self, header: &VirtioVsockHdr, buffer: &[u8]) -> Result {
430        let _len = if buffer.is_empty() {
431            self.tx
432                .add_notify_wait_pop(&[header.as_bytes()], &mut [], &mut self.transport)?
433        } else {
434            self.tx.add_notify_wait_pop(
435                &[header.as_bytes(), buffer],
436                &mut [],
437                &mut self.transport,
438            )?
439        };
440        Ok(())
441    }
442}
443
444fn read_header_and_body(buffer: &[u8]) -> Result<(VirtioVsockHdr, &[u8])> {
445    // This could fail if the device returns a buffer used length shorter than the header size.
446    let header = VirtioVsockHdr::read_from_prefix(buffer)
447        .map_err(|_| SocketError::BufferTooShort)?
448        .0;
449    let body_length = header.len() as usize;
450
451    // This could fail if the device returns an unreasonably long body length.
452    let data_end = size_of::<VirtioVsockHdr>()
453        .checked_add(body_length)
454        .ok_or(SocketError::InvalidNumber)?;
455    // This could fail if the device returns a body length longer than buffer used length it
456    // returned.
457    let data = buffer
458        .get(size_of::<VirtioVsockHdr>()..data_end)
459        .ok_or(SocketError::BufferTooShort)?;
460    Ok((header, data))
461}
462
463#[cfg(test)]
464mod tests {
465    use super::*;
466    use crate::{
467        config::ReadOnly,
468        hal::fake::FakeHal,
469        transport::{
470            fake::{FakeTransport, QueueStatus, State},
471            DeviceType,
472        },
473    };
474    use alloc::{sync::Arc, vec};
475    use std::sync::Mutex;
476
477    #[test]
478    fn config() {
479        let config_space = VirtioVsockConfig {
480            guest_cid_low: ReadOnly::new(66),
481            guest_cid_high: ReadOnly::new(0),
482        };
483        let state = Arc::new(Mutex::new(State::new(
484            vec![
485                QueueStatus::default(),
486                QueueStatus::default(),
487                QueueStatus::default(),
488            ],
489            config_space,
490        )));
491        let transport = FakeTransport {
492            device_type: DeviceType::Socket,
493            max_queue_size: 32,
494            device_features: 0,
495            state: state.clone(),
496        };
497        let socket =
498            VirtIOSocket::<FakeHal, FakeTransport<VirtioVsockConfig>>::new(transport).unwrap();
499        assert_eq!(socket.guest_cid(), 0x00_0000_0042);
500    }
501}