diff options
author | Android Build Coastguard Worker <android-build-coastguard-worker@google.com> | 2023-07-07 05:23:55 +0000 |
---|---|---|
committer | Android Build Coastguard Worker <android-build-coastguard-worker@google.com> | 2023-07-07 05:23:55 +0000 |
commit | 6a94222a62ac87785195e74ed683def9723926e8 (patch) | |
tree | 3870c0c02033031bf9d4b1739f6dab9f16c9ac9f | |
parent | a0678c2a14af25ab2bb0cadd74ae8a929e628e54 (diff) | |
parent | 66d53d7360e4e07d84e688cd116b96ebc8c7d4fb (diff) | |
download | virtio-drivers-aml_tz5_341510010.tar.gz |
Snap for 10453563 from 66d53d7360e4e07d84e688cd116b96ebc8c7d4fb to mainline-tzdata5-releaseaml_tz5_341510070aml_tz5_341510050aml_tz5_341510010aml_tz5_341510010
Change-Id: Idc8859ad3ba028191e62f27c4306977de752247a
-rw-r--r-- | .cargo_vcs_info.json | 2 | ||||
-rw-r--r-- | .github/workflows/main.yml | 2 | ||||
-rw-r--r-- | Android.bp | 17 | ||||
-rw-r--r-- | Cargo.toml | 2 | ||||
-rw-r--r-- | Cargo.toml.orig | 2 | ||||
-rw-r--r-- | METADATA | 8 | ||||
-rw-r--r-- | README.md | 1 | ||||
-rw-r--r-- | cargo2android.json | 2 | ||||
-rw-r--r-- | patches/Android.bp.patch | 19 | ||||
-rw-r--r-- | src/device/blk.rs | 12 | ||||
-rw-r--r-- | src/device/common.rs | 23 | ||||
-rw-r--r-- | src/device/console.rs | 17 | ||||
-rw-r--r-- | src/device/gpu.rs | 131 | ||||
-rw-r--r-- | src/device/input.rs | 40 | ||||
-rw-r--r-- | src/device/mod.rs | 4 | ||||
-rw-r--r-- | src/device/net.rs | 226 | ||||
-rw-r--r-- | src/device/socket/error.rs | 69 | ||||
-rw-r--r-- | src/device/socket/mod.rs | 8 | ||||
-rw-r--r-- | src/device/socket/protocol.rs | 184 | ||||
-rw-r--r-- | src/device/socket/vsock.rs | 596 | ||||
-rw-r--r-- | src/hal.rs | 58 | ||||
-rw-r--r-- | src/hal/fake.rs | 12 | ||||
-rw-r--r-- | src/lib.rs | 9 | ||||
-rw-r--r-- | src/queue.rs | 148 | ||||
-rw-r--r-- | src/transport/fake.rs | 4 | ||||
-rw-r--r-- | src/transport/mmio.rs | 9 | ||||
-rw-r--r-- | src/transport/mod.rs | 3 | ||||
-rw-r--r-- | src/transport/pci.rs | 27 |
28 files changed, 1411 insertions, 224 deletions
diff --git a/.cargo_vcs_info.json b/.cargo_vcs_info.json index dc95465..d2f328c 100644 --- a/.cargo_vcs_info.json +++ b/.cargo_vcs_info.json @@ -1,6 +1,6 @@ { "git": { - "sha1": "70f3f76626420e854f1d7cd1dbc8060c27d848cf" + "sha1": "8e52adace55c5e082ba2effffcb70bf480d76ec0" }, "path_in_vcs": "" }
\ No newline at end of file diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index b2ab149..025b7e9 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -78,7 +78,7 @@ jobs: steps: - uses: actions/checkout@v2 - name: Install QEMU - run: sudo apt update && sudo apt install ${{ matrix.packages }} + run: sudo apt update && sudo apt install ${{ matrix.packages }} && sudo chmod 666 /dev/vhost-vsock - uses: actions-rs/toolchain@v1 with: profile: minimal @@ -20,16 +20,15 @@ license { ], } -rust_library { +rust_library_rlib { name: "libvirtio_drivers", crate_name: "virtio_drivers", cargo_env_compat: true, - cargo_pkg_version: "0.3.0", + cargo_pkg_version: "0.4.0", srcs: ["src/lib.rs"], edition: "2018", - no_stdlibs: true, rustlibs: [ - "libbitflags", + "libbitflags-1.3.2", "liblog_rust_nostd", "libzerocopy_nostd", ], @@ -37,19 +36,25 @@ rust_library { "//apex_available:platform", "//apex_available:anyapex", ], + prefer_rlib: true, + no_stdlibs: true, + stdlibs: [ + "libcompiler_builtins.rust_sysroot", + "libcore.rust_sysroot", + ], } rust_test { name: "virtio-drivers_test_src_lib", crate_name: "virtio_drivers", cargo_env_compat: true, - cargo_pkg_version: "0.3.0", + cargo_pkg_version: "0.4.0", srcs: ["src/lib.rs"], test_suites: ["general-tests"], auto_gen_config: true, edition: "2018", rustlibs: [ - "libbitflags", + "libbitflags-1.3.2", "liblog_rust", "libzerocopy", ], @@ -12,7 +12,7 @@ [package] edition = "2018" name = "virtio-drivers" -version = "0.3.0" +version = "0.4.0" authors = [ "Jiajie Chen <noc@jiegec.ac.cn>", "Runji Wang <wangrunji0408@163.com>", diff --git a/Cargo.toml.orig b/Cargo.toml.orig index 01bb35c..431ccd2 100644 --- a/Cargo.toml.orig +++ b/Cargo.toml.orig @@ -1,6 +1,6 @@ [package] name = "virtio-drivers" -version = "0.3.0" +version = "0.4.0" license = "MIT" authors = [ "Jiajie Chen <noc@jiegec.ac.cn>", @@ -11,13 +11,13 @@ third_party { } url { type: ARCHIVE - value: "https://static.crates.io/crates/virtio-drivers/virtio-drivers-0.3.0.crate" + value: "https://static.crates.io/crates/virtio-drivers/virtio-drivers-0.4.0.crate" } - version: "0.3.0" + version: "0.4.0" license_type: NOTICE last_upgrade_date { year: 2023 - month: 1 - day: 24 + month: 4 + day: 19 } } @@ -17,6 +17,7 @@ VirtIO guest drivers in Rust. For **no_std** environment. | GPU | ✅ | | Input | ✅ | | Console | ✅ | +| Socket | ✅ | | ... | ❌ | ### Transports diff --git a/cargo2android.json b/cargo2android.json index d7a5b20..8bd0ae6 100644 --- a/cargo2android.json +++ b/cargo2android.json @@ -2,7 +2,9 @@ "dependencies": true, "device": true, "features": "", + "force-rlib": true, "no-host": "true", + "no-std": true, "patch": "patches/Android.bp.patch", "run": true, "tests": true, diff --git a/patches/Android.bp.patch b/patches/Android.bp.patch index cd5208e..8896c52 100644 --- a/patches/Android.bp.patch +++ b/patches/Android.bp.patch @@ -1,17 +1,26 @@ diff --git a/Android.bp b/Android.bp -index b08769d..230c659 100644 +index f635115..d111544 100644 --- a/Android.bp +++ b/Android.bp -@@ -29,9 +29,10 @@ rust_library { - cargo_pkg_version: "0.2.0", +@@ -28,9 +28,9 @@ rust_library_rlib { srcs: ["src/lib.rs"], edition: "2018", -+ no_stdlibs: true, rustlibs: [ - "libbitflags", +- "libbitflags", - "liblog_rust", - "libzerocopy", ++ "libbitflags-1.3.2", + "liblog_rust_nostd", + "libzerocopy_nostd", ], apex_available: [ + "//apex_available:platform", +@@ -54,7 +54,7 @@ rust_test { + auto_gen_config: true, + edition: "2018", + rustlibs: [ +- "libbitflags", ++ "libbitflags-1.3.2", + "liblog_rust", + "libzerocopy", + ], diff --git a/src/device/blk.rs b/src/device/blk.rs index 69528b6..d095047 100644 --- a/src/device/blk.rs +++ b/src/device/blk.rs @@ -109,7 +109,7 @@ impl<H: Hal, T: Transport> VirtIOBlk<H, T> { let mut resp = BlkResp::default(); self.queue.add_notify_wait_pop( &[req.as_bytes()], - &[buf, resp.as_bytes_mut()], + &mut [buf, resp.as_bytes_mut()], &mut self.transport, )?; resp.status.into() @@ -187,7 +187,7 @@ impl<H: Hal, T: Transport> VirtIOBlk<H, T> { }; let token = self .queue - .add(&[req.as_bytes()], &[buf, resp.as_bytes_mut()])?; + .add(&[req.as_bytes()], &mut [buf, resp.as_bytes_mut()])?; if self.queue.should_notify() { self.transport.notify(QUEUE); } @@ -208,7 +208,7 @@ impl<H: Hal, T: Transport> VirtIOBlk<H, T> { resp: &mut BlkResp, ) -> Result<()> { self.queue - .pop_used(token, &[req.as_bytes()], &[buf, resp.as_bytes_mut()])?; + .pop_used(token, &[req.as_bytes()], &mut [buf, resp.as_bytes_mut()])?; resp.status.into() } @@ -225,7 +225,7 @@ impl<H: Hal, T: Transport> VirtIOBlk<H, T> { let mut resp = BlkResp::default(); self.queue.add_notify_wait_pop( &[req.as_bytes(), buf], - &[resp.as_bytes_mut()], + &mut [resp.as_bytes_mut()], &mut self.transport, )?; resp.status.into() @@ -268,7 +268,7 @@ impl<H: Hal, T: Transport> VirtIOBlk<H, T> { }; let token = self .queue - .add(&[req.as_bytes(), buf], &[resp.as_bytes_mut()])?; + .add(&[req.as_bytes(), buf], &mut [resp.as_bytes_mut()])?; if self.queue.should_notify() { self.transport.notify(QUEUE); } @@ -289,7 +289,7 @@ impl<H: Hal, T: Transport> VirtIOBlk<H, T> { resp: &mut BlkResp, ) -> Result<()> { self.queue - .pop_used(token, &[req.as_bytes(), buf], &[resp.as_bytes_mut()])?; + .pop_used(token, &[req.as_bytes(), buf], &mut [resp.as_bytes_mut()])?; resp.status.into() } diff --git a/src/device/common.rs b/src/device/common.rs new file mode 100644 index 0000000..2c8be3e --- /dev/null +++ b/src/device/common.rs @@ -0,0 +1,23 @@ +//! Common part shared across all the devices. + +use bitflags::bitflags; + +bitflags! { + pub(crate) struct Feature: u64 { + // device independent + const NOTIFY_ON_EMPTY = 1 << 24; // legacy + const ANY_LAYOUT = 1 << 27; // legacy + const RING_INDIRECT_DESC = 1 << 28; + const RING_EVENT_IDX = 1 << 29; + const UNUSED = 1 << 30; // legacy + const VERSION_1 = 1 << 32; // detect legacy + + // since virtio v1.1 + const ACCESS_PLATFORM = 1 << 33; + const RING_PACKED = 1 << 34; + const IN_ORDER = 1 << 35; + const ORDER_PLATFORM = 1 << 36; + const SR_IOV = 1 << 37; + const NOTIFICATION_DATA = 1 << 38; + } +} diff --git a/src/device/console.rs b/src/device/console.rs index 749ebc1..e0b0356 100644 --- a/src/device/console.rs +++ b/src/device/console.rs @@ -118,7 +118,7 @@ impl<H: Hal, T: Transport> VirtIOConsole<'_, H, T> { if self.receive_token.is_none() && self.cursor == self.pending_len { // Safe because the buffer lasts at least as long as the queue, and there are no other // outstanding requests using the buffer. - self.receive_token = Some(unsafe { self.receiveq.add(&[], &[self.queue_buf_rx]) }?); + self.receive_token = Some(unsafe { self.receiveq.add(&[], &mut [self.queue_buf_rx]) }?); if self.receiveq.should_notify() { self.transport.notify(QUEUE_RECEIVEQ_PORT_0); } @@ -145,13 +145,19 @@ impl<H: Hal, T: Transport> VirtIOConsole<'_, H, T> { let mut flag = false; if let Some(receive_token) = self.receive_token { if self.receive_token == self.receiveq.peek_used() { - let len = self - .receiveq - .pop_used(receive_token, &[], &[self.queue_buf_rx])?; + // Safe because we are passing the same buffer as we passed to `VirtQueue::add` in + // `poll_retrieve` and it is still valid. + let len = unsafe { + self.receiveq + .pop_used(receive_token, &[], &mut [self.queue_buf_rx])? + }; flag = true; assert_ne!(len, 0); self.cursor = 0; self.pending_len = len as usize; + // Clear `receive_token` so that when the buffer is used up the next call to + // `poll_retrieve` will add a new pending request. + self.receive_token.take(); } } Ok(flag) @@ -176,9 +182,8 @@ impl<H: Hal, T: Transport> VirtIOConsole<'_, H, T> { /// Sends a character to the console. pub fn send(&mut self, chr: u8) -> Result<()> { let buf: [u8; 1] = [chr]; - // Safe because the buffer is valid until we pop_used below. self.transmitq - .add_notify_wait_pop(&[&buf], &[], &mut self.transport)?; + .add_notify_wait_pop(&[&buf], &mut [], &mut self.transport)?; Ok(()) } } diff --git a/src/device/gpu.rs b/src/device/gpu.rs index eabf2d4..b1b53bd 100644 --- a/src/device/gpu.rs +++ b/src/device/gpu.rs @@ -7,6 +7,7 @@ use crate::volatile::{volread, ReadOnly, Volatile, WriteOnly}; use crate::{pages, Error, Result}; use bitflags::bitflags; use log::info; +use zerocopy::{AsBytes, FromBytes}; const QUEUE_SIZE: u16 = 2; @@ -173,86 +174,86 @@ impl<H: Hal, T: Transport> VirtIOGpu<'_, H, T> { } /// Send a request to the device and block for a response. - fn request<Req, Rsp>(&mut self, req: Req) -> Result<Rsp> { - unsafe { - (self.queue_buf_send.as_mut_ptr() as *mut Req).write(req); - } + fn request<Req: AsBytes, Rsp: FromBytes>(&mut self, req: Req) -> Result<Rsp> { + req.write_to_prefix(&mut *self.queue_buf_send).unwrap(); self.control_queue.add_notify_wait_pop( &[self.queue_buf_send], - &[self.queue_buf_recv], + &mut [self.queue_buf_recv], &mut self.transport, )?; - Ok(unsafe { (self.queue_buf_recv.as_ptr() as *const Rsp).read() }) + Ok(Rsp::read_from_prefix(&*self.queue_buf_recv).unwrap()) } /// Send a mouse cursor operation request to the device and block for a response. - fn cursor_request<Req>(&mut self, req: Req) -> Result { - unsafe { - (self.queue_buf_send.as_mut_ptr() as *mut Req).write(req); - } - self.cursor_queue - .add_notify_wait_pop(&[self.queue_buf_send], &[], &mut self.transport)?; + fn cursor_request<Req: AsBytes>(&mut self, req: Req) -> Result { + req.write_to_prefix(&mut *self.queue_buf_send).unwrap(); + self.cursor_queue.add_notify_wait_pop( + &[self.queue_buf_send], + &mut [], + &mut self.transport, + )?; Ok(()) } fn get_display_info(&mut self) -> Result<RespDisplayInfo> { - let info: RespDisplayInfo = self.request(CtrlHeader::with_type(Command::GetDisplayInfo))?; - info.header.check_type(Command::OkDisplayInfo)?; + let info: RespDisplayInfo = + self.request(CtrlHeader::with_type(Command::GET_DISPLAY_INFO))?; + info.header.check_type(Command::OK_DISPLAY_INFO)?; Ok(info) } fn resource_create_2d(&mut self, resource_id: u32, width: u32, height: u32) -> Result { let rsp: CtrlHeader = self.request(ResourceCreate2D { - header: CtrlHeader::with_type(Command::ResourceCreate2d), + header: CtrlHeader::with_type(Command::RESOURCE_CREATE_2D), resource_id, format: Format::B8G8R8A8UNORM, width, height, })?; - rsp.check_type(Command::OkNodata) + rsp.check_type(Command::OK_NODATA) } fn set_scanout(&mut self, rect: Rect, scanout_id: u32, resource_id: u32) -> Result { let rsp: CtrlHeader = self.request(SetScanout { - header: CtrlHeader::with_type(Command::SetScanout), + header: CtrlHeader::with_type(Command::SET_SCANOUT), rect, scanout_id, resource_id, })?; - rsp.check_type(Command::OkNodata) + rsp.check_type(Command::OK_NODATA) } fn resource_flush(&mut self, rect: Rect, resource_id: u32) -> Result { let rsp: CtrlHeader = self.request(ResourceFlush { - header: CtrlHeader::with_type(Command::ResourceFlush), + header: CtrlHeader::with_type(Command::RESOURCE_FLUSH), rect, resource_id, _padding: 0, })?; - rsp.check_type(Command::OkNodata) + rsp.check_type(Command::OK_NODATA) } fn transfer_to_host_2d(&mut self, rect: Rect, offset: u64, resource_id: u32) -> Result { let rsp: CtrlHeader = self.request(TransferToHost2D { - header: CtrlHeader::with_type(Command::TransferToHost2d), + header: CtrlHeader::with_type(Command::TRANSFER_TO_HOST_2D), rect, offset, resource_id, _padding: 0, })?; - rsp.check_type(Command::OkNodata) + rsp.check_type(Command::OK_NODATA) } fn resource_attach_backing(&mut self, resource_id: u32, paddr: u64, length: u32) -> Result { let rsp: CtrlHeader = self.request(ResourceAttachBacking { - header: CtrlHeader::with_type(Command::ResourceAttachBacking), + header: CtrlHeader::with_type(Command::RESOURCE_ATTACH_BACKING), resource_id, nr_entries: 1, addr: paddr, length, _padding: 0, })?; - rsp.check_type(Command::OkNodata) + rsp.check_type(Command::OK_NODATA) } fn update_cursor( @@ -267,9 +268,9 @@ impl<H: Hal, T: Transport> VirtIOGpu<'_, H, T> { ) -> Result { self.cursor_request(UpdateCursor { header: if is_move { - CtrlHeader::with_type(Command::MoveCursor) + CtrlHeader::with_type(Command::MOVE_CURSOR) } else { - CtrlHeader::with_type(Command::UpdateCursor) + CtrlHeader::with_type(Command::UPDATE_CURSOR) }, pos: CursorPos { scanout_id, @@ -336,39 +337,41 @@ bitflags! { } } -#[repr(u32)] -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -enum Command { - GetDisplayInfo = 0x100, - ResourceCreate2d = 0x101, - ResourceUnref = 0x102, - SetScanout = 0x103, - ResourceFlush = 0x104, - TransferToHost2d = 0x105, - ResourceAttachBacking = 0x106, - ResourceDetachBacking = 0x107, - GetCapsetInfo = 0x108, - GetCapset = 0x109, - GetEdid = 0x10a, - - UpdateCursor = 0x300, - MoveCursor = 0x301, - - OkNodata = 0x1100, - OkDisplayInfo = 0x1101, - OkCapsetInfo = 0x1102, - OkCapset = 0x1103, - OkEdid = 0x1104, - - ErrUnspec = 0x1200, - ErrOutOfMemory = 0x1201, - ErrInvalidScanoutId = 0x1202, +#[repr(transparent)] +#[derive(AsBytes, Clone, Copy, Debug, Eq, PartialEq, FromBytes)] +struct Command(u32); + +impl Command { + const GET_DISPLAY_INFO: Command = Command(0x100); + const RESOURCE_CREATE_2D: Command = Command(0x101); + const RESOURCE_UNREF: Command = Command(0x102); + const SET_SCANOUT: Command = Command(0x103); + const RESOURCE_FLUSH: Command = Command(0x104); + const TRANSFER_TO_HOST_2D: Command = Command(0x105); + const RESOURCE_ATTACH_BACKING: Command = Command(0x106); + const RESOURCE_DETACH_BACKING: Command = Command(0x107); + const GET_CAPSET_INFO: Command = Command(0x108); + const GET_CAPSET: Command = Command(0x109); + const GET_EDID: Command = Command(0x10a); + + const UPDATE_CURSOR: Command = Command(0x300); + const MOVE_CURSOR: Command = Command(0x301); + + const OK_NODATA: Command = Command(0x1100); + const OK_DISPLAY_INFO: Command = Command(0x1101); + const OK_CAPSET_INFO: Command = Command(0x1102); + const OK_CAPSET: Command = Command(0x1103); + const OK_EDID: Command = Command(0x1104); + + const ERR_UNSPEC: Command = Command(0x1200); + const ERR_OUT_OF_MEMORY: Command = Command(0x1201); + const ERR_INVALID_SCANOUT_ID: Command = Command(0x1202); } const GPU_FLAG_FENCE: u32 = 1 << 0; #[repr(C)] -#[derive(Debug, Clone, Copy)] +#[derive(AsBytes, Debug, Clone, Copy, FromBytes)] struct CtrlHeader { hdr_type: Command, flags: u32, @@ -399,7 +402,7 @@ impl CtrlHeader { } #[repr(C)] -#[derive(Debug, Copy, Clone, Default)] +#[derive(AsBytes, Debug, Copy, Clone, Default, FromBytes)] struct Rect { x: u32, y: u32, @@ -408,7 +411,7 @@ struct Rect { } #[repr(C)] -#[derive(Debug)] +#[derive(Debug, FromBytes)] struct RespDisplayInfo { header: CtrlHeader, rect: Rect, @@ -417,7 +420,7 @@ struct RespDisplayInfo { } #[repr(C)] -#[derive(Debug)] +#[derive(AsBytes, Debug)] struct ResourceCreate2D { header: CtrlHeader, resource_id: u32, @@ -427,13 +430,13 @@ struct ResourceCreate2D { } #[repr(u32)] -#[derive(Debug)] +#[derive(AsBytes, Debug)] enum Format { B8G8R8A8UNORM = 1, } #[repr(C)] -#[derive(Debug)] +#[derive(AsBytes, Debug)] struct ResourceAttachBacking { header: CtrlHeader, resource_id: u32, @@ -444,7 +447,7 @@ struct ResourceAttachBacking { } #[repr(C)] -#[derive(Debug)] +#[derive(AsBytes, Debug)] struct SetScanout { header: CtrlHeader, rect: Rect, @@ -453,7 +456,7 @@ struct SetScanout { } #[repr(C)] -#[derive(Debug)] +#[derive(AsBytes, Debug)] struct TransferToHost2D { header: CtrlHeader, rect: Rect, @@ -463,7 +466,7 @@ struct TransferToHost2D { } #[repr(C)] -#[derive(Debug)] +#[derive(AsBytes, Debug)] struct ResourceFlush { header: CtrlHeader, rect: Rect, @@ -472,7 +475,7 @@ struct ResourceFlush { } #[repr(C)] -#[derive(Debug, Clone, Copy)] +#[derive(AsBytes, Debug, Clone, Copy)] struct CursorPos { scanout_id: u32, x: u32, @@ -481,7 +484,7 @@ struct CursorPos { } #[repr(C)] -#[derive(Debug, Clone, Copy)] +#[derive(AsBytes, Debug, Clone, Copy)] struct UpdateCursor { header: CtrlHeader, pos: CursorPos, diff --git a/src/device/input.rs b/src/device/input.rs index 8554282..dee2fec 100644 --- a/src/device/input.rs +++ b/src/device/input.rs @@ -1,12 +1,12 @@ //! Driver for VirtIO input devices. +use super::common::Feature; use crate::hal::Hal; use crate::queue::VirtQueue; use crate::transport::Transport; use crate::volatile::{volread, volwrite, ReadOnly, WriteOnly}; use crate::Result; use alloc::boxed::Box; -use bitflags::bitflags; use core::ptr::NonNull; use log::info; use zerocopy::{AsBytes, FromBytes}; @@ -42,7 +42,7 @@ impl<H: Hal, T: Transport> VirtIOInput<H, T> { let status_queue = VirtQueue::new(&mut transport, QUEUE_STATUS)?; for (i, event) in event_buf.as_mut().iter_mut().enumerate() { // Safe because the buffer lasts as long as the queue. - let token = unsafe { event_queue.add(&[], &[event.as_bytes_mut()])? }; + let token = unsafe { event_queue.add(&[], &mut [event.as_bytes_mut()])? }; assert_eq!(token, i as u16); } if event_queue.should_notify() { @@ -69,12 +69,18 @@ impl<H: Hal, T: Transport> VirtIOInput<H, T> { pub fn pop_pending_event(&mut self) -> Option<InputEvent> { if let Some(token) = self.event_queue.peek_used() { let event = &mut self.event_buf[token as usize]; - self.event_queue - .pop_used(token, &[], &[event.as_bytes_mut()]) - .ok()?; + // Safe because we are passing the same buffer as we passed to `VirtQueue::add` and it + // is still valid. + unsafe { + self.event_queue + .pop_used(token, &[], &mut [event.as_bytes_mut()]) + .ok()?; + } + let event_saved = *event; // requeue // Safe because buffer lasts as long as the queue. - if let Ok(new_token) = unsafe { self.event_queue.add(&[], &[event.as_bytes_mut()]) } { + if let Ok(new_token) = unsafe { self.event_queue.add(&[], &mut [event.as_bytes_mut()]) } + { // This only works because nothing happen between `pop_used` and `add` that affects // the list of free descriptors in the queue, so `add` reuses the descriptor which // was just freed by `pop_used`. @@ -82,7 +88,7 @@ impl<H: Hal, T: Transport> VirtIOInput<H, T> { if self.event_queue.should_notify() { self.transport.notify(QUEUE_EVENT); } - return Some(*event); + return Some(event_saved); } } None @@ -185,26 +191,6 @@ pub struct InputEvent { pub value: u32, } -bitflags! { - struct Feature: u64 { - // device independent - const NOTIFY_ON_EMPTY = 1 << 24; // legacy - const ANY_LAYOUT = 1 << 27; // legacy - const RING_INDIRECT_DESC = 1 << 28; - const RING_EVENT_IDX = 1 << 29; - const UNUSED = 1 << 30; // legacy - const VERSION_1 = 1 << 32; // detect legacy - - // since virtio v1.1 - const ACCESS_PLATFORM = 1 << 33; - const RING_PACKED = 1 << 34; - const IN_ORDER = 1 << 35; - const ORDER_PLATFORM = 1 << 36; - const SR_IOV = 1 << 37; - const NOTIFICATION_DATA = 1 << 38; - } -} - const QUEUE_EVENT: u16 = 0; const QUEUE_STATUS: u16 = 1; diff --git a/src/device/mod.rs b/src/device/mod.rs index f3e4f66..ca68901 100644 --- a/src/device/mod.rs +++ b/src/device/mod.rs @@ -5,4 +5,8 @@ pub mod console; pub mod gpu; #[cfg(feature = "alloc")] pub mod input; +#[cfg(feature = "alloc")] pub mod net; +pub mod socket; + +pub(crate) mod common; diff --git a/src/device/net.rs b/src/device/net.rs index 7ca487e..4441f63 100644 --- a/src/device/net.rs +++ b/src/device/net.rs @@ -4,13 +4,95 @@ use crate::hal::Hal; use crate::queue::VirtQueue; use crate::transport::Transport; use crate::volatile::{volread, ReadOnly}; -use crate::Result; +use crate::{Error, Result}; +use alloc::{vec, vec::Vec}; use bitflags::bitflags; -use core::mem::{size_of, MaybeUninit}; -use log::{debug, info}; +use core::{convert::TryInto, mem::size_of}; +use log::{debug, info, warn}; use zerocopy::{AsBytes, FromBytes}; -const QUEUE_SIZE: u16 = 2; +const MAX_BUFFER_LEN: usize = 65535; +const MIN_BUFFER_LEN: usize = 1526; +const NET_HDR_SIZE: usize = size_of::<VirtioNetHdr>(); + +/// A buffer used for transmitting. +pub struct TxBuffer(Vec<u8>); + +/// A buffer used for receiving. +pub struct RxBuffer { + buf: Vec<usize>, // for alignment + packet_len: usize, + idx: u16, +} + +impl TxBuffer { + /// Constructs the buffer from the given slice. + pub fn from(buf: &[u8]) -> Self { + Self(Vec::from(buf)) + } + + /// Returns the network packet length. + pub fn packet_len(&self) -> usize { + self.0.len() + } + + /// Returns the network packet as a slice. + pub fn packet(&self) -> &[u8] { + self.0.as_slice() + } + + /// Returns the network packet as a mutable slice. + pub fn packet_mut(&mut self) -> &mut [u8] { + self.0.as_mut_slice() + } +} + +impl RxBuffer { + /// Allocates a new buffer with length `buf_len`. + fn new(idx: usize, buf_len: usize) -> Self { + Self { + buf: vec![0; buf_len / size_of::<usize>()], + packet_len: 0, + idx: idx.try_into().unwrap(), + } + } + + /// Set the network packet length. + fn set_packet_len(&mut self, packet_len: usize) { + self.packet_len = packet_len + } + + /// Returns the network packet length (witout header). + pub const fn packet_len(&self) -> usize { + self.packet_len + } + + /// Returns all data in the buffer, including both the header and the packet. + pub fn as_bytes(&self) -> &[u8] { + self.buf.as_bytes() + } + + /// Returns all data in the buffer with the mutable reference, + /// including both the header and the packet. + pub fn as_bytes_mut(&mut self) -> &mut [u8] { + self.buf.as_bytes_mut() + } + + /// Returns the reference of the header. + pub fn header(&self) -> &VirtioNetHdr { + unsafe { &*(self.buf.as_ptr() as *const VirtioNetHdr) } + } + + /// Returns the network packet as a slice. + pub fn packet(&self) -> &[u8] { + &self.buf.as_bytes()[NET_HDR_SIZE..NET_HDR_SIZE + self.packet_len] + } + + /// Returns the network packet as a mutable slice. + pub fn packet_mut(&mut self) -> &mut [u8] { + &mut self.buf.as_bytes_mut()[NET_HDR_SIZE..NET_HDR_SIZE + self.packet_len] + } +} /// The virtio network device is a virtual ethernet card. /// @@ -19,16 +101,17 @@ const QUEUE_SIZE: u16 = 2; /// Empty buffers are placed in one virtqueue for receiving packets, and /// outgoing packets are enqueued into another for transmission in that order. /// A third command queue is used to control advanced filtering features. -pub struct VirtIONet<H: Hal, T: Transport> { +pub struct VirtIONet<H: Hal, T: Transport, const QUEUE_SIZE: usize> { transport: T, mac: EthernetAddress, - recv_queue: VirtQueue<H, { QUEUE_SIZE as usize }>, - send_queue: VirtQueue<H, { QUEUE_SIZE as usize }>, + recv_queue: VirtQueue<H, QUEUE_SIZE>, + send_queue: VirtQueue<H, QUEUE_SIZE>, + rx_buffers: [Option<RxBuffer>; QUEUE_SIZE], } -impl<H: Hal, T: Transport> VirtIONet<H, T> { +impl<H: Hal, T: Transport, const QUEUE_SIZE: usize> VirtIONet<H, T, QUEUE_SIZE> { /// Create a new VirtIO-Net driver. - pub fn new(mut transport: T) -> Result<Self> { + pub fn new(mut transport: T, buf_len: usize) -> Result<Self> { transport.begin_init(|features| { let features = Features::from_bits_truncate(features); info!("Device features {:?}", features); @@ -41,11 +124,37 @@ impl<H: Hal, T: Transport> VirtIONet<H, T> { // Safe because config points to a valid MMIO region for the config space. unsafe { mac = volread!(config, mac); - debug!("Got MAC={:?}, status={:?}", mac, volread!(config, status)); + debug!( + "Got MAC={:02x?}, status={:?}", + mac, + volread!(config, status) + ); + } + + if !(MIN_BUFFER_LEN..=MAX_BUFFER_LEN).contains(&buf_len) { + warn!( + "Receive buffer len {} is not in range [{}, {}]", + buf_len, MIN_BUFFER_LEN, MAX_BUFFER_LEN + ); + return Err(Error::InvalidParam); } - let recv_queue = VirtQueue::new(&mut transport, QUEUE_RECEIVE)?; let send_queue = VirtQueue::new(&mut transport, QUEUE_TRANSMIT)?; + let mut recv_queue = VirtQueue::new(&mut transport, QUEUE_RECEIVE)?; + + const NONE_BUF: Option<RxBuffer> = None; + let mut rx_buffers = [NONE_BUF; QUEUE_SIZE]; + for (i, rx_buf_place) in rx_buffers.iter_mut().enumerate() { + let mut rx_buf = RxBuffer::new(i, buf_len); + // Safe because the buffer lives as long as the queue. + let token = unsafe { recv_queue.add(&[], &mut [rx_buf.as_bytes_mut()])? }; + assert_eq!(token, i as u16); + *rx_buf_place = Some(rx_buf); + } + + if recv_queue.should_notify() { + transport.notify(QUEUE_RECEIVE); + } transport.finish_init(); @@ -54,6 +163,7 @@ impl<H: Hal, T: Transport> VirtIONet<H, T> { mac, recv_queue, send_queue, + rx_buffers, }) } @@ -63,7 +173,7 @@ impl<H: Hal, T: Transport> VirtIONet<H, T> { } /// Get MAC address. - pub fn mac(&self) -> EthernetAddress { + pub fn mac_address(&self) -> EthernetAddress { self.mac } @@ -77,27 +187,72 @@ impl<H: Hal, T: Transport> VirtIONet<H, T> { self.recv_queue.can_pop() } - /// Receive a packet. - pub fn recv(&mut self, buf: &mut [u8]) -> Result<usize> { - let mut header = MaybeUninit::<Header>::uninit(); - let header_buf = unsafe { (*header.as_mut_ptr()).as_bytes_mut() }; - let len = - self.recv_queue - .add_notify_wait_pop(&[], &[header_buf, buf], &mut self.transport)?; - // let header = unsafe { header.assume_init() }; - Ok(len as usize - size_of::<Header>()) + /// Receives a [`RxBuffer`] from network. If currently no data, returns an + /// error with type [`Error::NotReady`]. + /// + /// It will try to pop a buffer that completed data reception in the + /// NIC queue. + pub fn receive(&mut self) -> Result<RxBuffer> { + if let Some(token) = self.recv_queue.peek_used() { + let mut rx_buf = self.rx_buffers[token as usize] + .take() + .ok_or(Error::WrongToken)?; + if token != rx_buf.idx { + return Err(Error::WrongToken); + } + + // Safe because `token` == `rx_buf.idx`, we are passing the same + // buffer as we passed to `VirtQueue::add` and it is still valid. + let len = unsafe { + self.recv_queue + .pop_used(token, &[], &mut [rx_buf.as_bytes_mut()])? + } as usize; + rx_buf.set_packet_len(len.checked_sub(NET_HDR_SIZE).ok_or(Error::IoError)?); + Ok(rx_buf) + } else { + Err(Error::NotReady) + } } - /// Send a packet. - pub fn send(&mut self, buf: &[u8]) -> Result { - let header = unsafe { MaybeUninit::<Header>::zeroed().assume_init() }; - self.send_queue - .add_notify_wait_pop(&[header.as_bytes(), buf], &[], &mut self.transport)?; + /// Gives back the ownership of `rx_buf`, and recycles it for next use. + /// + /// It will add the buffer back to the NIC queue. + pub fn recycle_rx_buffer(&mut self, mut rx_buf: RxBuffer) -> Result { + // Safe because we take the ownership of `rx_buf` back to `rx_buffers`, + // it lives as long as the queue. + let new_token = unsafe { self.recv_queue.add(&[], &mut [rx_buf.as_bytes_mut()]) }?; + // `rx_buffers[new_token]` is expected to be `None` since it was taken + // away at `Self::receive()` and has not been added back. + if self.rx_buffers[new_token as usize].is_some() { + return Err(Error::WrongToken); + } + rx_buf.idx = new_token; + self.rx_buffers[new_token as usize] = Some(rx_buf); + if self.recv_queue.should_notify() { + self.transport.notify(QUEUE_RECEIVE); + } + Ok(()) + } + + /// Allocate a new buffer for transmitting. + pub fn new_tx_buffer(&self, buf_len: usize) -> TxBuffer { + TxBuffer(vec![0; buf_len]) + } + + /// Sends a [`TxBuffer`] to the network, and blocks until the request + /// completed. + pub fn send(&mut self, tx_buf: TxBuffer) -> Result { + let header = VirtioNetHdr::default(); + self.send_queue.add_notify_wait_pop( + &[header.as_bytes(), tx_buf.packet()], + &mut [], + &mut self.transport, + )?; Ok(()) } } -impl<H: Hal, T: Transport> Drop for VirtIONet<H, T> { +impl<H: Hal, T: Transport, const QUEUE_SIZE: usize> Drop for VirtIONet<H, T, QUEUE_SIZE> { fn drop(&mut self) { // Clear any pointers pointing to DMA regions, so the device doesn't try to access them // after they have been freed. @@ -185,26 +340,33 @@ bitflags! { struct Config { mac: ReadOnly<EthernetAddress>, status: ReadOnly<Status>, + max_virtqueue_pairs: ReadOnly<u16>, + mtu: ReadOnly<u16>, } type EthernetAddress = [u8; 6]; -// virtio 5.1.6 Device Operation +/// VirtIO 5.1.6 Device Operation: +/// +/// Packets are transmitted by placing them in the transmitq1. . .transmitqN, +/// and buffers for incoming packets are placed in the receiveq1. . .receiveqN. +/// In each case, the packet itself is preceded by a header. #[repr(C)] -#[derive(AsBytes, Debug, FromBytes)] -struct Header { +#[derive(AsBytes, Debug, Default, FromBytes)] +pub struct VirtioNetHdr { flags: Flags, gso_type: GsoType, hdr_len: u16, // cannot rely on this gso_size: u16, csum_start: u16, csum_offset: u16, + // num_buffers: u16, // only available when the feature MRG_RXBUF is negotiated. // payload starts from here } bitflags! { #[repr(transparent)] - #[derive(AsBytes, FromBytes)] + #[derive(AsBytes, Default, FromBytes)] struct Flags: u8 { const NEEDS_CSUM = 1; const DATA_VALID = 2; @@ -213,7 +375,7 @@ bitflags! { } #[repr(transparent)] -#[derive(AsBytes, Debug, Copy, Clone, Eq, FromBytes, PartialEq)] +#[derive(AsBytes, Debug, Copy, Clone, Default, Eq, FromBytes, PartialEq)] struct GsoType(u8); impl GsoType { diff --git a/src/device/socket/error.rs b/src/device/socket/error.rs new file mode 100644 index 0000000..4beec38 --- /dev/null +++ b/src/device/socket/error.rs @@ -0,0 +1,69 @@ +//! This module contain the error from the VirtIO socket driver. + +use core::{fmt, result}; + +/// The error type of VirtIO socket driver. +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +pub enum SocketError { + /// There is an existing connection. + ConnectionExists, + /// Failed to establish the connection. + ConnectionFailed, + /// The device is not connected to any peer. + NotConnected, + /// Peer socket is shutdown. + PeerSocketShutdown, + /// No response received. + NoResponseReceived, + /// The given buffer is shorter than expected. + BufferTooShort, + /// The given buffer for output is shorter than expected. + OutputBufferTooShort(usize), + /// The given buffer has exceeded the maximum buffer size. + BufferTooLong(usize, usize), + /// Unknown operation. + UnknownOperation(u16), + /// Invalid operation, + InvalidOperation, + /// Invalid number. + InvalidNumber, + /// Unexpected data in packet. + UnexpectedDataInPacket, + /// Peer has insufficient buffer space, try again later. + InsufficientBufferSpaceInPeer, + /// Recycled a wrong buffer. + RecycledWrongBuffer, +} + +impl fmt::Display for SocketError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::ConnectionExists => write!( + f, + "There is an existing connection. Please close the current connection before attempting to connect again."), + Self::ConnectionFailed => write!( + f, "Failed to establish the connection. The packet sent may have an unknown type value" + ), + Self::NotConnected => write!(f, "The device is not connected to any peer. Please connect it to a peer first."), + Self::PeerSocketShutdown => write!(f, "The peer socket is shutdown."), + Self::NoResponseReceived => write!(f, "No response received"), + Self::BufferTooShort => write!(f, "The given buffer is shorter than expected"), + Self::BufferTooLong(actual, max) => { + write!(f, "The given buffer length '{actual}' has exceeded the maximum allowed buffer length '{max}'") + } + Self::OutputBufferTooShort(expected) => { + write!(f, "The given output buffer is too short. '{expected}' bytes is needed for the output buffer.") + } + Self::UnknownOperation(op) => { + write!(f, "The operation code '{op}' is unknown") + } + Self::InvalidOperation => write!(f, "Invalid operation"), + Self::InvalidNumber => write!(f, "Invalid number"), + Self::UnexpectedDataInPacket => write!(f, "No data is expected in the packet"), + Self::InsufficientBufferSpaceInPeer => write!(f, "Peer has insufficient buffer space, try again later"), + Self::RecycledWrongBuffer => write!(f, "Recycled a wrong buffer"), + } + } +} + +pub type Result<T> = result::Result<T, SocketError>; diff --git a/src/device/socket/mod.rs b/src/device/socket/mod.rs new file mode 100644 index 0000000..65280aa --- /dev/null +++ b/src/device/socket/mod.rs @@ -0,0 +1,8 @@ +//! This module implements the virtio vsock device. + +mod error; +mod protocol; +mod vsock; + +pub use error::SocketError; +pub use vsock::{DisconnectReason, VirtIOSocket, VsockEvent, VsockEventType}; diff --git a/src/device/socket/protocol.rs b/src/device/socket/protocol.rs new file mode 100644 index 0000000..abc1702 --- /dev/null +++ b/src/device/socket/protocol.rs @@ -0,0 +1,184 @@ +//! This module defines the socket device protocol according to the virtio spec v1.1 5.10 Socket Device + +use super::error::{self, SocketError}; +use crate::volatile::ReadOnly; +use core::{ + convert::{TryFrom, TryInto}, + fmt, +}; +use zerocopy::{ + byteorder::{LittleEndian, U16, U32, U64}, + AsBytes, FromBytes, +}; + +/// Currently only stream sockets are supported. type is 1 for stream socket types. +#[derive(Copy, Clone, Debug)] +#[repr(u16)] +pub enum SocketType { + /// Stream sockets provide in-order, guaranteed, connection-oriented delivery without message boundaries. + Stream = 1, +} + +impl From<SocketType> for U16<LittleEndian> { + fn from(socket_type: SocketType) -> Self { + (socket_type as u16).into() + } +} + +/// VirtioVsockConfig is the vsock device configuration space. +#[repr(C)] +pub struct VirtioVsockConfig { + /// The guest_cid field contains the guest’s context ID, which uniquely identifies + /// the device for its lifetime. The upper 32 bits of the CID are reserved and zeroed. + /// + /// According to virtio spec v1.1 2.4.1 Driver Requirements: Device Configuration Space, + /// drivers MUST NOT assume reads from fields greater than 32 bits wide are atomic. + /// So we need to split the u64 guest_cid into two parts. + pub guest_cid_low: ReadOnly<u32>, + pub guest_cid_high: ReadOnly<u32>, +} + +/// The message header for data packets sent on the tx/rx queues +#[repr(packed)] +#[derive(AsBytes, Clone, Copy, Debug, FromBytes)] +pub struct VirtioVsockHdr { + pub src_cid: U64<LittleEndian>, + pub dst_cid: U64<LittleEndian>, + pub src_port: U32<LittleEndian>, + pub dst_port: U32<LittleEndian>, + pub len: U32<LittleEndian>, + pub socket_type: U16<LittleEndian>, + pub op: U16<LittleEndian>, + pub flags: U32<LittleEndian>, + /// Total receive buffer space for this socket. This includes both free and in-use buffers. + pub buf_alloc: U32<LittleEndian>, + /// Free-running bytes received counter. + pub fwd_cnt: U32<LittleEndian>, +} + +impl Default for VirtioVsockHdr { + fn default() -> Self { + Self { + src_cid: 0.into(), + dst_cid: 0.into(), + src_port: 0.into(), + dst_port: 0.into(), + len: 0.into(), + socket_type: SocketType::Stream.into(), + op: 0.into(), + flags: 0.into(), + buf_alloc: 0.into(), + fwd_cnt: 0.into(), + } + } +} + +impl VirtioVsockHdr { + /// Returns the length of the data. + pub fn len(&self) -> u32 { + u32::from(self.len) + } + + pub fn op(&self) -> error::Result<VirtioVsockOp> { + self.op.try_into() + } + + pub fn source(&self) -> VsockAddr { + VsockAddr { + cid: self.src_cid.get(), + port: self.src_port.get(), + } + } + + pub fn destination(&self) -> VsockAddr { + VsockAddr { + cid: self.dst_cid.get(), + port: self.dst_port.get(), + } + } + + pub fn check_data_is_empty(&self) -> error::Result<()> { + if self.len() == 0 { + Ok(()) + } else { + Err(SocketError::UnexpectedDataInPacket) + } + } +} + +/// Socket address. +#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)] +pub struct VsockAddr { + /// Context Identifier. + pub cid: u64, + /// Port number. + pub port: u32, +} + +/// An event sent to the event queue +#[derive(Copy, Clone, Debug, Default, AsBytes, FromBytes)] +#[repr(C)] +pub struct VirtioVsockEvent { + // ID from the virtio_vsock_event_id struct in the virtio spec + pub id: U32<LittleEndian>, +} + +#[derive(Copy, Clone, Eq, PartialEq)] +#[repr(u16)] +pub enum VirtioVsockOp { + Invalid = 0, + + /* Connect operations */ + Request = 1, + Response = 2, + Rst = 3, + Shutdown = 4, + + /* To send payload */ + Rw = 5, + + /* Tell the peer our credit info */ + CreditUpdate = 6, + /* Request the peer to send the credit info to us */ + CreditRequest = 7, +} + +impl From<VirtioVsockOp> for U16<LittleEndian> { + fn from(op: VirtioVsockOp) -> Self { + (op as u16).into() + } +} + +impl TryFrom<U16<LittleEndian>> for VirtioVsockOp { + type Error = SocketError; + + fn try_from(v: U16<LittleEndian>) -> Result<Self, Self::Error> { + let op = match u16::from(v) { + 0 => Self::Invalid, + 1 => Self::Request, + 2 => Self::Response, + 3 => Self::Rst, + 4 => Self::Shutdown, + 5 => Self::Rw, + 6 => Self::CreditUpdate, + 7 => Self::CreditRequest, + _ => return Err(SocketError::UnknownOperation(v.into())), + }; + Ok(op) + } +} + +impl fmt::Debug for VirtioVsockOp { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::Invalid => write!(f, "VIRTIO_VSOCK_OP_INVALID"), + Self::Request => write!(f, "VIRTIO_VSOCK_OP_REQUEST"), + Self::Response => write!(f, "VIRTIO_VSOCK_OP_RESPONSE"), + Self::Rst => write!(f, "VIRTIO_VSOCK_OP_RST"), + Self::Shutdown => write!(f, "VIRTIO_VSOCK_OP_SHUTDOWN"), + Self::Rw => write!(f, "VIRTIO_VSOCK_OP_RW"), + Self::CreditUpdate => write!(f, "VIRTIO_VSOCK_OP_CREDIT_UPDATE"), + Self::CreditRequest => write!(f, "VIRTIO_VSOCK_OP_CREDIT_REQUEST"), + } + } +} diff --git a/src/device/socket/vsock.rs b/src/device/socket/vsock.rs new file mode 100644 index 0000000..686d7a6 --- /dev/null +++ b/src/device/socket/vsock.rs @@ -0,0 +1,596 @@ +//! Driver for VirtIO socket devices. +#![deny(unsafe_op_in_unsafe_fn)] + +use super::error::SocketError; +use super::protocol::{VirtioVsockConfig, VirtioVsockHdr, VirtioVsockOp, VsockAddr}; +use crate::device::common::Feature; +use crate::hal::{BufferDirection, Dma, Hal}; +use crate::queue::VirtQueue; +use crate::transport::Transport; +use crate::volatile::volread; +use crate::Result; +use core::hint::spin_loop; +use core::mem::size_of; +use core::ptr::NonNull; +use log::{debug, info}; +use zerocopy::{AsBytes, FromBytes}; + +const RX_QUEUE_IDX: u16 = 0; +const TX_QUEUE_IDX: u16 = 1; +const EVENT_QUEUE_IDX: u16 = 2; + +const QUEUE_SIZE: usize = 8; + +#[derive(Clone, Debug, Default, PartialEq, Eq)] +struct ConnectionInfo { + dst: VsockAddr, + src_port: u32, + /// The last `buf_alloc` value the peer sent to us, indicating how much receive buffer space in + /// bytes it has allocated for packet bodies. + peer_buf_alloc: u32, + /// The last `fwd_cnt` value the peer sent to us, indicating how many bytes of packet bodies it + /// has finished processing. + peer_fwd_cnt: u32, + /// The number of bytes of packet bodies which we have sent to the peer. + tx_cnt: u32, + /// The number of bytes of packet bodies which we have received from the peer and handled. + fwd_cnt: u32, + /// Whether we have recently requested credit from the peer. + /// + /// This is set to true when we send a `VIRTIO_VSOCK_OP_CREDIT_REQUEST`, and false when we + /// receive a `VIRTIO_VSOCK_OP_CREDIT_UPDATE`. + has_pending_credit_request: bool, +} + +impl ConnectionInfo { + fn peer_free(&self) -> u32 { + self.peer_buf_alloc - (self.tx_cnt - self.peer_fwd_cnt) + } + + fn new_header(&self, src_cid: u64) -> VirtioVsockHdr { + VirtioVsockHdr { + src_cid: src_cid.into(), + dst_cid: self.dst.cid.into(), + src_port: self.src_port.into(), + dst_port: self.dst.port.into(), + fwd_cnt: self.fwd_cnt.into(), + ..Default::default() + } + } +} + +/// An event received from a VirtIO socket device. +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct VsockEvent { + /// The source of the event, i.e. the peer who sent it. + pub source: VsockAddr, + /// The destination of the event, i.e. the CID and port on our side. + pub destination: VsockAddr, + /// The type of event. + pub event_type: VsockEventType, +} + +/// The reason why a vsock connection was closed. +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +pub enum DisconnectReason { + /// The peer has either closed the connection in response to our shutdown request, or forcibly + /// closed it of its own accord. + Reset, + /// The peer asked to shut down the connection. + Shutdown, +} + +/// Details of the type of an event received from a VirtIO socket. +#[derive(Clone, Debug, Eq, PartialEq)] +pub enum VsockEventType { + /// The connection was successfully established. + Connected, + /// The connection was closed. + Disconnected { + /// The reason for the disconnection. + reason: DisconnectReason, + }, + /// Data was received on the connection. + Received { + /// The length of the data in bytes. + length: usize, + }, +} + +/// Driver for a VirtIO socket device. +pub struct VirtIOSocket<H: Hal, T: Transport> { + transport: T, + /// Virtqueue to receive packets. + rx: VirtQueue<H, { QUEUE_SIZE }>, + tx: VirtQueue<H, { QUEUE_SIZE }>, + /// Virtqueue to receive events from the device. + event: VirtQueue<H, { QUEUE_SIZE }>, + /// The guest_cid field contains the guest’s context ID, which uniquely identifies + /// the device for its lifetime. The upper 32 bits of the CID are reserved and zeroed. + guest_cid: u64, + rx_buf_dma: Dma<H>, + + /// Currently the device is only allowed to be connected to one destination at a time. + connection_info: Option<ConnectionInfo>, +} + +impl<H: Hal, T: Transport> Drop for VirtIOSocket<H, T> { + fn drop(&mut self) { + // Clear any pointers pointing to DMA regions, so the device doesn't try to access them + // after they have been freed. + self.transport.queue_unset(RX_QUEUE_IDX); + self.transport.queue_unset(TX_QUEUE_IDX); + self.transport.queue_unset(EVENT_QUEUE_IDX); + } +} + +impl<H: Hal, T: Transport> VirtIOSocket<H, T> { + /// Create a new VirtIO Vsock driver. + pub fn new(mut transport: T) -> Result<Self> { + transport.begin_init(|features| { + let features = Feature::from_bits_truncate(features); + info!("Device features: {:?}", features); + // negotiate these flags only + let supported_features = Feature::empty(); + (features & supported_features).bits() + }); + + let config = transport.config_space::<VirtioVsockConfig>()?; + info!("config: {:?}", config); + // Safe because config is a valid pointer to the device configuration space. + let guest_cid = unsafe { + volread!(config, guest_cid_low) as u64 | (volread!(config, guest_cid_high) as u64) << 32 + }; + info!("guest cid: {guest_cid:?}"); + + let mut rx = VirtQueue::new(&mut transport, RX_QUEUE_IDX)?; + let tx = VirtQueue::new(&mut transport, TX_QUEUE_IDX)?; + let event = VirtQueue::new(&mut transport, EVENT_QUEUE_IDX)?; + + // Allocates 4 KiB memory as the rx buffer. + let rx_buf_dma = Dma::new( + 1, // pages + BufferDirection::DeviceToDriver, + )?; + let rx_buf = rx_buf_dma.raw_slice(); + // Safe because `rx_buf` lives as long as the `rx` queue. + unsafe { + Self::fill_rx_queue(&mut rx, rx_buf, &mut transport)?; + } + transport.finish_init(); + + Ok(Self { + transport, + rx, + tx, + event, + guest_cid, + rx_buf_dma, + connection_info: None, + }) + } + + /// Fills the `rx` queue with the buffer `rx_buf`. + /// + /// # Safety + /// + /// `rx_buf` must live at least as long as the `rx` queue, and the parts of the buffer which are + /// in the queue must not be used anywhere else at the same time. + unsafe fn fill_rx_queue( + rx: &mut VirtQueue<H, { QUEUE_SIZE }>, + rx_buf: NonNull<[u8]>, + transport: &mut T, + ) -> Result { + if rx_buf.len() < size_of::<VirtioVsockHdr>() * QUEUE_SIZE { + return Err(SocketError::BufferTooShort.into()); + } + for i in 0..QUEUE_SIZE { + // Safe because the buffer lives as long as the queue, as specified in the function + // safety requirement, and we don't access it until it is popped. + unsafe { + let buffer = Self::as_mut_sub_rx_buffer(rx_buf, i)?; + let token = rx.add(&[], &mut [buffer])?; + assert_eq!(i, token.into()); + } + } + + if rx.should_notify() { + transport.notify(RX_QUEUE_IDX); + } + Ok(()) + } + + /// Sends a request to connect to the given destination. + /// + /// This returns as soon as the request is sent; you should wait until `poll_recv` returns a + /// `VsockEventType::Connected` event indicating that the peer has accepted the connection + /// before sending data. + pub fn connect(&mut self, dst_cid: u64, src_port: u32, dst_port: u32) -> Result { + if self.connection_info.is_some() { + return Err(SocketError::ConnectionExists.into()); + } + let new_connection_info = ConnectionInfo { + dst: VsockAddr { + cid: dst_cid, + port: dst_port, + }, + src_port, + ..Default::default() + }; + let header = VirtioVsockHdr { + op: VirtioVsockOp::Request.into(), + ..new_connection_info.new_header(self.guest_cid) + }; + // Sends a header only packet to the tx queue to connect the device to the listening + // socket at the given destination. + self.send_packet_to_tx_queue(&header, &[])?; + + self.connection_info = Some(new_connection_info); + debug!("Connection requested: {:?}", self.connection_info); + Ok(()) + } + + /// Blocks until the peer either accepts our connection request (with a + /// `VIRTIO_VSOCK_OP_RESPONSE`) or rejects it (with a + /// `VIRTIO_VSOCK_OP_RST`). + pub fn wait_for_connect(&mut self) -> Result { + match self.wait_for_recv(&mut [])?.event_type { + VsockEventType::Connected => Ok(()), + VsockEventType::Disconnected { .. } => Err(SocketError::ConnectionFailed.into()), + VsockEventType::Received { .. } => Err(SocketError::InvalidOperation.into()), + } + } + + /// Requests the peer to send us a credit update for the current connection. + fn request_credit(&mut self) -> Result { + let connection_info = self.connection_info()?; + let header = VirtioVsockHdr { + op: VirtioVsockOp::CreditRequest.into(), + ..connection_info.new_header(self.guest_cid) + }; + self.send_packet_to_tx_queue(&header, &[]) + } + + /// Sends the buffer to the destination. + pub fn send(&mut self, buffer: &[u8]) -> Result { + let mut connection_info = self.connection_info()?; + + let result = self.check_peer_buffer_is_sufficient(&mut connection_info, buffer.len()); + self.connection_info = Some(connection_info.clone()); + result?; + + let len = buffer.len() as u32; + let header = VirtioVsockHdr { + op: VirtioVsockOp::Rw.into(), + len: len.into(), + buf_alloc: 0.into(), + ..connection_info.new_header(self.guest_cid) + }; + self.connection_info.as_mut().unwrap().tx_cnt += len; + self.send_packet_to_tx_queue(&header, buffer) + } + + fn check_peer_buffer_is_sufficient( + &mut self, + connection_info: &mut ConnectionInfo, + buffer_len: usize, + ) -> Result { + if connection_info.peer_free() as usize >= buffer_len { + Ok(()) + } else { + // Request an update of the cached peer credit, if we haven't already done so, and tell + // the caller to try again later. + if !connection_info.has_pending_credit_request { + self.request_credit()?; + connection_info.has_pending_credit_request = true; + } + Err(SocketError::InsufficientBufferSpaceInPeer.into()) + } + } + + /// Polls the vsock device to receive data or other updates. + /// + /// A buffer must be provided to put the data in if there is some to + /// receive. + pub fn poll_recv(&mut self, buffer: &mut [u8]) -> Result<Option<VsockEvent>> { + let connection_info = self.connection_info()?; + + // Tell the peer that we have space to receive some data. + let header = VirtioVsockHdr { + op: VirtioVsockOp::CreditUpdate.into(), + buf_alloc: (buffer.len() as u32).into(), + ..connection_info.new_header(self.guest_cid) + }; + self.send_packet_to_tx_queue(&header, &[])?; + + // Handle entries from the RX virtqueue until we find one that generates an event. + let event = self.poll_rx_queue(buffer)?; + + if self.rx.should_notify() { + self.transport.notify(RX_QUEUE_IDX); + } + + Ok(event) + } + + /// Blocks until we get some event from the vsock device. + /// + /// A buffer must be provided to put the data in if there is some to + /// receive. + pub fn wait_for_recv(&mut self, buffer: &mut [u8]) -> Result<VsockEvent> { + loop { + if let Some(event) = self.poll_recv(buffer)? { + return Ok(event); + } else { + spin_loop(); + } + } + } + + /// Request to shut down the connection cleanly. + /// + /// This returns as soon as the request is sent; you should wait until `poll_recv` returns a + /// `VsockEventType::Disconnected` event if you want to know that the peer has acknowledged the + /// shutdown. + pub fn shutdown(&mut self) -> Result { + let connection_info = self.connection_info()?; + let header = VirtioVsockHdr { + op: VirtioVsockOp::Shutdown.into(), + ..connection_info.new_header(self.guest_cid) + }; + self.send_packet_to_tx_queue(&header, &[]) + } + + /// Forcibly closes the connection without waiting for the peer. + pub fn force_close(&mut self) -> Result { + let connection_info = self.connection_info()?; + let header = VirtioVsockHdr { + op: VirtioVsockOp::Rst.into(), + ..connection_info.new_header(self.guest_cid) + }; + self.send_packet_to_tx_queue(&header, &[])?; + self.connection_info = None; + Ok(()) + } + + fn send_packet_to_tx_queue(&mut self, header: &VirtioVsockHdr, buffer: &[u8]) -> Result { + let _len = self.tx.add_notify_wait_pop( + &[header.as_bytes(), buffer], + &mut [], + &mut self.transport, + )?; + Ok(()) + } + + /// Polls the RX virtqueue until either it is empty, there is an error, or we find a packet + /// which generates a `VsockEvent`. + /// + /// Returns `Ok(None)` if the virtqueue is empty, possibly after processing some packets which + /// don't result in any events to return. + fn poll_rx_queue(&mut self, body: &mut [u8]) -> Result<Option<VsockEvent>> { + loop { + let mut connection_info = self.connection_info.clone().unwrap_or_default(); + let Some(header) = self.pop_packet_from_rx_queue(body)? else{ + return Ok(None); + }; + + let op = header.op()?; + + // Skip packets which don't match our current connection. + if header.source() != connection_info.dst + || header.dst_cid.get() != self.guest_cid + || header.dst_port.get() != connection_info.src_port + { + debug!( + "Skipping {:?} as connection is {:?}", + header, connection_info + ); + continue; + } + + connection_info.peer_buf_alloc = header.buf_alloc.into(); + connection_info.peer_fwd_cnt = header.fwd_cnt.into(); + if self.connection_info.is_some() { + self.connection_info = Some(connection_info.clone()); + debug!("Connection info updated: {:?}", self.connection_info); + } + + match op { + VirtioVsockOp::Request => { + header.check_data_is_empty()?; + // TODO: Send a Rst, or support listening. + } + VirtioVsockOp::Response => { + header.check_data_is_empty()?; + return Ok(Some(VsockEvent { + source: connection_info.dst, + destination: VsockAddr { + cid: self.guest_cid, + port: connection_info.src_port, + }, + event_type: VsockEventType::Connected, + })); + } + VirtioVsockOp::CreditUpdate => { + header.check_data_is_empty()?; + connection_info.has_pending_credit_request = false; + if self.connection_info.is_some() { + self.connection_info = Some(connection_info.clone()); + } + + // Virtio v1.1 5.10.6.3 + // The driver can also receive a VIRTIO_VSOCK_OP_CREDIT_UPDATE packet without previously + // sending a VIRTIO_VSOCK_OP_CREDIT_REQUEST packet. This allows communicating updates + // any time a change in buffer space occurs. + continue; + } + VirtioVsockOp::Rst | VirtioVsockOp::Shutdown => { + header.check_data_is_empty()?; + + self.connection_info = None; + info!("Disconnected from the peer"); + + let reason = if op == VirtioVsockOp::Rst { + DisconnectReason::Reset + } else { + DisconnectReason::Shutdown + }; + return Ok(Some(VsockEvent { + source: connection_info.dst, + destination: VsockAddr { + cid: self.guest_cid, + port: connection_info.src_port, + }, + event_type: VsockEventType::Disconnected { reason }, + })); + } + VirtioVsockOp::Rw => { + self.connection_info.as_mut().unwrap().fwd_cnt += header.len(); + return Ok(Some(VsockEvent { + source: connection_info.dst, + destination: VsockAddr { + cid: self.guest_cid, + port: connection_info.src_port, + }, + event_type: VsockEventType::Received { + length: header.len() as usize, + }, + })); + } + VirtioVsockOp::CreditRequest => { + header.check_data_is_empty()?; + // TODO: Send a credit update. + } + VirtioVsockOp::Invalid => { + return Err(SocketError::InvalidOperation.into()); + } + } + } + } + + /// Pops one packet from the RX queue, if there is one pending. Returns the header, and copies + /// the body into the given buffer. + /// + /// Returns `None` if there is no pending packet, or an error if the body is bigger than the + /// buffer supplied. + fn pop_packet_from_rx_queue(&mut self, body: &mut [u8]) -> Result<Option<VirtioVsockHdr>> { + let Some(token) = self.rx.peek_used() else { + return Ok(None); + }; + + // Safe because we maintain a consistent mapping of tokens to buffers, so we pass the same + // buffer to `pop_used` as we previously passed to `add` for the token. Once we add the + // buffer back to the RX queue then we don't access it again until next time it is popped. + let header = unsafe { + let buffer = Self::as_mut_sub_rx_buffer(self.rx_buf_dma.raw_slice(), token.into())?; + let _len = self.rx.pop_used(token, &[], &mut [buffer])?; + + // Read the header and body from the buffer. Don't check the result yet, because we need + // to add the buffer back to the queue either way. + let header_result = read_header_and_body(buffer, body); + + // Add the buffer back to the RX queue. + let new_token = self.rx.add(&[], &mut [buffer])?; + // If the RX buffer somehow gets assigned a different token, then our safety assumptions + // are broken and we can't safely continue to do anything with the device. + assert_eq!(new_token, token); + + header_result + }?; + + debug!("Received packet {:?}. Op {:?}", header, header.op()); + Ok(Some(header)) + } + + fn connection_info(&self) -> Result<ConnectionInfo> { + self.connection_info + .clone() + .ok_or(SocketError::NotConnected.into()) + } + + /// Gets a reference to a subslice of the RX buffer to be used for the given entry in the RX + /// virtqueue. + /// + /// # Safety + /// + /// `rx_buf` must be a valid dereferenceable pointer. + /// The returned reference has an arbitrary lifetime `'a`. This lifetime must not overlap with + /// any other references to the same subslice of the RX buffer or outlive the buffer. + unsafe fn as_mut_sub_rx_buffer<'a>( + mut rx_buf: NonNull<[u8]>, + i: usize, + ) -> Result<&'a mut [u8]> { + let buffer_size = rx_buf.len() / QUEUE_SIZE; + let start = buffer_size + .checked_mul(i) + .ok_or(SocketError::InvalidNumber)?; + let end = start + .checked_add(buffer_size) + .ok_or(SocketError::InvalidNumber)?; + // Safe because no alignment or initialisation is required for [u8], and our caller assures + // us that `rx_buf` is dereferenceable and that the lifetime of the slice we are creating + // won't overlap with any other references to the same slice or outlive it. + unsafe { + rx_buf + .as_mut() + .get_mut(start..end) + .ok_or(SocketError::BufferTooShort.into()) + } + } +} + +fn read_header_and_body(buffer: &[u8], body: &mut [u8]) -> Result<VirtioVsockHdr> { + let header = VirtioVsockHdr::read_from_prefix(buffer).ok_or(SocketError::BufferTooShort)?; + let body_length = header.len() as usize; + let data_end = size_of::<VirtioVsockHdr>() + .checked_add(body_length) + .ok_or(SocketError::InvalidNumber)?; + let data = buffer + .get(size_of::<VirtioVsockHdr>()..data_end) + .ok_or(SocketError::BufferTooShort)?; + body.get_mut(0..body_length) + .ok_or(SocketError::OutputBufferTooShort(body_length))? + .copy_from_slice(data); + Ok(header) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::volatile::ReadOnly; + use crate::{ + hal::fake::FakeHal, + transport::{ + fake::{FakeTransport, QueueStatus, State}, + DeviceStatus, DeviceType, + }, + }; + use alloc::{sync::Arc, vec}; + use core::ptr::NonNull; + use std::sync::Mutex; + + #[test] + fn config() { + let mut config_space = VirtioVsockConfig { + guest_cid_low: ReadOnly::new(66), + guest_cid_high: ReadOnly::new(0), + }; + let state = Arc::new(Mutex::new(State { + status: DeviceStatus::empty(), + driver_features: 0, + guest_page_size: 0, + interrupt_pending: false, + queues: vec![QueueStatus::default(); 3], + })); + let transport = FakeTransport { + device_type: DeviceType::Socket, + max_queue_size: 32, + device_features: 0, + config_space: NonNull::from(&mut config_space), + state: state.clone(), + }; + let socket = + VirtIOSocket::<FakeHal, FakeTransport<VirtioVsockConfig>>::new(transport).unwrap(); + assert_eq!(socket.guest_cid, 0x00_0000_0042); + } +} @@ -53,35 +53,81 @@ impl<H: Hal> Dma<H> { impl<H: Hal> Drop for Dma<H> { fn drop(&mut self) { - let err = H::dma_dealloc(self.paddr, self.vaddr, self.pages); + // Safe because the memory was previously allocated by `dma_alloc` in `Dma::new`, not yet + // deallocated, and we are passing the values from then. + let err = unsafe { H::dma_dealloc(self.paddr, self.vaddr, self.pages) }; assert_eq!(err, 0, "failed to deallocate DMA"); } } /// The interface which a particular hardware implementation must implement. -pub trait Hal { +/// +/// # Safety +/// +/// Implementations of this trait must follow the "implementation safety" requirements documented +/// for each method. Callers must follow the safety requirements documented for the unsafe methods. +pub unsafe trait Hal { /// Allocates the given number of contiguous physical pages of DMA memory for VirtIO use. /// /// Returns both the physical address which the device can use to access the memory, and a /// pointer to the start of it which the driver can use to access it. + /// + /// # Implementation safety + /// + /// Implementations of this method must ensure that the `NonNull<u8>` returned is a + /// [_valid_](https://doc.rust-lang.org/std/ptr/index.html#safety) pointer, aligned to + /// [`PAGE_SIZE`], and won't alias any other allocations or references in the program until it + /// is deallocated by `dma_dealloc`. fn dma_alloc(pages: usize, direction: BufferDirection) -> (PhysAddr, NonNull<u8>); + /// Deallocates the given contiguous physical DMA memory pages. - fn dma_dealloc(paddr: PhysAddr, vaddr: NonNull<u8>, pages: usize) -> i32; + /// + /// # Safety + /// + /// The memory must have been allocated by `dma_alloc` on the same `Hal` implementation, and not + /// yet deallocated. `pages` must be the same number passed to `dma_alloc` originally, and both + /// `paddr` and `vaddr` must be the values returned by `dma_alloc`. + unsafe fn dma_dealloc(paddr: PhysAddr, vaddr: NonNull<u8>, pages: usize) -> i32; + /// Converts a physical address used for MMIO to a virtual address which the driver can access. /// /// This is only used for MMIO addresses within BARs read from the device, for the PCI /// transport. It may check that the address range up to the given size is within the region /// expected for MMIO. - fn mmio_phys_to_virt(paddr: PhysAddr, size: usize) -> NonNull<u8>; + /// + /// # Implementation safety + /// + /// Implementations of this method must ensure that the `NonNull<u8>` returned is a + /// [_valid_](https://doc.rust-lang.org/std/ptr/index.html#safety) pointer, and won't alias any + /// other allocations or references in the program. + /// + /// # Safety + /// + /// The `paddr` and `size` must describe a valid MMIO region. The implementation may validate it + /// in some way (and panic if it is invalid) but is not guaranteed to. + unsafe fn mmio_phys_to_virt(paddr: PhysAddr, size: usize) -> NonNull<u8>; + /// Shares the given memory range with the device, and returns the physical address that the /// device can use to access it. /// /// This may involve mapping the buffer into an IOMMU, giving the host permission to access the /// memory, or copying it to a special region where it can be accessed. - fn share(buffer: NonNull<[u8]>, direction: BufferDirection) -> PhysAddr; + /// + /// # Safety + /// + /// The buffer must be a valid pointer to memory which will not be accessed by any other thread + /// for the duration of this method call. + unsafe fn share(buffer: NonNull<[u8]>, direction: BufferDirection) -> PhysAddr; + /// Unshares the given memory range from the device and (if necessary) copies it back to the /// original buffer. - fn unshare(paddr: PhysAddr, buffer: NonNull<[u8]>, direction: BufferDirection); + /// + /// # Safety + /// + /// The buffer must be a valid pointer to memory which will not be accessed by any other thread + /// for the duration of this method call. The `paddr` must be the value previously returned by + /// the corresponding `share` call. + unsafe fn unshare(paddr: PhysAddr, buffer: NonNull<[u8]>, direction: BufferDirection); } /// The direction in which a buffer is passed. diff --git a/src/hal/fake.rs b/src/hal/fake.rs index 2af60a9..5d46835 100644 --- a/src/hal/fake.rs +++ b/src/hal/fake.rs @@ -1,5 +1,7 @@ //! Fake HAL implementation for tests. +#![deny(unsafe_op_in_unsafe_fn)] + use crate::{BufferDirection, Hal, PhysAddr, PAGE_SIZE}; use alloc::alloc::{alloc_zeroed, dealloc, handle_alloc_error}; use core::{alloc::Layout, ptr::NonNull}; @@ -8,7 +10,7 @@ use core::{alloc::Layout, ptr::NonNull}; pub struct FakeHal; /// Fake HAL implementation for use in unit tests. -impl Hal for FakeHal { +unsafe impl Hal for FakeHal { fn dma_alloc(pages: usize, _direction: BufferDirection) -> (PhysAddr, NonNull<u8>) { assert_ne!(pages, 0); let layout = Layout::from_size_align(pages * PAGE_SIZE, PAGE_SIZE).unwrap(); @@ -21,7 +23,7 @@ impl Hal for FakeHal { } } - fn dma_dealloc(_paddr: PhysAddr, vaddr: NonNull<u8>, pages: usize) -> i32 { + unsafe fn dma_dealloc(_paddr: PhysAddr, vaddr: NonNull<u8>, pages: usize) -> i32 { assert_ne!(pages, 0); let layout = Layout::from_size_align(pages * PAGE_SIZE, PAGE_SIZE).unwrap(); // Safe because the layout is the same as was used when the memory was allocated by @@ -32,17 +34,17 @@ impl Hal for FakeHal { 0 } - fn mmio_phys_to_virt(paddr: PhysAddr, _size: usize) -> NonNull<u8> { + unsafe fn mmio_phys_to_virt(paddr: PhysAddr, _size: usize) -> NonNull<u8> { NonNull::new(paddr as _).unwrap() } - fn share(buffer: NonNull<[u8]>, _direction: BufferDirection) -> PhysAddr { + unsafe fn share(buffer: NonNull<[u8]>, _direction: BufferDirection) -> PhysAddr { let vaddr = buffer.as_ptr() as *mut u8 as usize; // Nothing to do, as the host already has access to all memory. virt_to_phys(vaddr) } - fn unshare(_paddr: PhysAddr, _buffer: NonNull<[u8]>, _direction: BufferDirection) { + unsafe fn unshare(_paddr: PhysAddr, _buffer: NonNull<[u8]>, _direction: BufferDirection) { // Nothing to do, as the host already has access to all memory and we didn't copy the buffer // anywhere else. } @@ -89,6 +89,8 @@ pub enum Error { ConfigSpaceTooSmall, /// The device doesn't have any config space, but the driver expects some. ConfigSpaceMissing, + /// Error from the socket device. + SocketDeviceError(device::socket::SocketError), } impl Display for Error { @@ -115,10 +117,17 @@ impl Display for Error { "The device doesn't have any config space, but the driver expects some" ) } + Self::SocketDeviceError(e) => write!(f, "Error from the socket device: {e:?}"), } } } +impl From<device::socket::SocketError> for Error { + fn from(e: device::socket::SocketError) -> Self { + Self::SocketDeviceError(e) + } +} + /// Align `size` up to a page. fn align_up(size: usize) -> usize { (size + PAGE_SIZE) & !(PAGE_SIZE - 1) diff --git a/src/queue.rs b/src/queue.rs index f45da11..d6baf17 100644 --- a/src/queue.rs +++ b/src/queue.rs @@ -1,3 +1,5 @@ +#![deny(unsafe_op_in_unsafe_fn)] + use crate::hal::{BufferDirection, Dma, Hal, PhysAddr}; use crate::transport::Transport; use crate::{align_up, nonnull_slice_from_raw_parts, pages, Error, Result, PAGE_SIZE}; @@ -5,7 +7,7 @@ use bitflags::bitflags; #[cfg(test)] use core::cmp::min; use core::hint::spin_loop; -use core::mem::size_of; +use core::mem::{size_of, take}; #[cfg(test)] use core::ptr; use core::ptr::NonNull; @@ -114,8 +116,13 @@ impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> { /// /// # Safety /// - /// The input and output buffers must remain valid until the token is returned by `pop_used`. - pub unsafe fn add(&mut self, inputs: &[*const [u8]], outputs: &[*mut [u8]]) -> Result<u16> { + /// The input and output buffers must remain valid and not be accessed until a call to + /// `pop_used` with the returned token succeeds. + pub unsafe fn add<'a, 'b>( + &mut self, + inputs: &'a [&'b [u8]], + outputs: &'a mut [&'b mut [u8]], + ) -> Result<u16> { if inputs.is_empty() && outputs.is_empty() { return Err(Error::InvalidParam); } @@ -127,10 +134,14 @@ impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> { let head = self.free_head; let mut last = self.free_head; - for (buffer, direction) in input_output_iter(inputs, outputs) { + for (buffer, direction) in InputOutputIter::new(inputs, outputs) { // Write to desc_shadow then copy. let desc = &mut self.desc_shadow[usize::from(self.free_head)]; - desc.set_buf::<H>(buffer, direction, DescFlags::NEXT); + // Safe because our caller promises that the buffers live at least until `pop_used` + // returns them. + unsafe { + desc.set_buf::<H>(buffer, direction, DescFlags::NEXT); + } last = self.free_head; self.free_head = desc.next; @@ -172,14 +183,14 @@ impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> { /// them, then pops them. /// /// This assumes that the device isn't processing any other buffers at the same time. - pub fn add_notify_wait_pop( + pub fn add_notify_wait_pop<'a>( &mut self, - inputs: &[*const [u8]], - outputs: &[*mut [u8]], + inputs: &'a [&'a [u8]], + outputs: &'a mut [&'a mut [u8]], transport: &mut impl Transport, ) -> Result<u32> { - // Safe because we don't return until the same token has been popped, so they remain valid - // until then. + // Safe because we don't return until the same token has been popped, so the buffers remain + // valid and are not otherwise accessed until then. let token = unsafe { self.add(inputs, outputs) }?; // Notify the queue. @@ -192,7 +203,9 @@ impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> { spin_loop(); } - self.pop_used(token, inputs, outputs) + // Safe because these are the same buffers as we passed to `add` above and they are still + // valid. + unsafe { self.pop_used(token, inputs, outputs) } } /// Returns whether the driver should notify the device after adding a new buffer to the @@ -252,12 +265,22 @@ impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> { /// passed in too. /// /// This will push all linked descriptors at the front of the free list. - fn recycle_descriptors(&mut self, head: u16, inputs: &[*const [u8]], outputs: &[*mut [u8]]) { + /// + /// # Safety + /// + /// The buffers in `inputs` and `outputs` must match the set of buffers originally added to the + /// queue by `add`. + unsafe fn recycle_descriptors<'a>( + &mut self, + head: u16, + inputs: &'a [&'a [u8]], + outputs: &'a mut [&'a mut [u8]], + ) { let original_free_head = self.free_head; self.free_head = head; let mut next = Some(head); - for (buffer, direction) in input_output_iter(inputs, outputs) { + for (buffer, direction) in InputOutputIter::new(inputs, outputs) { let desc_index = next.expect("Descriptor chain was shorter than expected."); let desc = &mut self.desc_shadow[usize::from(desc_index)]; @@ -271,8 +294,12 @@ impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> { self.write_desc(desc_index); - // Unshare the buffer (and perhaps copy its contents back to the original buffer). - H::unshare(paddr as usize, buffer, direction); + // Safe because the caller ensures that the buffer is valid and matches the descriptor + // from which we got `paddr`. + unsafe { + // Unshare the buffer (and perhaps copy its contents back to the original buffer). + H::unshare(paddr as usize, buffer, direction); + } } if next.is_some() { @@ -284,11 +311,16 @@ impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> { /// length which was used (written) by the device. /// /// Ref: linux virtio_ring.c virtqueue_get_buf_ctx - pub fn pop_used( + /// + /// # Safety + /// + /// The buffers in `inputs` and `outputs` must match the set of buffers originally added to the + /// queue by `add` when it returned the token being passed in here. + pub unsafe fn pop_used<'a>( &mut self, token: u16, - inputs: &[*const [u8]], - outputs: &[*mut [u8]], + inputs: &'a [&'a [u8]], + outputs: &'a mut [&'a mut [u8]], ) -> Result<u32> { if !self.can_pop() { return Err(Error::NotReady); @@ -311,7 +343,10 @@ impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> { return Err(Error::WrongToken); } - self.recycle_descriptors(index, inputs, outputs); + // Safe because the caller ensures the buffers are valid and match the descriptor. + unsafe { + self.recycle_descriptors(index, inputs, outputs); + } self.last_used_idx = self.last_used_idx.wrapping_add(1); Ok(len) @@ -486,7 +521,10 @@ impl Descriptor { direction: BufferDirection, extra_flags: DescFlags, ) { - self.addr = H::share(buf, direction) as u64; + // Safe because our caller promises that the buffer is valid. + unsafe { + self.addr = H::share(buf, direction) as u64; + } self.len = buf.len() as u32; self.flags = extra_flags | match direction { @@ -558,6 +596,46 @@ struct UsedElem { len: u32, } +struct InputOutputIter<'a, 'b> { + inputs: &'a [&'b [u8]], + outputs: &'a mut [&'b mut [u8]], +} + +impl<'a, 'b> InputOutputIter<'a, 'b> { + fn new(inputs: &'a [&'b [u8]], outputs: &'a mut [&'b mut [u8]]) -> Self { + Self { inputs, outputs } + } +} + +impl<'a, 'b> Iterator for InputOutputIter<'a, 'b> { + type Item = (NonNull<[u8]>, BufferDirection); + + fn next(&mut self) -> Option<Self::Item> { + if let Some(input) = take_first(&mut self.inputs) { + Some(((*input).into(), BufferDirection::DriverToDevice)) + } else { + let output = take_first_mut(&mut self.outputs)?; + Some(((*output).into(), BufferDirection::DeviceToDriver)) + } + } +} + +// TODO: Use `slice::take_first` once it is stable +// (https://github.com/rust-lang/rust/issues/62280). +fn take_first<'a, T>(slice: &mut &'a [T]) -> Option<&'a T> { + let (first, rem) = slice.split_first()?; + *slice = rem; + Some(first) +} + +// TODO: Use `slice::take_first_mut` once it is stable +// (https://github.com/rust-lang/rust/issues/62280). +fn take_first_mut<'a, T>(slice: &mut &'a mut [T]) -> Option<&'a mut T> { + let (first, rem) = take(slice).split_first_mut()?; + *slice = rem; + Some(first) +} + /// Simulates the device reading from a VirtIO queue and writing a response back, for use in tests. /// /// The fake device always uses descriptors in order. @@ -680,7 +758,7 @@ mod tests { let mut transport = unsafe { MmioTransport::new(NonNull::from(&mut header)) }.unwrap(); let mut queue = VirtQueue::<FakeHal, 4>::new(&mut transport, 0).unwrap(); assert_eq!( - unsafe { queue.add(&[], &[]) }.unwrap_err(), + unsafe { queue.add(&[], &mut []) }.unwrap_err(), Error::InvalidParam ); } @@ -692,7 +770,7 @@ mod tests { let mut queue = VirtQueue::<FakeHal, 4>::new(&mut transport, 0).unwrap(); assert_eq!(queue.available_desc(), 4); assert_eq!( - unsafe { queue.add(&[&[], &[], &[]], &[&mut [], &mut []]) }.unwrap_err(), + unsafe { queue.add(&[&[], &[], &[]], &mut [&mut [], &mut []]) }.unwrap_err(), Error::QueueFull ); } @@ -706,7 +784,7 @@ mod tests { // Add a buffer chain consisting of two device-readable parts followed by two // device-writable parts. - let token = unsafe { queue.add(&[&[1, 2], &[3]], &[&mut [0, 0], &mut [0]]) }.unwrap(); + let token = unsafe { queue.add(&[&[1, 2], &[3]], &mut [&mut [0, 0], &mut [0]]) }.unwrap(); assert_eq!(queue.available_desc(), 0); assert!(!queue.can_pop()); @@ -757,27 +835,3 @@ mod tests { } } } - -/// Returns an iterator over the buffers of first `inputs` and then `outputs`, paired with the -/// corresponding `BufferDirection`. -/// -/// Panics if any of the buffer pointers is null. -fn input_output_iter<'a>( - inputs: &'a [*const [u8]], - outputs: &'a [*mut [u8]], -) -> impl Iterator<Item = (NonNull<[u8]>, BufferDirection)> + 'a { - inputs - .iter() - .map(|input| { - ( - NonNull::new(*input as *mut [u8]).unwrap(), - BufferDirection::DriverToDevice, - ) - }) - .chain(outputs.iter().map(|output| { - ( - NonNull::new(*output).unwrap(), - BufferDirection::DeviceToDriver, - ) - })) -} diff --git a/src/transport/fake.rs b/src/transport/fake.rs index 1d599d8..a578db2 100644 --- a/src/transport/fake.rs +++ b/src/transport/fake.rs @@ -38,6 +38,10 @@ impl<C> Transport for FakeTransport<C> { self.state.lock().unwrap().queues[queue as usize].notified = true; } + fn get_status(&self) -> DeviceStatus { + self.state.lock().unwrap().status + } + fn set_status(&mut self, status: DeviceStatus) { self.state.lock().unwrap().status = status; } diff --git a/src/transport/mmio.rs b/src/transport/mmio.rs index a6d421e..026646b 100644 --- a/src/transport/mmio.rs +++ b/src/transport/mmio.rs @@ -350,6 +350,11 @@ impl Transport for MmioTransport { } } + fn get_status(&self) -> DeviceStatus { + // Safe because self.header points to a valid VirtIO MMIO region. + unsafe { volread!(self.header, status) } + } + fn set_status(&mut self, status: DeviceStatus) { // Safe because self.header points to a valid VirtIO MMIO region. unsafe { @@ -442,7 +447,11 @@ impl Transport for MmioTransport { // Safe because self.header points to a valid VirtIO MMIO region. unsafe { volwrite!(self.header, queue_sel, queue.into()); + volwrite!(self.header, queue_ready, 0); + // Wait until we read the same value back, to ensure synchronisation (see 4.2.2.2). + while volread!(self.header, queue_ready) != 0 {} + volwrite!(self.header, queue_num, 0); volwrite!(self.header, queue_desc_low, 0); volwrite!(self.header, queue_desc_high, 0); diff --git a/src/transport/mod.rs b/src/transport/mod.rs index 013fa27..f88293c 100644 --- a/src/transport/mod.rs +++ b/src/transport/mod.rs @@ -26,6 +26,9 @@ pub trait Transport { /// Notifies the given queue on the device. fn notify(&mut self, queue: u16); + /// Gets the device status. + fn get_status(&self) -> DeviceStatus; + /// Sets the device status. fn set_status(&mut self, status: DeviceStatus); diff --git a/src/transport/pci.rs b/src/transport/pci.rs index f6473f8..b8bcb15 100644 --- a/src/transport/pci.rs +++ b/src/transport/pci.rs @@ -251,6 +251,13 @@ impl Transport for PciTransport { } } + fn get_status(&self) -> DeviceStatus { + // Safe because the common config pointer is valid and we checked in get_bar_region that it + // was aligned. + let status = unsafe { volread!(self.common_cfg, device_status) }; + DeviceStatus::from_bits_truncate(status.into()) + } + fn set_status(&mut self, status: DeviceStatus) { // Safe because the common config pointer is valid and we checked in get_bar_region that it // was aligned. @@ -287,16 +294,9 @@ impl Transport for PciTransport { } } - fn queue_unset(&mut self, queue: u16) { - // Safe because the common config pointer is valid and we checked in get_bar_region that it - // was aligned. - unsafe { - volwrite!(self.common_cfg, queue_select, queue); - volwrite!(self.common_cfg, queue_size, 0); - volwrite!(self.common_cfg, queue_desc, 0); - volwrite!(self.common_cfg, queue_driver, 0); - volwrite!(self.common_cfg, queue_device, 0); - } + fn queue_unset(&mut self, _queue: u16) { + // The VirtIO spec doesn't allow queues to be unset once they have been set up for the PCI + // transport, so this is a no-op. } fn queue_used(&mut self, queue: u16) -> bool { @@ -341,7 +341,8 @@ impl Transport for PciTransport { impl Drop for PciTransport { fn drop(&mut self) { // Reset the device when the transport is dropped. - self.set_status(DeviceStatus::empty()) + self.set_status(DeviceStatus::empty()); + while self.get_status() != DeviceStatus::empty() {} } } @@ -395,7 +396,9 @@ fn get_bar_region<H: Hal, T>( return Err(VirtioPciError::BarOffsetOutOfRange); } let paddr = bar_address as PhysAddr + struct_info.offset as PhysAddr; - let vaddr = H::mmio_phys_to_virt(paddr, struct_info.length as usize); + // Safe because the paddr and size describe a valid MMIO region, at least according to the PCI + // bus. + let vaddr = unsafe { H::mmio_phys_to_virt(paddr, struct_info.length as usize) }; if vaddr.as_ptr() as usize % align_of::<T>() != 0 { return Err(VirtioPciError::Misaligned { vaddr, |