aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndrew Walbran <qwandor@google.com>2022-11-23 14:58:01 +0000
committerAutomerger Merge Worker <android-build-automerger-merge-worker@system.gserviceaccount.com>2022-11-23 14:58:01 +0000
commitf0c868a70f9a8d4127d35533e089e7dbf5268d4b (patch)
treeea067972eb67a29016a0c953646a18a3c5772cc4
parent840b89128d981c51bc517ca666ce64ece2862cbe (diff)
parent8e6deac7ed2c056b014d425c4df3b0a88d7204db (diff)
downloadvirtio-drivers-f0c868a70f9a8d4127d35533e089e7dbf5268d4b.tar.gz
Pull in patches to be submitted upstream. am: f1cc9a6993 am: 8e6deac7ed
Original change: https://android-review.googlesource.com/c/platform/external/rust/crates/virtio-drivers/+/2310906 Change-Id: I158c5d0e5d63ba23fda834e08de5afe52eef9b53 Signed-off-by: Automerger Merge Worker <android-build-automerger-merge-worker@system.gserviceaccount.com>
-rw-r--r--.github/workflows/main.yml60
-rw-r--r--.gitignore2
-rw-r--r--Android.bp4
-rw-r--r--Cargo.toml4
-rw-r--r--Cargo.toml.orig4
-rw-r--r--README.md11
-rw-r--r--cargo2android.json1
-rw-r--r--src/blk.rs49
-rw-r--r--src/console.rs55
-rw-r--r--src/gpu.rs45
-rw-r--r--src/input.rs33
-rw-r--r--src/lib.rs8
-rw-r--r--src/net.rs49
-rw-r--r--src/queue.rs67
-rw-r--r--src/transport/fake.rs36
-rw-r--r--src/transport/mmio.rs63
-rw-r--r--src/transport/mod.rs70
-rw-r--r--src/transport/pci.rs551
-rw-r--r--src/transport/pci/bus.rs599
19 files changed, 1552 insertions, 159 deletions
diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml
index 0c79e84..b2ab149 100644
--- a/.github/workflows/main.yml
+++ b/.github/workflows/main.yml
@@ -1,6 +1,9 @@
name: CI
-on: [push, pull_request]
+on:
+ push:
+ branches: [master]
+ pull_request:
jobs:
check:
@@ -18,9 +21,9 @@ jobs:
command: fmt
args: --all -- --check
- name: Clippy
- uses: actions-rs/cargo@v1
+ uses: actions-rs/clippy-check@v1
with:
- command: clippy
+ token: ${{ secrets.GITHUB_TOKEN }}
build:
runs-on: ubuntu-latest
@@ -30,7 +33,12 @@ jobs:
with:
profile: minimal
toolchain: stable
- - name: Build
+ - name: Build with no features
+ uses: actions-rs/cargo@v1
+ with:
+ command: build
+ args: --no-default-features
+ - name: Build with all features
uses: actions-rs/cargo@v1
with:
command: build
@@ -39,8 +47,50 @@ jobs:
uses: actions-rs/cargo@v1
with:
command: doc
- - name: Test
+ - name: Test with no features
+ uses: actions-rs/cargo@v1
+ with:
+ command: test
+ args: --no-default-features
+ - name: Test with all features
uses: actions-rs/cargo@v1
with:
command: test
args: --all-features
+
+ examples:
+ runs-on: ubuntu-22.04
+ strategy:
+ fail-fast: false
+ matrix:
+ example:
+ - aarch64
+ - riscv
+ include:
+ - example: aarch64
+ toolchain: stable
+ target: aarch64-unknown-none
+ packages: qemu-system-arm gcc-aarch64-linux-gnu
+ - example: riscv
+ toolchain: nightly-2022-11-03
+ target: riscv64imac-unknown-none-elf
+ packages: qemu-system-misc
+ steps:
+ - uses: actions/checkout@v2
+ - name: Install QEMU
+ run: sudo apt update && sudo apt install ${{ matrix.packages }}
+ - uses: actions-rs/toolchain@v1
+ with:
+ profile: minimal
+ toolchain: ${{ matrix.toolchain }}
+ target: ${{ matrix.target }}
+ components: llvm-tools-preview, rustfmt
+ - name: Check code format
+ working-directory: examples/${{ matrix.example }}
+ run: cargo fmt --all -- --check
+ - name: Build
+ working-directory: examples/${{ matrix.example }}
+ run: make kernel
+ - name: Run
+ working-directory: examples/${{ matrix.example }}
+ run: QEMU_ARGS="-display none" make qemu
diff --git a/.gitignore b/.gitignore
index 97e6d83..9b16963 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,3 +1,3 @@
-/target
+target/
Cargo.lock
.vscode/
diff --git a/Android.bp b/Android.bp
index cc50bbe..ff6d45a 100644
--- a/Android.bp
+++ b/Android.bp
@@ -1,8 +1,6 @@
// This file is generated by cargo2android.py --config cargo2android.json.
// Do not modify this file as changes will be overridden on upgrade.
-
-
package {
default_applicable_licenses: [
"external_rust_crates_virtio-drivers_license",
@@ -24,6 +22,7 @@ license {
rust_library {
name: "libvirtio_drivers",
+ // has rustc warnings
host_supported: true,
crate_name: "virtio_drivers",
cargo_env_compat: true,
@@ -42,6 +41,7 @@ rust_library {
rust_test {
name: "virtio-drivers_test_src_lib",
+ // has rustc warnings
host_supported: true,
crate_name: "virtio_drivers",
cargo_env_compat: true,
diff --git a/Cargo.toml b/Cargo.toml
index 9ab634e..2767964 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -34,3 +34,7 @@ version = "1.3"
[dependencies.log]
version = "0.4"
+
+[features]
+default = ["alloc"]
+alloc = []
diff --git a/Cargo.toml.orig b/Cargo.toml.orig
index 9aafd94..cc52f7c 100644
--- a/Cargo.toml.orig
+++ b/Cargo.toml.orig
@@ -17,3 +17,7 @@ categories = ["hardware-support", "no-std"]
[dependencies]
log = "0.4"
bitflags = "1.3"
+
+[features]
+default = ["alloc"]
+alloc = []
diff --git a/README.md b/README.md
index e6089f8..e3016bb 100644
--- a/README.md
+++ b/README.md
@@ -19,11 +19,11 @@ VirtIO guest drivers in Rust. For **no_std** environment.
### Transports
-| Transport | Supported | |
-| ----------- | --------- | --------- |
-| Legacy MMIO | ✅ | version 1 |
-| MMIO | ✅ | version 2 |
-| PCI | ❌ | |
+| Transport | Supported | |
+| ----------- | --------- | ------------------------------------------------- |
+| Legacy MMIO | ✅ | version 1 |
+| MMIO | ✅ | version 2 |
+| PCI | ✅ | Memory-mapped CAM only, e.g. aarch64 or PCIe ECAM |
### Device-independent features
@@ -43,4 +43,5 @@ VirtIO guest drivers in Rust. For **no_std** environment.
- x86_64 (TODO)
+- [aarch64](./examples/aarch64)
- [RISCV](./examples/riscv)
diff --git a/cargo2android.json b/cargo2android.json
index cf6ca9c..b893c29 100644
--- a/cargo2android.json
+++ b/cargo2android.json
@@ -1,6 +1,7 @@
{
"dependencies": true,
"device": true,
+ "features": "",
"run": true,
"tests": true
} \ No newline at end of file
diff --git a/src/blk.rs b/src/blk.rs
index 8c0512a..9022e57 100644
--- a/src/blk.rs
+++ b/src/blk.rs
@@ -3,9 +3,10 @@ use crate::queue::VirtQueue;
use crate::transport::Transport;
use crate::volatile::{volread, Volatile};
use bitflags::*;
-use core::hint::spin_loop;
use log::*;
+const QUEUE: u16 = 0;
+
/// The virtio block device is a simple virtual block device (ie. disk).
///
/// Read and write requests (and other exotic requests) are placed in the queue,
@@ -28,13 +29,15 @@ impl<H: Hal, T: Transport> VirtIOBlk<H, T> {
});
// read configuration space
- let config = transport.config_space().cast::<BlkConfig>();
+ let config = transport.config_space::<BlkConfig>()?;
info!("config: {:?}", config);
// Safe because config is a valid pointer to the device configuration space.
- let capacity = unsafe { volread!(config, capacity) };
+ let capacity = unsafe {
+ volread!(config, capacity_low) as u64 | (volread!(config, capacity_high) as u64) << 32
+ };
info!("found a block device of size {}KB", capacity / 2);
- let queue = VirtQueue::new(&mut transport, 0, 16)?;
+ let queue = VirtQueue::new(&mut transport, QUEUE, 16)?;
transport.finish_init();
Ok(VirtIOBlk {
@@ -58,12 +61,11 @@ impl<H: Hal, T: Transport> VirtIOBlk<H, T> {
sector: block_id as u64,
};
let mut resp = BlkResp::default();
- self.queue.add(&[req.as_buf()], &[buf, resp.as_buf_mut()])?;
- self.transport.notify(0);
- while !self.queue.can_pop() {
- spin_loop();
- }
- self.queue.pop_used()?;
+ self.queue.add_notify_wait_pop(
+ &[req.as_buf()],
+ &[buf, resp.as_buf_mut()],
+ &mut self.transport,
+ )?;
match resp.status {
RespStatus::Ok => Ok(()),
_ => Err(Error::IoError),
@@ -111,7 +113,7 @@ impl<H: Hal, T: Transport> VirtIOBlk<H, T> {
sector: block_id as u64,
};
let token = self.queue.add(&[req.as_buf()], &[buf, resp.as_buf_mut()])?;
- self.transport.notify(0);
+ self.transport.notify(QUEUE);
Ok(token)
}
@@ -124,12 +126,11 @@ impl<H: Hal, T: Transport> VirtIOBlk<H, T> {
sector: block_id as u64,
};
let mut resp = BlkResp::default();
- self.queue.add(&[req.as_buf(), buf], &[resp.as_buf_mut()])?;
- self.transport.notify(0);
- while !self.queue.can_pop() {
- spin_loop();
- }
- self.queue.pop_used()?;
+ self.queue.add_notify_wait_pop(
+ &[req.as_buf(), buf],
+ &[resp.as_buf_mut()],
+ &mut self.transport,
+ )?;
match resp.status {
RespStatus::Ok => Ok(()),
_ => Err(Error::IoError),
@@ -166,7 +167,7 @@ impl<H: Hal, T: Transport> VirtIOBlk<H, T> {
sector: block_id as u64,
};
let token = self.queue.add(&[req.as_buf(), buf], &[resp.as_buf_mut()])?;
- self.transport.notify(0);
+ self.transport.notify(QUEUE);
Ok(token)
}
@@ -184,11 +185,19 @@ impl<H: Hal, T: Transport> VirtIOBlk<H, T> {
}
}
+impl<H: Hal, T: Transport> Drop for VirtIOBlk<H, T> {
+ fn drop(&mut self) {
+ // Clear any pointers pointing to DMA regions, so the device doesn't try to access them
+ // after they have been freed.
+ self.transport.queue_unset(QUEUE);
+ }
+}
+
#[repr(C)]
-#[derive(Debug)]
struct BlkConfig {
/// Number of 512 Bytes sectors
- capacity: Volatile<u64>,
+ capacity_low: Volatile<u32>,
+ capacity_high: Volatile<u32>,
size_max: Volatile<u32>,
seg_max: Volatile<u32>,
cylinders: Volatile<u16>,
diff --git a/src/console.rs b/src/console.rs
index 3ba4280..50743ff 100644
--- a/src/console.rs
+++ b/src/console.rs
@@ -1,13 +1,12 @@
use super::*;
use crate::queue::VirtQueue;
use crate::transport::Transport;
-use crate::volatile::{ReadOnly, WriteOnly};
+use crate::volatile::{volread, ReadOnly, WriteOnly};
use bitflags::*;
-use core::{fmt, hint::spin_loop};
use log::*;
-const QUEUE_RECEIVEQ_PORT_0: usize = 0;
-const QUEUE_TRANSMITQ_PORT_0: usize = 1;
+const QUEUE_RECEIVEQ_PORT_0: u16 = 0;
+const QUEUE_TRANSMITQ_PORT_0: u16 = 1;
const QUEUE_SIZE: u16 = 2;
/// Virtio console. Only one single port is allowed since ``alloc'' is disabled.
@@ -31,9 +30,16 @@ impl<H: Hal, T: Transport> VirtIOConsole<'_, H, T> {
let supported_features = Features::empty();
(features & supported_features).bits()
});
- let config_space = transport.config_space().cast::<Config>();
- let config = unsafe { config_space.as_ref() };
- info!("Config: {:?}", config);
+ let config_space = transport.config_space::<Config>()?;
+ unsafe {
+ let columns = volread!(config_space, cols);
+ let rows = volread!(config_space, rows);
+ let max_ports = volread!(config_space, max_nr_ports);
+ info!(
+ "Columns: {} Rows: {} Max ports: {}",
+ columns, rows, max_ports,
+ );
+ }
let receiveq = VirtQueue::new(&mut transport, QUEUE_RECEIVEQ_PORT_0, QUEUE_SIZE)?;
let transmitq = VirtQueue::new(&mut transport, QUEUE_TRANSMITQ_PORT_0, QUEUE_SIZE)?;
let queue_buf_dma = DMA::new(1)?;
@@ -53,7 +59,8 @@ impl<H: Hal, T: Transport> VirtIOConsole<'_, H, T> {
}
fn poll_retrieve(&mut self) -> Result<()> {
- self.receiveq.add(&[], &[self.queue_buf_rx])?;
+ // Safe because the buffer lasts at least as long as the queue.
+ unsafe { self.receiveq.add(&[], &[self.queue_buf_rx])? };
Ok(())
}
@@ -92,16 +99,22 @@ impl<H: Hal, T: Transport> VirtIOConsole<'_, H, T> {
/// Put a char onto the device.
pub fn send(&mut self, chr: u8) -> Result<()> {
let buf: [u8; 1] = [chr];
- self.transmitq.add(&[&buf], &[])?;
- self.transport.notify(QUEUE_TRANSMITQ_PORT_0 as u32);
- while !self.transmitq.can_pop() {
- spin_loop();
- }
- self.transmitq.pop_used()?;
+ // Safe because the buffer is valid until we pop_used below.
+ self.transmitq
+ .add_notify_wait_pop(&[&buf], &[], &mut self.transport)?;
Ok(())
}
}
+impl<H: Hal, T: Transport> Drop for VirtIOConsole<'_, H, T> {
+ fn drop(&mut self) {
+ // Clear any pointers pointing to DMA regions, so the device doesn't try to access them
+ // after they have been freed.
+ self.transport.queue_unset(QUEUE_RECEIVEQ_PORT_0);
+ self.transport.queue_unset(QUEUE_TRANSMITQ_PORT_0);
+ }
+}
+
#[repr(C)]
struct Config {
cols: ReadOnly<u16>,
@@ -110,16 +123,6 @@ struct Config {
emerg_wr: WriteOnly<u32>,
}
-impl fmt::Debug for Config {
- fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
- f.debug_struct("Config")
- .field("cols", &self.cols)
- .field("rows", &self.rows)
- .field("max_nr_ports", &self.max_nr_ports)
- .finish()
- }
-}
-
bitflags! {
struct Features: u64 {
const SIZE = 1 << 0;
@@ -174,10 +177,10 @@ mod tests {
device_type: DeviceType::Console,
max_queue_size: 2,
device_features: 0,
- config_space: NonNull::from(&mut config_space).cast(),
+ config_space: NonNull::from(&mut config_space),
state: state.clone(),
};
- let mut console = VirtIOConsole::<FakeHal, FakeTransport>::new(transport).unwrap();
+ let mut console = VirtIOConsole::<FakeHal, FakeTransport<Config>>::new(transport).unwrap();
// Nothing is available to receive.
assert_eq!(console.recv(false).unwrap(), None);
diff --git a/src/gpu.rs b/src/gpu.rs
index 26fb4d6..6c17f3a 100644
--- a/src/gpu.rs
+++ b/src/gpu.rs
@@ -1,7 +1,7 @@
use super::*;
use crate::queue::VirtQueue;
use crate::transport::Transport;
-use crate::volatile::{ReadOnly, Volatile, WriteOnly};
+use crate::volatile::{volread, ReadOnly, Volatile, WriteOnly};
use bitflags::*;
use core::{fmt, hint::spin_loop};
use log::*;
@@ -43,9 +43,15 @@ impl<H: Hal, T: Transport> VirtIOGpu<'_, H, T> {
});
// read configuration space
- let config_space = transport.config_space().cast::<Config>();
- let config = unsafe { config_space.as_ref() };
- info!("Config: {:?}", config);
+ let config_space = transport.config_space::<Config>()?;
+ unsafe {
+ let events_read = volread!(config_space, events_read);
+ let num_scanouts = volread!(config_space, num_scanouts);
+ info!(
+ "events_read: {:#x}, num_scanouts: {:#x}",
+ events_read, num_scanouts
+ );
+ }
let control_queue = VirtQueue::new(&mut transport, QUEUE_TRANSMIT, 2)?;
let cursor_queue = VirtQueue::new(&mut transport, QUEUE_CURSOR, 2)?;
@@ -163,13 +169,16 @@ impl<H: Hal, T: Transport> VirtIOGpu<'_, H, T> {
unsafe {
(self.queue_buf_send.as_mut_ptr() as *mut Req).write(req);
}
- self.control_queue
- .add(&[self.queue_buf_send], &[self.queue_buf_recv])?;
- self.transport.notify(QUEUE_TRANSMIT as u32);
+ let token = unsafe {
+ self.control_queue
+ .add(&[self.queue_buf_send], &[self.queue_buf_recv])?
+ };
+ self.transport.notify(QUEUE_TRANSMIT);
while !self.control_queue.can_pop() {
spin_loop();
}
- self.control_queue.pop_used()?;
+ let (popped_token, _) = self.control_queue.pop_used()?;
+ assert_eq!(popped_token, token);
Ok(unsafe { (self.queue_buf_recv.as_ptr() as *const Rsp).read() })
}
@@ -178,12 +187,13 @@ impl<H: Hal, T: Transport> VirtIOGpu<'_, H, T> {
unsafe {
(self.queue_buf_send.as_mut_ptr() as *mut Req).write(req);
}
- self.cursor_queue.add(&[self.queue_buf_send], &[])?;
- self.transport.notify(QUEUE_CURSOR as u32);
+ let token = unsafe { self.cursor_queue.add(&[self.queue_buf_send], &[])? };
+ self.transport.notify(QUEUE_CURSOR);
while !self.cursor_queue.can_pop() {
spin_loop();
}
- self.cursor_queue.pop_used()?;
+ let (popped_token, _) = self.cursor_queue.pop_used()?;
+ assert_eq!(popped_token, token);
Ok(())
}
@@ -277,6 +287,15 @@ impl<H: Hal, T: Transport> VirtIOGpu<'_, H, T> {
}
}
+impl<H: Hal, T: Transport> Drop for VirtIOGpu<'_, H, T> {
+ fn drop(&mut self) {
+ // Clear any pointers pointing to DMA regions, so the device doesn't try to access them
+ // after they have been freed.
+ self.transport.queue_unset(QUEUE_TRANSMIT);
+ self.transport.queue_unset(QUEUE_CURSOR);
+ }
+}
+
#[repr(C)]
struct Config {
/// Signals pending events to the driver。
@@ -483,8 +502,8 @@ struct UpdateCursor {
_padding: u32,
}
-const QUEUE_TRANSMIT: usize = 0;
-const QUEUE_CURSOR: usize = 1;
+const QUEUE_TRANSMIT: u16 = 0;
+const QUEUE_CURSOR: u16 = 1;
const SCANOUT_ID: u32 = 0;
const RESOURCE_ID_FB: u32 = 0xbabe;
diff --git a/src/input.rs b/src/input.rs
index aa3de36..70cef0f 100644
--- a/src/input.rs
+++ b/src/input.rs
@@ -3,6 +3,7 @@ use crate::transport::Transport;
use crate::volatile::{volread, volwrite, ReadOnly, WriteOnly};
use alloc::boxed::Box;
use bitflags::*;
+use core::ptr::NonNull;
use log::*;
/// Virtual human interface devices such as keyboards, mice and tablets.
@@ -15,6 +16,7 @@ pub struct VirtIOInput<H: Hal, T: Transport> {
event_queue: VirtQueue<H>,
status_queue: VirtQueue<H>,
event_buf: Box<[InputEvent; 32]>,
+ config: NonNull<Config>,
}
impl<H: Hal, T: Transport> VirtIOInput<H, T> {
@@ -29,10 +31,13 @@ impl<H: Hal, T: Transport> VirtIOInput<H, T> {
(features & supported_features).bits()
});
+ let config = transport.config_space::<Config>()?;
+
let mut event_queue = VirtQueue::new(&mut transport, QUEUE_EVENT, QUEUE_SIZE as u16)?;
let status_queue = VirtQueue::new(&mut transport, QUEUE_STATUS, QUEUE_SIZE as u16)?;
for (i, event) in event_buf.as_mut().iter_mut().enumerate() {
- let token = event_queue.add(&[], &[event.as_buf_mut()])?;
+ // Safe because the buffer lasts as long as the queue.
+ let token = unsafe { event_queue.add(&[], &[event.as_buf_mut()])? };
assert_eq!(token, i as u16);
}
@@ -43,6 +48,7 @@ impl<H: Hal, T: Transport> VirtIOInput<H, T> {
event_queue,
status_queue,
event_buf,
+ config,
})
}
@@ -56,7 +62,8 @@ impl<H: Hal, T: Transport> VirtIOInput<H, T> {
if let Ok((token, _)) = self.event_queue.pop_used() {
let event = &mut self.event_buf[token as usize];
// requeue
- if let Ok(new_token) = self.event_queue.add(&[], &[event.as_buf_mut()]) {
+ // Safe because buffer lasts as long as the queue.
+ if let Ok(new_token) = unsafe { self.event_queue.add(&[], &[event.as_buf_mut()]) } {
// This only works because nothing happen between `pop_used` and `add` that affects
// the list of free descriptors in the queue, so `add` reuses the descriptor which
// was just freed by `pop_used`.
@@ -75,21 +82,29 @@ impl<H: Hal, T: Transport> VirtIOInput<H, T> {
subsel: u8,
out: &mut [u8],
) -> u8 {
- let config = self.transport.config_space().cast::<Config>();
let size;
let data;
// Safe because config points to a valid MMIO region for the config space.
unsafe {
- volwrite!(config, select, select as u8);
- volwrite!(config, subsel, subsel);
- size = volread!(config, size);
- data = volread!(config, data);
+ volwrite!(self.config, select, select as u8);
+ volwrite!(self.config, subsel, subsel);
+ size = volread!(self.config, size);
+ data = volread!(self.config, data);
}
out[..size as usize].copy_from_slice(&data[..size as usize]);
size
}
}
+impl<H: Hal, T: Transport> Drop for VirtIOInput<H, T> {
+ fn drop(&mut self) {
+ // Clear any pointers pointing to DMA regions, so the device doesn't try to access them
+ // after they have been freed.
+ self.transport.queue_unset(QUEUE_EVENT);
+ self.transport.queue_unset(QUEUE_STATUS);
+ }
+}
+
/// Select value used for [`VirtIOInput::query_config_select()`].
#[repr(u8)]
#[derive(Debug, Clone, Copy)]
@@ -178,8 +193,8 @@ bitflags! {
}
}
-const QUEUE_EVENT: usize = 0;
-const QUEUE_STATUS: usize = 1;
+const QUEUE_EVENT: u16 = 0;
+const QUEUE_STATUS: u16 = 1;
// a parameter that can change
const QUEUE_SIZE: usize = 32;
diff --git a/src/lib.rs b/src/lib.rs
index 291af7b..0209daa 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -5,12 +5,14 @@
#![allow(clippy::identity_op)]
#![allow(dead_code)]
+#[cfg(any(feature = "alloc", test))]
extern crate alloc;
mod blk;
mod console;
mod gpu;
mod hal;
+#[cfg(feature = "alloc")]
mod input;
mod net;
mod queue;
@@ -21,10 +23,12 @@ pub use self::blk::{BlkResp, RespStatus, VirtIOBlk};
pub use self::console::VirtIOConsole;
pub use self::gpu::VirtIOGpu;
pub use self::hal::{Hal, PhysAddr, VirtAddr};
+#[cfg(feature = "alloc")]
pub use self::input::{InputConfigSelect, InputEvent, VirtIOInput};
pub use self::net::VirtIONet;
use self::queue::VirtQueue;
pub use self::transport::mmio::{MmioError, MmioTransport, MmioVersion, VirtIOHeader};
+pub use self::transport::pci;
pub use self::transport::{DeviceStatus, DeviceType, Transport};
use core::mem::size_of;
use hal::*;
@@ -50,6 +54,10 @@ pub enum Error {
DmaError,
/// I/O Error
IoError,
+ /// The config space advertised by the device is smaller than the driver expected.
+ ConfigSpaceTooSmall,
+ /// The device doesn't have any config space, but the driver expects some.
+ ConfigSpaceMissing,
}
/// Align `size` up to a page.
diff --git a/src/net.rs b/src/net.rs
index 3641ac8..82abc60 100644
--- a/src/net.rs
+++ b/src/net.rs
@@ -2,9 +2,8 @@ use core::mem::{size_of, MaybeUninit};
use super::*;
use crate::transport::Transport;
-use crate::volatile::{volread, ReadOnly, Volatile};
+use crate::volatile::{volread, ReadOnly};
use bitflags::*;
-use core::hint::spin_loop;
use log::*;
/// The virtio network device is a virtual ethernet card.
@@ -31,7 +30,7 @@ impl<H: Hal, T: Transport> VirtIONet<H, T> {
(features & supported_features).bits()
});
// read configuration space
- let config = transport.config_space().cast::<Config>();
+ let config = transport.config_space::<Config>()?;
let mac;
// Safe because config points to a valid MMIO region for the config space.
unsafe {
@@ -77,13 +76,9 @@ impl<H: Hal, T: Transport> VirtIONet<H, T> {
pub fn recv(&mut self, buf: &mut [u8]) -> Result<usize> {
let mut header = MaybeUninit::<Header>::uninit();
let header_buf = unsafe { (*header.as_mut_ptr()).as_buf_mut() };
- self.recv_queue.add(&[], &[header_buf, buf])?;
- self.transport.notify(QUEUE_RECEIVE as u32);
- while !self.recv_queue.can_pop() {
- spin_loop();
- }
-
- let (_, len) = self.recv_queue.pop_used()?;
+ let len =
+ self.recv_queue
+ .add_notify_wait_pop(&[], &[header_buf, buf], &mut self.transport)?;
// let header = unsafe { header.assume_init() };
Ok(len as usize - size_of::<Header>())
}
@@ -91,16 +86,21 @@ impl<H: Hal, T: Transport> VirtIONet<H, T> {
/// Send a packet.
pub fn send(&mut self, buf: &[u8]) -> Result {
let header = unsafe { MaybeUninit::<Header>::zeroed().assume_init() };
- self.send_queue.add(&[header.as_buf(), buf], &[])?;
- self.transport.notify(QUEUE_TRANSMIT as u32);
- while !self.send_queue.can_pop() {
- spin_loop();
- }
- self.send_queue.pop_used()?;
+ self.send_queue
+ .add_notify_wait_pop(&[header.as_buf(), buf], &[], &mut self.transport)?;
Ok(())
}
}
+impl<H: Hal, T: Transport> Drop for VirtIONet<H, T> {
+ fn drop(&mut self) {
+ // Clear any pointers pointing to DMA regions, so the device doesn't try to access them
+ // after they have been freed.
+ self.transport.queue_unset(QUEUE_RECEIVE);
+ self.transport.queue_unset(QUEUE_TRANSMIT);
+ }
+}
+
bitflags! {
struct Features: u64 {
/// Device handles packets with partial checksum.
@@ -177,7 +177,6 @@ bitflags! {
}
#[repr(C)]
-#[derive(Debug)]
struct Config {
mac: ReadOnly<EthernetAddress>,
status: ReadOnly<Status>,
@@ -189,12 +188,12 @@ type EthernetAddress = [u8; 6];
#[repr(C)]
#[derive(Debug)]
struct Header {
- flags: Volatile<Flags>,
- gso_type: Volatile<GsoType>,
- hdr_len: Volatile<u16>, // cannot rely on this
- gso_size: Volatile<u16>,
- csum_start: Volatile<u16>,
- csum_offset: Volatile<u16>,
+ flags: Flags,
+ gso_type: GsoType,
+ hdr_len: u16, // cannot rely on this
+ gso_size: u16,
+ csum_start: u16,
+ csum_offset: u16,
// payload starts from here
}
@@ -218,5 +217,5 @@ enum GsoType {
ECN = 0x80,
}
-const QUEUE_RECEIVE: usize = 0;
-const QUEUE_TRANSMIT: usize = 1;
+const QUEUE_RECEIVE: u16 = 0;
+const QUEUE_TRANSMIT: u16 = 1;
diff --git a/src/queue.rs b/src/queue.rs
index 264c87e..4dc7c01 100644
--- a/src/queue.rs
+++ b/src/queue.rs
@@ -1,5 +1,6 @@
#[cfg(test)]
use core::cmp::min;
+use core::hint::spin_loop;
use core::mem::size_of;
use core::ptr::{self, addr_of_mut, NonNull};
use core::sync::atomic::{fence, Ordering};
@@ -23,7 +24,7 @@ pub struct VirtQueue<H: Hal> {
used: NonNull<UsedRing>,
/// The index of queue
- queue_idx: u32,
+ queue_idx: u16,
/// The size of the queue.
///
/// This is both the number of descriptors, and the number of slots in the available and used
@@ -39,8 +40,8 @@ pub struct VirtQueue<H: Hal> {
impl<H: Hal> VirtQueue<H> {
/// Create a new VirtQueue.
- pub fn new<T: Transport>(transport: &mut T, idx: usize, size: u16) -> Result<Self> {
- if transport.queue_used(idx as u32) {
+ pub fn new<T: Transport>(transport: &mut T, idx: u16, size: u16) -> Result<Self> {
+ if transport.queue_used(idx) {
return Err(Error::AlreadyUsed);
}
if !size.is_power_of_two() || transport.max_queue_size() < size as u32 {
@@ -51,7 +52,7 @@ impl<H: Hal> VirtQueue<H> {
let dma = DMA::new(layout.size / PAGE_SIZE)?;
transport.queue_set(
- idx as u32,
+ idx,
size as u32,
dma.paddr(),
dma.paddr() + layout.avail_offset,
@@ -81,7 +82,7 @@ impl<H: Hal> VirtQueue<H> {
avail,
used,
queue_size: size,
- queue_idx: idx as u32,
+ queue_idx: idx,
num_used: 0,
free_head: 0,
avail_idx: 0,
@@ -92,7 +93,11 @@ impl<H: Hal> VirtQueue<H> {
/// Add buffers to the virtqueue, return a token.
///
/// Ref: linux virtio_ring.c virtqueue_add
- pub fn add(&mut self, inputs: &[&[u8]], outputs: &[&mut [u8]]) -> Result<u16> {
+ ///
+ /// # Safety
+ ///
+ /// The input and output buffers must remain valid until the token is returned by `pop_used`.
+ pub unsafe fn add(&mut self, inputs: &[*const [u8]], outputs: &[*mut [u8]]) -> Result<u16> {
if inputs.is_empty() && outputs.is_empty() {
return Err(Error::InvalidParam);
}
@@ -109,14 +114,14 @@ impl<H: Hal> VirtQueue<H> {
unsafe {
for input in inputs.iter() {
let mut desc = self.desc_ptr(self.free_head);
- (*desc).set_buf::<H>(input);
+ (*desc).set_buf::<H>(NonNull::new(*input as *mut [u8]).unwrap());
(*desc).flags = DescFlags::NEXT;
last = self.free_head;
self.free_head = (*desc).next;
}
for output in outputs.iter() {
let desc = self.desc_ptr(self.free_head);
- (*desc).set_buf::<H>(output);
+ (*desc).set_buf::<H>(NonNull::new(*output).unwrap());
(*desc).flags = DescFlags::NEXT | DescFlags::WRITE;
last = self.free_head;
self.free_head = (*desc).next;
@@ -150,6 +155,32 @@ impl<H: Hal> VirtQueue<H> {
Ok(head)
}
+ /// Add the given buffers to the virtqueue, notifies the device, blocks until the device uses
+ /// them, then pops them.
+ ///
+ /// This assumes that the device isn't processing any other buffers at the same time.
+ pub fn add_notify_wait_pop(
+ &mut self,
+ inputs: &[*const [u8]],
+ outputs: &[*mut [u8]],
+ transport: &mut impl Transport,
+ ) -> Result<u32> {
+ // Safe because we don't return until the same token has been popped, so they remain valid
+ // until then.
+ let token = unsafe { self.add(inputs, outputs) }?;
+
+ // Notify the queue.
+ transport.notify(self.queue_idx);
+
+ while !self.can_pop() {
+ spin_loop();
+ }
+ let (popped_token, length) = self.pop_used()?;
+ assert_eq!(popped_token, token);
+
+ Ok(length)
+ }
+
/// Returns a non-null pointer to the descriptor at the given index.
fn desc_ptr(&mut self, index: u16) -> *mut Descriptor {
// Safe because self.desc is properly aligned and dereferenceable.
@@ -263,8 +294,11 @@ pub(crate) struct Descriptor {
}
impl Descriptor {
- fn set_buf<H: Hal>(&mut self, buf: &[u8]) {
- self.addr = H::virt_to_phys(buf.as_ptr() as usize) as u64;
+ /// # Safety
+ ///
+ /// The caller must ensure that the buffer lives at least as long as the descriptor is active.
+ unsafe fn set_buf<H: Hal>(&mut self, buf: NonNull<[u8]>) {
+ self.addr = H::virt_to_phys(buf.as_ptr() as *mut u8 as usize) as u64;
self.len = buf.len() as u32;
}
}
@@ -408,7 +442,10 @@ mod tests {
let mut header = VirtIOHeader::make_fake_header(MODERN_VERSION, 1, 0, 0, 4);
let mut transport = unsafe { MmioTransport::new(NonNull::from(&mut header)) }.unwrap();
let mut queue = VirtQueue::<FakeHal>::new(&mut transport, 0, 4).unwrap();
- assert_eq!(queue.add(&[], &[]).unwrap_err(), Error::InvalidParam);
+ assert_eq!(
+ unsafe { queue.add(&[], &[]) }.unwrap_err(),
+ Error::InvalidParam
+ );
}
#[test]
@@ -418,9 +455,7 @@ mod tests {
let mut queue = VirtQueue::<FakeHal>::new(&mut transport, 0, 4).unwrap();
assert_eq!(queue.available_desc(), 4);
assert_eq!(
- queue
- .add(&[&[], &[], &[]], &[&mut [], &mut []])
- .unwrap_err(),
+ unsafe { queue.add(&[&[], &[], &[]], &[&mut [], &mut []]) }.unwrap_err(),
Error::BufferTooSmall
);
}
@@ -435,9 +470,7 @@ mod tests {
// Add a buffer chain consisting of two device-readable parts followed by two
// device-writable parts.
- let token = queue
- .add(&[&[1, 2], &[3]], &[&mut [0, 0], &mut [0]])
- .unwrap();
+ let token = unsafe { queue.add(&[&[1, 2], &[3]], &[&mut [0, 0], &mut [0]]) }.unwrap();
assert_eq!(queue.available_desc(), 0);
assert!(!queue.can_pop());
diff --git a/src/transport/fake.rs b/src/transport/fake.rs
index 9095292..40105ec 100644
--- a/src/transport/fake.rs
+++ b/src/transport/fake.rs
@@ -1,23 +1,23 @@
use super::{DeviceStatus, Transport};
use crate::{
queue::{fake_write_to_queue, Descriptor},
- DeviceType, PhysAddr,
+ DeviceType, PhysAddr, Result,
};
use alloc::{sync::Arc, vec::Vec};
-use core::ptr::NonNull;
+use core::{any::TypeId, ptr::NonNull};
use std::sync::Mutex;
/// A fake implementation of [`Transport`] for unit tests.
#[derive(Debug)]
-pub struct FakeTransport {
+pub struct FakeTransport<C: 'static> {
pub device_type: DeviceType,
pub max_queue_size: u32,
pub device_features: u64,
- pub config_space: NonNull<u64>,
+ pub config_space: NonNull<C>,
pub state: Arc<Mutex<State>>,
}
-impl Transport for FakeTransport {
+impl<C> Transport for FakeTransport<C> {
fn device_type(&self) -> DeviceType {
self.device_type
}
@@ -34,7 +34,7 @@ impl Transport for FakeTransport {
self.max_queue_size
}
- fn notify(&mut self, queue: u32) {
+ fn notify(&mut self, queue: u16) {
self.state.lock().unwrap().queues[queue as usize].notified = true;
}
@@ -48,7 +48,7 @@ impl Transport for FakeTransport {
fn queue_set(
&mut self,
- queue: u32,
+ queue: u16,
size: u32,
descriptors: PhysAddr,
driver_area: PhysAddr,
@@ -61,7 +61,15 @@ impl Transport for FakeTransport {
state.queues[queue as usize].device_area = device_area;
}
- fn queue_used(&mut self, queue: u32) -> bool {
+ fn queue_unset(&mut self, queue: u16) {
+ let mut state = self.state.lock().unwrap();
+ state.queues[queue as usize].size = 0;
+ state.queues[queue as usize].descriptors = 0;
+ state.queues[queue as usize].driver_area = 0;
+ state.queues[queue as usize].device_area = 0;
+ }
+
+ fn queue_used(&mut self, queue: u16) -> bool {
self.state.lock().unwrap().queues[queue as usize].descriptors != 0
}
@@ -74,8 +82,12 @@ impl Transport for FakeTransport {
pending
}
- fn config_space(&self) -> NonNull<u64> {
- self.config_space
+ fn config_space<T: 'static>(&self) -> Result<NonNull<T>> {
+ if TypeId::of::<T>() == TypeId::of::<C>() {
+ Ok(self.config_space.cast())
+ } else {
+ panic!("Unexpected config space type.");
+ }
}
}
@@ -92,8 +104,8 @@ impl State {
/// Simulates the device writing to the given queue.
///
/// The fake device always uses descriptors in order.
- pub fn write_to_queue(&mut self, queue_size: u16, queue_index: usize, data: &[u8]) {
- let receive_queue = &self.queues[queue_index];
+ pub fn write_to_queue(&mut self, queue_size: u16, queue_index: u16, data: &[u8]) {
+ let receive_queue = &self.queues[queue_index as usize];
assert_ne!(receive_queue.descriptors, 0);
fake_write_to_queue(
queue_size,
diff --git a/src/transport/mmio.rs b/src/transport/mmio.rs
index 9b6a49c..8321ec6 100644
--- a/src/transport/mmio.rs
+++ b/src/transport/mmio.rs
@@ -3,12 +3,12 @@ use crate::{
align_up,
queue::Descriptor,
volatile::{volread, volwrite, ReadOnly, Volatile, WriteOnly},
- PhysAddr, PAGE_SIZE,
+ Error, PhysAddr, PAGE_SIZE,
};
use core::{
convert::{TryFrom, TryInto},
fmt::{self, Display, Formatter},
- mem::size_of,
+ mem::{align_of, size_of},
ptr::NonNull,
};
@@ -311,10 +311,8 @@ impl MmioTransport {
impl Transport for MmioTransport {
fn device_type(&self) -> DeviceType {
// Safe because self.header points to a valid VirtIO MMIO region.
- match unsafe { volread!(self.header, device_id) } {
- x @ 1..=13 | x @ 16..=24 => unsafe { core::mem::transmute(x as u8) },
- _ => DeviceType::Invalid,
- }
+ let device_id = unsafe { volread!(self.header, device_id) };
+ device_id.into()
}
fn read_device_features(&mut self) -> u64 {
@@ -343,10 +341,10 @@ impl Transport for MmioTransport {
unsafe { volread!(self.header, queue_num_max) }
}
- fn notify(&mut self, queue: u32) {
+ fn notify(&mut self, queue: u16) {
// Safe because self.header points to a valid VirtIO MMIO region.
unsafe {
- volwrite!(self.header, queue_notify, queue);
+ volwrite!(self.header, queue_notify, queue.into());
}
}
@@ -373,7 +371,7 @@ impl Transport for MmioTransport {
fn queue_set(
&mut self,
- queue: u32,
+ queue: u16,
size: u32,
descriptors: PhysAddr,
driver_area: PhysAddr,
@@ -397,7 +395,7 @@ impl Transport for MmioTransport {
assert_eq!(pfn as usize * PAGE_SIZE, descriptors);
// Safe because self.header points to a valid VirtIO MMIO region.
unsafe {
- volwrite!(self.header, queue_sel, queue);
+ volwrite!(self.header, queue_sel, queue.into());
volwrite!(self.header, queue_num, size);
volwrite!(self.header, legacy_queue_align, align);
volwrite!(self.header, legacy_queue_pfn, pfn);
@@ -406,7 +404,7 @@ impl Transport for MmioTransport {
MmioVersion::Modern => {
// Safe because self.header points to a valid VirtIO MMIO region.
unsafe {
- volwrite!(self.header, queue_sel, queue);
+ volwrite!(self.header, queue_sel, queue.into());
volwrite!(self.header, queue_num, size);
volwrite!(self.header, queue_desc_low, descriptors as u32);
volwrite!(self.header, queue_desc_high, (descriptors >> 32) as u32);
@@ -420,10 +418,38 @@ impl Transport for MmioTransport {
}
}
- fn queue_used(&mut self, queue: u32) -> bool {
+ fn queue_unset(&mut self, queue: u16) {
+ match self.version {
+ MmioVersion::Legacy => {
+ // Safe because self.header points to a valid VirtIO MMIO region.
+ unsafe {
+ volwrite!(self.header, queue_sel, queue.into());
+ volwrite!(self.header, queue_num, 0);
+ volwrite!(self.header, legacy_queue_align, 0);
+ volwrite!(self.header, legacy_queue_pfn, 0);
+ }
+ }
+ MmioVersion::Modern => {
+ // Safe because self.header points to a valid VirtIO MMIO region.
+ unsafe {
+ volwrite!(self.header, queue_sel, queue.into());
+ volwrite!(self.header, queue_ready, 0);
+ volwrite!(self.header, queue_num, 0);
+ volwrite!(self.header, queue_desc_low, 0);
+ volwrite!(self.header, queue_desc_high, 0);
+ volwrite!(self.header, queue_driver_low, 9);
+ volwrite!(self.header, queue_driver_high, 0);
+ volwrite!(self.header, queue_device_low, 0);
+ volwrite!(self.header, queue_device_high, 0);
+ }
+ }
+ }
+ }
+
+ fn queue_used(&mut self, queue: u16) -> bool {
// Safe because self.header points to a valid VirtIO MMIO region.
unsafe {
- volwrite!(self.header, queue_sel, queue);
+ volwrite!(self.header, queue_sel, queue.into());
match self.version {
MmioVersion::Legacy => volread!(self.header, legacy_queue_pfn) != 0,
MmioVersion::Modern => volread!(self.header, queue_ready) != 0,
@@ -444,7 +470,14 @@ impl Transport for MmioTransport {
}
}
- fn config_space(&self) -> NonNull<u64> {
- NonNull::new((self.header.as_ptr() as usize + CONFIG_SPACE_OFFSET) as _).unwrap()
+ fn config_space<T>(&self) -> Result<NonNull<T>, Error> {
+ if align_of::<T>() > 4 {
+ // Panic as this should only happen if the driver is written incorrectly.
+ panic!(
+ "Driver expected config space alignment of {} bytes, but VirtIO only guarantees 4 byte alignment.",
+ align_of::<T>()
+ );
+ }
+ Ok(NonNull::new((self.header.as_ptr() as usize + CONFIG_SPACE_OFFSET) as _).unwrap())
}
}
diff --git a/src/transport/mod.rs b/src/transport/mod.rs
index 64a564a..1156c09 100644
--- a/src/transport/mod.rs
+++ b/src/transport/mod.rs
@@ -1,8 +1,9 @@
#[cfg(test)]
pub mod fake;
pub mod mmio;
+pub mod pci;
-use crate::{PhysAddr, PAGE_SIZE};
+use crate::{PhysAddr, Result, PAGE_SIZE};
use bitflags::bitflags;
use core::ptr::NonNull;
@@ -21,7 +22,7 @@ pub trait Transport {
fn max_queue_size(&self) -> u32;
/// Notifies the given queue on the device.
- fn notify(&mut self, queue: u32);
+ fn notify(&mut self, queue: u16);
/// Sets the device status.
fn set_status(&mut self, status: DeviceStatus);
@@ -32,15 +33,18 @@ pub trait Transport {
/// Sets up the given queue.
fn queue_set(
&mut self,
- queue: u32,
+ queue: u16,
size: u32,
descriptors: PhysAddr,
driver_area: PhysAddr,
device_area: PhysAddr,
);
+ /// Disables and resets the given queue.
+ fn queue_unset(&mut self, queue: u16);
+
/// Returns whether the queue is in use, i.e. has a nonzero PFN or is marked as ready.
- fn queue_used(&mut self, queue: u32) -> bool;
+ fn queue_used(&mut self, queue: u16) -> bool;
/// Acknowledges an interrupt.
///
@@ -51,23 +55,29 @@ pub trait Transport {
///
/// Ref: virtio 3.1.1 Device Initialization
fn begin_init(&mut self, negotiate_features: impl FnOnce(u64) -> u64) {
- self.set_status(DeviceStatus::ACKNOWLEDGE);
- self.set_status(DeviceStatus::DRIVER);
+ self.set_status(DeviceStatus::ACKNOWLEDGE | DeviceStatus::DRIVER);
let features = self.read_device_features();
self.write_driver_features(negotiate_features(features));
- self.set_status(DeviceStatus::FEATURES_OK);
+ self.set_status(
+ DeviceStatus::ACKNOWLEDGE | DeviceStatus::DRIVER | DeviceStatus::FEATURES_OK,
+ );
self.set_guest_page_size(PAGE_SIZE as u32);
}
/// Finishes initializing the device.
fn finish_init(&mut self) {
- self.set_status(DeviceStatus::DRIVER_OK);
+ self.set_status(
+ DeviceStatus::ACKNOWLEDGE
+ | DeviceStatus::DRIVER
+ | DeviceStatus::FEATURES_OK
+ | DeviceStatus::DRIVER_OK,
+ );
}
/// Gets the pointer to the config space.
- fn config_space(&self) -> NonNull<u64>;
+ fn config_space<T: 'static>(&self) -> Result<NonNull<T>>;
}
bitflags! {
@@ -129,3 +139,45 @@ pub enum DeviceType {
IOMMU = 23,
Memory = 24,
}
+
+impl From<u32> for DeviceType {
+ fn from(virtio_device_id: u32) -> Self {
+ match virtio_device_id {
+ 1 => DeviceType::Network,
+ 2 => DeviceType::Block,
+ 3 => DeviceType::Console,
+ 4 => DeviceType::EntropySource,
+ 5 => DeviceType::MemoryBalloon,
+ 6 => DeviceType::IoMemory,
+ 7 => DeviceType::Rpmsg,
+ 8 => DeviceType::ScsiHost,
+ 9 => DeviceType::_9P,
+ 10 => DeviceType::Mac80211,
+ 11 => DeviceType::RprocSerial,
+ 12 => DeviceType::VirtioCAIF,
+ 13 => DeviceType::MemoryBalloon,
+ 16 => DeviceType::GPU,
+ 17 => DeviceType::Timer,
+ 18 => DeviceType::Input,
+ 19 => DeviceType::Socket,
+ 20 => DeviceType::Crypto,
+ 21 => DeviceType::SignalDistributionModule,
+ 22 => DeviceType::Pstore,
+ 23 => DeviceType::IOMMU,
+ 24 => DeviceType::Memory,
+ _ => DeviceType::Invalid,
+ }
+ }
+}
+
+impl From<u16> for DeviceType {
+ fn from(virtio_device_id: u16) -> Self {
+ u32::from(virtio_device_id).into()
+ }
+}
+
+impl From<u8> for DeviceType {
+ fn from(virtio_device_id: u8) -> Self {
+ u32::from(virtio_device_id).into()
+ }
+}
diff --git a/src/transport/pci.rs b/src/transport/pci.rs
new file mode 100644
index 0000000..58584cb
--- /dev/null
+++ b/src/transport/pci.rs
@@ -0,0 +1,551 @@
+//! PCI transport for VirtIO.
+
+pub mod bus;
+
+use self::bus::{DeviceFunction, DeviceFunctionInfo, PciError, PciRoot, PCI_CAP_ID_VNDR};
+use super::{DeviceStatus, DeviceType, Transport};
+use crate::{
+ hal::{Hal, PhysAddr, VirtAddr},
+ volatile::{
+ volread, volwrite, ReadOnly, Volatile, VolatileReadable, VolatileWritable, WriteOnly,
+ },
+ Error,
+};
+use core::{
+ fmt::{self, Display, Formatter},
+ mem::{align_of, size_of},
+ ptr::{self, addr_of_mut, NonNull},
+};
+
+/// The PCI vendor ID for VirtIO devices.
+const VIRTIO_VENDOR_ID: u16 = 0x1af4;
+
+/// The offset to add to a VirtIO device ID to get the corresponding PCI device ID.
+const PCI_DEVICE_ID_OFFSET: u16 = 0x1040;
+
+const TRANSITIONAL_NETWORK: u16 = 0x1000;
+const TRANSITIONAL_BLOCK: u16 = 0x1001;
+const TRANSITIONAL_MEMORY_BALLOONING: u16 = 0x1002;
+const TRANSITIONAL_CONSOLE: u16 = 0x1003;
+const TRANSITIONAL_SCSI_HOST: u16 = 0x1004;
+const TRANSITIONAL_ENTROPY_SOURCE: u16 = 0x1005;
+const TRANSITIONAL_9P_TRANSPORT: u16 = 0x1009;
+
+/// The offset of the bar field within `virtio_pci_cap`.
+const CAP_BAR_OFFSET: u8 = 4;
+/// The offset of the offset field with `virtio_pci_cap`.
+const CAP_BAR_OFFSET_OFFSET: u8 = 8;
+/// The offset of the `length` field within `virtio_pci_cap`.
+const CAP_LENGTH_OFFSET: u8 = 12;
+/// The offset of the`notify_off_multiplier` field within `virtio_pci_notify_cap`.
+const CAP_NOTIFY_OFF_MULTIPLIER_OFFSET: u8 = 16;
+
+/// Common configuration.
+const VIRTIO_PCI_CAP_COMMON_CFG: u8 = 1;
+/// Notifications.
+const VIRTIO_PCI_CAP_NOTIFY_CFG: u8 = 2;
+/// ISR Status.
+const VIRTIO_PCI_CAP_ISR_CFG: u8 = 3;
+/// Device specific configuration.
+const VIRTIO_PCI_CAP_DEVICE_CFG: u8 = 4;
+
+fn device_type(pci_device_id: u16) -> DeviceType {
+ match pci_device_id {
+ TRANSITIONAL_NETWORK => DeviceType::Network,
+ TRANSITIONAL_BLOCK => DeviceType::Block,
+ TRANSITIONAL_MEMORY_BALLOONING => DeviceType::MemoryBalloon,
+ TRANSITIONAL_CONSOLE => DeviceType::Console,
+ TRANSITIONAL_SCSI_HOST => DeviceType::ScsiHost,
+ TRANSITIONAL_ENTROPY_SOURCE => DeviceType::EntropySource,
+ TRANSITIONAL_9P_TRANSPORT => DeviceType::_9P,
+ id if id >= PCI_DEVICE_ID_OFFSET => DeviceType::from(id - PCI_DEVICE_ID_OFFSET),
+ _ => DeviceType::Invalid,
+ }
+}
+
+/// Returns the type of VirtIO device to which the given PCI vendor and device ID corresponds, or
+/// `None` if it is not a recognised VirtIO device.
+pub fn virtio_device_type(device_function_info: &DeviceFunctionInfo) -> Option<DeviceType> {
+ if device_function_info.vendor_id == VIRTIO_VENDOR_ID {
+ let device_type = device_type(device_function_info.device_id);
+ if device_type != DeviceType::Invalid {
+ return Some(device_type);
+ }
+ }
+ None
+}
+
+/// PCI transport for VirtIO.
+///
+/// Ref: 4.1 Virtio Over PCI Bus
+#[derive(Debug)]
+pub struct PciTransport {
+ device_type: DeviceType,
+ /// The bus, device and function identifier for the VirtIO device.
+ device_function: DeviceFunction,
+ /// The common configuration structure within some BAR.
+ common_cfg: NonNull<CommonCfg>,
+ // TODO: Use a raw slice, once they are supported by our MSRV.
+ /// The start of the queue notification region within some BAR.
+ notify_region: NonNull<[WriteOnly<u16>]>,
+ notify_off_multiplier: u32,
+ /// The ISR status register within some BAR.
+ isr_status: NonNull<Volatile<u8>>,
+ /// The VirtIO device-specific configuration within some BAR.
+ config_space: Option<NonNull<[u32]>>,
+}
+
+impl PciTransport {
+ /// Construct a new PCI VirtIO device driver for the given device function on the given PCI
+ /// root controller.
+ ///
+ /// The PCI device must already have had its BARs allocated.
+ pub fn new<H: Hal>(
+ root: &mut PciRoot,
+ device_function: DeviceFunction,
+ ) -> Result<Self, VirtioPciError> {
+ let device_vendor = root.config_read_word(device_function, 0);
+ let device_id = (device_vendor >> 16) as u16;
+ let vendor_id = device_vendor as u16;
+ if vendor_id != VIRTIO_VENDOR_ID {
+ return Err(VirtioPciError::InvalidVendorId(vendor_id));
+ }
+ let device_type = device_type(device_id);
+
+ // Find the PCI capabilities we need.
+ let mut common_cfg = None;
+ let mut notify_cfg = None;
+ let mut notify_off_multiplier = 0;
+ let mut isr_cfg = None;
+ let mut device_cfg = None;
+ for capability in root.capabilities(device_function) {
+ if capability.id != PCI_CAP_ID_VNDR {
+ continue;
+ }
+ let cap_len = capability.private_header as u8;
+ let cfg_type = (capability.private_header >> 8) as u8;
+ if cap_len < 16 {
+ continue;
+ }
+ let struct_info = VirtioCapabilityInfo {
+ bar: root.config_read_word(device_function, capability.offset + CAP_BAR_OFFSET)
+ as u8,
+ offset: root
+ .config_read_word(device_function, capability.offset + CAP_BAR_OFFSET_OFFSET),
+ length: root
+ .config_read_word(device_function, capability.offset + CAP_LENGTH_OFFSET),
+ };
+
+ match cfg_type {
+ VIRTIO_PCI_CAP_COMMON_CFG if common_cfg.is_none() => {
+ common_cfg = Some(struct_info);
+ }
+ VIRTIO_PCI_CAP_NOTIFY_CFG if cap_len >= 20 && notify_cfg.is_none() => {
+ notify_cfg = Some(struct_info);
+ notify_off_multiplier = root.config_read_word(
+ device_function,
+ capability.offset + CAP_NOTIFY_OFF_MULTIPLIER_OFFSET,
+ );
+ }
+ VIRTIO_PCI_CAP_ISR_CFG if isr_cfg.is_none() => {
+ isr_cfg = Some(struct_info);
+ }
+ VIRTIO_PCI_CAP_DEVICE_CFG if device_cfg.is_none() => {
+ device_cfg = Some(struct_info);
+ }
+ _ => {}
+ }
+ }
+
+ let common_cfg = get_bar_region::<H, _>(
+ root,
+ device_function,
+ &common_cfg.ok_or(VirtioPciError::MissingCommonConfig)?,
+ )?;
+
+ let notify_cfg = notify_cfg.ok_or(VirtioPciError::MissingNotifyConfig)?;
+ if notify_off_multiplier % 2 != 0 {
+ return Err(VirtioPciError::InvalidNotifyOffMultiplier(
+ notify_off_multiplier,
+ ));
+ }
+ let notify_region = get_bar_region_slice::<H, _>(root, device_function, &notify_cfg)?;
+
+ let isr_status = get_bar_region::<H, _>(
+ root,
+ device_function,
+ &isr_cfg.ok_or(VirtioPciError::MissingIsrConfig)?,
+ )?;
+
+ let config_space = if let Some(device_cfg) = device_cfg {
+ Some(get_bar_region_slice::<H, _>(
+ root,
+ device_function,
+ &device_cfg,
+ )?)
+ } else {
+ None
+ };
+
+ Ok(Self {
+ device_type,
+ device_function,
+ common_cfg,
+ notify_region,
+ notify_off_multiplier,
+ isr_status,
+ config_space,
+ })
+ }
+}
+
+impl Transport for PciTransport {
+ fn device_type(&self) -> DeviceType {
+ self.device_type
+ }
+
+ fn read_device_features(&mut self) -> u64 {
+ // Safe because the common config pointer is valid and we checked in get_bar_region that it
+ // was aligned.
+ unsafe {
+ volwrite!(self.common_cfg, device_feature_select, 0);
+ let mut device_features_bits = volread!(self.common_cfg, device_feature) as u64;
+ volwrite!(self.common_cfg, device_feature_select, 1);
+ device_features_bits |= (volread!(self.common_cfg, device_feature) as u64) << 32;
+ device_features_bits
+ }
+ }
+
+ fn write_driver_features(&mut self, driver_features: u64) {
+ // Safe because the common config pointer is valid and we checked in get_bar_region that it
+ // was aligned.
+ unsafe {
+ volwrite!(self.common_cfg, driver_feature_select, 0);
+ volwrite!(self.common_cfg, driver_feature, driver_features as u32);
+ volwrite!(self.common_cfg, driver_feature_select, 1);
+ volwrite!(
+ self.common_cfg,
+ driver_feature,
+ (driver_features >> 32) as u32
+ );
+ }
+ }
+
+ fn max_queue_size(&self) -> u32 {
+ // Safe because the common config pointer is valid and we checked in get_bar_region that it
+ // was aligned.
+ unsafe { volread!(self.common_cfg, queue_size) }.into()
+ }
+
+ fn notify(&mut self, queue: u16) {
+ // Safe because the common config and notify region pointers are valid and we checked in
+ // get_bar_region that they were aligned.
+ unsafe {
+ volwrite!(self.common_cfg, queue_select, queue);
+ // TODO: Consider caching this somewhere (per queue).
+ let queue_notify_off = volread!(self.common_cfg, queue_notify_off);
+
+ let offset_bytes = usize::from(queue_notify_off) * self.notify_off_multiplier as usize;
+ let index = offset_bytes / size_of::<u16>();
+ addr_of_mut!((*self.notify_region.as_ptr())[index]).vwrite(queue);
+ }
+ }
+
+ fn set_status(&mut self, status: DeviceStatus) {
+ // Safe because the common config pointer is valid and we checked in get_bar_region that it
+ // was aligned.
+ unsafe {
+ volwrite!(self.common_cfg, device_status, status.bits() as u8);
+ }
+ }
+
+ fn set_guest_page_size(&mut self, _guest_page_size: u32) {
+ // No-op, the PCI transport doesn't care.
+ }
+
+ fn queue_set(
+ &mut self,
+ queue: u16,
+ size: u32,
+ descriptors: PhysAddr,
+ driver_area: PhysAddr,
+ device_area: PhysAddr,
+ ) {
+ // Safe because the common config pointer is valid and we checked in get_bar_region that it
+ // was aligned.
+ unsafe {
+ volwrite!(self.common_cfg, queue_select, queue);
+ volwrite!(self.common_cfg, queue_size, size as u16);
+ volwrite!(self.common_cfg, queue_desc, descriptors as u64);
+ volwrite!(self.common_cfg, queue_driver, driver_area as u64);
+ volwrite!(self.common_cfg, queue_device, device_area as u64);
+ volwrite!(self.common_cfg, queue_enable, 1);
+ }
+ }
+
+ fn queue_unset(&mut self, queue: u16) {
+ // Safe because the common config pointer is valid and we checked in get_bar_region that it
+ // was aligned.
+ unsafe {
+ volwrite!(self.common_cfg, queue_enable, 0);
+ volwrite!(self.common_cfg, queue_select, queue);
+ volwrite!(self.common_cfg, queue_size, 0);
+ volwrite!(self.common_cfg, queue_desc, 0);
+ volwrite!(self.common_cfg, queue_driver, 0);
+ volwrite!(self.common_cfg, queue_device, 0);
+ }
+ }
+
+ fn queue_used(&mut self, queue: u16) -> bool {
+ // Safe because the common config pointer is valid and we checked in get_bar_region that it
+ // was aligned.
+ unsafe {
+ volwrite!(self.common_cfg, queue_select, queue);
+ volread!(self.common_cfg, queue_enable) == 1
+ }
+ }
+
+ fn ack_interrupt(&mut self) -> bool {
+ // Safe because the common config pointer is valid and we checked in get_bar_region that it
+ // was aligned.
+ // Reading the ISR status resets it to 0 and causes the device to de-assert the interrupt.
+ let isr_status = unsafe { self.isr_status.as_ptr().vread() };
+ // TODO: Distinguish between queue interrupt and device configuration interrupt.
+ isr_status & 0x3 != 0
+ }
+
+ fn config_space<T>(&self) -> Result<NonNull<T>, Error> {
+ if let Some(config_space) = self.config_space {
+ if size_of::<T>() > config_space.len() * size_of::<u32>() {
+ Err(Error::ConfigSpaceTooSmall)
+ } else if align_of::<T>() > 4 {
+ // Panic as this should only happen if the driver is written incorrectly.
+ panic!(
+ "Driver expected config space alignment of {} bytes, but VirtIO only guarantees 4 byte alignment.",
+ align_of::<T>()
+ );
+ } else {
+ // TODO: Use NonNull::as_non_null_ptr once it is stable.
+ let config_space_ptr = NonNull::new(config_space.as_ptr() as *mut u32).unwrap();
+ Ok(config_space_ptr.cast())
+ }
+ } else {
+ Err(Error::ConfigSpaceMissing)
+ }
+ }
+}
+
+/// `virtio_pci_common_cfg`, see 4.1.4.3 "Common configuration structure layout".
+#[repr(C)]
+struct CommonCfg {
+ device_feature_select: Volatile<u32>,
+ device_feature: ReadOnly<u32>,
+ driver_feature_select: Volatile<u32>,
+ driver_feature: Volatile<u32>,
+ msix_config: Volatile<u16>,
+ num_queues: ReadOnly<u16>,
+ device_status: Volatile<u8>,
+ config_generation: ReadOnly<u8>,
+ queue_select: Volatile<u16>,
+ queue_size: Volatile<u16>,
+ queue_msix_vector: Volatile<u16>,
+ queue_enable: Volatile<u16>,
+ queue_notify_off: Volatile<u16>,
+ queue_desc: Volatile<u64>,
+ queue_driver: Volatile<u64>,
+ queue_device: Volatile<u64>,
+}
+
+/// Information about a VirtIO structure within some BAR, as provided by a `virtio_pci_cap`.
+#[derive(Clone, Debug, Eq, PartialEq)]
+struct VirtioCapabilityInfo {
+ /// The bar in which the structure can be found.
+ bar: u8,
+ /// The offset within the bar.
+ offset: u32,
+ /// The length in bytes of the structure within the bar.
+ length: u32,
+}
+
+fn get_bar_region<H: Hal, T>(
+ root: &mut PciRoot,
+ device_function: DeviceFunction,
+ struct_info: &VirtioCapabilityInfo,
+) -> Result<NonNull<T>, VirtioPciError> {
+ let bar_info = root.bar_info(device_function, struct_info.bar)?;
+ let (bar_address, bar_size) = bar_info
+ .memory_address_size()
+ .ok_or(VirtioPciError::UnexpectedIoBar)?;
+ if bar_address == 0 {
+ return Err(VirtioPciError::BarNotAllocated(struct_info.bar));
+ }
+ if struct_info.offset + struct_info.length > bar_size
+ || size_of::<T>() > struct_info.length as usize
+ {
+ return Err(VirtioPciError::BarOffsetOutOfRange);
+ }
+ let paddr = bar_address as PhysAddr + struct_info.offset as PhysAddr;
+ let vaddr = H::phys_to_virt(paddr);
+ if vaddr % align_of::<T>() != 0 {
+ return Err(VirtioPciError::Misaligned {
+ vaddr,
+ alignment: align_of::<T>(),
+ });
+ }
+ Ok(NonNull::new(vaddr as _).unwrap())
+}
+
+fn get_bar_region_slice<H: Hal, T>(
+ root: &mut PciRoot,
+ device_function: DeviceFunction,
+ struct_info: &VirtioCapabilityInfo,
+) -> Result<NonNull<[T]>, VirtioPciError> {
+ let ptr = get_bar_region::<H, T>(root, device_function, struct_info)?;
+ let raw_slice =
+ ptr::slice_from_raw_parts_mut(ptr.as_ptr(), struct_info.length as usize / size_of::<T>());
+ Ok(NonNull::new(raw_slice).unwrap())
+}
+
+/// An error encountered initialising a VirtIO PCI transport.
+#[derive(Clone, Debug, Eq, PartialEq)]
+pub enum VirtioPciError {
+ /// PCI device vender ID was not the VirtIO vendor ID.
+ InvalidVendorId(u16),
+ /// No valid `VIRTIO_PCI_CAP_COMMON_CFG` capability was found.
+ MissingCommonConfig,
+ /// No valid `VIRTIO_PCI_CAP_NOTIFY_CFG` capability was found.
+ MissingNotifyConfig,
+ /// `VIRTIO_PCI_CAP_NOTIFY_CFG` capability has a `notify_off_multiplier` that is not a multiple
+ /// of 2.
+ InvalidNotifyOffMultiplier(u32),
+ /// No valid `VIRTIO_PCI_CAP_ISR_CFG` capability was found.
+ MissingIsrConfig,
+ /// An IO BAR was provided rather than a memory BAR.
+ UnexpectedIoBar,
+ /// A BAR which we need was not allocated an address.
+ BarNotAllocated(u8),
+ /// The offset for some capability was greater than the length of the BAR.
+ BarOffsetOutOfRange,
+ /// The virtual address was not aligned as expected.
+ Misaligned {
+ /// The virtual address in question.
+ vaddr: VirtAddr,
+ /// The expected alignment in bytes.
+ alignment: usize,
+ },
+ /// A generic PCI error,
+ Pci(PciError),
+}
+
+impl Display for VirtioPciError {
+ fn fmt(&self, f: &mut Formatter) -> fmt::Result {
+ match self {
+ Self::InvalidVendorId(vendor_id) => write!(
+ f,
+ "PCI device vender ID {:#06x} was not the VirtIO vendor ID {:#06x}.",
+ vendor_id, VIRTIO_VENDOR_ID
+ ),
+ Self::MissingCommonConfig => write!(
+ f,
+ "No valid `VIRTIO_PCI_CAP_COMMON_CFG` capability was found."
+ ),
+ Self::MissingNotifyConfig => write!(
+ f,
+ "No valid `VIRTIO_PCI_CAP_NOTIFY_CFG` capability was found."
+ ),
+ Self::InvalidNotifyOffMultiplier(notify_off_multiplier) => {
+ write!(
+ f,
+ "`VIRTIO_PCI_CAP_NOTIFY_CFG` capability has a `notify_off_multiplier` that is not a multiple of 2: {}",
+ notify_off_multiplier
+ )
+ }
+ Self::MissingIsrConfig => {
+ write!(f, "No valid `VIRTIO_PCI_CAP_ISR_CFG` capability was found.")
+ }
+ Self::UnexpectedIoBar => write!(f, "Unexpected IO BAR (expected memory BAR)."),
+ Self::BarNotAllocated(bar_index) => write!(f, "Bar {} not allocated.", bar_index),
+ Self::BarOffsetOutOfRange => write!(f, "Capability offset greater than BAR length."),
+ Self::Misaligned { vaddr, alignment } => write!(
+ f,
+ "Virtual address {:#018x} was not aligned to a {} byte boundary as expected.",
+ vaddr, alignment
+ ),
+ Self::Pci(pci_error) => pci_error.fmt(f),
+ }
+ }
+}
+
+impl From<PciError> for VirtioPciError {
+ fn from(error: PciError) -> Self {
+ Self::Pci(error)
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn transitional_device_ids() {
+ assert_eq!(device_type(0x1000), DeviceType::Network);
+ assert_eq!(device_type(0x1002), DeviceType::MemoryBalloon);
+ assert_eq!(device_type(0x1009), DeviceType::_9P);
+ }
+
+ #[test]
+ fn offset_device_ids() {
+ assert_eq!(device_type(0x1045), DeviceType::MemoryBalloon);
+ assert_eq!(device_type(0x1049), DeviceType::_9P);
+ assert_eq!(device_type(0x1058), DeviceType::Memory);
+ assert_eq!(device_type(0x1040), DeviceType::Invalid);
+ assert_eq!(device_type(0x1059), DeviceType::Invalid);
+ }
+
+ #[test]
+ fn virtio_device_type_valid() {
+ assert_eq!(
+ virtio_device_type(&DeviceFunctionInfo {
+ vendor_id: VIRTIO_VENDOR_ID,
+ device_id: TRANSITIONAL_BLOCK,
+ class: 0,
+ subclass: 0,
+ prog_if: 0,
+ revision: 0,
+ header_type: bus::HeaderType::Standard,
+ }),
+ Some(DeviceType::Block)
+ );
+ }
+
+ #[test]
+ fn virtio_device_type_invalid() {
+ // Non-VirtIO vendor ID.
+ assert_eq!(
+ virtio_device_type(&DeviceFunctionInfo {
+ vendor_id: 0x1234,
+ device_id: TRANSITIONAL_BLOCK,
+ class: 0,
+ subclass: 0,
+ prog_if: 0,
+ revision: 0,
+ header_type: bus::HeaderType::Standard,
+ }),
+ None
+ );
+
+ // Invalid device ID.
+ assert_eq!(
+ virtio_device_type(&DeviceFunctionInfo {
+ vendor_id: VIRTIO_VENDOR_ID,
+ device_id: 0x1040,
+ class: 0,
+ subclass: 0,
+ prog_if: 0,
+ revision: 0,
+ header_type: bus::HeaderType::Standard,
+ }),
+ None
+ );
+ }
+}
diff --git a/src/transport/pci/bus.rs b/src/transport/pci/bus.rs
new file mode 100644
index 0000000..a1abb23
--- /dev/null
+++ b/src/transport/pci/bus.rs
@@ -0,0 +1,599 @@
+//! Module for dealing with a PCI bus in general, without anything specific to VirtIO.
+
+use bitflags::bitflags;
+use core::{
+ convert::TryFrom,
+ fmt::{self, Display, Formatter},
+};
+use log::warn;
+
+const INVALID_READ: u32 = 0xffffffff;
+// PCI MMIO configuration region size.
+const AARCH64_PCI_CFG_SIZE: u32 = 0x1000000;
+// PCIe MMIO configuration region size.
+const AARCH64_PCIE_CFG_SIZE: u32 = 0x10000000;
+
+/// The maximum number of devices on a bus.
+const MAX_DEVICES: u8 = 32;
+/// The maximum number of functions on a device.
+const MAX_FUNCTIONS: u8 = 8;
+
+/// The offset in bytes to the status and command fields within PCI configuration space.
+const STATUS_COMMAND_OFFSET: u8 = 0x04;
+/// The offset in bytes to BAR0 within PCI configuration space.
+const BAR0_OFFSET: u8 = 0x10;
+
+/// ID for vendor-specific PCI capabilities.
+pub const PCI_CAP_ID_VNDR: u8 = 0x09;
+
+bitflags! {
+ /// The status register in PCI configuration space.
+ pub struct Status: u16 {
+ // Bits 0-2 are reserved.
+ /// The state of the device's INTx# signal.
+ const INTERRUPT_STATUS = 1 << 3;
+ /// The device has a linked list of capabilities.
+ const CAPABILITIES_LIST = 1 << 4;
+ /// The device is capabile of running at 66 MHz rather than 33 MHz.
+ const MHZ_66_CAPABLE = 1 << 5;
+ // Bit 6 is reserved.
+ /// The device can accept fast back-to-back transactions not from the same agent.
+ const FAST_BACK_TO_BACK_CAPABLE = 1 << 7;
+ /// The bus agent observed a parity error (if parity error handling is enabled).
+ const MASTER_DATA_PARITY_ERROR = 1 << 8;
+ // Bits 9-10 are DEVSEL timing.
+ /// A target device terminated a transaction with target-abort.
+ const SIGNALED_TARGET_ABORT = 1 << 11;
+ /// A master device transaction was terminated with target-abort.
+ const RECEIVED_TARGET_ABORT = 1 << 12;
+ /// A master device transaction was terminated with master-abort.
+ const RECEIVED_MASTER_ABORT = 1 << 13;
+ /// A device asserts SERR#.
+ const SIGNALED_SYSTEM_ERROR = 1 << 14;
+ /// The device detects a parity error, even if parity error handling is disabled.
+ const DETECTED_PARITY_ERROR = 1 << 15;
+ }
+}
+
+bitflags! {
+ /// The command register in PCI configuration space.
+ pub struct Command: u16 {
+ /// The device can respond to I/O Space accesses.
+ const IO_SPACE = 1 << 0;
+ /// The device can respond to Memory Space accesses.
+ const MEMORY_SPACE = 1 << 1;
+ /// The device can behave as a bus master.
+ const BUS_MASTER = 1 << 2;
+ /// The device can monitor Special Cycle operations.
+ const SPECIAL_CYCLES = 1 << 3;
+ /// The device can generate the Memory Write and Invalidate command.
+ const MEMORY_WRITE_AND_INVALIDATE_ENABLE = 1 << 4;
+ /// The device will snoop palette register data.
+ const VGA_PALETTE_SNOOP = 1 << 5;
+ /// The device should take its normal action when a parity error is detected.
+ const PARITY_ERROR_RESPONSE = 1 << 6;
+ // Bit 7 is reserved.
+ /// The SERR# driver is enabled.
+ const SERR_ENABLE = 1 << 8;
+ /// The device is allowed to generate fast back-to-back transactions.
+ const FAST_BACK_TO_BACK_ENABLE = 1 << 9;
+ /// Assertion of the device's INTx# signal is disabled.
+ const INTERRUPT_DISABLE = 1 << 10;
+ }
+}
+
+/// Errors accessing a PCI device.
+#[derive(Copy, Clone, Debug, Eq, PartialEq)]
+pub enum PciError {
+ /// The device reported an invalid BAR type.
+ InvalidBarType,
+}
+
+impl Display for PciError {
+ fn fmt(&self, f: &mut Formatter) -> fmt::Result {
+ match self {
+ Self::InvalidBarType => write!(f, "Invalid PCI BAR type."),
+ }
+ }
+}
+
+/// The root complex of a PCI bus.
+#[derive(Debug)]
+pub struct PciRoot {
+ mmio_base: *mut u32,
+ cam: Cam,
+}
+
+/// A PCI Configuration Access Mechanism.
+#[derive(Copy, Clone, Debug, Eq, PartialEq)]
+pub enum Cam {
+ /// The PCI memory-mapped Configuration Access Mechanism.
+ ///
+ /// This provides access to 256 bytes of configuration space per device function.
+ MmioCam,
+ /// The PCIe memory-mapped Enhanced Configuration Access Mechanism.
+ ///
+ /// This provides access to 4 KiB of configuration space per device function.
+ Ecam,
+}
+
+impl PciRoot {
+ /// Wraps the PCI root complex with the given MMIO base address.
+ ///
+ /// Panics if the base address is not aligned to a 4-byte boundary.
+ ///
+ /// # Safety
+ ///
+ /// `mmio_base` must be a valid pointer to an appropriately-mapped MMIO region of at least
+ /// 16 MiB (if `cam == Cam::MmioCam`) or 256 MiB (if `cam == Cam::Ecam`). The pointer must be
+ /// valid for the entire lifetime of the program (i.e. `'static`), which implies that no Rust
+ /// references may be used to access any of the memory region at any point.
+ pub unsafe fn new(mmio_base: *mut u8, cam: Cam) -> Self {
+ assert!(mmio_base as usize & 0x3 == 0);
+ Self {
+ mmio_base: mmio_base as *mut u32,
+ cam,
+ }
+ }
+
+ /// Makes a clone of the `PciRoot`, pointing at the same MMIO region.
+ ///
+ /// # Safety
+ ///
+ /// This function allows concurrent mutable access to the PCI CAM. To avoid this causing
+ /// problems, the returned `PciRoot` instance must only be used to read read-only fields.
+ unsafe fn unsafe_clone(&self) -> Self {
+ Self {
+ mmio_base: self.mmio_base,
+ cam: self.cam,
+ }
+ }
+
+ fn cam_offset(&self, device_function: DeviceFunction, register_offset: u8) -> u32 {
+ assert!(device_function.valid());
+
+ let bdf = (device_function.bus as u32) << 8
+ | (device_function.device as u32) << 3
+ | device_function.function as u32;
+ let address;
+ match self.cam {
+ Cam::MmioCam => {
+ address = bdf << 8 | register_offset as u32;
+ // Ensure that address is within range.
+ assert!(address < AARCH64_PCI_CFG_SIZE);
+ }
+ Cam::Ecam => {
+ address = bdf << 12 | register_offset as u32;
+ // Ensure that address is within range.
+ assert!(address < AARCH64_PCIE_CFG_SIZE);
+ }
+ }
+ // Ensure that address is word-aligned.
+ assert!(address & 0x3 == 0);
+ address
+ }
+
+ /// Reads 4 bytes from configuration space using the appropriate CAM.
+ pub(crate) fn config_read_word(
+ &self,
+ device_function: DeviceFunction,
+ register_offset: u8,
+ ) -> u32 {
+ let address = self.cam_offset(device_function, register_offset);
+ // Safe because both the `mmio_base` and the address offset are properly aligned, and the
+ // resulting pointer is within the MMIO range of the CAM.
+ unsafe {
+ // Right shift to convert from byte offset to word offset.
+ (self.mmio_base.add((address >> 2) as usize)).read_volatile()
+ }
+ }
+
+ /// Writes 4 bytes to configuration space using the appropriate CAM.
+ pub(crate) fn config_write_word(
+ &mut self,
+ device_function: DeviceFunction,
+ register_offset: u8,
+ data: u32,
+ ) {
+ let address = self.cam_offset(device_function, register_offset);
+ // Safe because both the `mmio_base` and the address offset are properly aligned, and the
+ // resulting pointer is within the MMIO range of the CAM.
+ unsafe {
+ // Right shift to convert from byte offset to word offset.
+ (self.mmio_base.add((address >> 2) as usize)).write_volatile(data)
+ }
+ }
+
+ /// Enumerates PCI devices on the given bus.
+ pub fn enumerate_bus(&self, bus: u8) -> BusDeviceIterator {
+ // Safe because the BusDeviceIterator only reads read-only fields.
+ let root = unsafe { self.unsafe_clone() };
+ BusDeviceIterator {
+ root,
+ next: DeviceFunction {
+ bus,
+ device: 0,
+ function: 0,
+ },
+ }
+ }
+
+ /// Reads the status and command registers of the given device function.
+ pub fn get_status_command(&self, device_function: DeviceFunction) -> (Status, Command) {
+ let status_command = self.config_read_word(device_function, STATUS_COMMAND_OFFSET);
+ let status = Status::from_bits_truncate((status_command >> 16) as u16);
+ let command = Command::from_bits_truncate(status_command as u16);
+ (status, command)
+ }
+
+ /// Sets the command register of the given device function.
+ pub fn set_command(&mut self, device_function: DeviceFunction, command: Command) {
+ self.config_write_word(
+ device_function,
+ STATUS_COMMAND_OFFSET,
+ command.bits().into(),
+ );
+ }
+
+ /// Gets an iterator over the capabilities of the given device function.
+ pub fn capabilities(&self, device_function: DeviceFunction) -> CapabilityIterator {
+ CapabilityIterator {
+ root: self,
+ device_function,
+ next_capability_offset: self.capabilities_offset(device_function),
+ }
+ }
+
+ /// Gets information about the given BAR of the given device function.
+ pub fn bar_info(
+ &mut self,
+ device_function: DeviceFunction,
+ bar_index: u8,
+ ) -> Result<BarInfo, PciError> {
+ let bar_orig = self.config_read_word(device_function, BAR0_OFFSET + 4 * bar_index);
+
+ // Get the size of the BAR.
+ self.config_write_word(device_function, BAR0_OFFSET + 4 * bar_index, 0xffffffff);
+ let size_mask = self.config_read_word(device_function, BAR0_OFFSET + 4 * bar_index);
+ let size = !(size_mask & 0xfffffff0) + 1;
+
+ // Restore the original value.
+ self.config_write_word(device_function, BAR0_OFFSET + 4 * bar_index, bar_orig);
+
+ if bar_orig & 0x00000001 == 0x00000001 {
+ // I/O space
+ let address = bar_orig & 0xfffffffc;
+ Ok(BarInfo::IO { address, size })
+ } else {
+ // Memory space
+ let mut address = u64::from(bar_orig & 0xfffffff0);
+ let prefetchable = bar_orig & 0x00000008 != 0;
+ let address_type = MemoryBarType::try_from(((bar_orig & 0x00000006) >> 1) as u8)?;
+ if address_type == MemoryBarType::Width64 {
+ if bar_index >= 5 {
+ return Err(PciError::InvalidBarType);
+ }
+ let address_top =
+ self.config_read_word(device_function, BAR0_OFFSET + 4 * (bar_index + 1));
+ address |= u64::from(address_top) << 32;
+ }
+ Ok(BarInfo::Memory {
+ address_type,
+ prefetchable,
+ address,
+ size,
+ })
+ }
+ }
+
+ /// Sets the address of the given 32-bit memory or I/O BAR of the given device function.
+ pub fn set_bar_32(&mut self, device_function: DeviceFunction, bar_index: u8, address: u32) {
+ self.config_write_word(device_function, BAR0_OFFSET + 4 * bar_index, address);
+ }
+
+ /// Sets the address of the given 64-bit memory BAR of the given device function.
+ pub fn set_bar_64(&mut self, device_function: DeviceFunction, bar_index: u8, address: u64) {
+ self.config_write_word(device_function, BAR0_OFFSET + 4 * bar_index, address as u32);
+ self.config_write_word(
+ device_function,
+ BAR0_OFFSET + 4 * (bar_index + 1),
+ (address >> 32) as u32,
+ );
+ }
+
+ /// Gets the capabilities 'pointer' for the device function, if any.
+ fn capabilities_offset(&self, device_function: DeviceFunction) -> Option<u8> {
+ let (status, _) = self.get_status_command(device_function);
+ if status.contains(Status::CAPABILITIES_LIST) {
+ Some((self.config_read_word(device_function, 0x34) & 0xFC) as u8)
+ } else {
+ None
+ }
+ }
+}
+
+/// Information about a PCI Base Address Register.
+#[derive(Clone, Debug, Eq, PartialEq)]
+pub enum BarInfo {
+ /// The BAR is for a memory region.
+ Memory {
+ /// The size of the BAR address and where it can be located.
+ address_type: MemoryBarType,
+ /// If true, then reading from the region doesn't have side effects. The CPU may cache reads
+ /// and merge repeated stores.
+ prefetchable: bool,
+ /// The memory address, always 16-byte aligned.
+ address: u64,
+ /// The size of the BAR in bytes.
+ size: u32,
+ },
+ /// The BAR is for an I/O region.
+ IO {
+ /// The I/O address, always 4-byte aligned.
+ address: u32,
+ /// The size of the BAR in bytes.
+ size: u32,
+ },
+}
+
+impl BarInfo {
+ /// Returns whether this BAR is a 64-bit memory region, and so takes two entries in the table in
+ /// configuration space.
+ pub fn takes_two_entries(&self) -> bool {
+ matches!(
+ self,
+ BarInfo::Memory {
+ address_type: MemoryBarType::Width64,
+ ..
+ }
+ )
+ }
+
+ /// Returns the address and size of this BAR if it is a memory bar, or `None` if it is an IO
+ /// BAR.
+ pub fn memory_address_size(&self) -> Option<(u64, u32)> {
+ if let Self::Memory { address, size, .. } = self {
+ Some((*address, *size))
+ } else {
+ None
+ }
+ }
+}
+
+impl Display for BarInfo {
+ fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
+ match self {
+ Self::Memory {
+ address_type,
+ prefetchable,
+ address,
+ size,
+ } => write!(
+ f,
+ "Memory space at {:#010x}, size {}, type {:?}, prefetchable {}",
+ address, size, address_type, prefetchable
+ ),
+ Self::IO { address, size } => {
+ write!(f, "I/O space at {:#010x}, size {}", address, size)
+ }
+ }
+ }
+}
+
+/// The location allowed for a memory BAR.
+#[derive(Copy, Clone, Debug, Eq, PartialEq)]
+pub enum MemoryBarType {
+ /// The BAR has a 32-bit address and can be mapped anywhere in 32-bit address space.
+ Width32,
+ /// The BAR must be mapped below 1MiB.
+ Below1MiB,
+ /// The BAR has a 64-bit address and can be mapped anywhere in 64-bit address space.
+ Width64,
+}
+
+impl From<MemoryBarType> for u8 {
+ fn from(bar_type: MemoryBarType) -> Self {
+ match bar_type {
+ MemoryBarType::Width32 => 0,
+ MemoryBarType::Below1MiB => 1,
+ MemoryBarType::Width64 => 2,
+ }
+ }
+}
+
+impl TryFrom<u8> for MemoryBarType {
+ type Error = PciError;
+
+ fn try_from(value: u8) -> Result<Self, Self::Error> {
+ match value {
+ 0 => Ok(Self::Width32),
+ 1 => Ok(Self::Below1MiB),
+ 2 => Ok(Self::Width64),
+ _ => Err(PciError::InvalidBarType),
+ }
+ }
+}
+
+/// Iterator over capabilities for a device.
+#[derive(Debug)]
+pub struct CapabilityIterator<'a> {
+ root: &'a PciRoot,
+ device_function: DeviceFunction,
+ next_capability_offset: Option<u8>,
+}
+
+impl<'a> Iterator for CapabilityIterator<'a> {
+ type Item = CapabilityInfo;
+
+ fn next(&mut self) -> Option<Self::Item> {
+ let offset = self.next_capability_offset?;
+
+ // Read the first 4 bytes of the capability.
+ let capability_header = self.root.config_read_word(self.device_function, offset);
+ let id = capability_header as u8;
+ let next_offset = (capability_header >> 8) as u8;
+ let private_header = (capability_header >> 16) as u16;
+
+ self.next_capability_offset = if next_offset == 0 {
+ None
+ } else if next_offset < 64 || next_offset & 0x3 != 0 {
+ warn!("Invalid next capability offset {:#04x}", next_offset);
+ None
+ } else {
+ Some(next_offset)
+ };
+
+ Some(CapabilityInfo {
+ offset,
+ id,
+ private_header,
+ })
+ }
+}
+
+/// Information about a PCI device capability.
+#[derive(Debug, Copy, Clone, Eq, PartialEq)]
+pub struct CapabilityInfo {
+ /// The offset of the capability in the PCI configuration space of the device function.
+ pub offset: u8,
+ /// The ID of the capability.
+ pub id: u8,
+ /// The third and fourth bytes of the capability, to save reading them again.
+ pub private_header: u16,
+}
+
+/// An iterator which enumerates PCI devices and functions on a given bus.
+#[derive(Debug)]
+pub struct BusDeviceIterator {
+ /// This must only be used to read read-only fields, and must not be exposed outside this
+ /// module, because it uses the same CAM as the main `PciRoot` instance.
+ root: PciRoot,
+ next: DeviceFunction,
+}
+
+impl Iterator for BusDeviceIterator {
+ type Item = (DeviceFunction, DeviceFunctionInfo);
+
+ fn next(&mut self) -> Option<Self::Item> {
+ while self.next.device < MAX_DEVICES {
+ // Read the header for the current device and function.
+ let current = self.next;
+ let device_vendor = self.root.config_read_word(current, 0);
+
+ // Advance to the next device or function.
+ self.next.function += 1;
+ if self.next.function >= MAX_FUNCTIONS {
+ self.next.function = 0;
+ self.next.device += 1;
+ }
+
+ if device_vendor != INVALID_READ {
+ let class_revision = self.root.config_read_word(current, 8);
+ let device_id = (device_vendor >> 16) as u16;
+ let vendor_id = device_vendor as u16;
+ let class = (class_revision >> 24) as u8;
+ let subclass = (class_revision >> 16) as u8;
+ let prog_if = (class_revision >> 8) as u8;
+ let revision = class_revision as u8;
+ let bist_type_latency_cache = self.root.config_read_word(current, 12);
+ let header_type = HeaderType::from((bist_type_latency_cache >> 16) as u8 & 0x7f);
+ return Some((
+ current,
+ DeviceFunctionInfo {
+ vendor_id,
+ device_id,
+ class,
+ subclass,
+ prog_if,
+ revision,
+ header_type,
+ },
+ ));
+ }
+ }
+ None
+ }
+}
+
+/// An identifier for a PCI bus, device and function.
+#[derive(Copy, Clone, Debug, Eq, PartialEq)]
+pub struct DeviceFunction {
+ /// The PCI bus number, between 0 and 255.
+ pub bus: u8,
+ /// The device number on the bus, between 0 and 31.
+ pub device: u8,
+ /// The function number of the device, between 0 and 7.
+ pub function: u8,
+}
+
+impl DeviceFunction {
+ /// Returns whether the device and function numbers are valid, i.e. the device is between 0 and
+ /// 31, and the function is between 0 and 7.
+ pub fn valid(&self) -> bool {
+ self.device < 32 && self.function < 8
+ }
+}
+
+impl Display for DeviceFunction {
+ fn fmt(&self, f: &mut Formatter) -> fmt::Result {
+ write!(f, "{:02x}:{:02x}.{}", self.bus, self.device, self.function)
+ }
+}
+
+/// Information about a PCI device function.
+#[derive(Clone, Debug, Eq, PartialEq)]
+pub struct DeviceFunctionInfo {
+ /// The PCI vendor ID.
+ pub vendor_id: u16,
+ /// The PCI device ID.
+ pub device_id: u16,
+ /// The PCI class.
+ pub class: u8,
+ /// The PCI subclass.
+ pub subclass: u8,
+ /// The PCI programming interface byte.
+ pub prog_if: u8,
+ /// The PCI revision ID.
+ pub revision: u8,
+ /// The type of PCI device.
+ pub header_type: HeaderType,
+}
+
+impl Display for DeviceFunctionInfo {
+ fn fmt(&self, f: &mut Formatter) -> fmt::Result {
+ write!(
+ f,
+ "{:04x}:{:04x} (class {:02x}.{:02x}, rev {:02x}) {:?}",
+ self.vendor_id,
+ self.device_id,
+ self.class,
+ self.subclass,
+ self.revision,
+ self.header_type,
+ )
+ }
+}
+
+/// The type of a PCI device function header.
+#[derive(Copy, Clone, Debug, Eq, PartialEq)]
+pub enum HeaderType {
+ /// A normal PCI device.
+ Standard,
+ /// A PCI to PCI bridge.
+ PciPciBridge,
+ /// A PCI to CardBus bridge.
+ PciCardbusBridge,
+ /// Unrecognised header type.
+ Unrecognised(u8),
+}
+
+impl From<u8> for HeaderType {
+ fn from(value: u8) -> Self {
+ match value {
+ 0x00 => Self::Standard,
+ 0x01 => Self::PciPciBridge,
+ 0x02 => Self::PciCardbusBridge,
+ _ => Self::Unrecognised(value),
+ }
+ }
+}