aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavid Brazdil <dbrazdil@google.com>2023-05-11 17:14:49 +0000
committerAutomerger Merge Worker <android-build-automerger-merge-worker@system.gserviceaccount.com>2023-05-11 17:14:49 +0000
commit1cc8ff475bc8ab6edcf501a03034cfe399c034a2 (patch)
treedcd29542fb0b356ead9c857c54aac66d77a8c530
parent9e41804cb697578cc4b60302e3190a86dd8abafd (diff)
parentbe908da35b79969eb11a853cb2090f26346cd821 (diff)
downloadvirtio-drivers-android14-d2-s2-release.tar.gz
Original change: https://googleplex-android-review.googlesource.com/c/platform/external/rust/crates/virtio-drivers/+/23121941 Change-Id: I32e6e4e9831e5f093800a1d2b7c3c5876fa6f717 Signed-off-by: Automerger Merge Worker <android-build-automerger-merge-worker@system.gserviceaccount.com>
-rw-r--r--.cargo_vcs_info.json2
-rw-r--r--.github/workflows/main.yml2
-rw-r--r--Android.bp4
-rw-r--r--Cargo.toml2
-rw-r--r--Cargo.toml.orig2
-rw-r--r--METADATA8
-rw-r--r--README.md1
-rw-r--r--src/device/blk.rs12
-rw-r--r--src/device/common.rs23
-rw-r--r--src/device/console.rs17
-rw-r--r--src/device/gpu.rs131
-rw-r--r--src/device/input.rs40
-rw-r--r--src/device/mod.rs4
-rw-r--r--src/device/net.rs226
-rw-r--r--src/device/socket/error.rs69
-rw-r--r--src/device/socket/mod.rs8
-rw-r--r--src/device/socket/protocol.rs184
-rw-r--r--src/device/socket/vsock.rs596
-rw-r--r--src/hal.rs58
-rw-r--r--src/hal/fake.rs12
-rw-r--r--src/lib.rs9
-rw-r--r--src/queue.rs148
-rw-r--r--src/transport/fake.rs4
-rw-r--r--src/transport/mmio.rs9
-rw-r--r--src/transport/mod.rs3
-rw-r--r--src/transport/pci.rs27
26 files changed, 1386 insertions, 215 deletions
diff --git a/.cargo_vcs_info.json b/.cargo_vcs_info.json
index dc95465..d2f328c 100644
--- a/.cargo_vcs_info.json
+++ b/.cargo_vcs_info.json
@@ -1,6 +1,6 @@
{
"git": {
- "sha1": "70f3f76626420e854f1d7cd1dbc8060c27d848cf"
+ "sha1": "8e52adace55c5e082ba2effffcb70bf480d76ec0"
},
"path_in_vcs": ""
} \ No newline at end of file
diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml
index b2ab149..025b7e9 100644
--- a/.github/workflows/main.yml
+++ b/.github/workflows/main.yml
@@ -78,7 +78,7 @@ jobs:
steps:
- uses: actions/checkout@v2
- name: Install QEMU
- run: sudo apt update && sudo apt install ${{ matrix.packages }}
+ run: sudo apt update && sudo apt install ${{ matrix.packages }} && sudo chmod 666 /dev/vhost-vsock
- uses: actions-rs/toolchain@v1
with:
profile: minimal
diff --git a/Android.bp b/Android.bp
index b8ee9ba..4e47048 100644
--- a/Android.bp
+++ b/Android.bp
@@ -24,7 +24,7 @@ rust_library {
name: "libvirtio_drivers",
crate_name: "virtio_drivers",
cargo_env_compat: true,
- cargo_pkg_version: "0.3.0",
+ cargo_pkg_version: "0.4.0",
srcs: ["src/lib.rs"],
edition: "2018",
no_stdlibs: true,
@@ -43,7 +43,7 @@ rust_test {
name: "virtio-drivers_test_src_lib",
crate_name: "virtio_drivers",
cargo_env_compat: true,
- cargo_pkg_version: "0.3.0",
+ cargo_pkg_version: "0.4.0",
srcs: ["src/lib.rs"],
test_suites: ["general-tests"],
auto_gen_config: true,
diff --git a/Cargo.toml b/Cargo.toml
index 17673fb..7f9968a 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -12,7 +12,7 @@
[package]
edition = "2018"
name = "virtio-drivers"
-version = "0.3.0"
+version = "0.4.0"
authors = [
"Jiajie Chen <noc@jiegec.ac.cn>",
"Runji Wang <wangrunji0408@163.com>",
diff --git a/Cargo.toml.orig b/Cargo.toml.orig
index 01bb35c..431ccd2 100644
--- a/Cargo.toml.orig
+++ b/Cargo.toml.orig
@@ -1,6 +1,6 @@
[package]
name = "virtio-drivers"
-version = "0.3.0"
+version = "0.4.0"
license = "MIT"
authors = [
"Jiajie Chen <noc@jiegec.ac.cn>",
diff --git a/METADATA b/METADATA
index 3623029..b865aa4 100644
--- a/METADATA
+++ b/METADATA
@@ -11,13 +11,13 @@ third_party {
}
url {
type: ARCHIVE
- value: "https://static.crates.io/crates/virtio-drivers/virtio-drivers-0.3.0.crate"
+ value: "https://static.crates.io/crates/virtio-drivers/virtio-drivers-0.4.0.crate"
}
- version: "0.3.0"
+ version: "0.4.0"
license_type: NOTICE
last_upgrade_date {
year: 2023
- month: 1
- day: 24
+ month: 4
+ day: 19
}
}
diff --git a/README.md b/README.md
index fb1cda5..fdb61d8 100644
--- a/README.md
+++ b/README.md
@@ -17,6 +17,7 @@ VirtIO guest drivers in Rust. For **no_std** environment.
| GPU | ✅ |
| Input | ✅ |
| Console | ✅ |
+| Socket | ✅ |
| ... | ❌ |
### Transports
diff --git a/src/device/blk.rs b/src/device/blk.rs
index 69528b6..d095047 100644
--- a/src/device/blk.rs
+++ b/src/device/blk.rs
@@ -109,7 +109,7 @@ impl<H: Hal, T: Transport> VirtIOBlk<H, T> {
let mut resp = BlkResp::default();
self.queue.add_notify_wait_pop(
&[req.as_bytes()],
- &[buf, resp.as_bytes_mut()],
+ &mut [buf, resp.as_bytes_mut()],
&mut self.transport,
)?;
resp.status.into()
@@ -187,7 +187,7 @@ impl<H: Hal, T: Transport> VirtIOBlk<H, T> {
};
let token = self
.queue
- .add(&[req.as_bytes()], &[buf, resp.as_bytes_mut()])?;
+ .add(&[req.as_bytes()], &mut [buf, resp.as_bytes_mut()])?;
if self.queue.should_notify() {
self.transport.notify(QUEUE);
}
@@ -208,7 +208,7 @@ impl<H: Hal, T: Transport> VirtIOBlk<H, T> {
resp: &mut BlkResp,
) -> Result<()> {
self.queue
- .pop_used(token, &[req.as_bytes()], &[buf, resp.as_bytes_mut()])?;
+ .pop_used(token, &[req.as_bytes()], &mut [buf, resp.as_bytes_mut()])?;
resp.status.into()
}
@@ -225,7 +225,7 @@ impl<H: Hal, T: Transport> VirtIOBlk<H, T> {
let mut resp = BlkResp::default();
self.queue.add_notify_wait_pop(
&[req.as_bytes(), buf],
- &[resp.as_bytes_mut()],
+ &mut [resp.as_bytes_mut()],
&mut self.transport,
)?;
resp.status.into()
@@ -268,7 +268,7 @@ impl<H: Hal, T: Transport> VirtIOBlk<H, T> {
};
let token = self
.queue
- .add(&[req.as_bytes(), buf], &[resp.as_bytes_mut()])?;
+ .add(&[req.as_bytes(), buf], &mut [resp.as_bytes_mut()])?;
if self.queue.should_notify() {
self.transport.notify(QUEUE);
}
@@ -289,7 +289,7 @@ impl<H: Hal, T: Transport> VirtIOBlk<H, T> {
resp: &mut BlkResp,
) -> Result<()> {
self.queue
- .pop_used(token, &[req.as_bytes(), buf], &[resp.as_bytes_mut()])?;
+ .pop_used(token, &[req.as_bytes(), buf], &mut [resp.as_bytes_mut()])?;
resp.status.into()
}
diff --git a/src/device/common.rs b/src/device/common.rs
new file mode 100644
index 0000000..2c8be3e
--- /dev/null
+++ b/src/device/common.rs
@@ -0,0 +1,23 @@
+//! Common part shared across all the devices.
+
+use bitflags::bitflags;
+
+bitflags! {
+ pub(crate) struct Feature: u64 {
+ // device independent
+ const NOTIFY_ON_EMPTY = 1 << 24; // legacy
+ const ANY_LAYOUT = 1 << 27; // legacy
+ const RING_INDIRECT_DESC = 1 << 28;
+ const RING_EVENT_IDX = 1 << 29;
+ const UNUSED = 1 << 30; // legacy
+ const VERSION_1 = 1 << 32; // detect legacy
+
+ // since virtio v1.1
+ const ACCESS_PLATFORM = 1 << 33;
+ const RING_PACKED = 1 << 34;
+ const IN_ORDER = 1 << 35;
+ const ORDER_PLATFORM = 1 << 36;
+ const SR_IOV = 1 << 37;
+ const NOTIFICATION_DATA = 1 << 38;
+ }
+}
diff --git a/src/device/console.rs b/src/device/console.rs
index 749ebc1..e0b0356 100644
--- a/src/device/console.rs
+++ b/src/device/console.rs
@@ -118,7 +118,7 @@ impl<H: Hal, T: Transport> VirtIOConsole<'_, H, T> {
if self.receive_token.is_none() && self.cursor == self.pending_len {
// Safe because the buffer lasts at least as long as the queue, and there are no other
// outstanding requests using the buffer.
- self.receive_token = Some(unsafe { self.receiveq.add(&[], &[self.queue_buf_rx]) }?);
+ self.receive_token = Some(unsafe { self.receiveq.add(&[], &mut [self.queue_buf_rx]) }?);
if self.receiveq.should_notify() {
self.transport.notify(QUEUE_RECEIVEQ_PORT_0);
}
@@ -145,13 +145,19 @@ impl<H: Hal, T: Transport> VirtIOConsole<'_, H, T> {
let mut flag = false;
if let Some(receive_token) = self.receive_token {
if self.receive_token == self.receiveq.peek_used() {
- let len = self
- .receiveq
- .pop_used(receive_token, &[], &[self.queue_buf_rx])?;
+ // Safe because we are passing the same buffer as we passed to `VirtQueue::add` in
+ // `poll_retrieve` and it is still valid.
+ let len = unsafe {
+ self.receiveq
+ .pop_used(receive_token, &[], &mut [self.queue_buf_rx])?
+ };
flag = true;
assert_ne!(len, 0);
self.cursor = 0;
self.pending_len = len as usize;
+ // Clear `receive_token` so that when the buffer is used up the next call to
+ // `poll_retrieve` will add a new pending request.
+ self.receive_token.take();
}
}
Ok(flag)
@@ -176,9 +182,8 @@ impl<H: Hal, T: Transport> VirtIOConsole<'_, H, T> {
/// Sends a character to the console.
pub fn send(&mut self, chr: u8) -> Result<()> {
let buf: [u8; 1] = [chr];
- // Safe because the buffer is valid until we pop_used below.
self.transmitq
- .add_notify_wait_pop(&[&buf], &[], &mut self.transport)?;
+ .add_notify_wait_pop(&[&buf], &mut [], &mut self.transport)?;
Ok(())
}
}
diff --git a/src/device/gpu.rs b/src/device/gpu.rs
index eabf2d4..b1b53bd 100644
--- a/src/device/gpu.rs
+++ b/src/device/gpu.rs
@@ -7,6 +7,7 @@ use crate::volatile::{volread, ReadOnly, Volatile, WriteOnly};
use crate::{pages, Error, Result};
use bitflags::bitflags;
use log::info;
+use zerocopy::{AsBytes, FromBytes};
const QUEUE_SIZE: u16 = 2;
@@ -173,86 +174,86 @@ impl<H: Hal, T: Transport> VirtIOGpu<'_, H, T> {
}
/// Send a request to the device and block for a response.
- fn request<Req, Rsp>(&mut self, req: Req) -> Result<Rsp> {
- unsafe {
- (self.queue_buf_send.as_mut_ptr() as *mut Req).write(req);
- }
+ fn request<Req: AsBytes, Rsp: FromBytes>(&mut self, req: Req) -> Result<Rsp> {
+ req.write_to_prefix(&mut *self.queue_buf_send).unwrap();
self.control_queue.add_notify_wait_pop(
&[self.queue_buf_send],
- &[self.queue_buf_recv],
+ &mut [self.queue_buf_recv],
&mut self.transport,
)?;
- Ok(unsafe { (self.queue_buf_recv.as_ptr() as *const Rsp).read() })
+ Ok(Rsp::read_from_prefix(&*self.queue_buf_recv).unwrap())
}
/// Send a mouse cursor operation request to the device and block for a response.
- fn cursor_request<Req>(&mut self, req: Req) -> Result {
- unsafe {
- (self.queue_buf_send.as_mut_ptr() as *mut Req).write(req);
- }
- self.cursor_queue
- .add_notify_wait_pop(&[self.queue_buf_send], &[], &mut self.transport)?;
+ fn cursor_request<Req: AsBytes>(&mut self, req: Req) -> Result {
+ req.write_to_prefix(&mut *self.queue_buf_send).unwrap();
+ self.cursor_queue.add_notify_wait_pop(
+ &[self.queue_buf_send],
+ &mut [],
+ &mut self.transport,
+ )?;
Ok(())
}
fn get_display_info(&mut self) -> Result<RespDisplayInfo> {
- let info: RespDisplayInfo = self.request(CtrlHeader::with_type(Command::GetDisplayInfo))?;
- info.header.check_type(Command::OkDisplayInfo)?;
+ let info: RespDisplayInfo =
+ self.request(CtrlHeader::with_type(Command::GET_DISPLAY_INFO))?;
+ info.header.check_type(Command::OK_DISPLAY_INFO)?;
Ok(info)
}
fn resource_create_2d(&mut self, resource_id: u32, width: u32, height: u32) -> Result {
let rsp: CtrlHeader = self.request(ResourceCreate2D {
- header: CtrlHeader::with_type(Command::ResourceCreate2d),
+ header: CtrlHeader::with_type(Command::RESOURCE_CREATE_2D),
resource_id,
format: Format::B8G8R8A8UNORM,
width,
height,
})?;
- rsp.check_type(Command::OkNodata)
+ rsp.check_type(Command::OK_NODATA)
}
fn set_scanout(&mut self, rect: Rect, scanout_id: u32, resource_id: u32) -> Result {
let rsp: CtrlHeader = self.request(SetScanout {
- header: CtrlHeader::with_type(Command::SetScanout),
+ header: CtrlHeader::with_type(Command::SET_SCANOUT),
rect,
scanout_id,
resource_id,
})?;
- rsp.check_type(Command::OkNodata)
+ rsp.check_type(Command::OK_NODATA)
}
fn resource_flush(&mut self, rect: Rect, resource_id: u32) -> Result {
let rsp: CtrlHeader = self.request(ResourceFlush {
- header: CtrlHeader::with_type(Command::ResourceFlush),
+ header: CtrlHeader::with_type(Command::RESOURCE_FLUSH),
rect,
resource_id,
_padding: 0,
})?;
- rsp.check_type(Command::OkNodata)
+ rsp.check_type(Command::OK_NODATA)
}
fn transfer_to_host_2d(&mut self, rect: Rect, offset: u64, resource_id: u32) -> Result {
let rsp: CtrlHeader = self.request(TransferToHost2D {
- header: CtrlHeader::with_type(Command::TransferToHost2d),
+ header: CtrlHeader::with_type(Command::TRANSFER_TO_HOST_2D),
rect,
offset,
resource_id,
_padding: 0,
})?;
- rsp.check_type(Command::OkNodata)
+ rsp.check_type(Command::OK_NODATA)
}
fn resource_attach_backing(&mut self, resource_id: u32, paddr: u64, length: u32) -> Result {
let rsp: CtrlHeader = self.request(ResourceAttachBacking {
- header: CtrlHeader::with_type(Command::ResourceAttachBacking),
+ header: CtrlHeader::with_type(Command::RESOURCE_ATTACH_BACKING),
resource_id,
nr_entries: 1,
addr: paddr,
length,
_padding: 0,
})?;
- rsp.check_type(Command::OkNodata)
+ rsp.check_type(Command::OK_NODATA)
}
fn update_cursor(
@@ -267,9 +268,9 @@ impl<H: Hal, T: Transport> VirtIOGpu<'_, H, T> {
) -> Result {
self.cursor_request(UpdateCursor {
header: if is_move {
- CtrlHeader::with_type(Command::MoveCursor)
+ CtrlHeader::with_type(Command::MOVE_CURSOR)
} else {
- CtrlHeader::with_type(Command::UpdateCursor)
+ CtrlHeader::with_type(Command::UPDATE_CURSOR)
},
pos: CursorPos {
scanout_id,
@@ -336,39 +337,41 @@ bitflags! {
}
}
-#[repr(u32)]
-#[derive(Debug, Clone, Copy, PartialEq, Eq)]
-enum Command {
- GetDisplayInfo = 0x100,
- ResourceCreate2d = 0x101,
- ResourceUnref = 0x102,
- SetScanout = 0x103,
- ResourceFlush = 0x104,
- TransferToHost2d = 0x105,
- ResourceAttachBacking = 0x106,
- ResourceDetachBacking = 0x107,
- GetCapsetInfo = 0x108,
- GetCapset = 0x109,
- GetEdid = 0x10a,
-
- UpdateCursor = 0x300,
- MoveCursor = 0x301,
-
- OkNodata = 0x1100,
- OkDisplayInfo = 0x1101,
- OkCapsetInfo = 0x1102,
- OkCapset = 0x1103,
- OkEdid = 0x1104,
-
- ErrUnspec = 0x1200,
- ErrOutOfMemory = 0x1201,
- ErrInvalidScanoutId = 0x1202,
+#[repr(transparent)]
+#[derive(AsBytes, Clone, Copy, Debug, Eq, PartialEq, FromBytes)]
+struct Command(u32);
+
+impl Command {
+ const GET_DISPLAY_INFO: Command = Command(0x100);
+ const RESOURCE_CREATE_2D: Command = Command(0x101);
+ const RESOURCE_UNREF: Command = Command(0x102);
+ const SET_SCANOUT: Command = Command(0x103);
+ const RESOURCE_FLUSH: Command = Command(0x104);
+ const TRANSFER_TO_HOST_2D: Command = Command(0x105);
+ const RESOURCE_ATTACH_BACKING: Command = Command(0x106);
+ const RESOURCE_DETACH_BACKING: Command = Command(0x107);
+ const GET_CAPSET_INFO: Command = Command(0x108);
+ const GET_CAPSET: Command = Command(0x109);
+ const GET_EDID: Command = Command(0x10a);
+
+ const UPDATE_CURSOR: Command = Command(0x300);
+ const MOVE_CURSOR: Command = Command(0x301);
+
+ const OK_NODATA: Command = Command(0x1100);
+ const OK_DISPLAY_INFO: Command = Command(0x1101);
+ const OK_CAPSET_INFO: Command = Command(0x1102);
+ const OK_CAPSET: Command = Command(0x1103);
+ const OK_EDID: Command = Command(0x1104);
+
+ const ERR_UNSPEC: Command = Command(0x1200);
+ const ERR_OUT_OF_MEMORY: Command = Command(0x1201);
+ const ERR_INVALID_SCANOUT_ID: Command = Command(0x1202);
}
const GPU_FLAG_FENCE: u32 = 1 << 0;
#[repr(C)]
-#[derive(Debug, Clone, Copy)]
+#[derive(AsBytes, Debug, Clone, Copy, FromBytes)]
struct CtrlHeader {
hdr_type: Command,
flags: u32,
@@ -399,7 +402,7 @@ impl CtrlHeader {
}
#[repr(C)]
-#[derive(Debug, Copy, Clone, Default)]
+#[derive(AsBytes, Debug, Copy, Clone, Default, FromBytes)]
struct Rect {
x: u32,
y: u32,
@@ -408,7 +411,7 @@ struct Rect {
}
#[repr(C)]
-#[derive(Debug)]
+#[derive(Debug, FromBytes)]
struct RespDisplayInfo {
header: CtrlHeader,
rect: Rect,
@@ -417,7 +420,7 @@ struct RespDisplayInfo {
}
#[repr(C)]
-#[derive(Debug)]
+#[derive(AsBytes, Debug)]
struct ResourceCreate2D {
header: CtrlHeader,
resource_id: u32,
@@ -427,13 +430,13 @@ struct ResourceCreate2D {
}
#[repr(u32)]
-#[derive(Debug)]
+#[derive(AsBytes, Debug)]
enum Format {
B8G8R8A8UNORM = 1,
}
#[repr(C)]
-#[derive(Debug)]
+#[derive(AsBytes, Debug)]
struct ResourceAttachBacking {
header: CtrlHeader,
resource_id: u32,
@@ -444,7 +447,7 @@ struct ResourceAttachBacking {
}
#[repr(C)]
-#[derive(Debug)]
+#[derive(AsBytes, Debug)]
struct SetScanout {
header: CtrlHeader,
rect: Rect,
@@ -453,7 +456,7 @@ struct SetScanout {
}
#[repr(C)]
-#[derive(Debug)]
+#[derive(AsBytes, Debug)]
struct TransferToHost2D {
header: CtrlHeader,
rect: Rect,
@@ -463,7 +466,7 @@ struct TransferToHost2D {
}
#[repr(C)]
-#[derive(Debug)]
+#[derive(AsBytes, Debug)]
struct ResourceFlush {
header: CtrlHeader,
rect: Rect,
@@ -472,7 +475,7 @@ struct ResourceFlush {
}
#[repr(C)]
-#[derive(Debug, Clone, Copy)]
+#[derive(AsBytes, Debug, Clone, Copy)]
struct CursorPos {
scanout_id: u32,
x: u32,
@@ -481,7 +484,7 @@ struct CursorPos {
}
#[repr(C)]
-#[derive(Debug, Clone, Copy)]
+#[derive(AsBytes, Debug, Clone, Copy)]
struct UpdateCursor {
header: CtrlHeader,
pos: CursorPos,
diff --git a/src/device/input.rs b/src/device/input.rs
index 8554282..dee2fec 100644
--- a/src/device/input.rs
+++ b/src/device/input.rs
@@ -1,12 +1,12 @@
//! Driver for VirtIO input devices.
+use super::common::Feature;
use crate::hal::Hal;
use crate::queue::VirtQueue;
use crate::transport::Transport;
use crate::volatile::{volread, volwrite, ReadOnly, WriteOnly};
use crate::Result;
use alloc::boxed::Box;
-use bitflags::bitflags;
use core::ptr::NonNull;
use log::info;
use zerocopy::{AsBytes, FromBytes};
@@ -42,7 +42,7 @@ impl<H: Hal, T: Transport> VirtIOInput<H, T> {
let status_queue = VirtQueue::new(&mut transport, QUEUE_STATUS)?;
for (i, event) in event_buf.as_mut().iter_mut().enumerate() {
// Safe because the buffer lasts as long as the queue.
- let token = unsafe { event_queue.add(&[], &[event.as_bytes_mut()])? };
+ let token = unsafe { event_queue.add(&[], &mut [event.as_bytes_mut()])? };
assert_eq!(token, i as u16);
}
if event_queue.should_notify() {
@@ -69,12 +69,18 @@ impl<H: Hal, T: Transport> VirtIOInput<H, T> {
pub fn pop_pending_event(&mut self) -> Option<InputEvent> {
if let Some(token) = self.event_queue.peek_used() {
let event = &mut self.event_buf[token as usize];
- self.event_queue
- .pop_used(token, &[], &[event.as_bytes_mut()])
- .ok()?;
+ // Safe because we are passing the same buffer as we passed to `VirtQueue::add` and it
+ // is still valid.
+ unsafe {
+ self.event_queue
+ .pop_used(token, &[], &mut [event.as_bytes_mut()])
+ .ok()?;
+ }
+ let event_saved = *event;
// requeue
// Safe because buffer lasts as long as the queue.
- if let Ok(new_token) = unsafe { self.event_queue.add(&[], &[event.as_bytes_mut()]) } {
+ if let Ok(new_token) = unsafe { self.event_queue.add(&[], &mut [event.as_bytes_mut()]) }
+ {
// This only works because nothing happen between `pop_used` and `add` that affects
// the list of free descriptors in the queue, so `add` reuses the descriptor which
// was just freed by `pop_used`.
@@ -82,7 +88,7 @@ impl<H: Hal, T: Transport> VirtIOInput<H, T> {
if self.event_queue.should_notify() {
self.transport.notify(QUEUE_EVENT);
}
- return Some(*event);
+ return Some(event_saved);
}
}
None
@@ -185,26 +191,6 @@ pub struct InputEvent {
pub value: u32,
}
-bitflags! {
- struct Feature: u64 {
- // device independent
- const NOTIFY_ON_EMPTY = 1 << 24; // legacy
- const ANY_LAYOUT = 1 << 27; // legacy
- const RING_INDIRECT_DESC = 1 << 28;
- const RING_EVENT_IDX = 1 << 29;
- const UNUSED = 1 << 30; // legacy
- const VERSION_1 = 1 << 32; // detect legacy
-
- // since virtio v1.1
- const ACCESS_PLATFORM = 1 << 33;
- const RING_PACKED = 1 << 34;
- const IN_ORDER = 1 << 35;
- const ORDER_PLATFORM = 1 << 36;
- const SR_IOV = 1 << 37;
- const NOTIFICATION_DATA = 1 << 38;
- }
-}
-
const QUEUE_EVENT: u16 = 0;
const QUEUE_STATUS: u16 = 1;
diff --git a/src/device/mod.rs b/src/device/mod.rs
index f3e4f66..ca68901 100644
--- a/src/device/mod.rs
+++ b/src/device/mod.rs
@@ -5,4 +5,8 @@ pub mod console;
pub mod gpu;
#[cfg(feature = "alloc")]
pub mod input;
+#[cfg(feature = "alloc")]
pub mod net;
+pub mod socket;
+
+pub(crate) mod common;
diff --git a/src/device/net.rs b/src/device/net.rs
index 7ca487e..4441f63 100644
--- a/src/device/net.rs
+++ b/src/device/net.rs
@@ -4,13 +4,95 @@ use crate::hal::Hal;
use crate::queue::VirtQueue;
use crate::transport::Transport;
use crate::volatile::{volread, ReadOnly};
-use crate::Result;
+use crate::{Error, Result};
+use alloc::{vec, vec::Vec};
use bitflags::bitflags;
-use core::mem::{size_of, MaybeUninit};
-use log::{debug, info};
+use core::{convert::TryInto, mem::size_of};
+use log::{debug, info, warn};
use zerocopy::{AsBytes, FromBytes};
-const QUEUE_SIZE: u16 = 2;
+const MAX_BUFFER_LEN: usize = 65535;
+const MIN_BUFFER_LEN: usize = 1526;
+const NET_HDR_SIZE: usize = size_of::<VirtioNetHdr>();
+
+/// A buffer used for transmitting.
+pub struct TxBuffer(Vec<u8>);
+
+/// A buffer used for receiving.
+pub struct RxBuffer {
+ buf: Vec<usize>, // for alignment
+ packet_len: usize,
+ idx: u16,
+}
+
+impl TxBuffer {
+ /// Constructs the buffer from the given slice.
+ pub fn from(buf: &[u8]) -> Self {
+ Self(Vec::from(buf))
+ }
+
+ /// Returns the network packet length.
+ pub fn packet_len(&self) -> usize {
+ self.0.len()
+ }
+
+ /// Returns the network packet as a slice.
+ pub fn packet(&self) -> &[u8] {
+ self.0.as_slice()
+ }
+
+ /// Returns the network packet as a mutable slice.
+ pub fn packet_mut(&mut self) -> &mut [u8] {
+ self.0.as_mut_slice()
+ }
+}
+
+impl RxBuffer {
+ /// Allocates a new buffer with length `buf_len`.
+ fn new(idx: usize, buf_len: usize) -> Self {
+ Self {
+ buf: vec![0; buf_len / size_of::<usize>()],
+ packet_len: 0,
+ idx: idx.try_into().unwrap(),
+ }
+ }
+
+ /// Set the network packet length.
+ fn set_packet_len(&mut self, packet_len: usize) {
+ self.packet_len = packet_len
+ }
+
+ /// Returns the network packet length (witout header).
+ pub const fn packet_len(&self) -> usize {
+ self.packet_len
+ }
+
+ /// Returns all data in the buffer, including both the header and the packet.
+ pub fn as_bytes(&self) -> &[u8] {
+ self.buf.as_bytes()
+ }
+
+ /// Returns all data in the buffer with the mutable reference,
+ /// including both the header and the packet.
+ pub fn as_bytes_mut(&mut self) -> &mut [u8] {
+ self.buf.as_bytes_mut()
+ }
+
+ /// Returns the reference of the header.
+ pub fn header(&self) -> &VirtioNetHdr {
+ unsafe { &*(self.buf.as_ptr() as *const VirtioNetHdr) }
+ }
+
+ /// Returns the network packet as a slice.
+ pub fn packet(&self) -> &[u8] {
+ &self.buf.as_bytes()[NET_HDR_SIZE..NET_HDR_SIZE + self.packet_len]
+ }
+
+ /// Returns the network packet as a mutable slice.
+ pub fn packet_mut(&mut self) -> &mut [u8] {
+ &mut self.buf.as_bytes_mut()[NET_HDR_SIZE..NET_HDR_SIZE + self.packet_len]
+ }
+}
/// The virtio network device is a virtual ethernet card.
///
@@ -19,16 +101,17 @@ const QUEUE_SIZE: u16 = 2;
/// Empty buffers are placed in one virtqueue for receiving packets, and
/// outgoing packets are enqueued into another for transmission in that order.
/// A third command queue is used to control advanced filtering features.
-pub struct VirtIONet<H: Hal, T: Transport> {
+pub struct VirtIONet<H: Hal, T: Transport, const QUEUE_SIZE: usize> {
transport: T,
mac: EthernetAddress,
- recv_queue: VirtQueue<H, { QUEUE_SIZE as usize }>,
- send_queue: VirtQueue<H, { QUEUE_SIZE as usize }>,
+ recv_queue: VirtQueue<H, QUEUE_SIZE>,
+ send_queue: VirtQueue<H, QUEUE_SIZE>,
+ rx_buffers: [Option<RxBuffer>; QUEUE_SIZE],
}
-impl<H: Hal, T: Transport> VirtIONet<H, T> {
+impl<H: Hal, T: Transport, const QUEUE_SIZE: usize> VirtIONet<H, T, QUEUE_SIZE> {
/// Create a new VirtIO-Net driver.
- pub fn new(mut transport: T) -> Result<Self> {
+ pub fn new(mut transport: T, buf_len: usize) -> Result<Self> {
transport.begin_init(|features| {
let features = Features::from_bits_truncate(features);
info!("Device features {:?}", features);
@@ -41,11 +124,37 @@ impl<H: Hal, T: Transport> VirtIONet<H, T> {
// Safe because config points to a valid MMIO region for the config space.
unsafe {
mac = volread!(config, mac);
- debug!("Got MAC={:?}, status={:?}", mac, volread!(config, status));
+ debug!(
+ "Got MAC={:02x?}, status={:?}",
+ mac,
+ volread!(config, status)
+ );
+ }
+
+ if !(MIN_BUFFER_LEN..=MAX_BUFFER_LEN).contains(&buf_len) {
+ warn!(
+ "Receive buffer len {} is not in range [{}, {}]",
+ buf_len, MIN_BUFFER_LEN, MAX_BUFFER_LEN
+ );
+ return Err(Error::InvalidParam);
}
- let recv_queue = VirtQueue::new(&mut transport, QUEUE_RECEIVE)?;
let send_queue = VirtQueue::new(&mut transport, QUEUE_TRANSMIT)?;
+ let mut recv_queue = VirtQueue::new(&mut transport, QUEUE_RECEIVE)?;
+
+ const NONE_BUF: Option<RxBuffer> = None;
+ let mut rx_buffers = [NONE_BUF; QUEUE_SIZE];
+ for (i, rx_buf_place) in rx_buffers.iter_mut().enumerate() {
+ let mut rx_buf = RxBuffer::new(i, buf_len);
+ // Safe because the buffer lives as long as the queue.
+ let token = unsafe { recv_queue.add(&[], &mut [rx_buf.as_bytes_mut()])? };
+ assert_eq!(token, i as u16);
+ *rx_buf_place = Some(rx_buf);
+ }
+
+ if recv_queue.should_notify() {
+ transport.notify(QUEUE_RECEIVE);
+ }
transport.finish_init();
@@ -54,6 +163,7 @@ impl<H: Hal, T: Transport> VirtIONet<H, T> {
mac,
recv_queue,
send_queue,
+ rx_buffers,
})
}
@@ -63,7 +173,7 @@ impl<H: Hal, T: Transport> VirtIONet<H, T> {
}
/// Get MAC address.
- pub fn mac(&self) -> EthernetAddress {
+ pub fn mac_address(&self) -> EthernetAddress {
self.mac
}
@@ -77,27 +187,72 @@ impl<H: Hal, T: Transport> VirtIONet<H, T> {
self.recv_queue.can_pop()
}
- /// Receive a packet.
- pub fn recv(&mut self, buf: &mut [u8]) -> Result<usize> {
- let mut header = MaybeUninit::<Header>::uninit();
- let header_buf = unsafe { (*header.as_mut_ptr()).as_bytes_mut() };
- let len =
- self.recv_queue
- .add_notify_wait_pop(&[], &[header_buf, buf], &mut self.transport)?;
- // let header = unsafe { header.assume_init() };
- Ok(len as usize - size_of::<Header>())
+ /// Receives a [`RxBuffer`] from network. If currently no data, returns an
+ /// error with type [`Error::NotReady`].
+ ///
+ /// It will try to pop a buffer that completed data reception in the
+ /// NIC queue.
+ pub fn receive(&mut self) -> Result<RxBuffer> {
+ if let Some(token) = self.recv_queue.peek_used() {
+ let mut rx_buf = self.rx_buffers[token as usize]
+ .take()
+ .ok_or(Error::WrongToken)?;
+ if token != rx_buf.idx {
+ return Err(Error::WrongToken);
+ }
+
+ // Safe because `token` == `rx_buf.idx`, we are passing the same
+ // buffer as we passed to `VirtQueue::add` and it is still valid.
+ let len = unsafe {
+ self.recv_queue
+ .pop_used(token, &[], &mut [rx_buf.as_bytes_mut()])?
+ } as usize;
+ rx_buf.set_packet_len(len.checked_sub(NET_HDR_SIZE).ok_or(Error::IoError)?);
+ Ok(rx_buf)
+ } else {
+ Err(Error::NotReady)
+ }
}
- /// Send a packet.
- pub fn send(&mut self, buf: &[u8]) -> Result {
- let header = unsafe { MaybeUninit::<Header>::zeroed().assume_init() };
- self.send_queue
- .add_notify_wait_pop(&[header.as_bytes(), buf], &[], &mut self.transport)?;
+ /// Gives back the ownership of `rx_buf`, and recycles it for next use.
+ ///
+ /// It will add the buffer back to the NIC queue.
+ pub fn recycle_rx_buffer(&mut self, mut rx_buf: RxBuffer) -> Result {
+ // Safe because we take the ownership of `rx_buf` back to `rx_buffers`,
+ // it lives as long as the queue.
+ let new_token = unsafe { self.recv_queue.add(&[], &mut [rx_buf.as_bytes_mut()]) }?;
+ // `rx_buffers[new_token]` is expected to be `None` since it was taken
+ // away at `Self::receive()` and has not been added back.
+ if self.rx_buffers[new_token as usize].is_some() {
+ return Err(Error::WrongToken);
+ }
+ rx_buf.idx = new_token;
+ self.rx_buffers[new_token as usize] = Some(rx_buf);
+ if self.recv_queue.should_notify() {
+ self.transport.notify(QUEUE_RECEIVE);
+ }
+ Ok(())
+ }
+
+ /// Allocate a new buffer for transmitting.
+ pub fn new_tx_buffer(&self, buf_len: usize) -> TxBuffer {
+ TxBuffer(vec![0; buf_len])
+ }
+
+ /// Sends a [`TxBuffer`] to the network, and blocks until the request
+ /// completed.
+ pub fn send(&mut self, tx_buf: TxBuffer) -> Result {
+ let header = VirtioNetHdr::default();
+ self.send_queue.add_notify_wait_pop(
+ &[header.as_bytes(), tx_buf.packet()],
+ &mut [],
+ &mut self.transport,
+ )?;
Ok(())
}
}
-impl<H: Hal, T: Transport> Drop for VirtIONet<H, T> {
+impl<H: Hal, T: Transport, const QUEUE_SIZE: usize> Drop for VirtIONet<H, T, QUEUE_SIZE> {
fn drop(&mut self) {
// Clear any pointers pointing to DMA regions, so the device doesn't try to access them
// after they have been freed.
@@ -185,26 +340,33 @@ bitflags! {
struct Config {
mac: ReadOnly<EthernetAddress>,
status: ReadOnly<Status>,
+ max_virtqueue_pairs: ReadOnly<u16>,
+ mtu: ReadOnly<u16>,
}
type EthernetAddress = [u8; 6];
-// virtio 5.1.6 Device Operation
+/// VirtIO 5.1.6 Device Operation:
+///
+/// Packets are transmitted by placing them in the transmitq1. . .transmitqN,
+/// and buffers for incoming packets are placed in the receiveq1. . .receiveqN.
+/// In each case, the packet itself is preceded by a header.
#[repr(C)]
-#[derive(AsBytes, Debug, FromBytes)]
-struct Header {
+#[derive(AsBytes, Debug, Default, FromBytes)]
+pub struct VirtioNetHdr {
flags: Flags,
gso_type: GsoType,
hdr_len: u16, // cannot rely on this
gso_size: u16,
csum_start: u16,
csum_offset: u16,
+ // num_buffers: u16, // only available when the feature MRG_RXBUF is negotiated.
// payload starts from here
}
bitflags! {
#[repr(transparent)]
- #[derive(AsBytes, FromBytes)]
+ #[derive(AsBytes, Default, FromBytes)]
struct Flags: u8 {
const NEEDS_CSUM = 1;
const DATA_VALID = 2;
@@ -213,7 +375,7 @@ bitflags! {
}
#[repr(transparent)]
-#[derive(AsBytes, Debug, Copy, Clone, Eq, FromBytes, PartialEq)]
+#[derive(AsBytes, Debug, Copy, Clone, Default, Eq, FromBytes, PartialEq)]
struct GsoType(u8);
impl GsoType {
diff --git a/src/device/socket/error.rs b/src/device/socket/error.rs
new file mode 100644
index 0000000..4beec38
--- /dev/null
+++ b/src/device/socket/error.rs
@@ -0,0 +1,69 @@
+//! This module contain the error from the VirtIO socket driver.
+
+use core::{fmt, result};
+
+/// The error type of VirtIO socket driver.
+#[derive(Copy, Clone, Debug, Eq, PartialEq)]
+pub enum SocketError {
+ /// There is an existing connection.
+ ConnectionExists,
+ /// Failed to establish the connection.
+ ConnectionFailed,
+ /// The device is not connected to any peer.
+ NotConnected,
+ /// Peer socket is shutdown.
+ PeerSocketShutdown,
+ /// No response received.
+ NoResponseReceived,
+ /// The given buffer is shorter than expected.
+ BufferTooShort,
+ /// The given buffer for output is shorter than expected.
+ OutputBufferTooShort(usize),
+ /// The given buffer has exceeded the maximum buffer size.
+ BufferTooLong(usize, usize),
+ /// Unknown operation.
+ UnknownOperation(u16),
+ /// Invalid operation,
+ InvalidOperation,
+ /// Invalid number.
+ InvalidNumber,
+ /// Unexpected data in packet.
+ UnexpectedDataInPacket,
+ /// Peer has insufficient buffer space, try again later.
+ InsufficientBufferSpaceInPeer,
+ /// Recycled a wrong buffer.
+ RecycledWrongBuffer,
+}
+
+impl fmt::Display for SocketError {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ match self {
+ Self::ConnectionExists => write!(
+ f,
+ "There is an existing connection. Please close the current connection before attempting to connect again."),
+ Self::ConnectionFailed => write!(
+ f, "Failed to establish the connection. The packet sent may have an unknown type value"
+ ),
+ Self::NotConnected => write!(f, "The device is not connected to any peer. Please connect it to a peer first."),
+ Self::PeerSocketShutdown => write!(f, "The peer socket is shutdown."),
+ Self::NoResponseReceived => write!(f, "No response received"),
+ Self::BufferTooShort => write!(f, "The given buffer is shorter than expected"),
+ Self::BufferTooLong(actual, max) => {
+ write!(f, "The given buffer length '{actual}' has exceeded the maximum allowed buffer length '{max}'")
+ }
+ Self::OutputBufferTooShort(expected) => {
+ write!(f, "The given output buffer is too short. '{expected}' bytes is needed for the output buffer.")
+ }
+ Self::UnknownOperation(op) => {
+ write!(f, "The operation code '{op}' is unknown")
+ }
+ Self::InvalidOperation => write!(f, "Invalid operation"),
+ Self::InvalidNumber => write!(f, "Invalid number"),
+ Self::UnexpectedDataInPacket => write!(f, "No data is expected in the packet"),
+ Self::InsufficientBufferSpaceInPeer => write!(f, "Peer has insufficient buffer space, try again later"),
+ Self::RecycledWrongBuffer => write!(f, "Recycled a wrong buffer"),
+ }
+ }
+}
+
+pub type Result<T> = result::Result<T, SocketError>;
diff --git a/src/device/socket/mod.rs b/src/device/socket/mod.rs
new file mode 100644
index 0000000..65280aa
--- /dev/null
+++ b/src/device/socket/mod.rs
@@ -0,0 +1,8 @@
+//! This module implements the virtio vsock device.
+
+mod error;
+mod protocol;
+mod vsock;
+
+pub use error::SocketError;
+pub use vsock::{DisconnectReason, VirtIOSocket, VsockEvent, VsockEventType};
diff --git a/src/device/socket/protocol.rs b/src/device/socket/protocol.rs
new file mode 100644
index 0000000..abc1702
--- /dev/null
+++ b/src/device/socket/protocol.rs
@@ -0,0 +1,184 @@
+//! This module defines the socket device protocol according to the virtio spec v1.1 5.10 Socket Device
+
+use super::error::{self, SocketError};
+use crate::volatile::ReadOnly;
+use core::{
+ convert::{TryFrom, TryInto},
+ fmt,
+};
+use zerocopy::{
+ byteorder::{LittleEndian, U16, U32, U64},
+ AsBytes, FromBytes,
+};
+
+/// Currently only stream sockets are supported. type is 1 for stream socket types.
+#[derive(Copy, Clone, Debug)]
+#[repr(u16)]
+pub enum SocketType {
+ /// Stream sockets provide in-order, guaranteed, connection-oriented delivery without message boundaries.
+ Stream = 1,
+}
+
+impl From<SocketType> for U16<LittleEndian> {
+ fn from(socket_type: SocketType) -> Self {
+ (socket_type as u16).into()
+ }
+}
+
+/// VirtioVsockConfig is the vsock device configuration space.
+#[repr(C)]
+pub struct VirtioVsockConfig {
+ /// The guest_cid field contains the guest’s context ID, which uniquely identifies
+ /// the device for its lifetime. The upper 32 bits of the CID are reserved and zeroed.
+ ///
+ /// According to virtio spec v1.1 2.4.1 Driver Requirements: Device Configuration Space,
+ /// drivers MUST NOT assume reads from fields greater than 32 bits wide are atomic.
+ /// So we need to split the u64 guest_cid into two parts.
+ pub guest_cid_low: ReadOnly<u32>,
+ pub guest_cid_high: ReadOnly<u32>,
+}
+
+/// The message header for data packets sent on the tx/rx queues
+#[repr(packed)]
+#[derive(AsBytes, Clone, Copy, Debug, FromBytes)]
+pub struct VirtioVsockHdr {
+ pub src_cid: U64<LittleEndian>,
+ pub dst_cid: U64<LittleEndian>,
+ pub src_port: U32<LittleEndian>,
+ pub dst_port: U32<LittleEndian>,
+ pub len: U32<LittleEndian>,
+ pub socket_type: U16<LittleEndian>,
+ pub op: U16<LittleEndian>,
+ pub flags: U32<LittleEndian>,
+ /// Total receive buffer space for this socket. This includes both free and in-use buffers.
+ pub buf_alloc: U32<LittleEndian>,
+ /// Free-running bytes received counter.
+ pub fwd_cnt: U32<LittleEndian>,
+}
+
+impl Default for VirtioVsockHdr {
+ fn default() -> Self {
+ Self {
+ src_cid: 0.into(),
+ dst_cid: 0.into(),
+ src_port: 0.into(),
+ dst_port: 0.into(),
+ len: 0.into(),
+ socket_type: SocketType::Stream.into(),
+ op: 0.into(),
+ flags: 0.into(),
+ buf_alloc: 0.into(),
+ fwd_cnt: 0.into(),
+ }
+ }
+}
+
+impl VirtioVsockHdr {
+ /// Returns the length of the data.
+ pub fn len(&self) -> u32 {
+ u32::from(self.len)
+ }
+
+ pub fn op(&self) -> error::Result<VirtioVsockOp> {
+ self.op.try_into()
+ }
+
+ pub fn source(&self) -> VsockAddr {
+ VsockAddr {
+ cid: self.src_cid.get(),
+ port: self.src_port.get(),
+ }
+ }
+
+ pub fn destination(&self) -> VsockAddr {
+ VsockAddr {
+ cid: self.dst_cid.get(),
+ port: self.dst_port.get(),
+ }
+ }
+
+ pub fn check_data_is_empty(&self) -> error::Result<()> {
+ if self.len() == 0 {
+ Ok(())
+ } else {
+ Err(SocketError::UnexpectedDataInPacket)
+ }
+ }
+}
+
+/// Socket address.
+#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
+pub struct VsockAddr {
+ /// Context Identifier.
+ pub cid: u64,
+ /// Port number.
+ pub port: u32,
+}
+
+/// An event sent to the event queue
+#[derive(Copy, Clone, Debug, Default, AsBytes, FromBytes)]
+#[repr(C)]
+pub struct VirtioVsockEvent {
+ // ID from the virtio_vsock_event_id struct in the virtio spec
+ pub id: U32<LittleEndian>,
+}
+
+#[derive(Copy, Clone, Eq, PartialEq)]
+#[repr(u16)]
+pub enum VirtioVsockOp {
+ Invalid = 0,
+
+ /* Connect operations */
+ Request = 1,
+ Response = 2,
+ Rst = 3,
+ Shutdown = 4,
+
+ /* To send payload */
+ Rw = 5,
+
+ /* Tell the peer our credit info */
+ CreditUpdate = 6,
+ /* Request the peer to send the credit info to us */
+ CreditRequest = 7,
+}
+
+impl From<VirtioVsockOp> for U16<LittleEndian> {
+ fn from(op: VirtioVsockOp) -> Self {
+ (op as u16).into()
+ }
+}
+
+impl TryFrom<U16<LittleEndian>> for VirtioVsockOp {
+ type Error = SocketError;
+
+ fn try_from(v: U16<LittleEndian>) -> Result<Self, Self::Error> {
+ let op = match u16::from(v) {
+ 0 => Self::Invalid,
+ 1 => Self::Request,
+ 2 => Self::Response,
+ 3 => Self::Rst,
+ 4 => Self::Shutdown,
+ 5 => Self::Rw,
+ 6 => Self::CreditUpdate,
+ 7 => Self::CreditRequest,
+ _ => return Err(SocketError::UnknownOperation(v.into())),
+ };
+ Ok(op)
+ }
+}
+
+impl fmt::Debug for VirtioVsockOp {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ match self {
+ Self::Invalid => write!(f, "VIRTIO_VSOCK_OP_INVALID"),
+ Self::Request => write!(f, "VIRTIO_VSOCK_OP_REQUEST"),
+ Self::Response => write!(f, "VIRTIO_VSOCK_OP_RESPONSE"),
+ Self::Rst => write!(f, "VIRTIO_VSOCK_OP_RST"),
+ Self::Shutdown => write!(f, "VIRTIO_VSOCK_OP_SHUTDOWN"),
+ Self::Rw => write!(f, "VIRTIO_VSOCK_OP_RW"),
+ Self::CreditUpdate => write!(f, "VIRTIO_VSOCK_OP_CREDIT_UPDATE"),
+ Self::CreditRequest => write!(f, "VIRTIO_VSOCK_OP_CREDIT_REQUEST"),
+ }
+ }
+}
diff --git a/src/device/socket/vsock.rs b/src/device/socket/vsock.rs
new file mode 100644
index 0000000..686d7a6
--- /dev/null
+++ b/src/device/socket/vsock.rs
@@ -0,0 +1,596 @@
+//! Driver for VirtIO socket devices.
+#![deny(unsafe_op_in_unsafe_fn)]
+
+use super::error::SocketError;
+use super::protocol::{VirtioVsockConfig, VirtioVsockHdr, VirtioVsockOp, VsockAddr};
+use crate::device::common::Feature;
+use crate::hal::{BufferDirection, Dma, Hal};
+use crate::queue::VirtQueue;
+use crate::transport::Transport;
+use crate::volatile::volread;
+use crate::Result;
+use core::hint::spin_loop;
+use core::mem::size_of;
+use core::ptr::NonNull;
+use log::{debug, info};
+use zerocopy::{AsBytes, FromBytes};
+
+const RX_QUEUE_IDX: u16 = 0;
+const TX_QUEUE_IDX: u16 = 1;
+const EVENT_QUEUE_IDX: u16 = 2;
+
+const QUEUE_SIZE: usize = 8;
+
+#[derive(Clone, Debug, Default, PartialEq, Eq)]
+struct ConnectionInfo {
+ dst: VsockAddr,
+ src_port: u32,
+ /// The last `buf_alloc` value the peer sent to us, indicating how much receive buffer space in
+ /// bytes it has allocated for packet bodies.
+ peer_buf_alloc: u32,
+ /// The last `fwd_cnt` value the peer sent to us, indicating how many bytes of packet bodies it
+ /// has finished processing.
+ peer_fwd_cnt: u32,
+ /// The number of bytes of packet bodies which we have sent to the peer.
+ tx_cnt: u32,
+ /// The number of bytes of packet bodies which we have received from the peer and handled.
+ fwd_cnt: u32,
+ /// Whether we have recently requested credit from the peer.
+ ///
+ /// This is set to true when we send a `VIRTIO_VSOCK_OP_CREDIT_REQUEST`, and false when we
+ /// receive a `VIRTIO_VSOCK_OP_CREDIT_UPDATE`.
+ has_pending_credit_request: bool,
+}
+
+impl ConnectionInfo {
+ fn peer_free(&self) -> u32 {
+ self.peer_buf_alloc - (self.tx_cnt - self.peer_fwd_cnt)
+ }
+
+ fn new_header(&self, src_cid: u64) -> VirtioVsockHdr {
+ VirtioVsockHdr {
+ src_cid: src_cid.into(),
+ dst_cid: self.dst.cid.into(),
+ src_port: self.src_port.into(),
+ dst_port: self.dst.port.into(),
+ fwd_cnt: self.fwd_cnt.into(),
+ ..Default::default()
+ }
+ }
+}
+
+/// An event received from a VirtIO socket device.
+#[derive(Clone, Debug, Eq, PartialEq)]
+pub struct VsockEvent {
+ /// The source of the event, i.e. the peer who sent it.
+ pub source: VsockAddr,
+ /// The destination of the event, i.e. the CID and port on our side.
+ pub destination: VsockAddr,
+ /// The type of event.
+ pub event_type: VsockEventType,
+}
+
+/// The reason why a vsock connection was closed.
+#[derive(Copy, Clone, Debug, Eq, PartialEq)]
+pub enum DisconnectReason {
+ /// The peer has either closed the connection in response to our shutdown request, or forcibly
+ /// closed it of its own accord.
+ Reset,
+ /// The peer asked to shut down the connection.
+ Shutdown,
+}
+
+/// Details of the type of an event received from a VirtIO socket.
+#[derive(Clone, Debug, Eq, PartialEq)]
+pub enum VsockEventType {
+ /// The connection was successfully established.
+ Connected,
+ /// The connection was closed.
+ Disconnected {
+ /// The reason for the disconnection.
+ reason: DisconnectReason,
+ },
+ /// Data was received on the connection.
+ Received {
+ /// The length of the data in bytes.
+ length: usize,
+ },
+}
+
+/// Driver for a VirtIO socket device.
+pub struct VirtIOSocket<H: Hal, T: Transport> {
+ transport: T,
+ /// Virtqueue to receive packets.
+ rx: VirtQueue<H, { QUEUE_SIZE }>,
+ tx: VirtQueue<H, { QUEUE_SIZE }>,
+ /// Virtqueue to receive events from the device.
+ event: VirtQueue<H, { QUEUE_SIZE }>,
+ /// The guest_cid field contains the guest’s context ID, which uniquely identifies
+ /// the device for its lifetime. The upper 32 bits of the CID are reserved and zeroed.
+ guest_cid: u64,
+ rx_buf_dma: Dma<H>,
+
+ /// Currently the device is only allowed to be connected to one destination at a time.
+ connection_info: Option<ConnectionInfo>,
+}
+
+impl<H: Hal, T: Transport> Drop for VirtIOSocket<H, T> {
+ fn drop(&mut self) {
+ // Clear any pointers pointing to DMA regions, so the device doesn't try to access them
+ // after they have been freed.
+ self.transport.queue_unset(RX_QUEUE_IDX);
+ self.transport.queue_unset(TX_QUEUE_IDX);
+ self.transport.queue_unset(EVENT_QUEUE_IDX);
+ }
+}
+
+impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
+ /// Create a new VirtIO Vsock driver.
+ pub fn new(mut transport: T) -> Result<Self> {
+ transport.begin_init(|features| {
+ let features = Feature::from_bits_truncate(features);
+ info!("Device features: {:?}", features);
+ // negotiate these flags only
+ let supported_features = Feature::empty();
+ (features & supported_features).bits()
+ });
+
+ let config = transport.config_space::<VirtioVsockConfig>()?;
+ info!("config: {:?}", config);
+ // Safe because config is a valid pointer to the device configuration space.
+ let guest_cid = unsafe {
+ volread!(config, guest_cid_low) as u64 | (volread!(config, guest_cid_high) as u64) << 32
+ };
+ info!("guest cid: {guest_cid:?}");
+
+ let mut rx = VirtQueue::new(&mut transport, RX_QUEUE_IDX)?;
+ let tx = VirtQueue::new(&mut transport, TX_QUEUE_IDX)?;
+ let event = VirtQueue::new(&mut transport, EVENT_QUEUE_IDX)?;
+
+ // Allocates 4 KiB memory as the rx buffer.
+ let rx_buf_dma = Dma::new(
+ 1, // pages
+ BufferDirection::DeviceToDriver,
+ )?;
+ let rx_buf = rx_buf_dma.raw_slice();
+ // Safe because `rx_buf` lives as long as the `rx` queue.
+ unsafe {
+ Self::fill_rx_queue(&mut rx, rx_buf, &mut transport)?;
+ }
+ transport.finish_init();
+
+ Ok(Self {
+ transport,
+ rx,
+ tx,
+ event,
+ guest_cid,
+ rx_buf_dma,
+ connection_info: None,
+ })
+ }
+
+ /// Fills the `rx` queue with the buffer `rx_buf`.
+ ///
+ /// # Safety
+ ///
+ /// `rx_buf` must live at least as long as the `rx` queue, and the parts of the buffer which are
+ /// in the queue must not be used anywhere else at the same time.
+ unsafe fn fill_rx_queue(
+ rx: &mut VirtQueue<H, { QUEUE_SIZE }>,
+ rx_buf: NonNull<[u8]>,
+ transport: &mut T,
+ ) -> Result {
+ if rx_buf.len() < size_of::<VirtioVsockHdr>() * QUEUE_SIZE {
+ return Err(SocketError::BufferTooShort.into());
+ }
+ for i in 0..QUEUE_SIZE {
+ // Safe because the buffer lives as long as the queue, as specified in the function
+ // safety requirement, and we don't access it until it is popped.
+ unsafe {
+ let buffer = Self::as_mut_sub_rx_buffer(rx_buf, i)?;
+ let token = rx.add(&[], &mut [buffer])?;
+ assert_eq!(i, token.into());
+ }
+ }
+
+ if rx.should_notify() {
+ transport.notify(RX_QUEUE_IDX);
+ }
+ Ok(())
+ }
+
+ /// Sends a request to connect to the given destination.
+ ///
+ /// This returns as soon as the request is sent; you should wait until `poll_recv` returns a
+ /// `VsockEventType::Connected` event indicating that the peer has accepted the connection
+ /// before sending data.
+ pub fn connect(&mut self, dst_cid: u64, src_port: u32, dst_port: u32) -> Result {
+ if self.connection_info.is_some() {
+ return Err(SocketError::ConnectionExists.into());
+ }
+ let new_connection_info = ConnectionInfo {
+ dst: VsockAddr {
+ cid: dst_cid,
+ port: dst_port,
+ },
+ src_port,
+ ..Default::default()
+ };
+ let header = VirtioVsockHdr {
+ op: VirtioVsockOp::Request.into(),
+ ..new_connection_info.new_header(self.guest_cid)
+ };
+ // Sends a header only packet to the tx queue to connect the device to the listening
+ // socket at the given destination.
+ self.send_packet_to_tx_queue(&header, &[])?;
+
+ self.connection_info = Some(new_connection_info);
+ debug!("Connection requested: {:?}", self.connection_info);
+ Ok(())
+ }
+
+ /// Blocks until the peer either accepts our connection request (with a
+ /// `VIRTIO_VSOCK_OP_RESPONSE`) or rejects it (with a
+ /// `VIRTIO_VSOCK_OP_RST`).
+ pub fn wait_for_connect(&mut self) -> Result {
+ match self.wait_for_recv(&mut [])?.event_type {
+ VsockEventType::Connected => Ok(()),
+ VsockEventType::Disconnected { .. } => Err(SocketError::ConnectionFailed.into()),
+ VsockEventType::Received { .. } => Err(SocketError::InvalidOperation.into()),
+ }
+ }
+
+ /// Requests the peer to send us a credit update for the current connection.
+ fn request_credit(&mut self) -> Result {
+ let connection_info = self.connection_info()?;
+ let header = VirtioVsockHdr {
+ op: VirtioVsockOp::CreditRequest.into(),
+ ..connection_info.new_header(self.guest_cid)
+ };
+ self.send_packet_to_tx_queue(&header, &[])
+ }
+
+ /// Sends the buffer to the destination.
+ pub fn send(&mut self, buffer: &[u8]) -> Result {
+ let mut connection_info = self.connection_info()?;
+
+ let result = self.check_peer_buffer_is_sufficient(&mut connection_info, buffer.len());
+ self.connection_info = Some(connection_info.clone());
+ result?;
+
+ let len = buffer.len() as u32;
+ let header = VirtioVsockHdr {
+ op: VirtioVsockOp::Rw.into(),
+ len: len.into(),
+ buf_alloc: 0.into(),
+ ..connection_info.new_header(self.guest_cid)
+ };
+ self.connection_info.as_mut().unwrap().tx_cnt += len;
+ self.send_packet_to_tx_queue(&header, buffer)
+ }
+
+ fn check_peer_buffer_is_sufficient(
+ &mut self,
+ connection_info: &mut ConnectionInfo,
+ buffer_len: usize,
+ ) -> Result {
+ if connection_info.peer_free() as usize >= buffer_len {
+ Ok(())
+ } else {
+ // Request an update of the cached peer credit, if we haven't already done so, and tell
+ // the caller to try again later.
+ if !connection_info.has_pending_credit_request {
+ self.request_credit()?;
+ connection_info.has_pending_credit_request = true;
+ }
+ Err(SocketError::InsufficientBufferSpaceInPeer.into())
+ }
+ }
+
+ /// Polls the vsock device to receive data or other updates.
+ ///
+ /// A buffer must be provided to put the data in if there is some to
+ /// receive.
+ pub fn poll_recv(&mut self, buffer: &mut [u8]) -> Result<Option<VsockEvent>> {
+ let connection_info = self.connection_info()?;
+
+ // Tell the peer that we have space to receive some data.
+ let header = VirtioVsockHdr {
+ op: VirtioVsockOp::CreditUpdate.into(),
+ buf_alloc: (buffer.len() as u32).into(),
+ ..connection_info.new_header(self.guest_cid)
+ };
+ self.send_packet_to_tx_queue(&header, &[])?;
+
+ // Handle entries from the RX virtqueue until we find one that generates an event.
+ let event = self.poll_rx_queue(buffer)?;
+
+ if self.rx.should_notify() {
+ self.transport.notify(RX_QUEUE_IDX);
+ }
+
+ Ok(event)
+ }
+
+ /// Blocks until we get some event from the vsock device.
+ ///
+ /// A buffer must be provided to put the data in if there is some to
+ /// receive.
+ pub fn wait_for_recv(&mut self, buffer: &mut [u8]) -> Result<VsockEvent> {
+ loop {
+ if let Some(event) = self.poll_recv(buffer)? {
+ return Ok(event);
+ } else {
+ spin_loop();
+ }
+ }
+ }
+
+ /// Request to shut down the connection cleanly.
+ ///
+ /// This returns as soon as the request is sent; you should wait until `poll_recv` returns a
+ /// `VsockEventType::Disconnected` event if you want to know that the peer has acknowledged the
+ /// shutdown.
+ pub fn shutdown(&mut self) -> Result {
+ let connection_info = self.connection_info()?;
+ let header = VirtioVsockHdr {
+ op: VirtioVsockOp::Shutdown.into(),
+ ..connection_info.new_header(self.guest_cid)
+ };
+ self.send_packet_to_tx_queue(&header, &[])
+ }
+
+ /// Forcibly closes the connection without waiting for the peer.
+ pub fn force_close(&mut self) -> Result {
+ let connection_info = self.connection_info()?;
+ let header = VirtioVsockHdr {
+ op: VirtioVsockOp::Rst.into(),
+ ..connection_info.new_header(self.guest_cid)
+ };
+ self.send_packet_to_tx_queue(&header, &[])?;
+ self.connection_info = None;
+ Ok(())
+ }
+
+ fn send_packet_to_tx_queue(&mut self, header: &VirtioVsockHdr, buffer: &[u8]) -> Result {
+ let _len = self.tx.add_notify_wait_pop(
+ &[header.as_bytes(), buffer],
+ &mut [],
+ &mut self.transport,
+ )?;
+ Ok(())
+ }
+
+ /// Polls the RX virtqueue until either it is empty, there is an error, or we find a packet
+ /// which generates a `VsockEvent`.
+ ///
+ /// Returns `Ok(None)` if the virtqueue is empty, possibly after processing some packets which
+ /// don't result in any events to return.
+ fn poll_rx_queue(&mut self, body: &mut [u8]) -> Result<Option<VsockEvent>> {
+ loop {
+ let mut connection_info = self.connection_info.clone().unwrap_or_default();
+ let Some(header) = self.pop_packet_from_rx_queue(body)? else{
+ return Ok(None);
+ };
+
+ let op = header.op()?;
+
+ // Skip packets which don't match our current connection.
+ if header.source() != connection_info.dst
+ || header.dst_cid.get() != self.guest_cid
+ || header.dst_port.get() != connection_info.src_port
+ {
+ debug!(
+ "Skipping {:?} as connection is {:?}",
+ header, connection_info
+ );
+ continue;
+ }
+
+ connection_info.peer_buf_alloc = header.buf_alloc.into();
+ connection_info.peer_fwd_cnt = header.fwd_cnt.into();
+ if self.connection_info.is_some() {
+ self.connection_info = Some(connection_info.clone());
+ debug!("Connection info updated: {:?}", self.connection_info);
+ }
+
+ match op {
+ VirtioVsockOp::Request => {
+ header.check_data_is_empty()?;
+ // TODO: Send a Rst, or support listening.
+ }
+ VirtioVsockOp::Response => {
+ header.check_data_is_empty()?;
+ return Ok(Some(VsockEvent {
+ source: connection_info.dst,
+ destination: VsockAddr {
+ cid: self.guest_cid,
+ port: connection_info.src_port,
+ },
+ event_type: VsockEventType::Connected,
+ }));
+ }
+ VirtioVsockOp::CreditUpdate => {
+ header.check_data_is_empty()?;
+ connection_info.has_pending_credit_request = false;
+ if self.connection_info.is_some() {
+ self.connection_info = Some(connection_info.clone());
+ }
+
+ // Virtio v1.1 5.10.6.3
+ // The driver can also receive a VIRTIO_VSOCK_OP_CREDIT_UPDATE packet without previously
+ // sending a VIRTIO_VSOCK_OP_CREDIT_REQUEST packet. This allows communicating updates
+ // any time a change in buffer space occurs.
+ continue;
+ }
+ VirtioVsockOp::Rst | VirtioVsockOp::Shutdown => {
+ header.check_data_is_empty()?;
+
+ self.connection_info = None;
+ info!("Disconnected from the peer");
+
+ let reason = if op == VirtioVsockOp::Rst {
+ DisconnectReason::Reset
+ } else {
+ DisconnectReason::Shutdown
+ };
+ return Ok(Some(VsockEvent {
+ source: connection_info.dst,
+ destination: VsockAddr {
+ cid: self.guest_cid,
+ port: connection_info.src_port,
+ },
+ event_type: VsockEventType::Disconnected { reason },
+ }));
+ }
+ VirtioVsockOp::Rw => {
+ self.connection_info.as_mut().unwrap().fwd_cnt += header.len();
+ return Ok(Some(VsockEvent {
+ source: connection_info.dst,
+ destination: VsockAddr {
+ cid: self.guest_cid,
+ port: connection_info.src_port,
+ },
+ event_type: VsockEventType::Received {
+ length: header.len() as usize,
+ },
+ }));
+ }
+ VirtioVsockOp::CreditRequest => {
+ header.check_data_is_empty()?;
+ // TODO: Send a credit update.
+ }
+ VirtioVsockOp::Invalid => {
+ return Err(SocketError::InvalidOperation.into());
+ }
+ }
+ }
+ }
+
+ /// Pops one packet from the RX queue, if there is one pending. Returns the header, and copies
+ /// the body into the given buffer.
+ ///
+ /// Returns `None` if there is no pending packet, or an error if the body is bigger than the
+ /// buffer supplied.
+ fn pop_packet_from_rx_queue(&mut self, body: &mut [u8]) -> Result<Option<VirtioVsockHdr>> {
+ let Some(token) = self.rx.peek_used() else {
+ return Ok(None);
+ };
+
+ // Safe because we maintain a consistent mapping of tokens to buffers, so we pass the same
+ // buffer to `pop_used` as we previously passed to `add` for the token. Once we add the
+ // buffer back to the RX queue then we don't access it again until next time it is popped.
+ let header = unsafe {
+ let buffer = Self::as_mut_sub_rx_buffer(self.rx_buf_dma.raw_slice(), token.into())?;
+ let _len = self.rx.pop_used(token, &[], &mut [buffer])?;
+
+ // Read the header and body from the buffer. Don't check the result yet, because we need
+ // to add the buffer back to the queue either way.
+ let header_result = read_header_and_body(buffer, body);
+
+ // Add the buffer back to the RX queue.
+ let new_token = self.rx.add(&[], &mut [buffer])?;
+ // If the RX buffer somehow gets assigned a different token, then our safety assumptions
+ // are broken and we can't safely continue to do anything with the device.
+ assert_eq!(new_token, token);
+
+ header_result
+ }?;
+
+ debug!("Received packet {:?}. Op {:?}", header, header.op());
+ Ok(Some(header))
+ }
+
+ fn connection_info(&self) -> Result<ConnectionInfo> {
+ self.connection_info
+ .clone()
+ .ok_or(SocketError::NotConnected.into())
+ }
+
+ /// Gets a reference to a subslice of the RX buffer to be used for the given entry in the RX
+ /// virtqueue.
+ ///
+ /// # Safety
+ ///
+ /// `rx_buf` must be a valid dereferenceable pointer.
+ /// The returned reference has an arbitrary lifetime `'a`. This lifetime must not overlap with
+ /// any other references to the same subslice of the RX buffer or outlive the buffer.
+ unsafe fn as_mut_sub_rx_buffer<'a>(
+ mut rx_buf: NonNull<[u8]>,
+ i: usize,
+ ) -> Result<&'a mut [u8]> {
+ let buffer_size = rx_buf.len() / QUEUE_SIZE;
+ let start = buffer_size
+ .checked_mul(i)
+ .ok_or(SocketError::InvalidNumber)?;
+ let end = start
+ .checked_add(buffer_size)
+ .ok_or(SocketError::InvalidNumber)?;
+ // Safe because no alignment or initialisation is required for [u8], and our caller assures
+ // us that `rx_buf` is dereferenceable and that the lifetime of the slice we are creating
+ // won't overlap with any other references to the same slice or outlive it.
+ unsafe {
+ rx_buf
+ .as_mut()
+ .get_mut(start..end)
+ .ok_or(SocketError::BufferTooShort.into())
+ }
+ }
+}
+
+fn read_header_and_body(buffer: &[u8], body: &mut [u8]) -> Result<VirtioVsockHdr> {
+ let header = VirtioVsockHdr::read_from_prefix(buffer).ok_or(SocketError::BufferTooShort)?;
+ let body_length = header.len() as usize;
+ let data_end = size_of::<VirtioVsockHdr>()
+ .checked_add(body_length)
+ .ok_or(SocketError::InvalidNumber)?;
+ let data = buffer
+ .get(size_of::<VirtioVsockHdr>()..data_end)
+ .ok_or(SocketError::BufferTooShort)?;
+ body.get_mut(0..body_length)
+ .ok_or(SocketError::OutputBufferTooShort(body_length))?
+ .copy_from_slice(data);
+ Ok(header)
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::volatile::ReadOnly;
+ use crate::{
+ hal::fake::FakeHal,
+ transport::{
+ fake::{FakeTransport, QueueStatus, State},
+ DeviceStatus, DeviceType,
+ },
+ };
+ use alloc::{sync::Arc, vec};
+ use core::ptr::NonNull;
+ use std::sync::Mutex;
+
+ #[test]
+ fn config() {
+ let mut config_space = VirtioVsockConfig {
+ guest_cid_low: ReadOnly::new(66),
+ guest_cid_high: ReadOnly::new(0),
+ };
+ let state = Arc::new(Mutex::new(State {
+ status: DeviceStatus::empty(),
+ driver_features: 0,
+ guest_page_size: 0,
+ interrupt_pending: false,
+ queues: vec![QueueStatus::default(); 3],
+ }));
+ let transport = FakeTransport {
+ device_type: DeviceType::Socket,
+ max_queue_size: 32,
+ device_features: 0,
+ config_space: NonNull::from(&mut config_space),
+ state: state.clone(),
+ };
+ let socket =
+ VirtIOSocket::<FakeHal, FakeTransport<VirtioVsockConfig>>::new(transport).unwrap();
+ assert_eq!(socket.guest_cid, 0x00_0000_0042);
+ }
+}
diff --git a/src/hal.rs b/src/hal.rs
index fd8c435..6295f5f 100644
--- a/src/hal.rs
+++ b/src/hal.rs
@@ -53,35 +53,81 @@ impl<H: Hal> Dma<H> {
impl<H: Hal> Drop for Dma<H> {
fn drop(&mut self) {
- let err = H::dma_dealloc(self.paddr, self.vaddr, self.pages);
+ // Safe because the memory was previously allocated by `dma_alloc` in `Dma::new`, not yet
+ // deallocated, and we are passing the values from then.
+ let err = unsafe { H::dma_dealloc(self.paddr, self.vaddr, self.pages) };
assert_eq!(err, 0, "failed to deallocate DMA");
}
}
/// The interface which a particular hardware implementation must implement.
-pub trait Hal {
+///
+/// # Safety
+///
+/// Implementations of this trait must follow the "implementation safety" requirements documented
+/// for each method. Callers must follow the safety requirements documented for the unsafe methods.
+pub unsafe trait Hal {
/// Allocates the given number of contiguous physical pages of DMA memory for VirtIO use.
///
/// Returns both the physical address which the device can use to access the memory, and a
/// pointer to the start of it which the driver can use to access it.
+ ///
+ /// # Implementation safety
+ ///
+ /// Implementations of this method must ensure that the `NonNull<u8>` returned is a
+ /// [_valid_](https://doc.rust-lang.org/std/ptr/index.html#safety) pointer, aligned to
+ /// [`PAGE_SIZE`], and won't alias any other allocations or references in the program until it
+ /// is deallocated by `dma_dealloc`.
fn dma_alloc(pages: usize, direction: BufferDirection) -> (PhysAddr, NonNull<u8>);
+
/// Deallocates the given contiguous physical DMA memory pages.
- fn dma_dealloc(paddr: PhysAddr, vaddr: NonNull<u8>, pages: usize) -> i32;
+ ///
+ /// # Safety
+ ///
+ /// The memory must have been allocated by `dma_alloc` on the same `Hal` implementation, and not
+ /// yet deallocated. `pages` must be the same number passed to `dma_alloc` originally, and both
+ /// `paddr` and `vaddr` must be the values returned by `dma_alloc`.
+ unsafe fn dma_dealloc(paddr: PhysAddr, vaddr: NonNull<u8>, pages: usize) -> i32;
+
/// Converts a physical address used for MMIO to a virtual address which the driver can access.
///
/// This is only used for MMIO addresses within BARs read from the device, for the PCI
/// transport. It may check that the address range up to the given size is within the region
/// expected for MMIO.
- fn mmio_phys_to_virt(paddr: PhysAddr, size: usize) -> NonNull<u8>;
+ ///
+ /// # Implementation safety
+ ///
+ /// Implementations of this method must ensure that the `NonNull<u8>` returned is a
+ /// [_valid_](https://doc.rust-lang.org/std/ptr/index.html#safety) pointer, and won't alias any
+ /// other allocations or references in the program.
+ ///
+ /// # Safety
+ ///
+ /// The `paddr` and `size` must describe a valid MMIO region. The implementation may validate it
+ /// in some way (and panic if it is invalid) but is not guaranteed to.
+ unsafe fn mmio_phys_to_virt(paddr: PhysAddr, size: usize) -> NonNull<u8>;
+
/// Shares the given memory range with the device, and returns the physical address that the
/// device can use to access it.
///
/// This may involve mapping the buffer into an IOMMU, giving the host permission to access the
/// memory, or copying it to a special region where it can be accessed.
- fn share(buffer: NonNull<[u8]>, direction: BufferDirection) -> PhysAddr;
+ ///
+ /// # Safety
+ ///
+ /// The buffer must be a valid pointer to memory which will not be accessed by any other thread
+ /// for the duration of this method call.
+ unsafe fn share(buffer: NonNull<[u8]>, direction: BufferDirection) -> PhysAddr;
+
/// Unshares the given memory range from the device and (if necessary) copies it back to the
/// original buffer.
- fn unshare(paddr: PhysAddr, buffer: NonNull<[u8]>, direction: BufferDirection);
+ ///
+ /// # Safety
+ ///
+ /// The buffer must be a valid pointer to memory which will not be accessed by any other thread
+ /// for the duration of this method call. The `paddr` must be the value previously returned by
+ /// the corresponding `share` call.
+ unsafe fn unshare(paddr: PhysAddr, buffer: NonNull<[u8]>, direction: BufferDirection);
}
/// The direction in which a buffer is passed.
diff --git a/src/hal/fake.rs b/src/hal/fake.rs
index 2af60a9..5d46835 100644
--- a/src/hal/fake.rs
+++ b/src/hal/fake.rs
@@ -1,5 +1,7 @@
//! Fake HAL implementation for tests.
+#![deny(unsafe_op_in_unsafe_fn)]
+
use crate::{BufferDirection, Hal, PhysAddr, PAGE_SIZE};
use alloc::alloc::{alloc_zeroed, dealloc, handle_alloc_error};
use core::{alloc::Layout, ptr::NonNull};
@@ -8,7 +10,7 @@ use core::{alloc::Layout, ptr::NonNull};
pub struct FakeHal;
/// Fake HAL implementation for use in unit tests.
-impl Hal for FakeHal {
+unsafe impl Hal for FakeHal {
fn dma_alloc(pages: usize, _direction: BufferDirection) -> (PhysAddr, NonNull<u8>) {
assert_ne!(pages, 0);
let layout = Layout::from_size_align(pages * PAGE_SIZE, PAGE_SIZE).unwrap();
@@ -21,7 +23,7 @@ impl Hal for FakeHal {
}
}
- fn dma_dealloc(_paddr: PhysAddr, vaddr: NonNull<u8>, pages: usize) -> i32 {
+ unsafe fn dma_dealloc(_paddr: PhysAddr, vaddr: NonNull<u8>, pages: usize) -> i32 {
assert_ne!(pages, 0);
let layout = Layout::from_size_align(pages * PAGE_SIZE, PAGE_SIZE).unwrap();
// Safe because the layout is the same as was used when the memory was allocated by
@@ -32,17 +34,17 @@ impl Hal for FakeHal {
0
}
- fn mmio_phys_to_virt(paddr: PhysAddr, _size: usize) -> NonNull<u8> {
+ unsafe fn mmio_phys_to_virt(paddr: PhysAddr, _size: usize) -> NonNull<u8> {
NonNull::new(paddr as _).unwrap()
}
- fn share(buffer: NonNull<[u8]>, _direction: BufferDirection) -> PhysAddr {
+ unsafe fn share(buffer: NonNull<[u8]>, _direction: BufferDirection) -> PhysAddr {
let vaddr = buffer.as_ptr() as *mut u8 as usize;
// Nothing to do, as the host already has access to all memory.
virt_to_phys(vaddr)
}
- fn unshare(_paddr: PhysAddr, _buffer: NonNull<[u8]>, _direction: BufferDirection) {
+ unsafe fn unshare(_paddr: PhysAddr, _buffer: NonNull<[u8]>, _direction: BufferDirection) {
// Nothing to do, as the host already has access to all memory and we didn't copy the buffer
// anywhere else.
}
diff --git a/src/lib.rs b/src/lib.rs
index 6a12401..754dd51 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -89,6 +89,8 @@ pub enum Error {
ConfigSpaceTooSmall,
/// The device doesn't have any config space, but the driver expects some.
ConfigSpaceMissing,
+ /// Error from the socket device.
+ SocketDeviceError(device::socket::SocketError),
}
impl Display for Error {
@@ -115,10 +117,17 @@ impl Display for Error {
"The device doesn't have any config space, but the driver expects some"
)
}
+ Self::SocketDeviceError(e) => write!(f, "Error from the socket device: {e:?}"),
}
}
}
+impl From<device::socket::SocketError> for Error {
+ fn from(e: device::socket::SocketError) -> Self {
+ Self::SocketDeviceError(e)
+ }
+}
+
/// Align `size` up to a page.
fn align_up(size: usize) -> usize {
(size + PAGE_SIZE) & !(PAGE_SIZE - 1)
diff --git a/src/queue.rs b/src/queue.rs
index f45da11..d6baf17 100644
--- a/src/queue.rs
+++ b/src/queue.rs
@@ -1,3 +1,5 @@
+#![deny(unsafe_op_in_unsafe_fn)]
+
use crate::hal::{BufferDirection, Dma, Hal, PhysAddr};
use crate::transport::Transport;
use crate::{align_up, nonnull_slice_from_raw_parts, pages, Error, Result, PAGE_SIZE};
@@ -5,7 +7,7 @@ use bitflags::bitflags;
#[cfg(test)]
use core::cmp::min;
use core::hint::spin_loop;
-use core::mem::size_of;
+use core::mem::{size_of, take};
#[cfg(test)]
use core::ptr;
use core::ptr::NonNull;
@@ -114,8 +116,13 @@ impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> {
///
/// # Safety
///
- /// The input and output buffers must remain valid until the token is returned by `pop_used`.
- pub unsafe fn add(&mut self, inputs: &[*const [u8]], outputs: &[*mut [u8]]) -> Result<u16> {
+ /// The input and output buffers must remain valid and not be accessed until a call to
+ /// `pop_used` with the returned token succeeds.
+ pub unsafe fn add<'a, 'b>(
+ &mut self,
+ inputs: &'a [&'b [u8]],
+ outputs: &'a mut [&'b mut [u8]],
+ ) -> Result<u16> {
if inputs.is_empty() && outputs.is_empty() {
return Err(Error::InvalidParam);
}
@@ -127,10 +134,14 @@ impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> {
let head = self.free_head;
let mut last = self.free_head;
- for (buffer, direction) in input_output_iter(inputs, outputs) {
+ for (buffer, direction) in InputOutputIter::new(inputs, outputs) {
// Write to desc_shadow then copy.
let desc = &mut self.desc_shadow[usize::from(self.free_head)];
- desc.set_buf::<H>(buffer, direction, DescFlags::NEXT);
+ // Safe because our caller promises that the buffers live at least until `pop_used`
+ // returns them.
+ unsafe {
+ desc.set_buf::<H>(buffer, direction, DescFlags::NEXT);
+ }
last = self.free_head;
self.free_head = desc.next;
@@ -172,14 +183,14 @@ impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> {
/// them, then pops them.
///
/// This assumes that the device isn't processing any other buffers at the same time.
- pub fn add_notify_wait_pop(
+ pub fn add_notify_wait_pop<'a>(
&mut self,
- inputs: &[*const [u8]],
- outputs: &[*mut [u8]],
+ inputs: &'a [&'a [u8]],
+ outputs: &'a mut [&'a mut [u8]],
transport: &mut impl Transport,
) -> Result<u32> {
- // Safe because we don't return until the same token has been popped, so they remain valid
- // until then.
+ // Safe because we don't return until the same token has been popped, so the buffers remain
+ // valid and are not otherwise accessed until then.
let token = unsafe { self.add(inputs, outputs) }?;
// Notify the queue.
@@ -192,7 +203,9 @@ impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> {
spin_loop();
}
- self.pop_used(token, inputs, outputs)
+ // Safe because these are the same buffers as we passed to `add` above and they are still
+ // valid.
+ unsafe { self.pop_used(token, inputs, outputs) }
}
/// Returns whether the driver should notify the device after adding a new buffer to the
@@ -252,12 +265,22 @@ impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> {
/// passed in too.
///
/// This will push all linked descriptors at the front of the free list.
- fn recycle_descriptors(&mut self, head: u16, inputs: &[*const [u8]], outputs: &[*mut [u8]]) {
+ ///
+ /// # Safety
+ ///
+ /// The buffers in `inputs` and `outputs` must match the set of buffers originally added to the
+ /// queue by `add`.
+ unsafe fn recycle_descriptors<'a>(
+ &mut self,
+ head: u16,
+ inputs: &'a [&'a [u8]],
+ outputs: &'a mut [&'a mut [u8]],
+ ) {
let original_free_head = self.free_head;
self.free_head = head;
let mut next = Some(head);
- for (buffer, direction) in input_output_iter(inputs, outputs) {
+ for (buffer, direction) in InputOutputIter::new(inputs, outputs) {
let desc_index = next.expect("Descriptor chain was shorter than expected.");
let desc = &mut self.desc_shadow[usize::from(desc_index)];
@@ -271,8 +294,12 @@ impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> {
self.write_desc(desc_index);
- // Unshare the buffer (and perhaps copy its contents back to the original buffer).
- H::unshare(paddr as usize, buffer, direction);
+ // Safe because the caller ensures that the buffer is valid and matches the descriptor
+ // from which we got `paddr`.
+ unsafe {
+ // Unshare the buffer (and perhaps copy its contents back to the original buffer).
+ H::unshare(paddr as usize, buffer, direction);
+ }
}
if next.is_some() {
@@ -284,11 +311,16 @@ impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> {
/// length which was used (written) by the device.
///
/// Ref: linux virtio_ring.c virtqueue_get_buf_ctx
- pub fn pop_used(
+ ///
+ /// # Safety
+ ///
+ /// The buffers in `inputs` and `outputs` must match the set of buffers originally added to the
+ /// queue by `add` when it returned the token being passed in here.
+ pub unsafe fn pop_used<'a>(
&mut self,
token: u16,
- inputs: &[*const [u8]],
- outputs: &[*mut [u8]],
+ inputs: &'a [&'a [u8]],
+ outputs: &'a mut [&'a mut [u8]],
) -> Result<u32> {
if !self.can_pop() {
return Err(Error::NotReady);
@@ -311,7 +343,10 @@ impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> {
return Err(Error::WrongToken);
}
- self.recycle_descriptors(index, inputs, outputs);
+ // Safe because the caller ensures the buffers are valid and match the descriptor.
+ unsafe {
+ self.recycle_descriptors(index, inputs, outputs);
+ }
self.last_used_idx = self.last_used_idx.wrapping_add(1);
Ok(len)
@@ -486,7 +521,10 @@ impl Descriptor {
direction: BufferDirection,
extra_flags: DescFlags,
) {
- self.addr = H::share(buf, direction) as u64;
+ // Safe because our caller promises that the buffer is valid.
+ unsafe {
+ self.addr = H::share(buf, direction) as u64;
+ }
self.len = buf.len() as u32;
self.flags = extra_flags
| match direction {
@@ -558,6 +596,46 @@ struct UsedElem {
len: u32,
}
+struct InputOutputIter<'a, 'b> {
+ inputs: &'a [&'b [u8]],
+ outputs: &'a mut [&'b mut [u8]],
+}
+
+impl<'a, 'b> InputOutputIter<'a, 'b> {
+ fn new(inputs: &'a [&'b [u8]], outputs: &'a mut [&'b mut [u8]]) -> Self {
+ Self { inputs, outputs }
+ }
+}
+
+impl<'a, 'b> Iterator for InputOutputIter<'a, 'b> {
+ type Item = (NonNull<[u8]>, BufferDirection);
+
+ fn next(&mut self) -> Option<Self::Item> {
+ if let Some(input) = take_first(&mut self.inputs) {
+ Some(((*input).into(), BufferDirection::DriverToDevice))
+ } else {
+ let output = take_first_mut(&mut self.outputs)?;
+ Some(((*output).into(), BufferDirection::DeviceToDriver))
+ }
+ }
+}
+
+// TODO: Use `slice::take_first` once it is stable
+// (https://github.com/rust-lang/rust/issues/62280).
+fn take_first<'a, T>(slice: &mut &'a [T]) -> Option<&'a T> {
+ let (first, rem) = slice.split_first()?;
+ *slice = rem;
+ Some(first)
+}
+
+// TODO: Use `slice::take_first_mut` once it is stable
+// (https://github.com/rust-lang/rust/issues/62280).
+fn take_first_mut<'a, T>(slice: &mut &'a mut [T]) -> Option<&'a mut T> {
+ let (first, rem) = take(slice).split_first_mut()?;
+ *slice = rem;
+ Some(first)
+}
+
/// Simulates the device reading from a VirtIO queue and writing a response back, for use in tests.
///
/// The fake device always uses descriptors in order.
@@ -680,7 +758,7 @@ mod tests {
let mut transport = unsafe { MmioTransport::new(NonNull::from(&mut header)) }.unwrap();
let mut queue = VirtQueue::<FakeHal, 4>::new(&mut transport, 0).unwrap();
assert_eq!(
- unsafe { queue.add(&[], &[]) }.unwrap_err(),
+ unsafe { queue.add(&[], &mut []) }.unwrap_err(),
Error::InvalidParam
);
}
@@ -692,7 +770,7 @@ mod tests {
let mut queue = VirtQueue::<FakeHal, 4>::new(&mut transport, 0).unwrap();
assert_eq!(queue.available_desc(), 4);
assert_eq!(
- unsafe { queue.add(&[&[], &[], &[]], &[&mut [], &mut []]) }.unwrap_err(),
+ unsafe { queue.add(&[&[], &[], &[]], &mut [&mut [], &mut []]) }.unwrap_err(),
Error::QueueFull
);
}
@@ -706,7 +784,7 @@ mod tests {
// Add a buffer chain consisting of two device-readable parts followed by two
// device-writable parts.
- let token = unsafe { queue.add(&[&[1, 2], &[3]], &[&mut [0, 0], &mut [0]]) }.unwrap();
+ let token = unsafe { queue.add(&[&[1, 2], &[3]], &mut [&mut [0, 0], &mut [0]]) }.unwrap();
assert_eq!(queue.available_desc(), 0);
assert!(!queue.can_pop());
@@ -757,27 +835,3 @@ mod tests {
}
}
}
-
-/// Returns an iterator over the buffers of first `inputs` and then `outputs`, paired with the
-/// corresponding `BufferDirection`.
-///
-/// Panics if any of the buffer pointers is null.
-fn input_output_iter<'a>(
- inputs: &'a [*const [u8]],
- outputs: &'a [*mut [u8]],
-) -> impl Iterator<Item = (NonNull<[u8]>, BufferDirection)> + 'a {
- inputs
- .iter()
- .map(|input| {
- (
- NonNull::new(*input as *mut [u8]).unwrap(),
- BufferDirection::DriverToDevice,
- )
- })
- .chain(outputs.iter().map(|output| {
- (
- NonNull::new(*output).unwrap(),
- BufferDirection::DeviceToDriver,
- )
- }))
-}
diff --git a/src/transport/fake.rs b/src/transport/fake.rs
index 1d599d8..a578db2 100644
--- a/src/transport/fake.rs
+++ b/src/transport/fake.rs
@@ -38,6 +38,10 @@ impl<C> Transport for FakeTransport<C> {
self.state.lock().unwrap().queues[queue as usize].notified = true;
}
+ fn get_status(&self) -> DeviceStatus {
+ self.state.lock().unwrap().status
+ }
+
fn set_status(&mut self, status: DeviceStatus) {
self.state.lock().unwrap().status = status;
}
diff --git a/src/transport/mmio.rs b/src/transport/mmio.rs
index a6d421e..026646b 100644
--- a/src/transport/mmio.rs
+++ b/src/transport/mmio.rs
@@ -350,6 +350,11 @@ impl Transport for MmioTransport {
}
}
+ fn get_status(&self) -> DeviceStatus {
+ // Safe because self.header points to a valid VirtIO MMIO region.
+ unsafe { volread!(self.header, status) }
+ }
+
fn set_status(&mut self, status: DeviceStatus) {
// Safe because self.header points to a valid VirtIO MMIO region.
unsafe {
@@ -442,7 +447,11 @@ impl Transport for MmioTransport {
// Safe because self.header points to a valid VirtIO MMIO region.
unsafe {
volwrite!(self.header, queue_sel, queue.into());
+
volwrite!(self.header, queue_ready, 0);
+ // Wait until we read the same value back, to ensure synchronisation (see 4.2.2.2).
+ while volread!(self.header, queue_ready) != 0 {}
+
volwrite!(self.header, queue_num, 0);
volwrite!(self.header, queue_desc_low, 0);
volwrite!(self.header, queue_desc_high, 0);
diff --git a/src/transport/mod.rs b/src/transport/mod.rs
index 013fa27..f88293c 100644
--- a/src/transport/mod.rs
+++ b/src/transport/mod.rs
@@ -26,6 +26,9 @@ pub trait Transport {
/// Notifies the given queue on the device.
fn notify(&mut self, queue: u16);
+ /// Gets the device status.
+ fn get_status(&self) -> DeviceStatus;
+
/// Sets the device status.
fn set_status(&mut self, status: DeviceStatus);
diff --git a/src/transport/pci.rs b/src/transport/pci.rs
index f6473f8..b8bcb15 100644
--- a/src/transport/pci.rs
+++ b/src/transport/pci.rs
@@ -251,6 +251,13 @@ impl Transport for PciTransport {
}
}
+ fn get_status(&self) -> DeviceStatus {
+ // Safe because the common config pointer is valid and we checked in get_bar_region that it
+ // was aligned.
+ let status = unsafe { volread!(self.common_cfg, device_status) };
+ DeviceStatus::from_bits_truncate(status.into())
+ }
+
fn set_status(&mut self, status: DeviceStatus) {
// Safe because the common config pointer is valid and we checked in get_bar_region that it
// was aligned.
@@ -287,16 +294,9 @@ impl Transport for PciTransport {
}
}
- fn queue_unset(&mut self, queue: u16) {
- // Safe because the common config pointer is valid and we checked in get_bar_region that it
- // was aligned.
- unsafe {
- volwrite!(self.common_cfg, queue_select, queue);
- volwrite!(self.common_cfg, queue_size, 0);
- volwrite!(self.common_cfg, queue_desc, 0);
- volwrite!(self.common_cfg, queue_driver, 0);
- volwrite!(self.common_cfg, queue_device, 0);
- }
+ fn queue_unset(&mut self, _queue: u16) {
+ // The VirtIO spec doesn't allow queues to be unset once they have been set up for the PCI
+ // transport, so this is a no-op.
}
fn queue_used(&mut self, queue: u16) -> bool {
@@ -341,7 +341,8 @@ impl Transport for PciTransport {
impl Drop for PciTransport {
fn drop(&mut self) {
// Reset the device when the transport is dropped.
- self.set_status(DeviceStatus::empty())
+ self.set_status(DeviceStatus::empty());
+ while self.get_status() != DeviceStatus::empty() {}
}
}
@@ -395,7 +396,9 @@ fn get_bar_region<H: Hal, T>(
return Err(VirtioPciError::BarOffsetOutOfRange);
}
let paddr = bar_address as PhysAddr + struct_info.offset as PhysAddr;
- let vaddr = H::mmio_phys_to_virt(paddr, struct_info.length as usize);
+ // Safe because the paddr and size describe a valid MMIO region, at least according to the PCI
+ // bus.
+ let vaddr = unsafe { H::mmio_phys_to_virt(paddr, struct_info.length as usize) };
if vaddr.as_ptr() as usize % align_of::<T>() != 0 {
return Err(VirtioPciError::Misaligned {
vaddr,