1pub mod bus;
4
5use self::bus::{
6 ConfigurationAccess, DeviceFunction, DeviceFunctionInfo, PCI_CAP_ID_VNDR, PciError, PciRoot,
7};
8use super::{DeviceStatus, DeviceType, Transport};
9use crate::{
10 Error,
11 hal::{Hal, PhysAddr},
12 transport::InterruptStatus,
13};
14use core::{
15 mem::{align_of, size_of},
16 ops::Deref,
17 ptr::NonNull,
18};
19use safe_mmio::{
20 UniqueMmioPointer, field, field_shared,
21 fields::{ReadOnly, ReadPure, ReadPureWrite, WriteOnly},
22};
23use zerocopy::{FromBytes, Immutable, IntoBytes};
24
25pub const VIRTIO_VENDOR_ID: u16 = 0x1af4;
27
28const PCI_DEVICE_ID_OFFSET: u16 = 0x1040;
30
31const TRANSITIONAL_NETWORK: u16 = 0x1000;
32const TRANSITIONAL_BLOCK: u16 = 0x1001;
33const TRANSITIONAL_MEMORY_BALLOONING: u16 = 0x1002;
34const TRANSITIONAL_CONSOLE: u16 = 0x1003;
35const TRANSITIONAL_SCSI_HOST: u16 = 0x1004;
36const TRANSITIONAL_ENTROPY_SOURCE: u16 = 0x1005;
37const TRANSITIONAL_9P_TRANSPORT: u16 = 0x1009;
38
39pub(crate) const CAP_BAR_OFFSET: u8 = 4;
41pub(crate) const CAP_BAR_OFFSET_OFFSET: u8 = 8;
43pub(crate) const CAP_LENGTH_OFFSET: u8 = 12;
45pub(crate) const CAP_NOTIFY_OFF_MULTIPLIER_OFFSET: u8 = 16;
47
48pub const VIRTIO_PCI_CAP_COMMON_CFG: u8 = 1;
50pub const VIRTIO_PCI_CAP_NOTIFY_CFG: u8 = 2;
52pub const VIRTIO_PCI_CAP_ISR_CFG: u8 = 3;
54pub const VIRTIO_PCI_CAP_DEVICE_CFG: u8 = 4;
56
57pub(crate) fn device_type(pci_device_id: u16) -> Option<DeviceType> {
58 match pci_device_id {
59 TRANSITIONAL_NETWORK => Some(DeviceType::Network),
60 TRANSITIONAL_BLOCK => Some(DeviceType::Block),
61 TRANSITIONAL_MEMORY_BALLOONING => Some(DeviceType::MemoryBalloon),
62 TRANSITIONAL_CONSOLE => Some(DeviceType::Console),
63 TRANSITIONAL_SCSI_HOST => Some(DeviceType::ScsiHost),
64 TRANSITIONAL_ENTROPY_SOURCE => Some(DeviceType::EntropySource),
65 TRANSITIONAL_9P_TRANSPORT => Some(DeviceType::_9P),
66 id if id >= PCI_DEVICE_ID_OFFSET => DeviceType::try_from(id - PCI_DEVICE_ID_OFFSET).ok(),
67 _ => None,
68 }
69}
70
71pub fn virtio_device_type(device_function_info: &DeviceFunctionInfo) -> Option<DeviceType> {
74 if device_function_info.vendor_id == VIRTIO_VENDOR_ID {
75 device_type(device_function_info.device_id)
76 } else {
77 None
78 }
79}
80
81#[derive(Debug)]
85pub struct PciTransport {
86 device_type: DeviceType,
87 device_function: DeviceFunction,
89 common_cfg: UniqueMmioPointer<'static, CommonCfg>,
91 notify_region: UniqueMmioPointer<'static, [WriteOnly<u16>]>,
93 notify_off_multiplier: u32,
94 isr_status: UniqueMmioPointer<'static, ReadOnly<u8>>,
96 config_space: Option<UniqueMmioPointer<'static, [u32]>>,
98}
99
100impl PciTransport {
101 pub fn new<H: Hal, C: ConfigurationAccess>(
106 root: &mut PciRoot<C>,
107 device_function: DeviceFunction,
108 ) -> Result<Self, VirtioPciError> {
109 let device_vendor = root.configuration_access.read_word(device_function, 0);
110 let device_id = (device_vendor >> 16) as u16;
111 let vendor_id = device_vendor as u16;
112 if vendor_id != VIRTIO_VENDOR_ID {
113 return Err(VirtioPciError::InvalidVendorId(vendor_id));
114 }
115 let device_type =
116 device_type(device_id).ok_or(VirtioPciError::InvalidDeviceId(device_id))?;
117
118 let mut common_cfg = None;
120 let mut notify_cfg = None;
121 let mut notify_off_multiplier = 0;
122 let mut isr_cfg = None;
123 let mut device_cfg = None;
124 for capability in root.capabilities(device_function) {
125 if capability.id != PCI_CAP_ID_VNDR {
126 continue;
127 }
128 let cap_len = capability.private_header as u8;
129 let cfg_type = (capability.private_header >> 8) as u8;
130 if cap_len < 16 {
131 continue;
132 }
133 let struct_info = VirtioCapabilityInfo {
134 bar: root
135 .configuration_access
136 .read_word(device_function, capability.offset + CAP_BAR_OFFSET)
137 as u8,
138 offset: root
139 .configuration_access
140 .read_word(device_function, capability.offset + CAP_BAR_OFFSET_OFFSET),
141 length: root
142 .configuration_access
143 .read_word(device_function, capability.offset + CAP_LENGTH_OFFSET),
144 };
145
146 match cfg_type {
147 VIRTIO_PCI_CAP_COMMON_CFG if common_cfg.is_none() => {
148 common_cfg = Some(struct_info);
149 }
150 VIRTIO_PCI_CAP_NOTIFY_CFG if cap_len >= 20 && notify_cfg.is_none() => {
151 notify_cfg = Some(struct_info);
152 notify_off_multiplier = root.configuration_access.read_word(
153 device_function,
154 capability.offset + CAP_NOTIFY_OFF_MULTIPLIER_OFFSET,
155 );
156 }
157 VIRTIO_PCI_CAP_ISR_CFG if isr_cfg.is_none() => {
158 isr_cfg = Some(struct_info);
159 }
160 VIRTIO_PCI_CAP_DEVICE_CFG if device_cfg.is_none() => {
161 device_cfg = Some(struct_info);
162 }
163 _ => {}
164 }
165 }
166
167 let common_cfg = get_bar_region::<H, _, _>(
168 root,
169 device_function,
170 &common_cfg.ok_or(VirtioPciError::MissingCommonConfig)?,
171 )?;
172 let common_cfg = unsafe { UniqueMmioPointer::new(common_cfg) };
175
176 let notify_cfg = notify_cfg.ok_or(VirtioPciError::MissingNotifyConfig)?;
177 if notify_off_multiplier % 2 != 0 {
178 return Err(VirtioPciError::InvalidNotifyOffMultiplier(
179 notify_off_multiplier,
180 ));
181 }
182 let notify_region = get_bar_region_slice::<H, _, _>(root, device_function, ¬ify_cfg)?;
183 let notify_region = unsafe { UniqueMmioPointer::new(notify_region) };
186
187 let isr_status = get_bar_region::<H, _, _>(
188 root,
189 device_function,
190 &isr_cfg.ok_or(VirtioPciError::MissingIsrConfig)?,
191 )?;
192 let isr_status = unsafe { UniqueMmioPointer::new(isr_status) };
195
196 let config_space = if let Some(device_cfg) = device_cfg {
197 Some(unsafe {
200 UniqueMmioPointer::new(get_bar_region_slice::<H, _, _>(
201 root,
202 device_function,
203 &device_cfg,
204 )?)
205 })
206 } else {
207 None
208 };
209
210 Ok(Self {
211 device_type,
212 device_function,
213 common_cfg,
214 notify_region,
215 notify_off_multiplier,
216 isr_status,
217 config_space,
218 })
219 }
220}
221
222impl Transport for PciTransport {
223 fn device_type(&self) -> DeviceType {
224 self.device_type
225 }
226
227 fn read_device_features(&mut self) -> u64 {
228 field!(self.common_cfg, device_feature_select).write(0);
229 let mut device_features_bits = field_shared!(self.common_cfg, device_feature).read() as u64;
230 field!(self.common_cfg, device_feature_select).write(1);
231 device_features_bits |=
232 (field_shared!(self.common_cfg, device_feature).read() as u64) << 32;
233 device_features_bits
234 }
235
236 fn write_driver_features(&mut self, driver_features: u64) {
237 field!(self.common_cfg, driver_feature_select).write(0);
238 field!(self.common_cfg, driver_feature).write(driver_features as u32);
239 field!(self.common_cfg, driver_feature_select).write(1);
240 field!(self.common_cfg, driver_feature).write((driver_features >> 32) as u32);
241 }
242
243 fn max_queue_size(&mut self, queue: u16) -> u32 {
244 field!(self.common_cfg, queue_select).write(queue);
245 field_shared!(self.common_cfg, queue_size).read().into()
246 }
247
248 fn notify(&mut self, queue: u16) {
249 field!(self.common_cfg, queue_select).write(queue);
250 let queue_notify_off = field_shared!(self.common_cfg, queue_notify_off).read();
252
253 let offset_bytes = usize::from(queue_notify_off) * self.notify_off_multiplier as usize;
254 let index = offset_bytes / size_of::<u16>();
255 self.notify_region.get(index).unwrap().write(queue);
256 }
257
258 fn get_status(&self) -> DeviceStatus {
259 let status = field_shared!(self.common_cfg, device_status).read();
260 DeviceStatus::from_bits_truncate(status.into())
261 }
262
263 fn set_status(&mut self, status: DeviceStatus) {
264 field!(self.common_cfg, device_status).write(status.bits() as u8);
265 }
266
267 fn set_guest_page_size(&mut self, _guest_page_size: u32) {
268 }
270
271 fn requires_legacy_layout(&self) -> bool {
272 false
273 }
274
275 fn queue_set(
276 &mut self,
277 queue: u16,
278 size: u32,
279 descriptors: PhysAddr,
280 driver_area: PhysAddr,
281 device_area: PhysAddr,
282 ) {
283 field!(self.common_cfg, queue_select).write(queue);
284 field!(self.common_cfg, queue_size).write(size as u16);
285 field!(self.common_cfg, queue_desc).write(descriptors);
286 field!(self.common_cfg, queue_driver).write(driver_area);
287 field!(self.common_cfg, queue_device).write(device_area);
288 field!(self.common_cfg, queue_enable).write(1);
289 }
290
291 fn queue_unset(&mut self, _queue: u16) {
292 }
295
296 fn queue_used(&mut self, queue: u16) -> bool {
297 field!(self.common_cfg, queue_select).write(queue);
298 field_shared!(self.common_cfg, queue_enable).read() == 1
299 }
300
301 fn ack_interrupt(&mut self) -> InterruptStatus {
302 let isr_status = self.isr_status.read();
304 InterruptStatus::from_bits_retain(isr_status.into())
305 }
306
307 fn read_config_generation(&self) -> u32 {
308 field_shared!(self.common_cfg, config_generation)
309 .read()
310 .into()
311 }
312
313 fn read_config_space<T: FromBytes + IntoBytes>(&self, offset: usize) -> Result<T, Error> {
314 assert!(
315 align_of::<T>() <= 4,
316 "Driver expected config space alignment of {} bytes, but VirtIO only guarantees 4 byte alignment.",
317 align_of::<T>()
318 );
319 assert_eq!(offset % align_of::<T>(), 0);
320
321 let config_space = self
322 .config_space
323 .as_ref()
324 .ok_or(Error::ConfigSpaceMissing)?;
325 if config_space.len() * size_of::<u32>() < offset + size_of::<T>() {
326 Err(Error::ConfigSpaceTooSmall)
327 } else {
328 unsafe {
332 let ptr = config_space.ptr().cast::<T>().byte_add(offset);
333 Ok(config_space
334 .deref()
335 .child(NonNull::new(ptr.cast_mut()).unwrap())
336 .read_unsafe())
337 }
338 }
339 }
340
341 fn write_config_space<T: IntoBytes + Immutable>(
342 &mut self,
343 offset: usize,
344 value: T,
345 ) -> Result<(), Error> {
346 assert!(
347 align_of::<T>() <= 4,
348 "Driver expected config space alignment of {} bytes, but VirtIO only guarantees 4 byte alignment.",
349 align_of::<T>()
350 );
351 assert_eq!(offset % align_of::<T>(), 0);
352
353 let config_space = self
354 .config_space
355 .as_mut()
356 .ok_or(Error::ConfigSpaceMissing)?;
357 if config_space.len() * size_of::<u32>() < offset + size_of::<T>() {
358 Err(Error::ConfigSpaceTooSmall)
359 } else {
360 unsafe {
364 let ptr = config_space.ptr_nonnull().cast::<T>().byte_add(offset);
365 config_space.child(ptr).write_unsafe(value);
366 }
367 Ok(())
368 }
369 }
370}
371
372unsafe impl Send for PciTransport {}
374
375unsafe impl Sync for PciTransport {}
378
379impl Drop for PciTransport {
380 fn drop(&mut self) {
381 self.set_status(DeviceStatus::empty());
383 while self.get_status() != DeviceStatus::empty() {}
384 }
385}
386
387#[repr(C)]
389pub(crate) struct CommonCfg {
390 pub device_feature_select: ReadPureWrite<u32>,
391 pub device_feature: ReadPure<u32>,
392 pub driver_feature_select: ReadPureWrite<u32>,
393 pub driver_feature: ReadPureWrite<u32>,
394 pub msix_config: ReadPureWrite<u16>,
395 pub num_queues: ReadPure<u16>,
396 pub device_status: ReadPureWrite<u8>,
397 pub config_generation: ReadPure<u8>,
398 pub queue_select: ReadPureWrite<u16>,
399 pub queue_size: ReadPureWrite<u16>,
400 pub queue_msix_vector: ReadPureWrite<u16>,
401 pub queue_enable: ReadPureWrite<u16>,
402 pub queue_notify_off: ReadPureWrite<u16>,
403 pub queue_desc: ReadPureWrite<u64>,
404 pub queue_driver: ReadPureWrite<u64>,
405 pub queue_device: ReadPureWrite<u64>,
406}
407
408#[derive(Clone, Debug, Eq, PartialEq)]
410pub(crate) struct VirtioCapabilityInfo {
411 pub bar: u8,
413 pub offset: u32,
415 pub length: u32,
417}
418
419fn get_bar_region<H: Hal, T, C: ConfigurationAccess>(
420 root: &mut PciRoot<C>,
421 device_function: DeviceFunction,
422 struct_info: &VirtioCapabilityInfo,
423) -> Result<NonNull<T>, VirtioPciError> {
424 let bar_info = root
425 .bar_info(device_function, struct_info.bar)?
426 .ok_or(VirtioPciError::BarNotAllocated(struct_info.bar))?;
427 let (bar_address, bar_size) = bar_info
428 .memory_address_size()
429 .ok_or(VirtioPciError::UnexpectedIoBar)?;
430 if bar_address == 0 {
431 return Err(VirtioPciError::BarNotAllocated(struct_info.bar));
432 }
433 if u64::from(struct_info.offset + struct_info.length) > bar_size
434 || size_of::<T>() > struct_info.length as usize
435 {
436 return Err(VirtioPciError::BarOffsetOutOfRange);
437 }
438 let paddr = bar_address as PhysAddr + struct_info.offset as PhysAddr;
439 let vaddr = unsafe { H::mmio_phys_to_virt(paddr, struct_info.length as usize) };
441 if !(vaddr.as_ptr() as usize).is_multiple_of(align_of::<T>()) {
442 return Err(VirtioPciError::Misaligned {
443 address: vaddr.as_ptr() as usize,
444 alignment: align_of::<T>(),
445 });
446 }
447 Ok(vaddr.cast())
448}
449
450fn get_bar_region_slice<H: Hal, T, C: ConfigurationAccess>(
451 root: &mut PciRoot<C>,
452 device_function: DeviceFunction,
453 struct_info: &VirtioCapabilityInfo,
454) -> Result<NonNull<[T]>, VirtioPciError> {
455 let ptr = get_bar_region::<H, T, C>(root, device_function, struct_info)?;
456 Ok(NonNull::slice_from_raw_parts(
457 ptr,
458 struct_info.length as usize / size_of::<T>(),
459 ))
460}
461
462#[derive(Clone, Debug, Eq, Error, PartialEq)]
464pub enum VirtioPciError {
465 #[error("PCI device ID {0:#06x} was not a valid VirtIO device ID.")]
467 InvalidDeviceId(u16),
468 #[error("PCI device vender ID {0:#06x} was not the VirtIO vendor ID {VIRTIO_VENDOR_ID:#06x}.")]
470 InvalidVendorId(u16),
471 #[error("No valid `VIRTIO_PCI_CAP_COMMON_CFG` capability was found.")]
473 MissingCommonConfig,
474 #[error("No valid `VIRTIO_PCI_CAP_NOTIFY_CFG` capability was found.")]
476 MissingNotifyConfig,
477 #[error(
480 "`VIRTIO_PCI_CAP_NOTIFY_CFG` capability has a `notify_off_multiplier` that is not a multiple of 2: {0}"
481 )]
482 InvalidNotifyOffMultiplier(u32),
483 #[error("No valid `VIRTIO_PCI_CAP_ISR_CFG` capability was found.")]
485 MissingIsrConfig,
486 #[error("Unexpected IO BAR (expected memory BAR).")]
488 UnexpectedIoBar,
489 #[error("Bar {0} not allocated.")]
491 BarNotAllocated(u8),
492 #[error("Capability offset greater than BAR length.")]
494 BarOffsetOutOfRange,
495 #[error("Address {address:#018} was not aligned to a {alignment} byte boundary as expected.")]
497 Misaligned {
498 address: usize,
500 alignment: usize,
502 },
503 #[error(transparent)]
505 Pci(PciError),
506}
507
508impl From<PciError> for VirtioPciError {
509 fn from(error: PciError) -> Self {
510 Self::Pci(error)
511 }
512}
513
514#[cfg(test)]
515mod tests {
516 use super::*;
517
518 #[test]
519 fn transitional_device_ids() {
520 assert_eq!(device_type(0x1000), Some(DeviceType::Network));
521 assert_eq!(device_type(0x1002), Some(DeviceType::MemoryBalloon));
522 assert_eq!(device_type(0x1009), Some(DeviceType::_9P));
523 }
524
525 #[test]
526 fn offset_device_ids() {
527 assert_eq!(device_type(0x1040), None);
528 assert_eq!(device_type(0x1045), Some(DeviceType::MemoryBalloon));
529 assert_eq!(device_type(0x1049), Some(DeviceType::_9P));
530 assert_eq!(device_type(0x1058), Some(DeviceType::Memory));
531 assert_eq!(device_type(0x1059), Some(DeviceType::Sound));
532 assert_eq!(device_type(0x1060), None);
533 }
534
535 #[test]
536 fn virtio_device_type_valid() {
537 assert_eq!(
538 virtio_device_type(&DeviceFunctionInfo {
539 vendor_id: VIRTIO_VENDOR_ID,
540 device_id: TRANSITIONAL_BLOCK,
541 class: 0,
542 subclass: 0,
543 prog_if: 0,
544 revision: 0,
545 header_type: bus::HeaderType::Standard,
546 }),
547 Some(DeviceType::Block)
548 );
549 }
550
551 #[test]
552 fn virtio_device_type_invalid() {
553 assert_eq!(
555 virtio_device_type(&DeviceFunctionInfo {
556 vendor_id: 0x1234,
557 device_id: TRANSITIONAL_BLOCK,
558 class: 0,
559 subclass: 0,
560 prog_if: 0,
561 revision: 0,
562 header_type: bus::HeaderType::Standard,
563 }),
564 None
565 );
566
567 assert_eq!(
569 virtio_device_type(&DeviceFunctionInfo {
570 vendor_id: VIRTIO_VENDOR_ID,
571 device_id: 0x1040,
572 class: 0,
573 subclass: 0,
574 prog_if: 0,
575 revision: 0,
576 header_type: bus::HeaderType::Standard,
577 }),
578 None
579 );
580 }
581}