1pub mod bus;
4
5use self::bus::{
6 ConfigurationAccess, DeviceFunction, DeviceFunctionInfo, PciError, PciRoot, PCI_CAP_ID_VNDR,
7};
8use super::{DeviceStatus, DeviceType, Transport};
9use crate::{
10 hal::{Hal, PhysAddr},
11 transport::InterruptStatus,
12 Error,
13};
14use core::{
15 mem::{align_of, size_of},
16 ops::Deref,
17 ptr::NonNull,
18};
19use safe_mmio::{
20 field, field_shared,
21 fields::{ReadOnly, ReadPure, ReadPureWrite, WriteOnly},
22 UniqueMmioPointer,
23};
24use zerocopy::{FromBytes, Immutable, IntoBytes};
25
26pub const VIRTIO_VENDOR_ID: u16 = 0x1af4;
28
29const PCI_DEVICE_ID_OFFSET: u16 = 0x1040;
31
32const TRANSITIONAL_NETWORK: u16 = 0x1000;
33const TRANSITIONAL_BLOCK: u16 = 0x1001;
34const TRANSITIONAL_MEMORY_BALLOONING: u16 = 0x1002;
35const TRANSITIONAL_CONSOLE: u16 = 0x1003;
36const TRANSITIONAL_SCSI_HOST: u16 = 0x1004;
37const TRANSITIONAL_ENTROPY_SOURCE: u16 = 0x1005;
38const TRANSITIONAL_9P_TRANSPORT: u16 = 0x1009;
39
40pub(crate) const CAP_BAR_OFFSET: u8 = 4;
42pub(crate) const CAP_BAR_OFFSET_OFFSET: u8 = 8;
44pub(crate) const CAP_LENGTH_OFFSET: u8 = 12;
46pub(crate) const CAP_NOTIFY_OFF_MULTIPLIER_OFFSET: u8 = 16;
48
49pub const VIRTIO_PCI_CAP_COMMON_CFG: u8 = 1;
51pub const VIRTIO_PCI_CAP_NOTIFY_CFG: u8 = 2;
53pub const VIRTIO_PCI_CAP_ISR_CFG: u8 = 3;
55pub const VIRTIO_PCI_CAP_DEVICE_CFG: u8 = 4;
57
58pub(crate) fn device_type(pci_device_id: u16) -> Option<DeviceType> {
59 match pci_device_id {
60 TRANSITIONAL_NETWORK => Some(DeviceType::Network),
61 TRANSITIONAL_BLOCK => Some(DeviceType::Block),
62 TRANSITIONAL_MEMORY_BALLOONING => Some(DeviceType::MemoryBalloon),
63 TRANSITIONAL_CONSOLE => Some(DeviceType::Console),
64 TRANSITIONAL_SCSI_HOST => Some(DeviceType::ScsiHost),
65 TRANSITIONAL_ENTROPY_SOURCE => Some(DeviceType::EntropySource),
66 TRANSITIONAL_9P_TRANSPORT => Some(DeviceType::_9P),
67 id if id >= PCI_DEVICE_ID_OFFSET => DeviceType::try_from(id - PCI_DEVICE_ID_OFFSET).ok(),
68 _ => None,
69 }
70}
71
72pub fn virtio_device_type(device_function_info: &DeviceFunctionInfo) -> Option<DeviceType> {
75 if device_function_info.vendor_id == VIRTIO_VENDOR_ID {
76 device_type(device_function_info.device_id)
77 } else {
78 None
79 }
80}
81
82#[derive(Debug)]
86pub struct PciTransport {
87 device_type: DeviceType,
88 device_function: DeviceFunction,
90 common_cfg: UniqueMmioPointer<'static, CommonCfg>,
92 notify_region: UniqueMmioPointer<'static, [WriteOnly<u16>]>,
94 notify_off_multiplier: u32,
95 isr_status: UniqueMmioPointer<'static, ReadOnly<u8>>,
97 config_space: Option<UniqueMmioPointer<'static, [u32]>>,
99}
100
101impl PciTransport {
102 pub fn new<H: Hal, C: ConfigurationAccess>(
107 root: &mut PciRoot<C>,
108 device_function: DeviceFunction,
109 ) -> Result<Self, VirtioPciError> {
110 let device_vendor = root.configuration_access.read_word(device_function, 0);
111 let device_id = (device_vendor >> 16) as u16;
112 let vendor_id = device_vendor as u16;
113 if vendor_id != VIRTIO_VENDOR_ID {
114 return Err(VirtioPciError::InvalidVendorId(vendor_id));
115 }
116 let device_type =
117 device_type(device_id).ok_or(VirtioPciError::InvalidDeviceId(device_id))?;
118
119 let mut common_cfg = None;
121 let mut notify_cfg = None;
122 let mut notify_off_multiplier = 0;
123 let mut isr_cfg = None;
124 let mut device_cfg = None;
125 for capability in root.capabilities(device_function) {
126 if capability.id != PCI_CAP_ID_VNDR {
127 continue;
128 }
129 let cap_len = capability.private_header as u8;
130 let cfg_type = (capability.private_header >> 8) as u8;
131 if cap_len < 16 {
132 continue;
133 }
134 let struct_info = VirtioCapabilityInfo {
135 bar: root
136 .configuration_access
137 .read_word(device_function, capability.offset + CAP_BAR_OFFSET)
138 as u8,
139 offset: root
140 .configuration_access
141 .read_word(device_function, capability.offset + CAP_BAR_OFFSET_OFFSET),
142 length: root
143 .configuration_access
144 .read_word(device_function, capability.offset + CAP_LENGTH_OFFSET),
145 };
146
147 match cfg_type {
148 VIRTIO_PCI_CAP_COMMON_CFG if common_cfg.is_none() => {
149 common_cfg = Some(struct_info);
150 }
151 VIRTIO_PCI_CAP_NOTIFY_CFG if cap_len >= 20 && notify_cfg.is_none() => {
152 notify_cfg = Some(struct_info);
153 notify_off_multiplier = root.configuration_access.read_word(
154 device_function,
155 capability.offset + CAP_NOTIFY_OFF_MULTIPLIER_OFFSET,
156 );
157 }
158 VIRTIO_PCI_CAP_ISR_CFG if isr_cfg.is_none() => {
159 isr_cfg = Some(struct_info);
160 }
161 VIRTIO_PCI_CAP_DEVICE_CFG if device_cfg.is_none() => {
162 device_cfg = Some(struct_info);
163 }
164 _ => {}
165 }
166 }
167
168 let common_cfg = get_bar_region::<H, _, _>(
169 root,
170 device_function,
171 &common_cfg.ok_or(VirtioPciError::MissingCommonConfig)?,
172 )?;
173 let common_cfg = unsafe { UniqueMmioPointer::new(common_cfg) };
176
177 let notify_cfg = notify_cfg.ok_or(VirtioPciError::MissingNotifyConfig)?;
178 if notify_off_multiplier % 2 != 0 {
179 return Err(VirtioPciError::InvalidNotifyOffMultiplier(
180 notify_off_multiplier,
181 ));
182 }
183 let notify_region = get_bar_region_slice::<H, _, _>(root, device_function, ¬ify_cfg)?;
184 let notify_region = unsafe { UniqueMmioPointer::new(notify_region) };
187
188 let isr_status = get_bar_region::<H, _, _>(
189 root,
190 device_function,
191 &isr_cfg.ok_or(VirtioPciError::MissingIsrConfig)?,
192 )?;
193 let isr_status = unsafe { UniqueMmioPointer::new(isr_status) };
196
197 let config_space = if let Some(device_cfg) = device_cfg {
198 Some(unsafe {
201 UniqueMmioPointer::new(get_bar_region_slice::<H, _, _>(
202 root,
203 device_function,
204 &device_cfg,
205 )?)
206 })
207 } else {
208 None
209 };
210
211 Ok(Self {
212 device_type,
213 device_function,
214 common_cfg,
215 notify_region,
216 notify_off_multiplier,
217 isr_status,
218 config_space,
219 })
220 }
221}
222
223impl Transport for PciTransport {
224 fn device_type(&self) -> DeviceType {
225 self.device_type
226 }
227
228 fn read_device_features(&mut self) -> u64 {
229 field!(self.common_cfg, device_feature_select).write(0);
230 let mut device_features_bits = field_shared!(self.common_cfg, device_feature).read() as u64;
231 field!(self.common_cfg, device_feature_select).write(1);
232 device_features_bits |=
233 (field_shared!(self.common_cfg, device_feature).read() as u64) << 32;
234 device_features_bits
235 }
236
237 fn write_driver_features(&mut self, driver_features: u64) {
238 field!(self.common_cfg, driver_feature_select).write(0);
239 field!(self.common_cfg, driver_feature).write(driver_features as u32);
240 field!(self.common_cfg, driver_feature_select).write(1);
241 field!(self.common_cfg, driver_feature).write((driver_features >> 32) as u32);
242 }
243
244 fn max_queue_size(&mut self, queue: u16) -> u32 {
245 field!(self.common_cfg, queue_select).write(queue);
246 field_shared!(self.common_cfg, queue_size).read().into()
247 }
248
249 fn notify(&mut self, queue: u16) {
250 field!(self.common_cfg, queue_select).write(queue);
251 let queue_notify_off = field_shared!(self.common_cfg, queue_notify_off).read();
253
254 let offset_bytes = usize::from(queue_notify_off) * self.notify_off_multiplier as usize;
255 let index = offset_bytes / size_of::<u16>();
256 self.notify_region.get(index).unwrap().write(queue);
257 }
258
259 fn get_status(&self) -> DeviceStatus {
260 let status = field_shared!(self.common_cfg, device_status).read();
261 DeviceStatus::from_bits_truncate(status.into())
262 }
263
264 fn set_status(&mut self, status: DeviceStatus) {
265 field!(self.common_cfg, device_status).write(status.bits() as u8);
266 }
267
268 fn set_guest_page_size(&mut self, _guest_page_size: u32) {
269 }
271
272 fn requires_legacy_layout(&self) -> bool {
273 false
274 }
275
276 fn queue_set(
277 &mut self,
278 queue: u16,
279 size: u32,
280 descriptors: PhysAddr,
281 driver_area: PhysAddr,
282 device_area: PhysAddr,
283 ) {
284 field!(self.common_cfg, queue_select).write(queue);
285 field!(self.common_cfg, queue_size).write(size as u16);
286 field!(self.common_cfg, queue_desc).write(descriptors);
287 field!(self.common_cfg, queue_driver).write(driver_area);
288 field!(self.common_cfg, queue_device).write(device_area);
289 field!(self.common_cfg, queue_enable).write(1);
290 }
291
292 fn queue_unset(&mut self, _queue: u16) {
293 }
296
297 fn queue_used(&mut self, queue: u16) -> bool {
298 field!(self.common_cfg, queue_select).write(queue);
299 field_shared!(self.common_cfg, queue_enable).read() == 1
300 }
301
302 fn ack_interrupt(&mut self) -> InterruptStatus {
303 let isr_status = self.isr_status.read();
305 InterruptStatus::from_bits_retain(isr_status.into())
306 }
307
308 fn read_config_generation(&self) -> u32 {
309 field_shared!(self.common_cfg, config_generation)
310 .read()
311 .into()
312 }
313
314 fn read_config_space<T: FromBytes + IntoBytes>(&self, offset: usize) -> Result<T, Error> {
315 assert!(align_of::<T>() <= 4,
316 "Driver expected config space alignment of {} bytes, but VirtIO only guarantees 4 byte alignment.",
317 align_of::<T>());
318 assert_eq!(offset % align_of::<T>(), 0);
319
320 let config_space = self
321 .config_space
322 .as_ref()
323 .ok_or(Error::ConfigSpaceMissing)?;
324 if config_space.len() * size_of::<u32>() < offset + size_of::<T>() {
325 Err(Error::ConfigSpaceTooSmall)
326 } else {
327 unsafe {
331 let ptr = config_space.ptr().cast::<T>().byte_add(offset);
332 Ok(config_space
333 .deref()
334 .child(NonNull::new(ptr.cast_mut()).unwrap())
335 .read_unsafe())
336 }
337 }
338 }
339
340 fn write_config_space<T: IntoBytes + Immutable>(
341 &mut self,
342 offset: usize,
343 value: T,
344 ) -> Result<(), Error> {
345 assert!(align_of::<T>() <= 4,
346 "Driver expected config space alignment of {} bytes, but VirtIO only guarantees 4 byte alignment.",
347 align_of::<T>());
348 assert_eq!(offset % align_of::<T>(), 0);
349
350 let config_space = self
351 .config_space
352 .as_mut()
353 .ok_or(Error::ConfigSpaceMissing)?;
354 if config_space.len() * size_of::<u32>() < offset + size_of::<T>() {
355 Err(Error::ConfigSpaceTooSmall)
356 } else {
357 unsafe {
361 let ptr = config_space.ptr_nonnull().cast::<T>().byte_add(offset);
362 config_space.child(ptr).write_unsafe(value);
363 }
364 Ok(())
365 }
366 }
367}
368
369unsafe impl Send for PciTransport {}
371
372unsafe impl Sync for PciTransport {}
375
376impl Drop for PciTransport {
377 fn drop(&mut self) {
378 self.set_status(DeviceStatus::empty());
380 while self.get_status() != DeviceStatus::empty() {}
381 }
382}
383
384#[repr(C)]
386pub(crate) struct CommonCfg {
387 pub device_feature_select: ReadPureWrite<u32>,
388 pub device_feature: ReadPure<u32>,
389 pub driver_feature_select: ReadPureWrite<u32>,
390 pub driver_feature: ReadPureWrite<u32>,
391 pub msix_config: ReadPureWrite<u16>,
392 pub num_queues: ReadPure<u16>,
393 pub device_status: ReadPureWrite<u8>,
394 pub config_generation: ReadPure<u8>,
395 pub queue_select: ReadPureWrite<u16>,
396 pub queue_size: ReadPureWrite<u16>,
397 pub queue_msix_vector: ReadPureWrite<u16>,
398 pub queue_enable: ReadPureWrite<u16>,
399 pub queue_notify_off: ReadPureWrite<u16>,
400 pub queue_desc: ReadPureWrite<u64>,
401 pub queue_driver: ReadPureWrite<u64>,
402 pub queue_device: ReadPureWrite<u64>,
403}
404
405#[derive(Clone, Debug, Eq, PartialEq)]
407pub(crate) struct VirtioCapabilityInfo {
408 pub bar: u8,
410 pub offset: u32,
412 pub length: u32,
414}
415
416fn get_bar_region<H: Hal, T, C: ConfigurationAccess>(
417 root: &mut PciRoot<C>,
418 device_function: DeviceFunction,
419 struct_info: &VirtioCapabilityInfo,
420) -> Result<NonNull<T>, VirtioPciError> {
421 let bar_info = root
422 .bar_info(device_function, struct_info.bar)?
423 .ok_or(VirtioPciError::BarNotAllocated(struct_info.bar))?;
424 let (bar_address, bar_size) = bar_info
425 .memory_address_size()
426 .ok_or(VirtioPciError::UnexpectedIoBar)?;
427 if bar_address == 0 {
428 return Err(VirtioPciError::BarNotAllocated(struct_info.bar));
429 }
430 if u64::from(struct_info.offset + struct_info.length) > bar_size
431 || size_of::<T>() > struct_info.length as usize
432 {
433 return Err(VirtioPciError::BarOffsetOutOfRange);
434 }
435 let paddr = bar_address as PhysAddr + struct_info.offset as PhysAddr;
436 let vaddr = unsafe { H::mmio_phys_to_virt(paddr, struct_info.length as usize) };
438 if vaddr.as_ptr() as usize % align_of::<T>() != 0 {
439 return Err(VirtioPciError::Misaligned {
440 address: vaddr.as_ptr() as usize,
441 alignment: align_of::<T>(),
442 });
443 }
444 Ok(vaddr.cast())
445}
446
447fn get_bar_region_slice<H: Hal, T, C: ConfigurationAccess>(
448 root: &mut PciRoot<C>,
449 device_function: DeviceFunction,
450 struct_info: &VirtioCapabilityInfo,
451) -> Result<NonNull<[T]>, VirtioPciError> {
452 let ptr = get_bar_region::<H, T, C>(root, device_function, struct_info)?;
453 Ok(NonNull::slice_from_raw_parts(
454 ptr,
455 struct_info.length as usize / size_of::<T>(),
456 ))
457}
458
459#[derive(Clone, Debug, Eq, Error, PartialEq)]
461pub enum VirtioPciError {
462 #[error("PCI device ID {0:#06x} was not a valid VirtIO device ID.")]
464 InvalidDeviceId(u16),
465 #[error("PCI device vender ID {0:#06x} was not the VirtIO vendor ID {VIRTIO_VENDOR_ID:#06x}.")]
467 InvalidVendorId(u16),
468 #[error("No valid `VIRTIO_PCI_CAP_COMMON_CFG` capability was found.")]
470 MissingCommonConfig,
471 #[error("No valid `VIRTIO_PCI_CAP_NOTIFY_CFG` capability was found.")]
473 MissingNotifyConfig,
474 #[error("`VIRTIO_PCI_CAP_NOTIFY_CFG` capability has a `notify_off_multiplier` that is not a multiple of 2: {0}")]
477 InvalidNotifyOffMultiplier(u32),
478 #[error("No valid `VIRTIO_PCI_CAP_ISR_CFG` capability was found.")]
480 MissingIsrConfig,
481 #[error("Unexpected IO BAR (expected memory BAR).")]
483 UnexpectedIoBar,
484 #[error("Bar {0} not allocated.")]
486 BarNotAllocated(u8),
487 #[error("Capability offset greater than BAR length.")]
489 BarOffsetOutOfRange,
490 #[error("Address {address:#018} was not aligned to a {alignment} byte boundary as expected.")]
492 Misaligned {
493 address: usize,
495 alignment: usize,
497 },
498 #[error(transparent)]
500 Pci(PciError),
501}
502
503impl From<PciError> for VirtioPciError {
504 fn from(error: PciError) -> Self {
505 Self::Pci(error)
506 }
507}
508
509#[cfg(test)]
510mod tests {
511 use super::*;
512
513 #[test]
514 fn transitional_device_ids() {
515 assert_eq!(device_type(0x1000), Some(DeviceType::Network));
516 assert_eq!(device_type(0x1002), Some(DeviceType::MemoryBalloon));
517 assert_eq!(device_type(0x1009), Some(DeviceType::_9P));
518 }
519
520 #[test]
521 fn offset_device_ids() {
522 assert_eq!(device_type(0x1040), None);
523 assert_eq!(device_type(0x1045), Some(DeviceType::MemoryBalloon));
524 assert_eq!(device_type(0x1049), Some(DeviceType::_9P));
525 assert_eq!(device_type(0x1058), Some(DeviceType::Memory));
526 assert_eq!(device_type(0x1059), Some(DeviceType::Sound));
527 assert_eq!(device_type(0x1060), None);
528 }
529
530 #[test]
531 fn virtio_device_type_valid() {
532 assert_eq!(
533 virtio_device_type(&DeviceFunctionInfo {
534 vendor_id: VIRTIO_VENDOR_ID,
535 device_id: TRANSITIONAL_BLOCK,
536 class: 0,
537 subclass: 0,
538 prog_if: 0,
539 revision: 0,
540 header_type: bus::HeaderType::Standard,
541 }),
542 Some(DeviceType::Block)
543 );
544 }
545
546 #[test]
547 fn virtio_device_type_invalid() {
548 assert_eq!(
550 virtio_device_type(&DeviceFunctionInfo {
551 vendor_id: 0x1234,
552 device_id: TRANSITIONAL_BLOCK,
553 class: 0,
554 subclass: 0,
555 prog_if: 0,
556 revision: 0,
557 header_type: bus::HeaderType::Standard,
558 }),
559 None
560 );
561
562 assert_eq!(
564 virtio_device_type(&DeviceFunctionInfo {
565 vendor_id: VIRTIO_VENDOR_ID,
566 device_id: 0x1040,
567 class: 0,
568 subclass: 0,
569 prog_if: 0,
570 revision: 0,
571 header_type: bus::HeaderType::Standard,
572 }),
573 None
574 );
575 }
576}