virtio_drivers/
queue.rs

1#![deny(unsafe_op_in_unsafe_fn)]
2
3#[cfg(feature = "alloc")]
4pub mod owning;
5
6use crate::hal::{BufferDirection, Dma, Hal, PhysAddr};
7use crate::transport::Transport;
8use crate::{align_up, nonnull_slice_from_raw_parts, pages, Error, Result, PAGE_SIZE};
9#[cfg(feature = "alloc")]
10use alloc::boxed::Box;
11use bitflags::bitflags;
12#[cfg(test)]
13use core::cmp::min;
14use core::convert::TryInto;
15use core::hint::spin_loop;
16use core::mem::{size_of, take};
17#[cfg(test)]
18use core::ptr;
19use core::ptr::NonNull;
20use core::sync::atomic::{fence, AtomicU16, Ordering};
21use zerocopy::{AsBytes, FromBytes, FromZeroes};
22
23/// The mechanism for bulk data transport on virtio devices.
24///
25/// Each device can have zero or more virtqueues.
26///
27/// * `SIZE`: The size of the queue. This is both the number of descriptors, and the number of slots
28///   in the available and used rings. It must be a power of 2 and fit in a [`u16`].
29#[derive(Debug)]
30pub struct VirtQueue<H: Hal, const SIZE: usize> {
31    /// DMA guard
32    layout: VirtQueueLayout<H>,
33    /// Descriptor table
34    ///
35    /// The device may be able to modify this, even though it's not supposed to, so we shouldn't
36    /// trust values read back from it. Use `desc_shadow` instead to keep track of what we wrote to
37    /// it.
38    desc: NonNull<[Descriptor]>,
39    /// Available ring
40    ///
41    /// The device may be able to modify this, even though it's not supposed to, so we shouldn't
42    /// trust values read back from it. The only field we need to read currently is `idx`, so we
43    /// have `avail_idx` below to use instead.
44    avail: NonNull<AvailRing<SIZE>>,
45    /// Used ring
46    used: NonNull<UsedRing<SIZE>>,
47
48    /// The index of queue
49    queue_idx: u16,
50    /// The number of descriptors currently in use.
51    num_used: u16,
52    /// The head desc index of the free list.
53    free_head: u16,
54    /// Our trusted copy of `desc` that the device can't access.
55    desc_shadow: [Descriptor; SIZE],
56    /// Our trusted copy of `avail.idx`.
57    avail_idx: u16,
58    last_used_idx: u16,
59    /// Whether the `VIRTIO_F_EVENT_IDX` feature has been negotiated.
60    event_idx: bool,
61    #[cfg(feature = "alloc")]
62    indirect: bool,
63    #[cfg(feature = "alloc")]
64    indirect_lists: [Option<NonNull<[Descriptor]>>; SIZE],
65}
66
67impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> {
68    const SIZE_OK: () = assert!(SIZE.is_power_of_two() && SIZE <= u16::MAX as usize);
69
70    /// Creates a new VirtQueue.
71    ///
72    /// * `indirect`: Whether to use indirect descriptors. This should be set if the
73    ///   `VIRTIO_F_INDIRECT_DESC` feature has been negotiated with the device.
74    /// * `event_idx`: Whether to use the `used_event` and `avail_event` fields for notification
75    ///   suppression. This should be set if the `VIRTIO_F_EVENT_IDX` feature has been negotiated
76    ///   with the device.
77    pub fn new<T: Transport>(
78        transport: &mut T,
79        idx: u16,
80        indirect: bool,
81        event_idx: bool,
82    ) -> Result<Self> {
83        #[allow(clippy::let_unit_value)]
84        let _ = Self::SIZE_OK;
85
86        if transport.queue_used(idx) {
87            return Err(Error::AlreadyUsed);
88        }
89        if transport.max_queue_size(idx) < SIZE as u32 {
90            return Err(Error::InvalidParam);
91        }
92        let size = SIZE as u16;
93
94        let layout = if transport.requires_legacy_layout() {
95            VirtQueueLayout::allocate_legacy(size)?
96        } else {
97            VirtQueueLayout::allocate_flexible(size)?
98        };
99
100        transport.queue_set(
101            idx,
102            size.into(),
103            layout.descriptors_paddr(),
104            layout.driver_area_paddr(),
105            layout.device_area_paddr(),
106        );
107
108        let desc =
109            nonnull_slice_from_raw_parts(layout.descriptors_vaddr().cast::<Descriptor>(), SIZE);
110        let avail = layout.avail_vaddr().cast();
111        let used = layout.used_vaddr().cast();
112
113        let mut desc_shadow: [Descriptor; SIZE] = FromZeroes::new_zeroed();
114        // Link descriptors together.
115        for i in 0..(size - 1) {
116            desc_shadow[i as usize].next = i + 1;
117            // Safe because `desc` is properly aligned, dereferenceable, initialised, and the device
118            // won't access the descriptors for the duration of this unsafe block.
119            unsafe {
120                (*desc.as_ptr())[i as usize].next = i + 1;
121            }
122        }
123
124        #[cfg(feature = "alloc")]
125        const NONE: Option<NonNull<[Descriptor]>> = None;
126        Ok(VirtQueue {
127            layout,
128            desc,
129            avail,
130            used,
131            queue_idx: idx,
132            num_used: 0,
133            free_head: 0,
134            desc_shadow,
135            avail_idx: 0,
136            last_used_idx: 0,
137            event_idx,
138            #[cfg(feature = "alloc")]
139            indirect,
140            #[cfg(feature = "alloc")]
141            indirect_lists: [NONE; SIZE],
142        })
143    }
144
145    /// Add buffers to the virtqueue, return a token.
146    ///
147    /// The buffers must not be empty.
148    ///
149    /// Ref: linux virtio_ring.c virtqueue_add
150    ///
151    /// # Safety
152    ///
153    /// The input and output buffers must remain valid and not be accessed until a call to
154    /// `pop_used` with the returned token succeeds.
155    pub unsafe fn add<'a, 'b>(
156        &mut self,
157        inputs: &'a [&'b [u8]],
158        outputs: &'a mut [&'b mut [u8]],
159    ) -> Result<u16> {
160        if inputs.is_empty() && outputs.is_empty() {
161            return Err(Error::InvalidParam);
162        }
163        let descriptors_needed = inputs.len() + outputs.len();
164        // Only consider indirect descriptors if the alloc feature is enabled, as they require
165        // allocation.
166        #[cfg(feature = "alloc")]
167        if self.num_used as usize + 1 > SIZE
168            || descriptors_needed > SIZE
169            || (!self.indirect && self.num_used as usize + descriptors_needed > SIZE)
170        {
171            return Err(Error::QueueFull);
172        }
173        #[cfg(not(feature = "alloc"))]
174        if self.num_used as usize + descriptors_needed > SIZE {
175            return Err(Error::QueueFull);
176        }
177
178        #[cfg(feature = "alloc")]
179        let head = if self.indirect && descriptors_needed > 1 {
180            self.add_indirect(inputs, outputs)
181        } else {
182            self.add_direct(inputs, outputs)
183        };
184        #[cfg(not(feature = "alloc"))]
185        let head = self.add_direct(inputs, outputs);
186
187        let avail_slot = self.avail_idx & (SIZE as u16 - 1);
188        // Safe because self.avail is properly aligned, dereferenceable and initialised.
189        unsafe {
190            (*self.avail.as_ptr()).ring[avail_slot as usize] = head;
191        }
192
193        // Write barrier so that device sees changes to descriptor table and available ring before
194        // change to available index.
195        fence(Ordering::SeqCst);
196
197        // increase head of avail ring
198        self.avail_idx = self.avail_idx.wrapping_add(1);
199        // Safe because self.avail is properly aligned, dereferenceable and initialised.
200        unsafe {
201            (*self.avail.as_ptr())
202                .idx
203                .store(self.avail_idx, Ordering::Release);
204        }
205
206        Ok(head)
207    }
208
209    fn add_direct<'a, 'b>(
210        &mut self,
211        inputs: &'a [&'b [u8]],
212        outputs: &'a mut [&'b mut [u8]],
213    ) -> u16 {
214        // allocate descriptors from free list
215        let head = self.free_head;
216        let mut last = self.free_head;
217
218        for (buffer, direction) in InputOutputIter::new(inputs, outputs) {
219            assert_ne!(buffer.len(), 0);
220
221            // Write to desc_shadow then copy.
222            let desc = &mut self.desc_shadow[usize::from(self.free_head)];
223            // Safe because our caller promises that the buffers live at least until `pop_used`
224            // returns them.
225            unsafe {
226                desc.set_buf::<H>(buffer, direction, DescFlags::NEXT);
227            }
228            last = self.free_head;
229            self.free_head = desc.next;
230
231            self.write_desc(last);
232        }
233
234        // set last_elem.next = NULL
235        self.desc_shadow[usize::from(last)]
236            .flags
237            .remove(DescFlags::NEXT);
238        self.write_desc(last);
239
240        self.num_used += (inputs.len() + outputs.len()) as u16;
241
242        head
243    }
244
245    #[cfg(feature = "alloc")]
246    fn add_indirect<'a, 'b>(
247        &mut self,
248        inputs: &'a [&'b [u8]],
249        outputs: &'a mut [&'b mut [u8]],
250    ) -> u16 {
251        let head = self.free_head;
252
253        // Allocate and fill in indirect descriptor list.
254        let mut indirect_list = Descriptor::new_box_slice_zeroed(inputs.len() + outputs.len());
255        for (i, (buffer, direction)) in InputOutputIter::new(inputs, outputs).enumerate() {
256            let desc = &mut indirect_list[i];
257            // Safe because our caller promises that the buffers live at least until `pop_used`
258            // returns them.
259            unsafe {
260                desc.set_buf::<H>(buffer, direction, DescFlags::NEXT);
261            }
262            desc.next = (i + 1) as u16;
263        }
264        indirect_list
265            .last_mut()
266            .unwrap()
267            .flags
268            .remove(DescFlags::NEXT);
269
270        // Need to store pointer to indirect_list too, because direct_desc.set_buf will only store
271        // the physical DMA address which might be different.
272        assert!(self.indirect_lists[usize::from(head)].is_none());
273        self.indirect_lists[usize::from(head)] = Some(indirect_list.as_mut().into());
274
275        // Write a descriptor pointing to indirect descriptor list. We use Box::leak to prevent the
276        // indirect list from being freed when this function returns; recycle_descriptors is instead
277        // responsible for freeing the memory after the buffer chain is popped.
278        let direct_desc = &mut self.desc_shadow[usize::from(head)];
279        self.free_head = direct_desc.next;
280        unsafe {
281            direct_desc.set_buf::<H>(
282                Box::leak(indirect_list).as_bytes().into(),
283                BufferDirection::DriverToDevice,
284                DescFlags::INDIRECT,
285            );
286        }
287        self.write_desc(head);
288        self.num_used += 1;
289
290        head
291    }
292
293    /// Add the given buffers to the virtqueue, notifies the device, blocks until the device uses
294    /// them, then pops them.
295    ///
296    /// This assumes that the device isn't processing any other buffers at the same time.
297    ///
298    /// The buffers must not be empty.
299    pub fn add_notify_wait_pop<'a>(
300        &mut self,
301        inputs: &'a [&'a [u8]],
302        outputs: &'a mut [&'a mut [u8]],
303        transport: &mut impl Transport,
304    ) -> Result<u32> {
305        // Safe because we don't return until the same token has been popped, so the buffers remain
306        // valid and are not otherwise accessed until then.
307        let token = unsafe { self.add(inputs, outputs) }?;
308
309        // Notify the queue.
310        if self.should_notify() {
311            transport.notify(self.queue_idx);
312        }
313
314        // Wait until there is at least one element in the used ring.
315        while !self.can_pop() {
316            spin_loop();
317        }
318
319        // Safe because these are the same buffers as we passed to `add` above and they are still
320        // valid.
321        unsafe { self.pop_used(token, inputs, outputs) }
322    }
323
324    /// Advise the device whether used buffer notifications are needed.
325    ///
326    /// See Virtio v1.1 2.6.7 Used Buffer Notification Suppression
327    pub fn set_dev_notify(&mut self, enable: bool) {
328        let avail_ring_flags = if enable { 0x0000 } else { 0x0001 };
329        if !self.event_idx {
330            // Safe because self.avail points to a valid, aligned, initialised, dereferenceable, readable
331            // instance of AvailRing.
332            unsafe {
333                (*self.avail.as_ptr())
334                    .flags
335                    .store(avail_ring_flags, Ordering::Release)
336            }
337        }
338    }
339
340    /// Returns whether the driver should notify the device after adding a new buffer to the
341    /// virtqueue.
342    ///
343    /// This will be false if the device has supressed notifications.
344    pub fn should_notify(&self) -> bool {
345        if self.event_idx {
346            // Safe because self.used points to a valid, aligned, initialised, dereferenceable, readable
347            // instance of UsedRing.
348            let avail_event = unsafe { (*self.used.as_ptr()).avail_event.load(Ordering::Acquire) };
349            self.avail_idx >= avail_event.wrapping_add(1)
350        } else {
351            // Safe because self.used points to a valid, aligned, initialised, dereferenceable, readable
352            // instance of UsedRing.
353            unsafe { (*self.used.as_ptr()).flags.load(Ordering::Acquire) & 0x0001 == 0 }
354        }
355    }
356
357    /// Copies the descriptor at the given index from `desc_shadow` to `desc`, so it can be seen by
358    /// the device.
359    fn write_desc(&mut self, index: u16) {
360        let index = usize::from(index);
361        // Safe because self.desc is properly aligned, dereferenceable and initialised, and nothing
362        // else reads or writes the descriptor during this block.
363        unsafe {
364            (*self.desc.as_ptr())[index] = self.desc_shadow[index].clone();
365        }
366    }
367
368    /// Returns whether there is a used element that can be popped.
369    pub fn can_pop(&self) -> bool {
370        // Safe because self.used points to a valid, aligned, initialised, dereferenceable, readable
371        // instance of UsedRing.
372        self.last_used_idx != unsafe { (*self.used.as_ptr()).idx.load(Ordering::Acquire) }
373    }
374
375    /// Returns the descriptor index (a.k.a. token) of the next used element without popping it, or
376    /// `None` if the used ring is empty.
377    pub fn peek_used(&self) -> Option<u16> {
378        if self.can_pop() {
379            let last_used_slot = self.last_used_idx & (SIZE as u16 - 1);
380            // Safe because self.used points to a valid, aligned, initialised, dereferenceable,
381            // readable instance of UsedRing.
382            Some(unsafe { (*self.used.as_ptr()).ring[last_used_slot as usize].id as u16 })
383        } else {
384            None
385        }
386    }
387
388    /// Returns the number of free descriptors.
389    pub fn available_desc(&self) -> usize {
390        #[cfg(feature = "alloc")]
391        if self.indirect {
392            return if usize::from(self.num_used) == SIZE {
393                0
394            } else {
395                SIZE
396            };
397        }
398
399        SIZE - usize::from(self.num_used)
400    }
401
402    /// Unshares buffers in the list starting at descriptor index `head` and adds them to the free
403    /// list. Unsharing may involve copying data back to the original buffers, so they must be
404    /// passed in too.
405    ///
406    /// This will push all linked descriptors at the front of the free list.
407    ///
408    /// # Safety
409    ///
410    /// The buffers in `inputs` and `outputs` must match the set of buffers originally added to the
411    /// queue by `add`.
412    unsafe fn recycle_descriptors<'a>(
413        &mut self,
414        head: u16,
415        inputs: &'a [&'a [u8]],
416        outputs: &'a mut [&'a mut [u8]],
417    ) {
418        let original_free_head = self.free_head;
419        self.free_head = head;
420
421        let head_desc = &mut self.desc_shadow[usize::from(head)];
422        if head_desc.flags.contains(DescFlags::INDIRECT) {
423            #[cfg(feature = "alloc")]
424            {
425                // Find the indirect descriptor list, unshare it and move its descriptor to the free
426                // list.
427                let indirect_list = self.indirect_lists[usize::from(head)].take().unwrap();
428                // SAFETY: We allocated the indirect list in `add_indirect`, and the device has
429                // finished accessing it by this point.
430                let mut indirect_list = unsafe { Box::from_raw(indirect_list.as_ptr()) };
431                let paddr = head_desc.addr;
432                head_desc.unset_buf();
433                self.num_used -= 1;
434                head_desc.next = original_free_head;
435
436                unsafe {
437                    H::unshare(
438                        paddr as usize,
439                        indirect_list.as_bytes_mut().into(),
440                        BufferDirection::DriverToDevice,
441                    );
442                }
443
444                // Unshare the buffers in the indirect descriptor list, and free it.
445                assert_eq!(indirect_list.len(), inputs.len() + outputs.len());
446                for (i, (buffer, direction)) in InputOutputIter::new(inputs, outputs).enumerate() {
447                    assert_ne!(buffer.len(), 0);
448
449                    // SAFETY: The caller ensures that the buffer is valid and matches the
450                    // descriptor from which we got `paddr`.
451                    unsafe {
452                        // Unshare the buffer (and perhaps copy its contents back to the original
453                        // buffer).
454                        H::unshare(indirect_list[i].addr as usize, buffer, direction);
455                    }
456                }
457                drop(indirect_list);
458            }
459        } else {
460            let mut next = Some(head);
461
462            for (buffer, direction) in InputOutputIter::new(inputs, outputs) {
463                assert_ne!(buffer.len(), 0);
464
465                let desc_index = next.expect("Descriptor chain was shorter than expected.");
466                let desc = &mut self.desc_shadow[usize::from(desc_index)];
467
468                let paddr = desc.addr;
469                desc.unset_buf();
470                self.num_used -= 1;
471                next = desc.next();
472                if next.is_none() {
473                    desc.next = original_free_head;
474                }
475
476                self.write_desc(desc_index);
477
478                // SAFETY: The caller ensures that the buffer is valid and matches the descriptor
479                // from which we got `paddr`.
480                unsafe {
481                    // Unshare the buffer (and perhaps copy its contents back to the original buffer).
482                    H::unshare(paddr as usize, buffer, direction);
483                }
484            }
485
486            if next.is_some() {
487                panic!("Descriptor chain was longer than expected.");
488            }
489        }
490    }
491
492    /// If the given token is next on the device used queue, pops it and returns the total buffer
493    /// length which was used (written) by the device.
494    ///
495    /// Ref: linux virtio_ring.c virtqueue_get_buf_ctx
496    ///
497    /// # Safety
498    ///
499    /// The buffers in `inputs` and `outputs` must match the set of buffers originally added to the
500    /// queue by `add` when it returned the token being passed in here.
501    pub unsafe fn pop_used<'a>(
502        &mut self,
503        token: u16,
504        inputs: &'a [&'a [u8]],
505        outputs: &'a mut [&'a mut [u8]],
506    ) -> Result<u32> {
507        if !self.can_pop() {
508            return Err(Error::NotReady);
509        }
510
511        // Get the index of the start of the descriptor chain for the next element in the used ring.
512        let last_used_slot = self.last_used_idx & (SIZE as u16 - 1);
513        let index;
514        let len;
515        // Safe because self.used points to a valid, aligned, initialised, dereferenceable, readable
516        // instance of UsedRing.
517        unsafe {
518            index = (*self.used.as_ptr()).ring[last_used_slot as usize].id as u16;
519            len = (*self.used.as_ptr()).ring[last_used_slot as usize].len;
520        }
521
522        if index != token {
523            // The device used a different descriptor chain to the one we were expecting.
524            return Err(Error::WrongToken);
525        }
526
527        // Safe because the caller ensures the buffers are valid and match the descriptor.
528        unsafe {
529            self.recycle_descriptors(index, inputs, outputs);
530        }
531        self.last_used_idx = self.last_used_idx.wrapping_add(1);
532
533        if self.event_idx {
534            unsafe {
535                (*self.avail.as_ptr())
536                    .used_event
537                    .store(self.last_used_idx, Ordering::Release);
538            }
539        }
540
541        Ok(len)
542    }
543}
544
545// SAFETY: None of the virt queue resources are tied to a particular thread.
546unsafe impl<H: Hal, const SIZE: usize> Send for VirtQueue<H, SIZE> {}
547
548// SAFETY: A `&VirtQueue` only allows reading from the various pointers it contains, so there is no
549// data race.
550unsafe impl<H: Hal, const SIZE: usize> Sync for VirtQueue<H, SIZE> {}
551
552/// The inner layout of a VirtQueue.
553///
554/// Ref: 2.6 Split Virtqueues
555#[derive(Debug)]
556enum VirtQueueLayout<H: Hal> {
557    Legacy {
558        dma: Dma<H>,
559        avail_offset: usize,
560        used_offset: usize,
561    },
562    Modern {
563        /// The region used for the descriptor area and driver area.
564        driver_to_device_dma: Dma<H>,
565        /// The region used for the device area.
566        device_to_driver_dma: Dma<H>,
567        /// The offset from the start of the `driver_to_device_dma` region to the driver area
568        /// (available ring).
569        avail_offset: usize,
570    },
571}
572
573impl<H: Hal> VirtQueueLayout<H> {
574    /// Allocates a single DMA region containing all parts of the virtqueue, following the layout
575    /// required by legacy interfaces.
576    ///
577    /// Ref: 2.6.2 Legacy Interfaces: A Note on Virtqueue Layout
578    fn allocate_legacy(queue_size: u16) -> Result<Self> {
579        let (desc, avail, used) = queue_part_sizes(queue_size);
580        let size = align_up(desc + avail) + align_up(used);
581        // Allocate contiguous pages.
582        let dma = Dma::new(size / PAGE_SIZE, BufferDirection::Both)?;
583        Ok(Self::Legacy {
584            dma,
585            avail_offset: desc,
586            used_offset: align_up(desc + avail),
587        })
588    }
589
590    /// Allocates separate DMA regions for the the different parts of the virtqueue, as supported by
591    /// non-legacy interfaces.
592    ///
593    /// This is preferred over `allocate_legacy` where possible as it reduces memory fragmentation
594    /// and allows the HAL to know which DMA regions are used in which direction.
595    fn allocate_flexible(queue_size: u16) -> Result<Self> {
596        let (desc, avail, used) = queue_part_sizes(queue_size);
597        let driver_to_device_dma = Dma::new(pages(desc + avail), BufferDirection::DriverToDevice)?;
598        let device_to_driver_dma = Dma::new(pages(used), BufferDirection::DeviceToDriver)?;
599        Ok(Self::Modern {
600            driver_to_device_dma,
601            device_to_driver_dma,
602            avail_offset: desc,
603        })
604    }
605
606    /// Returns the physical address of the descriptor area.
607    fn descriptors_paddr(&self) -> PhysAddr {
608        match self {
609            Self::Legacy { dma, .. } => dma.paddr(),
610            Self::Modern {
611                driver_to_device_dma,
612                ..
613            } => driver_to_device_dma.paddr(),
614        }
615    }
616
617    /// Returns a pointer to the descriptor table (in the descriptor area).
618    fn descriptors_vaddr(&self) -> NonNull<u8> {
619        match self {
620            Self::Legacy { dma, .. } => dma.vaddr(0),
621            Self::Modern {
622                driver_to_device_dma,
623                ..
624            } => driver_to_device_dma.vaddr(0),
625        }
626    }
627
628    /// Returns the physical address of the driver area.
629    fn driver_area_paddr(&self) -> PhysAddr {
630        match self {
631            Self::Legacy {
632                dma, avail_offset, ..
633            } => dma.paddr() + avail_offset,
634            Self::Modern {
635                driver_to_device_dma,
636                avail_offset,
637                ..
638            } => driver_to_device_dma.paddr() + avail_offset,
639        }
640    }
641
642    /// Returns a pointer to the available ring (in the driver area).
643    fn avail_vaddr(&self) -> NonNull<u8> {
644        match self {
645            Self::Legacy {
646                dma, avail_offset, ..
647            } => dma.vaddr(*avail_offset),
648            Self::Modern {
649                driver_to_device_dma,
650                avail_offset,
651                ..
652            } => driver_to_device_dma.vaddr(*avail_offset),
653        }
654    }
655
656    /// Returns the physical address of the device area.
657    fn device_area_paddr(&self) -> PhysAddr {
658        match self {
659            Self::Legacy {
660                used_offset, dma, ..
661            } => dma.paddr() + used_offset,
662            Self::Modern {
663                device_to_driver_dma,
664                ..
665            } => device_to_driver_dma.paddr(),
666        }
667    }
668
669    /// Returns a pointer to the used ring (in the driver area).
670    fn used_vaddr(&self) -> NonNull<u8> {
671        match self {
672            Self::Legacy {
673                dma, used_offset, ..
674            } => dma.vaddr(*used_offset),
675            Self::Modern {
676                device_to_driver_dma,
677                ..
678            } => device_to_driver_dma.vaddr(0),
679        }
680    }
681}
682
683/// Returns the size in bytes of the descriptor table, available ring and used ring for a given
684/// queue size.
685///
686/// Ref: 2.6 Split Virtqueues
687fn queue_part_sizes(queue_size: u16) -> (usize, usize, usize) {
688    assert!(
689        queue_size.is_power_of_two(),
690        "queue size should be a power of 2"
691    );
692    let queue_size = queue_size as usize;
693    let desc = size_of::<Descriptor>() * queue_size;
694    let avail = size_of::<u16>() * (3 + queue_size);
695    let used = size_of::<u16>() * 3 + size_of::<UsedElem>() * queue_size;
696    (desc, avail, used)
697}
698
699#[repr(C, align(16))]
700#[derive(AsBytes, Clone, Debug, FromBytes, FromZeroes)]
701pub(crate) struct Descriptor {
702    addr: u64,
703    len: u32,
704    flags: DescFlags,
705    next: u16,
706}
707
708impl Descriptor {
709    /// Sets the buffer address, length and flags, and shares it with the device.
710    ///
711    /// # Safety
712    ///
713    /// The caller must ensure that the buffer lives at least as long as the descriptor is active.
714    unsafe fn set_buf<H: Hal>(
715        &mut self,
716        buf: NonNull<[u8]>,
717        direction: BufferDirection,
718        extra_flags: DescFlags,
719    ) {
720        // Safe because our caller promises that the buffer is valid.
721        unsafe {
722            self.addr = H::share(buf, direction) as u64;
723        }
724        self.len = buf.len().try_into().unwrap();
725        self.flags = extra_flags
726            | match direction {
727                BufferDirection::DeviceToDriver => DescFlags::WRITE,
728                BufferDirection::DriverToDevice => DescFlags::empty(),
729                BufferDirection::Both => {
730                    panic!("Buffer passed to device should never use BufferDirection::Both.")
731                }
732            };
733    }
734
735    /// Sets the buffer address and length to 0.
736    ///
737    /// This must only be called once the device has finished using the descriptor.
738    fn unset_buf(&mut self) {
739        self.addr = 0;
740        self.len = 0;
741    }
742
743    /// Returns the index of the next descriptor in the chain if the `NEXT` flag is set, or `None`
744    /// if it is not (and thus this descriptor is the end of the chain).
745    fn next(&self) -> Option<u16> {
746        if self.flags.contains(DescFlags::NEXT) {
747            Some(self.next)
748        } else {
749            None
750        }
751    }
752}
753
754/// Descriptor flags
755#[derive(AsBytes, Copy, Clone, Debug, Default, Eq, FromBytes, FromZeroes, PartialEq)]
756#[repr(transparent)]
757struct DescFlags(u16);
758
759bitflags! {
760    impl DescFlags: u16 {
761        const NEXT = 1;
762        const WRITE = 2;
763        const INDIRECT = 4;
764    }
765}
766
767/// The driver uses the available ring to offer buffers to the device:
768/// each ring entry refers to the head of a descriptor chain.
769/// It is only written by the driver and read by the device.
770#[repr(C)]
771#[derive(Debug)]
772struct AvailRing<const SIZE: usize> {
773    flags: AtomicU16,
774    /// A driver MUST NOT decrement the idx.
775    idx: AtomicU16,
776    ring: [u16; SIZE],
777    /// Only used if `VIRTIO_F_EVENT_IDX` is negotiated.
778    used_event: AtomicU16,
779}
780
781/// The used ring is where the device returns buffers once it is done with them:
782/// it is only written to by the device, and read by the driver.
783#[repr(C)]
784#[derive(Debug)]
785struct UsedRing<const SIZE: usize> {
786    flags: AtomicU16,
787    idx: AtomicU16,
788    ring: [UsedElem; SIZE],
789    /// Only used if `VIRTIO_F_EVENT_IDX` is negotiated.
790    avail_event: AtomicU16,
791}
792
793#[repr(C)]
794#[derive(Debug)]
795struct UsedElem {
796    id: u32,
797    len: u32,
798}
799
800struct InputOutputIter<'a, 'b> {
801    inputs: &'a [&'b [u8]],
802    outputs: &'a mut [&'b mut [u8]],
803}
804
805impl<'a, 'b> InputOutputIter<'a, 'b> {
806    fn new(inputs: &'a [&'b [u8]], outputs: &'a mut [&'b mut [u8]]) -> Self {
807        Self { inputs, outputs }
808    }
809}
810
811impl<'a, 'b> Iterator for InputOutputIter<'a, 'b> {
812    type Item = (NonNull<[u8]>, BufferDirection);
813
814    fn next(&mut self) -> Option<Self::Item> {
815        if let Some(input) = take_first(&mut self.inputs) {
816            Some(((*input).into(), BufferDirection::DriverToDevice))
817        } else {
818            let output = take_first_mut(&mut self.outputs)?;
819            Some(((*output).into(), BufferDirection::DeviceToDriver))
820        }
821    }
822}
823
824// TODO: Use `slice::take_first` once it is stable
825// (https://github.com/rust-lang/rust/issues/62280).
826fn take_first<'a, T>(slice: &mut &'a [T]) -> Option<&'a T> {
827    let (first, rem) = slice.split_first()?;
828    *slice = rem;
829    Some(first)
830}
831
832// TODO: Use `slice::take_first_mut` once it is stable
833// (https://github.com/rust-lang/rust/issues/62280).
834fn take_first_mut<'a, T>(slice: &mut &'a mut [T]) -> Option<&'a mut T> {
835    let (first, rem) = take(slice).split_first_mut()?;
836    *slice = rem;
837    Some(first)
838}
839
840/// Simulates the device reading from a VirtIO queue and writing a response back, for use in tests.
841///
842/// The fake device always uses descriptors in order.
843///
844/// Returns true if a descriptor chain was available and processed, or false if no descriptors were
845/// available.
846#[cfg(test)]
847pub(crate) fn fake_read_write_queue<const QUEUE_SIZE: usize>(
848    descriptors: *const [Descriptor; QUEUE_SIZE],
849    queue_driver_area: *const u8,
850    queue_device_area: *mut u8,
851    handler: impl FnOnce(Vec<u8>) -> Vec<u8>,
852) -> bool {
853    use core::{ops::Deref, slice};
854
855    let available_ring = queue_driver_area as *const AvailRing<QUEUE_SIZE>;
856    let used_ring = queue_device_area as *mut UsedRing<QUEUE_SIZE>;
857
858    // Safe because the various pointers are properly aligned, dereferenceable, initialised, and
859    // nothing else accesses them during this block.
860    unsafe {
861        // Make sure there is actually at least one descriptor available to read from.
862        if (*available_ring).idx.load(Ordering::Acquire) == (*used_ring).idx.load(Ordering::Acquire)
863        {
864            return false;
865        }
866        // The fake device always uses descriptors in order, like VIRTIO_F_IN_ORDER, so
867        // `used_ring.idx` marks the next descriptor we should take from the available ring.
868        let next_slot = (*used_ring).idx.load(Ordering::Acquire) & (QUEUE_SIZE as u16 - 1);
869        let head_descriptor_index = (*available_ring).ring[next_slot as usize];
870        let mut descriptor = &(*descriptors)[head_descriptor_index as usize];
871
872        let input_length;
873        let output;
874        if descriptor.flags.contains(DescFlags::INDIRECT) {
875            // The descriptor shouldn't have any other flags if it is indirect.
876            assert_eq!(descriptor.flags, DescFlags::INDIRECT);
877
878            // Loop through all input descriptors in the indirect descriptor list, reading data from
879            // them.
880            let indirect_descriptor_list: &[Descriptor] = zerocopy::Ref::new_slice(
881                slice::from_raw_parts(descriptor.addr as *const u8, descriptor.len as usize),
882            )
883            .unwrap()
884            .into_slice();
885            let mut input = Vec::new();
886            let mut indirect_descriptor_index = 0;
887            while indirect_descriptor_index < indirect_descriptor_list.len() {
888                let indirect_descriptor = &indirect_descriptor_list[indirect_descriptor_index];
889                if indirect_descriptor.flags.contains(DescFlags::WRITE) {
890                    break;
891                }
892
893                input.extend_from_slice(slice::from_raw_parts(
894                    indirect_descriptor.addr as *const u8,
895                    indirect_descriptor.len as usize,
896                ));
897
898                indirect_descriptor_index += 1;
899            }
900            input_length = input.len();
901
902            // Let the test handle the request.
903            output = handler(input);
904
905            // Write the response to the remaining descriptors.
906            let mut remaining_output = output.deref();
907            while indirect_descriptor_index < indirect_descriptor_list.len() {
908                let indirect_descriptor = &indirect_descriptor_list[indirect_descriptor_index];
909                assert!(indirect_descriptor.flags.contains(DescFlags::WRITE));
910
911                let length_to_write = min(remaining_output.len(), indirect_descriptor.len as usize);
912                ptr::copy(
913                    remaining_output.as_ptr(),
914                    indirect_descriptor.addr as *mut u8,
915                    length_to_write,
916                );
917                remaining_output = &remaining_output[length_to_write..];
918
919                indirect_descriptor_index += 1;
920            }
921            assert_eq!(remaining_output.len(), 0);
922        } else {
923            // Loop through all input descriptors in the chain, reading data from them.
924            let mut input = Vec::new();
925            while !descriptor.flags.contains(DescFlags::WRITE) {
926                input.extend_from_slice(slice::from_raw_parts(
927                    descriptor.addr as *const u8,
928                    descriptor.len as usize,
929                ));
930
931                if let Some(next) = descriptor.next() {
932                    descriptor = &(*descriptors)[next as usize];
933                } else {
934                    break;
935                }
936            }
937            input_length = input.len();
938
939            // Let the test handle the request.
940            output = handler(input);
941
942            // Write the response to the remaining descriptors.
943            let mut remaining_output = output.deref();
944            if descriptor.flags.contains(DescFlags::WRITE) {
945                loop {
946                    assert!(descriptor.flags.contains(DescFlags::WRITE));
947
948                    let length_to_write = min(remaining_output.len(), descriptor.len as usize);
949                    ptr::copy(
950                        remaining_output.as_ptr(),
951                        descriptor.addr as *mut u8,
952                        length_to_write,
953                    );
954                    remaining_output = &remaining_output[length_to_write..];
955
956                    if let Some(next) = descriptor.next() {
957                        descriptor = &(*descriptors)[next as usize];
958                    } else {
959                        break;
960                    }
961                }
962            }
963            assert_eq!(remaining_output.len(), 0);
964        }
965
966        // Mark the buffer as used.
967        (*used_ring).ring[next_slot as usize].id = head_descriptor_index.into();
968        (*used_ring).ring[next_slot as usize].len = (input_length + output.len()) as u32;
969        (*used_ring).idx.fetch_add(1, Ordering::AcqRel);
970
971        true
972    }
973}
974
975#[cfg(test)]
976mod tests {
977    use super::*;
978    use crate::{
979        device::common::Feature,
980        hal::fake::FakeHal,
981        transport::{
982            fake::{FakeTransport, QueueStatus, State},
983            mmio::{MmioTransport, VirtIOHeader, MODERN_VERSION},
984            DeviceType,
985        },
986    };
987    use core::ptr::NonNull;
988    use std::sync::{Arc, Mutex};
989
990    #[test]
991    fn queue_too_big() {
992        let mut header = VirtIOHeader::make_fake_header(MODERN_VERSION, 1, 0, 0, 4);
993        let mut transport = unsafe { MmioTransport::new(NonNull::from(&mut header)) }.unwrap();
994        assert_eq!(
995            VirtQueue::<FakeHal, 8>::new(&mut transport, 0, false, false).unwrap_err(),
996            Error::InvalidParam
997        );
998    }
999
1000    #[test]
1001    fn queue_already_used() {
1002        let mut header = VirtIOHeader::make_fake_header(MODERN_VERSION, 1, 0, 0, 4);
1003        let mut transport = unsafe { MmioTransport::new(NonNull::from(&mut header)) }.unwrap();
1004        VirtQueue::<FakeHal, 4>::new(&mut transport, 0, false, false).unwrap();
1005        assert_eq!(
1006            VirtQueue::<FakeHal, 4>::new(&mut transport, 0, false, false).unwrap_err(),
1007            Error::AlreadyUsed
1008        );
1009    }
1010
1011    #[test]
1012    fn add_empty() {
1013        let mut header = VirtIOHeader::make_fake_header(MODERN_VERSION, 1, 0, 0, 4);
1014        let mut transport = unsafe { MmioTransport::new(NonNull::from(&mut header)) }.unwrap();
1015        let mut queue = VirtQueue::<FakeHal, 4>::new(&mut transport, 0, false, false).unwrap();
1016        assert_eq!(
1017            unsafe { queue.add(&[], &mut []) }.unwrap_err(),
1018            Error::InvalidParam
1019        );
1020    }
1021
1022    #[test]
1023    fn add_too_many() {
1024        let mut header = VirtIOHeader::make_fake_header(MODERN_VERSION, 1, 0, 0, 4);
1025        let mut transport = unsafe { MmioTransport::new(NonNull::from(&mut header)) }.unwrap();
1026        let mut queue = VirtQueue::<FakeHal, 4>::new(&mut transport, 0, false, false).unwrap();
1027        assert_eq!(queue.available_desc(), 4);
1028        assert_eq!(
1029            unsafe { queue.add(&[&[], &[], &[]], &mut [&mut [], &mut []]) }.unwrap_err(),
1030            Error::QueueFull
1031        );
1032    }
1033
1034    #[test]
1035    fn add_buffers() {
1036        let mut header = VirtIOHeader::make_fake_header(MODERN_VERSION, 1, 0, 0, 4);
1037        let mut transport = unsafe { MmioTransport::new(NonNull::from(&mut header)) }.unwrap();
1038        let mut queue = VirtQueue::<FakeHal, 4>::new(&mut transport, 0, false, false).unwrap();
1039        assert_eq!(queue.available_desc(), 4);
1040
1041        // Add a buffer chain consisting of two device-readable parts followed by two
1042        // device-writable parts.
1043        let token = unsafe { queue.add(&[&[1, 2], &[3]], &mut [&mut [0, 0], &mut [0]]) }.unwrap();
1044
1045        assert_eq!(queue.available_desc(), 0);
1046        assert!(!queue.can_pop());
1047
1048        // Safe because the various parts of the queue are properly aligned, dereferenceable and
1049        // initialised, and nothing else is accessing them at the same time.
1050        unsafe {
1051            let first_descriptor_index = (*queue.avail.as_ptr()).ring[0];
1052            assert_eq!(first_descriptor_index, token);
1053            assert_eq!(
1054                (*queue.desc.as_ptr())[first_descriptor_index as usize].len,
1055                2
1056            );
1057            assert_eq!(
1058                (*queue.desc.as_ptr())[first_descriptor_index as usize].flags,
1059                DescFlags::NEXT
1060            );
1061            let second_descriptor_index =
1062                (*queue.desc.as_ptr())[first_descriptor_index as usize].next;
1063            assert_eq!(
1064                (*queue.desc.as_ptr())[second_descriptor_index as usize].len,
1065                1
1066            );
1067            assert_eq!(
1068                (*queue.desc.as_ptr())[second_descriptor_index as usize].flags,
1069                DescFlags::NEXT
1070            );
1071            let third_descriptor_index =
1072                (*queue.desc.as_ptr())[second_descriptor_index as usize].next;
1073            assert_eq!(
1074                (*queue.desc.as_ptr())[third_descriptor_index as usize].len,
1075                2
1076            );
1077            assert_eq!(
1078                (*queue.desc.as_ptr())[third_descriptor_index as usize].flags,
1079                DescFlags::NEXT | DescFlags::WRITE
1080            );
1081            let fourth_descriptor_index =
1082                (*queue.desc.as_ptr())[third_descriptor_index as usize].next;
1083            assert_eq!(
1084                (*queue.desc.as_ptr())[fourth_descriptor_index as usize].len,
1085                1
1086            );
1087            assert_eq!(
1088                (*queue.desc.as_ptr())[fourth_descriptor_index as usize].flags,
1089                DescFlags::WRITE
1090            );
1091        }
1092    }
1093
1094    #[cfg(feature = "alloc")]
1095    #[test]
1096    fn add_buffers_indirect() {
1097        use core::ptr::slice_from_raw_parts;
1098
1099        let mut header = VirtIOHeader::make_fake_header(MODERN_VERSION, 1, 0, 0, 4);
1100        let mut transport = unsafe { MmioTransport::new(NonNull::from(&mut header)) }.unwrap();
1101        let mut queue = VirtQueue::<FakeHal, 4>::new(&mut transport, 0, true, false).unwrap();
1102        assert_eq!(queue.available_desc(), 4);
1103
1104        // Add a buffer chain consisting of two device-readable parts followed by two
1105        // device-writable parts.
1106        let token = unsafe { queue.add(&[&[1, 2], &[3]], &mut [&mut [0, 0], &mut [0]]) }.unwrap();
1107
1108        assert_eq!(queue.available_desc(), 4);
1109        assert!(!queue.can_pop());
1110
1111        // Safe because the various parts of the queue are properly aligned, dereferenceable and
1112        // initialised, and nothing else is accessing them at the same time.
1113        unsafe {
1114            let indirect_descriptor_index = (*queue.avail.as_ptr()).ring[0];
1115            assert_eq!(indirect_descriptor_index, token);
1116            assert_eq!(
1117                (*queue.desc.as_ptr())[indirect_descriptor_index as usize].len as usize,
1118                4 * size_of::<Descriptor>()
1119            );
1120            assert_eq!(
1121                (*queue.desc.as_ptr())[indirect_descriptor_index as usize].flags,
1122                DescFlags::INDIRECT
1123            );
1124
1125            let indirect_descriptors = slice_from_raw_parts(
1126                (*queue.desc.as_ptr())[indirect_descriptor_index as usize].addr
1127                    as *const Descriptor,
1128                4,
1129            );
1130            assert_eq!((*indirect_descriptors)[0].len, 2);
1131            assert_eq!((*indirect_descriptors)[0].flags, DescFlags::NEXT);
1132            assert_eq!((*indirect_descriptors)[0].next, 1);
1133            assert_eq!((*indirect_descriptors)[1].len, 1);
1134            assert_eq!((*indirect_descriptors)[1].flags, DescFlags::NEXT);
1135            assert_eq!((*indirect_descriptors)[1].next, 2);
1136            assert_eq!((*indirect_descriptors)[2].len, 2);
1137            assert_eq!(
1138                (*indirect_descriptors)[2].flags,
1139                DescFlags::NEXT | DescFlags::WRITE
1140            );
1141            assert_eq!((*indirect_descriptors)[2].next, 3);
1142            assert_eq!((*indirect_descriptors)[3].len, 1);
1143            assert_eq!((*indirect_descriptors)[3].flags, DescFlags::WRITE);
1144        }
1145    }
1146
1147    /// Tests that the queue advises the device that notifications are needed.
1148    #[test]
1149    fn set_dev_notify() {
1150        let mut config_space = ();
1151        let state = Arc::new(Mutex::new(State {
1152            queues: vec![QueueStatus::default()],
1153            ..Default::default()
1154        }));
1155        let mut transport = FakeTransport {
1156            device_type: DeviceType::Block,
1157            max_queue_size: 4,
1158            device_features: 0,
1159            config_space: NonNull::from(&mut config_space),
1160            state: state.clone(),
1161        };
1162        let mut queue = VirtQueue::<FakeHal, 4>::new(&mut transport, 0, false, false).unwrap();
1163
1164        // Check that the avail ring's flag is zero by default.
1165        assert_eq!(
1166            unsafe { (*queue.avail.as_ptr()).flags.load(Ordering::Acquire) },
1167            0x0
1168        );
1169
1170        queue.set_dev_notify(false);
1171
1172        // Check that the avail ring's flag is 1 after `disable_dev_notify`.
1173        assert_eq!(
1174            unsafe { (*queue.avail.as_ptr()).flags.load(Ordering::Acquire) },
1175            0x1
1176        );
1177
1178        queue.set_dev_notify(true);
1179
1180        // Check that the avail ring's flag is 0 after `enable_dev_notify`.
1181        assert_eq!(
1182            unsafe { (*queue.avail.as_ptr()).flags.load(Ordering::Acquire) },
1183            0x0
1184        );
1185    }
1186
1187    /// Tests that the queue notifies the device about added buffers, if it hasn't suppressed
1188    /// notifications.
1189    #[test]
1190    fn add_notify() {
1191        let mut config_space = ();
1192        let state = Arc::new(Mutex::new(State {
1193            queues: vec![QueueStatus::default()],
1194            ..Default::default()
1195        }));
1196        let mut transport = FakeTransport {
1197            device_type: DeviceType::Block,
1198            max_queue_size: 4,
1199            device_features: 0,
1200            config_space: NonNull::from(&mut config_space),
1201            state: state.clone(),
1202        };
1203        let mut queue = VirtQueue::<FakeHal, 4>::new(&mut transport, 0, false, false).unwrap();
1204
1205        // Add a buffer chain with a single device-readable part.
1206        unsafe { queue.add(&[&[42]], &mut []) }.unwrap();
1207
1208        // Check that the transport would be notified.
1209        assert_eq!(queue.should_notify(), true);
1210
1211        // SAFETY: the various parts of the queue are properly aligned, dereferenceable and
1212        // initialised, and nothing else is accessing them at the same time.
1213        unsafe {
1214            // Suppress notifications.
1215            (*queue.used.as_ptr()).flags.store(0x01, Ordering::Release);
1216        }
1217
1218        // Check that the transport would not be notified.
1219        assert_eq!(queue.should_notify(), false);
1220    }
1221
1222    /// Tests that the queue notifies the device about added buffers, if it hasn't suppressed
1223    /// notifications with the `avail_event` index.
1224    #[test]
1225    fn add_notify_event_idx() {
1226        let mut config_space = ();
1227        let state = Arc::new(Mutex::new(State {
1228            queues: vec![QueueStatus::default()],
1229            ..Default::default()
1230        }));
1231        let mut transport = FakeTransport {
1232            device_type: DeviceType::Block,
1233            max_queue_size: 4,
1234            device_features: Feature::RING_EVENT_IDX.bits(),
1235            config_space: NonNull::from(&mut config_space),
1236            state: state.clone(),
1237        };
1238        let mut queue = VirtQueue::<FakeHal, 4>::new(&mut transport, 0, false, true).unwrap();
1239
1240        // Add a buffer chain with a single device-readable part.
1241        assert_eq!(unsafe { queue.add(&[&[42]], &mut []) }.unwrap(), 0);
1242
1243        // Check that the transport would be notified.
1244        assert_eq!(queue.should_notify(), true);
1245
1246        // SAFETY: the various parts of the queue are properly aligned, dereferenceable and
1247        // initialised, and nothing else is accessing them at the same time.
1248        unsafe {
1249            // Suppress notifications.
1250            (*queue.used.as_ptr())
1251                .avail_event
1252                .store(1, Ordering::Release);
1253        }
1254
1255        // Check that the transport would not be notified.
1256        assert_eq!(queue.should_notify(), false);
1257
1258        // Add another buffer chain.
1259        assert_eq!(unsafe { queue.add(&[&[42]], &mut []) }.unwrap(), 1);
1260
1261        // Check that the transport should be notified again now.
1262        assert_eq!(queue.should_notify(), true);
1263    }
1264}