aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/device/blk.rs10
-rw-r--r--src/device/console.rs18
-rw-r--r--src/device/gpu.rs12
-rw-r--r--src/device/input.rs18
-rw-r--r--src/device/mod.rs3
-rw-r--r--src/device/net.rs413
-rw-r--r--src/device/net/dev.rs125
-rw-r--r--src/device/net/dev_raw.rs281
-rw-r--r--src/device/net/mod.rs156
-rw-r--r--src/device/net/net_buf.rs83
-rw-r--r--src/device/socket/connectionmanager.rs57
-rw-r--r--src/device/socket/mod.rs4
-rw-r--r--src/device/socket/protocol.rs19
-rw-r--r--src/device/socket/vsock.rs65
-rw-r--r--src/hal.rs7
-rw-r--r--src/queue.rs149
-rw-r--r--src/transport/mmio.rs7
-rw-r--r--src/transport/pci.rs13
-rw-r--r--src/transport/pci/bus.rs7
19 files changed, 949 insertions, 498 deletions
diff --git a/src/device/blk.rs b/src/device/blk.rs
index bfdc5f8..cbd2dcc 100644
--- a/src/device/blk.rs
+++ b/src/device/blk.rs
@@ -96,6 +96,16 @@ impl<H: Hal, T: Transport> VirtIOBlk<H, T> {
self.transport.ack_interrupt()
}
+ /// Enables interrupts from the device.
+ pub fn enable_interrupts(&mut self) {
+ self.queue.set_dev_notify(true);
+ }
+
+ /// Disables interrupts from the device.
+ pub fn disable_interrupts(&mut self) {
+ self.queue.set_dev_notify(false);
+ }
+
/// Sends the given request to the device and waits for a response, with no extra data.
fn request(&mut self, request: BlkReq) -> Result {
let mut resp = BlkResp::default();
diff --git a/src/device/console.rs b/src/device/console.rs
index 6528276..f73e0d1 100644
--- a/src/device/console.rs
+++ b/src/device/console.rs
@@ -12,7 +12,7 @@ use core::ptr::NonNull;
const QUEUE_RECEIVEQ_PORT_0: u16 = 0;
const QUEUE_TRANSMITQ_PORT_0: u16 = 1;
const QUEUE_SIZE: usize = 2;
-const SUPPORTED_FEATURES: Features = Features::RING_EVENT_IDX;
+const SUPPORTED_FEATURES: Features = Features::RING_EVENT_IDX.union(Features::RING_INDIRECT_DESC);
/// Driver for a VirtIO console device.
///
@@ -51,6 +51,18 @@ pub struct VirtIOConsole<H: Hal, T: Transport> {
receive_token: Option<u16>,
}
+// SAFETY: The config space can be accessed from any thread.
+unsafe impl<H: Hal, T: Transport + Send> Send for VirtIOConsole<H, T> where
+ VirtQueue<H, QUEUE_SIZE>: Send
+{
+}
+
+// SAFETY: A `&VirtIOConsole` only allows reading the config space.
+unsafe impl<H: Hal, T: Transport + Sync> Sync for VirtIOConsole<H, T> where
+ VirtQueue<H, QUEUE_SIZE>: Sync
+{
+}
+
/// Information about a console device, read from its configuration space.
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct ConsoleInfo {
@@ -70,13 +82,13 @@ impl<H: Hal, T: Transport> VirtIOConsole<H, T> {
let receiveq = VirtQueue::new(
&mut transport,
QUEUE_RECEIVEQ_PORT_0,
- false,
+ negotiated_features.contains(Features::RING_INDIRECT_DESC),
negotiated_features.contains(Features::RING_EVENT_IDX),
)?;
let transmitq = VirtQueue::new(
&mut transport,
QUEUE_TRANSMITQ_PORT_0,
- false,
+ negotiated_features.contains(Features::RING_INDIRECT_DESC),
negotiated_features.contains(Features::RING_EVENT_IDX),
)?;
diff --git a/src/device/gpu.rs b/src/device/gpu.rs
index 6a49298..0cef8b4 100644
--- a/src/device/gpu.rs
+++ b/src/device/gpu.rs
@@ -11,7 +11,7 @@ use log::info;
use zerocopy::{AsBytes, FromBytes, FromZeroes};
const QUEUE_SIZE: u16 = 2;
-const SUPPORTED_FEATURES: Features = Features::RING_EVENT_IDX;
+const SUPPORTED_FEATURES: Features = Features::RING_EVENT_IDX.union(Features::RING_INDIRECT_DESC);
/// A virtio based graphics adapter.
///
@@ -56,13 +56,13 @@ impl<H: Hal, T: Transport> VirtIOGpu<H, T> {
let control_queue = VirtQueue::new(
&mut transport,
QUEUE_TRANSMIT,
- false,
+ negotiated_features.contains(Features::RING_INDIRECT_DESC),
negotiated_features.contains(Features::RING_EVENT_IDX),
)?;
let cursor_queue = VirtQueue::new(
&mut transport,
QUEUE_CURSOR,
- false,
+ negotiated_features.contains(Features::RING_INDIRECT_DESC),
negotiated_features.contains(Features::RING_EVENT_IDX),
)?;
@@ -174,18 +174,18 @@ impl<H: Hal, T: Transport> VirtIOGpu<H, T> {
/// Send a request to the device and block for a response.
fn request<Req: AsBytes, Rsp: FromBytes>(&mut self, req: Req) -> Result<Rsp> {
- req.write_to_prefix(&mut *self.queue_buf_send).unwrap();
+ req.write_to_prefix(&mut self.queue_buf_send).unwrap();
self.control_queue.add_notify_wait_pop(
&[&self.queue_buf_send],
&mut [&mut self.queue_buf_recv],
&mut self.transport,
)?;
- Ok(Rsp::read_from_prefix(&*self.queue_buf_recv).unwrap())
+ Ok(Rsp::read_from_prefix(&self.queue_buf_recv).unwrap())
}
/// Send a mouse cursor operation request to the device and block for a response.
fn cursor_request<Req: AsBytes>(&mut self, req: Req) -> Result {
- req.write_to_prefix(&mut *self.queue_buf_send).unwrap();
+ req.write_to_prefix(&mut self.queue_buf_send).unwrap();
self.cursor_queue.add_notify_wait_pop(
&[&self.queue_buf_send],
&mut [],
diff --git a/src/device/input.rs b/src/device/input.rs
index 0d5799e..f8ee95a 100644
--- a/src/device/input.rs
+++ b/src/device/input.rs
@@ -35,13 +35,13 @@ impl<H: Hal, T: Transport> VirtIOInput<H, T> {
let mut event_queue = VirtQueue::new(
&mut transport,
QUEUE_EVENT,
- false,
+ negotiated_features.contains(Feature::RING_INDIRECT_DESC),
negotiated_features.contains(Feature::RING_EVENT_IDX),
)?;
let status_queue = VirtQueue::new(
&mut transport,
QUEUE_STATUS,
- false,
+ negotiated_features.contains(Feature::RING_INDIRECT_DESC),
negotiated_features.contains(Feature::RING_EVENT_IDX),
)?;
for (i, event) in event_buf.as_mut().iter_mut().enumerate() {
@@ -120,6 +120,18 @@ impl<H: Hal, T: Transport> VirtIOInput<H, T> {
}
}
+// SAFETY: The config space can be accessed from any thread.
+unsafe impl<H: Hal, T: Transport + Send> Send for VirtIOInput<H, T> where
+ VirtQueue<H, QUEUE_SIZE>: Send
+{
+}
+
+// SAFETY: An '&VirtIOInput` can't do anything, all methods take `&mut self`.
+unsafe impl<H: Hal, T: Transport + Sync> Sync for VirtIOInput<H, T> where
+ VirtQueue<H, QUEUE_SIZE>: Sync
+{
+}
+
impl<H: Hal, T: Transport> Drop for VirtIOInput<H, T> {
fn drop(&mut self) {
// Clear any pointers pointing to DMA regions, so the device doesn't try to access them
@@ -197,7 +209,7 @@ pub struct InputEvent {
const QUEUE_EVENT: u16 = 0;
const QUEUE_STATUS: u16 = 1;
-const SUPPORTED_FEATURES: Feature = Feature::RING_EVENT_IDX;
+const SUPPORTED_FEATURES: Feature = Feature::RING_EVENT_IDX.union(Feature::RING_INDIRECT_DESC);
// a parameter that can change
const QUEUE_SIZE: usize = 32;
diff --git a/src/device/mod.rs b/src/device/mod.rs
index 00fa6fe..d8e6389 100644
--- a/src/device/mod.rs
+++ b/src/device/mod.rs
@@ -7,8 +7,9 @@ pub mod console;
pub mod gpu;
#[cfg(feature = "alloc")]
pub mod input;
-#[cfg(feature = "alloc")]
+
pub mod net;
+
pub mod socket;
pub(crate) mod common;
diff --git a/src/device/net.rs b/src/device/net.rs
deleted file mode 100644
index c7732a4..0000000
--- a/src/device/net.rs
+++ /dev/null
@@ -1,413 +0,0 @@
-//! Driver for VirtIO network devices.
-
-use crate::hal::Hal;
-use crate::queue::VirtQueue;
-use crate::transport::Transport;
-use crate::volatile::{volread, ReadOnly};
-use crate::{Error, Result};
-use alloc::{vec, vec::Vec};
-use bitflags::bitflags;
-use core::{convert::TryInto, mem::size_of};
-use log::{debug, warn};
-use zerocopy::{AsBytes, FromBytes, FromZeroes};
-
-const MAX_BUFFER_LEN: usize = 65535;
-const MIN_BUFFER_LEN: usize = 1526;
-const NET_HDR_SIZE: usize = size_of::<VirtioNetHdr>();
-
-/// A buffer used for transmitting.
-pub struct TxBuffer(Vec<u8>);
-
-/// A buffer used for receiving.
-pub struct RxBuffer {
- buf: Vec<usize>, // for alignment
- packet_len: usize,
- idx: u16,
-}
-
-impl TxBuffer {
- /// Constructs the buffer from the given slice.
- pub fn from(buf: &[u8]) -> Self {
- Self(Vec::from(buf))
- }
-
- /// Returns the network packet length.
- pub fn packet_len(&self) -> usize {
- self.0.len()
- }
-
- /// Returns the network packet as a slice.
- pub fn packet(&self) -> &[u8] {
- self.0.as_slice()
- }
-
- /// Returns the network packet as a mutable slice.
- pub fn packet_mut(&mut self) -> &mut [u8] {
- self.0.as_mut_slice()
- }
-}
-
-impl RxBuffer {
- /// Allocates a new buffer with length `buf_len`.
- fn new(idx: usize, buf_len: usize) -> Self {
- Self {
- buf: vec![0; buf_len / size_of::<usize>()],
- packet_len: 0,
- idx: idx.try_into().unwrap(),
- }
- }
-
- /// Set the network packet length.
- fn set_packet_len(&mut self, packet_len: usize) {
- self.packet_len = packet_len
- }
-
- /// Returns the network packet length (witout header).
- pub const fn packet_len(&self) -> usize {
- self.packet_len
- }
-
- /// Returns all data in the buffer, including both the header and the packet.
- pub fn as_bytes(&self) -> &[u8] {
- self.buf.as_bytes()
- }
-
- /// Returns all data in the buffer with the mutable reference,
- /// including both the header and the packet.
- pub fn as_bytes_mut(&mut self) -> &mut [u8] {
- self.buf.as_bytes_mut()
- }
-
- /// Returns the reference of the header.
- pub fn header(&self) -> &VirtioNetHdr {
- unsafe { &*(self.buf.as_ptr() as *const VirtioNetHdr) }
- }
-
- /// Returns the network packet as a slice.
- pub fn packet(&self) -> &[u8] {
- &self.buf.as_bytes()[NET_HDR_SIZE..NET_HDR_SIZE + self.packet_len]
- }
-
- /// Returns the network packet as a mutable slice.
- pub fn packet_mut(&mut self) -> &mut [u8] {
- &mut self.buf.as_bytes_mut()[NET_HDR_SIZE..NET_HDR_SIZE + self.packet_len]
- }
-}
-
-/// The virtio network device is a virtual ethernet card.
-///
-/// It has enhanced rapidly and demonstrates clearly how support for new
-/// features are added to an existing device.
-/// Empty buffers are placed in one virtqueue for receiving packets, and
-/// outgoing packets are enqueued into another for transmission in that order.
-/// A third command queue is used to control advanced filtering features.
-pub struct VirtIONet<H: Hal, T: Transport, const QUEUE_SIZE: usize> {
- transport: T,
- mac: EthernetAddress,
- recv_queue: VirtQueue<H, QUEUE_SIZE>,
- send_queue: VirtQueue<H, QUEUE_SIZE>,
- rx_buffers: [Option<RxBuffer>; QUEUE_SIZE],
-}
-
-impl<H: Hal, T: Transport, const QUEUE_SIZE: usize> VirtIONet<H, T, QUEUE_SIZE> {
- /// Create a new VirtIO-Net driver.
- pub fn new(mut transport: T, buf_len: usize) -> Result<Self> {
- let negotiated_features = transport.begin_init(SUPPORTED_FEATURES);
- // read configuration space
- let config = transport.config_space::<Config>()?;
- let mac;
- // Safe because config points to a valid MMIO region for the config space.
- unsafe {
- mac = volread!(config, mac);
- debug!(
- "Got MAC={:02x?}, status={:?}",
- mac,
- volread!(config, status)
- );
- }
-
- if !(MIN_BUFFER_LEN..=MAX_BUFFER_LEN).contains(&buf_len) {
- warn!(
- "Receive buffer len {} is not in range [{}, {}]",
- buf_len, MIN_BUFFER_LEN, MAX_BUFFER_LEN
- );
- return Err(Error::InvalidParam);
- }
-
- let send_queue = VirtQueue::new(
- &mut transport,
- QUEUE_TRANSMIT,
- false,
- negotiated_features.contains(Features::RING_EVENT_IDX),
- )?;
- let mut recv_queue = VirtQueue::new(
- &mut transport,
- QUEUE_RECEIVE,
- false,
- negotiated_features.contains(Features::RING_EVENT_IDX),
- )?;
-
- const NONE_BUF: Option<RxBuffer> = None;
- let mut rx_buffers = [NONE_BUF; QUEUE_SIZE];
- for (i, rx_buf_place) in rx_buffers.iter_mut().enumerate() {
- let mut rx_buf = RxBuffer::new(i, buf_len);
- // Safe because the buffer lives as long as the queue.
- let token = unsafe { recv_queue.add(&[], &mut [rx_buf.as_bytes_mut()])? };
- assert_eq!(token, i as u16);
- *rx_buf_place = Some(rx_buf);
- }
-
- if recv_queue.should_notify() {
- transport.notify(QUEUE_RECEIVE);
- }
-
- transport.finish_init();
-
- Ok(VirtIONet {
- transport,
- mac,
- recv_queue,
- send_queue,
- rx_buffers,
- })
- }
-
- /// Acknowledge interrupt.
- pub fn ack_interrupt(&mut self) -> bool {
- self.transport.ack_interrupt()
- }
-
- /// Get MAC address.
- pub fn mac_address(&self) -> EthernetAddress {
- self.mac
- }
-
- /// Whether can send packet.
- pub fn can_send(&self) -> bool {
- self.send_queue.available_desc() >= 2
- }
-
- /// Whether can receive packet.
- pub fn can_recv(&self) -> bool {
- self.recv_queue.can_pop()
- }
-
- /// Receives a [`RxBuffer`] from network. If currently no data, returns an
- /// error with type [`Error::NotReady`].
- ///
- /// It will try to pop a buffer that completed data reception in the
- /// NIC queue.
- pub fn receive(&mut self) -> Result<RxBuffer> {
- if let Some(token) = self.recv_queue.peek_used() {
- let mut rx_buf = self.rx_buffers[token as usize]
- .take()
- .ok_or(Error::WrongToken)?;
- if token != rx_buf.idx {
- return Err(Error::WrongToken);
- }
-
- // Safe because `token` == `rx_buf.idx`, we are passing the same
- // buffer as we passed to `VirtQueue::add` and it is still valid.
- let len = unsafe {
- self.recv_queue
- .pop_used(token, &[], &mut [rx_buf.as_bytes_mut()])?
- } as usize;
- rx_buf.set_packet_len(len.checked_sub(NET_HDR_SIZE).ok_or(Error::IoError)?);
- Ok(rx_buf)
- } else {
- Err(Error::NotReady)
- }
- }
-
- /// Gives back the ownership of `rx_buf`, and recycles it for next use.
- ///
- /// It will add the buffer back to the NIC queue.
- pub fn recycle_rx_buffer(&mut self, mut rx_buf: RxBuffer) -> Result {
- // Safe because we take the ownership of `rx_buf` back to `rx_buffers`,
- // it lives as long as the queue.
- let new_token = unsafe { self.recv_queue.add(&[], &mut [rx_buf.as_bytes_mut()]) }?;
- // `rx_buffers[new_token]` is expected to be `None` since it was taken
- // away at `Self::receive()` and has not been added back.
- if self.rx_buffers[new_token as usize].is_some() {
- return Err(Error::WrongToken);
- }
- rx_buf.idx = new_token;
- self.rx_buffers[new_token as usize] = Some(rx_buf);
- if self.recv_queue.should_notify() {
- self.transport.notify(QUEUE_RECEIVE);
- }
- Ok(())
- }
-
- /// Allocate a new buffer for transmitting.
- pub fn new_tx_buffer(&self, buf_len: usize) -> TxBuffer {
- TxBuffer(vec![0; buf_len])
- }
-
- /// Sends a [`TxBuffer`] to the network, and blocks until the request
- /// completed.
- pub fn send(&mut self, tx_buf: TxBuffer) -> Result {
- let header = VirtioNetHdr::default();
- if tx_buf.packet_len() == 0 {
- // Special case sending an empty packet, to avoid adding an empty buffer to the
- // virtqueue.
- self.send_queue.add_notify_wait_pop(
- &[header.as_bytes()],
- &mut [],
- &mut self.transport,
- )?;
- } else {
- self.send_queue.add_notify_wait_pop(
- &[header.as_bytes(), tx_buf.packet()],
- &mut [],
- &mut self.transport,
- )?;
- }
- Ok(())
- }
-}
-
-impl<H: Hal, T: Transport, const QUEUE_SIZE: usize> Drop for VirtIONet<H, T, QUEUE_SIZE> {
- fn drop(&mut self) {
- // Clear any pointers pointing to DMA regions, so the device doesn't try to access them
- // after they have been freed.
- self.transport.queue_unset(QUEUE_RECEIVE);
- self.transport.queue_unset(QUEUE_TRANSMIT);
- }
-}
-
-bitflags! {
- #[derive(Copy, Clone, Debug, Default, Eq, PartialEq)]
- struct Features: u64 {
- /// Device handles packets with partial checksum.
- /// This "checksum offload" is a common feature on modern network cards.
- const CSUM = 1 << 0;
- /// Driver handles packets with partial checksum.
- const GUEST_CSUM = 1 << 1;
- /// Control channel offloads reconfiguration support.
- const CTRL_GUEST_OFFLOADS = 1 << 2;
- /// Device maximum MTU reporting is supported.
- ///
- /// If offered by the device, device advises driver about the value of
- /// its maximum MTU. If negotiated, the driver uses mtu as the maximum
- /// MTU value.
- const MTU = 1 << 3;
- /// Device has given MAC address.
- const MAC = 1 << 5;
- /// Device handles packets with any GSO type. (legacy)
- const GSO = 1 << 6;
- /// Driver can receive TSOv4.
- const GUEST_TSO4 = 1 << 7;
- /// Driver can receive TSOv6.
- const GUEST_TSO6 = 1 << 8;
- /// Driver can receive TSO with ECN.
- const GUEST_ECN = 1 << 9;
- /// Driver can receive UFO.
- const GUEST_UFO = 1 << 10;
- /// Device can receive TSOv4.
- const HOST_TSO4 = 1 << 11;
- /// Device can receive TSOv6.
- const HOST_TSO6 = 1 << 12;
- /// Device can receive TSO with ECN.
- const HOST_ECN = 1 << 13;
- /// Device can receive UFO.
- const HOST_UFO = 1 << 14;
- /// Driver can merge receive buffers.
- const MRG_RXBUF = 1 << 15;
- /// Configuration status field is available.
- const STATUS = 1 << 16;
- /// Control channel is available.
- const CTRL_VQ = 1 << 17;
- /// Control channel RX mode support.
- const CTRL_RX = 1 << 18;
- /// Control channel VLAN filtering.
- const CTRL_VLAN = 1 << 19;
- ///
- const CTRL_RX_EXTRA = 1 << 20;
- /// Driver can send gratuitous packets.
- const GUEST_ANNOUNCE = 1 << 21;
- /// Device supports multiqueue with automatic receive steering.
- const MQ = 1 << 22;
- /// Set MAC address through control channel.
- const CTL_MAC_ADDR = 1 << 23;
-
- // device independent
- const RING_INDIRECT_DESC = 1 << 28;
- const RING_EVENT_IDX = 1 << 29;
- const VERSION_1 = 1 << 32; // legacy
- }
-}
-
-bitflags! {
- #[derive(Copy, Clone, Debug, Default, Eq, PartialEq)]
- pub(crate) struct Status: u16 {
- const LINK_UP = 1;
- const ANNOUNCE = 2;
- }
-}
-
-bitflags! {
- #[derive(Copy, Clone, Debug, Default, Eq, PartialEq)]
- struct InterruptStatus : u32 {
- const USED_RING_UPDATE = 1 << 0;
- const CONFIGURATION_CHANGE = 1 << 1;
- }
-}
-
-#[repr(C)]
-struct Config {
- mac: ReadOnly<EthernetAddress>,
- status: ReadOnly<Status>,
- max_virtqueue_pairs: ReadOnly<u16>,
- mtu: ReadOnly<u16>,
-}
-
-type EthernetAddress = [u8; 6];
-
-/// VirtIO 5.1.6 Device Operation:
-///
-/// Packets are transmitted by placing them in the transmitq1. . .transmitqN,
-/// and buffers for incoming packets are placed in the receiveq1. . .receiveqN.
-/// In each case, the packet itself is preceded by a header.
-#[repr(C)]
-#[derive(AsBytes, Debug, Default, FromBytes, FromZeroes)]
-pub struct VirtioNetHdr {
- flags: Flags,
- gso_type: GsoType,
- hdr_len: u16, // cannot rely on this
- gso_size: u16,
- csum_start: u16,
- csum_offset: u16,
- // num_buffers: u16, // only available when the feature MRG_RXBUF is negotiated.
- // payload starts from here
-}
-
-#[derive(AsBytes, Copy, Clone, Debug, Default, Eq, FromBytes, FromZeroes, PartialEq)]
-#[repr(transparent)]
-struct Flags(u8);
-
-bitflags! {
- impl Flags: u8 {
- const NEEDS_CSUM = 1;
- const DATA_VALID = 2;
- const RSC_INFO = 4;
- }
-}
-
-#[repr(transparent)]
-#[derive(AsBytes, Debug, Copy, Clone, Default, Eq, FromBytes, FromZeroes, PartialEq)]
-struct GsoType(u8);
-
-impl GsoType {
- const NONE: GsoType = GsoType(0);
- const TCPV4: GsoType = GsoType(1);
- const UDP: GsoType = GsoType(3);
- const TCPV6: GsoType = GsoType(4);
- const ECN: GsoType = GsoType(0x80);
-}
-
-const QUEUE_RECEIVE: u16 = 0;
-const QUEUE_TRANSMIT: u16 = 1;
-const SUPPORTED_FEATURES: Features = Features::MAC
- .union(Features::STATUS)
- .union(Features::RING_EVENT_IDX);
diff --git a/src/device/net/dev.rs b/src/device/net/dev.rs
new file mode 100644
index 0000000..6bab13c
--- /dev/null
+++ b/src/device/net/dev.rs
@@ -0,0 +1,125 @@
+use alloc::vec;
+
+use super::net_buf::{RxBuffer, TxBuffer};
+use super::{EthernetAddress, VirtIONetRaw};
+use crate::{hal::Hal, transport::Transport, Error, Result};
+
+/// Driver for a VirtIO network device.
+///
+/// Unlike [`VirtIONetRaw`], it uses [`RxBuffer`]s for transmission and
+/// reception rather than the raw slices. On initialization, it pre-allocates
+/// all receive buffers and puts them all in the receive queue.
+///
+/// The virtio network device is a virtual ethernet card.
+///
+/// It has enhanced rapidly and demonstrates clearly how support for new
+/// features are added to an existing device.
+/// Empty buffers are placed in one virtqueue for receiving packets, and
+/// outgoing packets are enqueued into another for transmission in that order.
+/// A third command queue is used to control advanced filtering features.
+pub struct VirtIONet<H: Hal, T: Transport, const QUEUE_SIZE: usize> {
+ inner: VirtIONetRaw<H, T, QUEUE_SIZE>,
+ rx_buffers: [Option<RxBuffer>; QUEUE_SIZE],
+}
+
+impl<H: Hal, T: Transport, const QUEUE_SIZE: usize> VirtIONet<H, T, QUEUE_SIZE> {
+ /// Create a new VirtIO-Net driver.
+ pub fn new(transport: T, buf_len: usize) -> Result<Self> {
+ let mut inner = VirtIONetRaw::new(transport)?;
+
+ const NONE_BUF: Option<RxBuffer> = None;
+ let mut rx_buffers = [NONE_BUF; QUEUE_SIZE];
+ for (i, rx_buf_place) in rx_buffers.iter_mut().enumerate() {
+ let mut rx_buf = RxBuffer::new(i, buf_len);
+ // Safe because the buffer lives as long as the queue.
+ let token = unsafe { inner.receive_begin(rx_buf.as_bytes_mut())? };
+ assert_eq!(token, i as u16);
+ *rx_buf_place = Some(rx_buf);
+ }
+
+ Ok(VirtIONet { inner, rx_buffers })
+ }
+
+ /// Acknowledge interrupt.
+ pub fn ack_interrupt(&mut self) -> bool {
+ self.inner.ack_interrupt()
+ }
+
+ /// Disable interrupts.
+ pub fn disable_interrupts(&mut self) {
+ self.inner.disable_interrupts()
+ }
+
+ /// Enable interrupts.
+ pub fn enable_interrupts(&mut self) {
+ self.inner.enable_interrupts()
+ }
+
+ /// Get MAC address.
+ pub fn mac_address(&self) -> EthernetAddress {
+ self.inner.mac_address()
+ }
+
+ /// Whether can send packet.
+ pub fn can_send(&self) -> bool {
+ self.inner.can_send()
+ }
+
+ /// Whether can receive packet.
+ pub fn can_recv(&self) -> bool {
+ self.inner.poll_receive().is_some()
+ }
+
+ /// Receives a [`RxBuffer`] from network. If currently no data, returns an
+ /// error with type [`Error::NotReady`].
+ ///
+ /// It will try to pop a buffer that completed data reception in the
+ /// NIC queue.
+ pub fn receive(&mut self) -> Result<RxBuffer> {
+ if let Some(token) = self.inner.poll_receive() {
+ let mut rx_buf = self.rx_buffers[token as usize]
+ .take()
+ .ok_or(Error::WrongToken)?;
+ if token != rx_buf.idx {
+ return Err(Error::WrongToken);
+ }
+
+ // Safe because `token` == `rx_buf.idx`, we are passing the same
+ // buffer as we passed to `VirtQueue::add` and it is still valid.
+ let (_hdr_len, pkt_len) =
+ unsafe { self.inner.receive_complete(token, rx_buf.as_bytes_mut())? };
+ rx_buf.set_packet_len(pkt_len);
+ Ok(rx_buf)
+ } else {
+ Err(Error::NotReady)
+ }
+ }
+
+ /// Gives back the ownership of `rx_buf`, and recycles it for next use.
+ ///
+ /// It will add the buffer back to the NIC queue.
+ pub fn recycle_rx_buffer(&mut self, mut rx_buf: RxBuffer) -> Result {
+ // Safe because we take the ownership of `rx_buf` back to `rx_buffers`,
+ // it lives as long as the queue.
+ let new_token = unsafe { self.inner.receive_begin(rx_buf.as_bytes_mut()) }?;
+ // `rx_buffers[new_token]` is expected to be `None` since it was taken
+ // away at `Self::receive()` and has not been added back.
+ if self.rx_buffers[new_token as usize].is_some() {
+ return Err(Error::WrongToken);
+ }
+ rx_buf.idx = new_token;
+ self.rx_buffers[new_token as usize] = Some(rx_buf);
+ Ok(())
+ }
+
+ /// Allocate a new buffer for transmitting.
+ pub fn new_tx_buffer(&self, buf_len: usize) -> TxBuffer {
+ TxBuffer(vec![0; buf_len])
+ }
+
+ /// Sends a [`TxBuffer`] to the network, and blocks until the request
+ /// completed.
+ pub fn send(&mut self, tx_buf: TxBuffer) -> Result {
+ self.inner.send(tx_buf.packet())
+ }
+}
diff --git a/src/device/net/dev_raw.rs b/src/device/net/dev_raw.rs
new file mode 100644
index 0000000..e9834e0
--- /dev/null
+++ b/src/device/net/dev_raw.rs
@@ -0,0 +1,281 @@
+use super::{Config, EthernetAddress, Features, VirtioNetHdr};
+use super::{MIN_BUFFER_LEN, NET_HDR_SIZE, QUEUE_RECEIVE, QUEUE_TRANSMIT, SUPPORTED_FEATURES};
+use crate::hal::Hal;
+use crate::queue::VirtQueue;
+use crate::transport::Transport;
+use crate::volatile::volread;
+use crate::{Error, Result};
+use log::{debug, info, warn};
+use zerocopy::AsBytes;
+
+/// Raw driver for a VirtIO block device.
+///
+/// This is a raw version of the VirtIONet driver. It provides non-blocking
+/// methods for transmitting and receiving raw slices, without the buffer
+/// management. For more higher-level fucntions such as receive buffer backing,
+/// see [`VirtIONet`].
+///
+/// [`VirtIONet`]: super::VirtIONet
+pub struct VirtIONetRaw<H: Hal, T: Transport, const QUEUE_SIZE: usize> {
+ transport: T,
+ mac: EthernetAddress,
+ recv_queue: VirtQueue<H, QUEUE_SIZE>,
+ send_queue: VirtQueue<H, QUEUE_SIZE>,
+}
+
+impl<H: Hal, T: Transport, const QUEUE_SIZE: usize> VirtIONetRaw<H, T, QUEUE_SIZE> {
+ /// Create a new VirtIO-Net driver.
+ pub fn new(mut transport: T) -> Result<Self> {
+ let negotiated_features = transport.begin_init(SUPPORTED_FEATURES);
+ info!("negotiated_features {:?}", negotiated_features);
+ // read configuration space
+ let config = transport.config_space::<Config>()?;
+ let mac;
+ // Safe because config points to a valid MMIO region for the config space.
+ unsafe {
+ mac = volread!(config, mac);
+ debug!(
+ "Got MAC={:02x?}, status={:?}",
+ mac,
+ volread!(config, status)
+ );
+ }
+ let send_queue = VirtQueue::new(
+ &mut transport,
+ QUEUE_TRANSMIT,
+ negotiated_features.contains(Features::RING_INDIRECT_DESC),
+ negotiated_features.contains(Features::RING_EVENT_IDX),
+ )?;
+ let recv_queue = VirtQueue::new(
+ &mut transport,
+ QUEUE_RECEIVE,
+ negotiated_features.contains(Features::RING_INDIRECT_DESC),
+ negotiated_features.contains(Features::RING_EVENT_IDX),
+ )?;
+
+ transport.finish_init();
+
+ Ok(VirtIONetRaw {
+ transport,
+ mac,
+ recv_queue,
+ send_queue,
+ })
+ }
+
+ /// Acknowledge interrupt.
+ pub fn ack_interrupt(&mut self) -> bool {
+ self.transport.ack_interrupt()
+ }
+
+ /// Disable interrupts.
+ pub fn disable_interrupts(&mut self) {
+ self.send_queue.set_dev_notify(false);
+ self.recv_queue.set_dev_notify(false);
+ }
+
+ /// Enable interrupts.
+ pub fn enable_interrupts(&mut self) {
+ self.send_queue.set_dev_notify(true);
+ self.recv_queue.set_dev_notify(true);
+ }
+
+ /// Get MAC address.
+ pub fn mac_address(&self) -> EthernetAddress {
+ self.mac
+ }
+
+ /// Whether can send packet.
+ pub fn can_send(&self) -> bool {
+ self.send_queue.available_desc() >= 2
+ }
+
+ /// Whether the length of the receive buffer is valid.
+ fn check_rx_buf_len(rx_buf: &[u8]) -> Result<()> {
+ if rx_buf.len() < MIN_BUFFER_LEN {
+ warn!("Receive buffer len {} is too small", rx_buf.len());
+ Err(Error::InvalidParam)
+ } else {
+ Ok(())
+ }
+ }
+
+ /// Whether the length of the transmit buffer is valid.
+ fn check_tx_buf_len(tx_buf: &[u8]) -> Result<()> {
+ if tx_buf.len() < NET_HDR_SIZE {
+ warn!("Transmit buffer len {} is too small", tx_buf.len());
+ Err(Error::InvalidParam)
+ } else {
+ Ok(())
+ }
+ }
+
+ /// Fill the header of the `buffer` with [`VirtioNetHdr`].
+ ///
+ /// If the `buffer` is not large enough, it returns [`Error::InvalidParam`].
+ pub fn fill_buffer_header(&self, buffer: &mut [u8]) -> Result<usize> {
+ if buffer.len() < NET_HDR_SIZE {
+ return Err(Error::InvalidParam);
+ }
+ let header = VirtioNetHdr::default();
+ buffer[..NET_HDR_SIZE].copy_from_slice(header.as_bytes());
+ Ok(NET_HDR_SIZE)
+ }
+
+ /// Submits a request to transmit a buffer immediately without waiting for
+ /// the transmission to complete.
+ ///
+ /// It will submit request to the VirtIO net device and return a token
+ /// identifying the position of the first descriptor in the chain. If there
+ /// are not enough descriptors to allocate, then it returns
+ /// [`Error::QueueFull`].
+ ///
+ /// The caller needs to fill the `tx_buf` with a header by calling
+ /// [`fill_buffer_header`] before transmission. Then it calls [`poll_transmit`]
+ /// with the returned token to check whether the device has finished handling
+ /// the request. Once it has, the caller must call [`transmit_complete`] with
+ /// the same buffer before reading the result (transmitted length).
+ ///
+ /// # Safety
+ ///
+ /// `tx_buf` is still borrowed by the underlying VirtIO net device even after
+ /// this method returns. Thus, it is the caller's responsibility to guarantee
+ /// that they are not accessed before the request is completed in order to
+ /// avoid data races.
+ ///
+ /// [`fill_buffer_header`]: Self::fill_buffer_header
+ /// [`poll_transmit`]: Self::poll_transmit
+ /// [`transmit_complete`]: Self::transmit_complete
+ pub unsafe fn transmit_begin(&mut self, tx_buf: &[u8]) -> Result<u16> {
+ Self::check_tx_buf_len(tx_buf)?;
+ let token = self.send_queue.add(&[tx_buf], &mut [])?;
+ if self.send_queue.should_notify() {
+ self.transport.notify(QUEUE_TRANSMIT);
+ }
+ Ok(token)
+ }
+
+ /// Fetches the token of the next completed transmission request from the
+ /// used ring and returns it, without removing it from the used ring. If
+ /// there are no pending completed requests it returns [`None`].
+ pub fn poll_transmit(&mut self) -> Option<u16> {
+ self.send_queue.peek_used()
+ }
+
+ /// Completes a transmission operation which was started by [`transmit_begin`].
+ /// Returns number of bytes transmitted.
+ ///
+ /// # Safety
+ ///
+ /// The same buffer must be passed in again as was passed to
+ /// [`transmit_begin`] when it returned the token.
+ ///
+ /// [`transmit_begin`]: Self::transmit_begin
+ pub unsafe fn transmit_complete(&mut self, token: u16, tx_buf: &[u8]) -> Result<usize> {
+ let len = self.send_queue.pop_used(token, &[tx_buf], &mut [])?;
+ Ok(len as usize)
+ }
+
+ /// Submits a request to receive a buffer immediately without waiting for
+ /// the reception to complete.
+ ///
+ /// It will submit request to the VirtIO net device and return a token
+ /// identifying the position of the first descriptor in the chain. If there
+ /// are not enough descriptors to allocate, then it returns
+ /// [`Error::QueueFull`].
+ ///
+ /// The caller can then call [`poll_receive`] with the returned token to
+ /// check whether the device has finished handling the request. Once it has,
+ /// the caller must call [`receive_complete`] with the same buffer before
+ /// reading the response.
+ ///
+ /// # Safety
+ ///
+ /// `rx_buf` is still borrowed by the underlying VirtIO net device even after
+ /// this method returns. Thus, it is the caller's responsibility to guarantee
+ /// that they are not accessed before the request is completed in order to
+ /// avoid data races.
+ ///
+ /// [`poll_receive`]: Self::poll_receive
+ /// [`receive_complete`]: Self::receive_complete
+ pub unsafe fn receive_begin(&mut self, rx_buf: &mut [u8]) -> Result<u16> {
+ Self::check_rx_buf_len(rx_buf)?;
+ let token = self.recv_queue.add(&[], &mut [rx_buf])?;
+ if self.recv_queue.should_notify() {
+ self.transport.notify(QUEUE_RECEIVE);
+ }
+ Ok(token)
+ }
+
+ /// Fetches the token of the next completed reception request from the
+ /// used ring and returns it, without removing it from the used ring. If
+ /// there are no pending completed requests it returns [`None`].
+ pub fn poll_receive(&self) -> Option<u16> {
+ self.recv_queue.peek_used()
+ }
+
+ /// Completes a transmission operation which was started by [`receive_begin`].
+ ///
+ /// After completion, the `rx_buf` will contain a header followed by the
+ /// received packet. It returns the length of the header and the length of
+ /// the packet.
+ ///
+ /// # Safety
+ ///
+ /// The same buffer must be passed in again as was passed to
+ /// [`receive_begin`] when it returned the token.
+ ///
+ /// [`receive_begin`]: Self::receive_begin
+ pub unsafe fn receive_complete(
+ &mut self,
+ token: u16,
+ rx_buf: &mut [u8],
+ ) -> Result<(usize, usize)> {
+ let len = self.recv_queue.pop_used(token, &[], &mut [rx_buf])? as usize;
+ let packet_len = len.checked_sub(NET_HDR_SIZE).ok_or(Error::IoError)?;
+ Ok((NET_HDR_SIZE, packet_len))
+ }
+
+ /// Sends a packet to the network, and blocks until the request completed.
+ pub fn send(&mut self, tx_buf: &[u8]) -> Result {
+ let header = VirtioNetHdr::default();
+ if tx_buf.is_empty() {
+ // Special case sending an empty packet, to avoid adding an empty buffer to the
+ // virtqueue.
+ self.send_queue.add_notify_wait_pop(
+ &[header.as_bytes()],
+ &mut [],
+ &mut self.transport,
+ )?;
+ } else {
+ self.send_queue.add_notify_wait_pop(
+ &[header.as_bytes(), tx_buf],
+ &mut [],
+ &mut self.transport,
+ )?;
+ }
+ Ok(())
+ }
+
+ /// Blocks and waits for a packet to be received.
+ ///
+ /// After completion, the `rx_buf` will contain a header followed by the
+ /// received packet. It returns the length of the header and the length of
+ /// the packet.
+ pub fn receive_wait(&mut self, rx_buf: &mut [u8]) -> Result<(usize, usize)> {
+ let token = unsafe { self.receive_begin(rx_buf)? };
+ while self.poll_receive().is_none() {
+ core::hint::spin_loop();
+ }
+ unsafe { self.receive_complete(token, rx_buf) }
+ }
+}
+
+impl<H: Hal, T: Transport, const QUEUE_SIZE: usize> Drop for VirtIONetRaw<H, T, QUEUE_SIZE> {
+ fn drop(&mut self) {
+ // Clear any pointers pointing to DMA regions, so the device doesn't try to access them
+ // after they have been freed.
+ self.transport.queue_unset(QUEUE_RECEIVE);
+ self.transport.queue_unset(QUEUE_TRANSMIT);
+ }
+}
diff --git a/src/device/net/mod.rs b/src/device/net/mod.rs
new file mode 100644
index 0000000..8375946
--- /dev/null
+++ b/src/device/net/mod.rs
@@ -0,0 +1,156 @@
+//! Driver for VirtIO network devices.
+
+#[cfg(feature = "alloc")]
+mod dev;
+mod dev_raw;
+#[cfg(feature = "alloc")]
+mod net_buf;
+
+pub use self::dev_raw::VirtIONetRaw;
+#[cfg(feature = "alloc")]
+pub use self::{dev::VirtIONet, net_buf::RxBuffer, net_buf::TxBuffer};
+
+use crate::volatile::ReadOnly;
+use bitflags::bitflags;
+use zerocopy::{AsBytes, FromBytes, FromZeroes};
+
+const MAX_BUFFER_LEN: usize = 65535;
+const MIN_BUFFER_LEN: usize = 1526;
+const NET_HDR_SIZE: usize = core::mem::size_of::<VirtioNetHdr>();
+
+bitflags! {
+ #[derive(Copy, Clone, Debug, Default, Eq, PartialEq)]
+ struct Features: u64 {
+ /// Device handles packets with partial checksum.
+ /// This "checksum offload" is a common feature on modern network cards.
+ const CSUM = 1 << 0;
+ /// Driver handles packets with partial checksum.
+ const GUEST_CSUM = 1 << 1;
+ /// Control channel offloads reconfiguration support.
+ const CTRL_GUEST_OFFLOADS = 1 << 2;
+ /// Device maximum MTU reporting is supported.
+ ///
+ /// If offered by the device, device advises driver about the value of
+ /// its maximum MTU. If negotiated, the driver uses mtu as the maximum
+ /// MTU value.
+ const MTU = 1 << 3;
+ /// Device has given MAC address.
+ const MAC = 1 << 5;
+ /// Device handles packets with any GSO type. (legacy)
+ const GSO = 1 << 6;
+ /// Driver can receive TSOv4.
+ const GUEST_TSO4 = 1 << 7;
+ /// Driver can receive TSOv6.
+ const GUEST_TSO6 = 1 << 8;
+ /// Driver can receive TSO with ECN.
+ const GUEST_ECN = 1 << 9;
+ /// Driver can receive UFO.
+ const GUEST_UFO = 1 << 10;
+ /// Device can receive TSOv4.
+ const HOST_TSO4 = 1 << 11;
+ /// Device can receive TSOv6.
+ const HOST_TSO6 = 1 << 12;
+ /// Device can receive TSO with ECN.
+ const HOST_ECN = 1 << 13;
+ /// Device can receive UFO.
+ const HOST_UFO = 1 << 14;
+ /// Driver can merge receive buffers.
+ const MRG_RXBUF = 1 << 15;
+ /// Configuration status field is available.
+ const STATUS = 1 << 16;
+ /// Control channel is available.
+ const CTRL_VQ = 1 << 17;
+ /// Control channel RX mode support.
+ const CTRL_RX = 1 << 18;
+ /// Control channel VLAN filtering.
+ const CTRL_VLAN = 1 << 19;
+ ///
+ const CTRL_RX_EXTRA = 1 << 20;
+ /// Driver can send gratuitous packets.
+ const GUEST_ANNOUNCE = 1 << 21;
+ /// Device supports multiqueue with automatic receive steering.
+ const MQ = 1 << 22;
+ /// Set MAC address through control channel.
+ const CTL_MAC_ADDR = 1 << 23;
+
+ // device independent
+ const RING_INDIRECT_DESC = 1 << 28;
+ const RING_EVENT_IDX = 1 << 29;
+ const VERSION_1 = 1 << 32; // legacy
+ }
+}
+
+bitflags! {
+ #[derive(Copy, Clone, Debug, Default, Eq, PartialEq)]
+ pub(crate) struct Status: u16 {
+ const LINK_UP = 1;
+ const ANNOUNCE = 2;
+ }
+}
+
+bitflags! {
+ #[derive(Copy, Clone, Debug, Default, Eq, PartialEq)]
+ struct InterruptStatus : u32 {
+ const USED_RING_UPDATE = 1 << 0;
+ const CONFIGURATION_CHANGE = 1 << 1;
+ }
+}
+
+#[repr(C)]
+struct Config {
+ mac: ReadOnly<EthernetAddress>,
+ status: ReadOnly<Status>,
+ max_virtqueue_pairs: ReadOnly<u16>,
+ mtu: ReadOnly<u16>,
+}
+
+type EthernetAddress = [u8; 6];
+
+/// VirtIO 5.1.6 Device Operation:
+///
+/// Packets are transmitted by placing them in the transmitq1. . .transmitqN,
+/// and buffers for incoming packets are placed in the receiveq1. . .receiveqN.
+/// In each case, the packet itself is preceded by a header.
+#[repr(C)]
+#[derive(AsBytes, Debug, Default, FromBytes, FromZeroes)]
+pub struct VirtioNetHdr {
+ flags: Flags,
+ gso_type: GsoType,
+ hdr_len: u16, // cannot rely on this
+ gso_size: u16,
+ csum_start: u16,
+ csum_offset: u16,
+ // num_buffers: u16, // only available when the feature MRG_RXBUF is negotiated.
+ // payload starts from here
+}
+
+#[derive(AsBytes, Copy, Clone, Debug, Default, Eq, FromBytes, FromZeroes, PartialEq)]
+#[repr(transparent)]
+struct Flags(u8);
+
+bitflags! {
+ impl Flags: u8 {
+ const NEEDS_CSUM = 1;
+ const DATA_VALID = 2;
+ const RSC_INFO = 4;
+ }
+}
+
+#[repr(transparent)]
+#[derive(AsBytes, Debug, Copy, Clone, Default, Eq, FromBytes, FromZeroes, PartialEq)]
+struct GsoType(u8);
+
+impl GsoType {
+ const NONE: GsoType = GsoType(0);
+ const TCPV4: GsoType = GsoType(1);
+ const UDP: GsoType = GsoType(3);
+ const TCPV6: GsoType = GsoType(4);
+ const ECN: GsoType = GsoType(0x80);
+}
+
+const QUEUE_RECEIVE: u16 = 0;
+const QUEUE_TRANSMIT: u16 = 1;
+const SUPPORTED_FEATURES: Features = Features::MAC
+ .union(Features::STATUS)
+ .union(Features::RING_EVENT_IDX)
+ .union(Features::RING_INDIRECT_DESC);
diff --git a/src/device/net/net_buf.rs b/src/device/net/net_buf.rs
new file mode 100644
index 0000000..8b4947b
--- /dev/null
+++ b/src/device/net/net_buf.rs
@@ -0,0 +1,83 @@
+use super::{VirtioNetHdr, NET_HDR_SIZE};
+use alloc::{vec, vec::Vec};
+use core::{convert::TryInto, mem::size_of};
+use zerocopy::AsBytes;
+
+/// A buffer used for transmitting.
+pub struct TxBuffer(pub(crate) Vec<u8>);
+
+/// A buffer used for receiving.
+pub struct RxBuffer {
+ pub(crate) buf: Vec<usize>, // for alignment
+ pub(crate) packet_len: usize,
+ pub(crate) idx: u16,
+}
+
+impl TxBuffer {
+ /// Constructs the buffer from the given slice.
+ pub fn from(buf: &[u8]) -> Self {
+ Self(Vec::from(buf))
+ }
+
+ /// Returns the network packet length.
+ pub fn packet_len(&self) -> usize {
+ self.0.len()
+ }
+
+ /// Returns the network packet as a slice.
+ pub fn packet(&self) -> &[u8] {
+ self.0.as_slice()
+ }
+
+ /// Returns the network packet as a mutable slice.
+ pub fn packet_mut(&mut self) -> &mut [u8] {
+ self.0.as_mut_slice()
+ }
+}
+
+impl RxBuffer {
+ /// Allocates a new buffer with length `buf_len`.
+ pub(crate) fn new(idx: usize, buf_len: usize) -> Self {
+ Self {
+ buf: vec![0; buf_len / size_of::<usize>()],
+ packet_len: 0,
+ idx: idx.try_into().unwrap(),
+ }
+ }
+
+ /// Set the network packet length.
+ pub(crate) fn set_packet_len(&mut self, packet_len: usize) {
+ self.packet_len = packet_len
+ }
+
+ /// Returns the network packet length (witout header).
+ pub const fn packet_len(&self) -> usize {
+ self.packet_len
+ }
+
+ /// Returns all data in the buffer, including both the header and the packet.
+ pub fn as_bytes(&self) -> &[u8] {
+ self.buf.as_bytes()
+ }
+
+ /// Returns all data in the buffer with the mutable reference,
+ /// including both the header and the packet.
+ pub fn as_bytes_mut(&mut self) -> &mut [u8] {
+ self.buf.as_bytes_mut()
+ }
+
+ /// Returns the reference of the header.
+ pub fn header(&self) -> &VirtioNetHdr {
+ unsafe { &*(self.buf.as_ptr() as *const VirtioNetHdr) }
+ }
+
+ /// Returns the network packet as a slice.
+ pub fn packet(&self) -> &[u8] {
+ &self.buf.as_bytes()[NET_HDR_SIZE..NET_HDR_SIZE + self.packet_len]
+ }
+
+ /// Returns the network packet as a mutable slice.
+ pub fn packet_mut(&mut self) -> &mut [u8] {
+ &mut self.buf.as_bytes_mut()[NET_HDR_SIZE..NET_HDR_SIZE + self.packet_len]
+ }
+}
diff --git a/src/device/socket/connectionmanager.rs b/src/device/socket/connectionmanager.rs
index 8690ca3..0c9e4e8 100644
--- a/src/device/socket/connectionmanager.rs
+++ b/src/device/socket/connectionmanager.rs
@@ -1,6 +1,6 @@
use super::{
protocol::VsockAddr, vsock::ConnectionInfo, DisconnectReason, SocketError, VirtIOSocket,
- VsockEvent, VsockEventType,
+ VsockEvent, VsockEventType, DEFAULT_RX_BUFFER_SIZE,
};
use crate::{transport::Transport, Hal, Result};
use alloc::{boxed::Box, vec::Vec};
@@ -10,12 +10,15 @@ use core::hint::spin_loop;
use log::debug;
use zerocopy::FromZeroes;
-const PER_CONNECTION_BUFFER_CAPACITY: usize = 1024;
+const DEFAULT_PER_CONNECTION_BUFFER_CAPACITY: u32 = 1024;
/// A higher level interface for VirtIO socket (vsock) devices.
///
/// This keeps track of multiple vsock connections.
///
+/// `RX_BUFFER_SIZE` is the size in bytes of each buffer used in the RX virtqueue. This must be
+/// bigger than `size_of::<VirtioVsockHdr>()`.
+///
/// # Example
///
/// ```
@@ -40,8 +43,13 @@ const PER_CONNECTION_BUFFER_CAPACITY: usize = 1024;
/// # Ok(())
/// # }
/// ```
-pub struct VsockConnectionManager<H: Hal, T: Transport> {
- driver: VirtIOSocket<H, T>,
+pub struct VsockConnectionManager<
+ H: Hal,
+ T: Transport,
+ const RX_BUFFER_SIZE: usize = DEFAULT_RX_BUFFER_SIZE,
+> {
+ driver: VirtIOSocket<H, T, RX_BUFFER_SIZE>,
+ per_connection_buffer_capacity: u32,
connections: Vec<Connection>,
listening_ports: Vec<u32>,
}
@@ -56,24 +64,36 @@ struct Connection {
}
impl Connection {
- fn new(peer: VsockAddr, local_port: u32) -> Self {
+ fn new(peer: VsockAddr, local_port: u32, buffer_capacity: u32) -> Self {
let mut info = ConnectionInfo::new(peer, local_port);
- info.buf_alloc = PER_CONNECTION_BUFFER_CAPACITY.try_into().unwrap();
+ info.buf_alloc = buffer_capacity;
Self {
info,
- buffer: RingBuffer::new(PER_CONNECTION_BUFFER_CAPACITY),
+ buffer: RingBuffer::new(buffer_capacity.try_into().unwrap()),
peer_requested_shutdown: false,
}
}
}
-impl<H: Hal, T: Transport> VsockConnectionManager<H, T> {
+impl<H: Hal, T: Transport, const RX_BUFFER_SIZE: usize>
+ VsockConnectionManager<H, T, RX_BUFFER_SIZE>
+{
/// Construct a new connection manager wrapping the given low-level VirtIO socket driver.
- pub fn new(driver: VirtIOSocket<H, T>) -> Self {
+ pub fn new(driver: VirtIOSocket<H, T, RX_BUFFER_SIZE>) -> Self {
+ Self::new_with_capacity(driver, DEFAULT_PER_CONNECTION_BUFFER_CAPACITY)
+ }
+
+ /// Construct a new connection manager wrapping the given low-level VirtIO socket driver, with
+ /// the given per-connection buffer capacity.
+ pub fn new_with_capacity(
+ driver: VirtIOSocket<H, T, RX_BUFFER_SIZE>,
+ per_connection_buffer_capacity: u32,
+ ) -> Self {
Self {
driver,
connections: Vec::new(),
listening_ports: Vec::new(),
+ per_connection_buffer_capacity,
}
}
@@ -106,7 +126,8 @@ impl<H: Hal, T: Transport> VsockConnectionManager<H, T> {
return Err(SocketError::ConnectionExists.into());
}
- let new_connection = Connection::new(destination, src_port);
+ let new_connection =
+ Connection::new(destination, src_port, self.per_connection_buffer_capacity);
self.driver.connect(&new_connection.info)?;
debug!("Connection requested: {:?}", new_connection.info);
@@ -125,6 +146,7 @@ impl<H: Hal, T: Transport> VsockConnectionManager<H, T> {
pub fn poll(&mut self) -> Result<Option<VsockEvent>> {
let guest_cid = self.driver.guest_cid();
let connections = &mut self.connections;
+ let per_connection_buffer_capacity = self.per_connection_buffer_capacity;
let result = self.driver.poll(|event, body| {
let connection = get_connection_for_event(connections, &event, guest_cid);
@@ -140,7 +162,11 @@ impl<H: Hal, T: Transport> VsockConnectionManager<H, T> {
}
// Add the new connection to our list, at least for now. It will be removed again
// below if we weren't listening on the port.
- connections.push(Connection::new(event.source, event.destination.port));
+ connections.push(Connection::new(
+ event.source,
+ event.destination.port,
+ per_connection_buffer_capacity,
+ ));
connections.last_mut().unwrap()
} else {
return Ok(None);
@@ -252,7 +278,8 @@ impl<H: Hal, T: Transport> VsockConnectionManager<H, T> {
}
}
- /// Requests to shut down the connection cleanly.
+ /// Requests to shut down the connection cleanly, telling the peer that we won't send or receive
+ /// any more data.
///
/// This returns as soon as the request is sent; you should wait until `poll` returns a
/// `VsockEventType::Disconnected` event if you want to know that the peer has acknowledged the
@@ -389,7 +416,9 @@ mod tests {
use super::*;
use crate::{
device::socket::{
- protocol::{SocketType, VirtioVsockConfig, VirtioVsockHdr, VirtioVsockOp},
+ protocol::{
+ SocketType, StreamShutdown, VirtioVsockConfig, VirtioVsockHdr, VirtioVsockOp,
+ },
vsock::{VsockBufferStatus, QUEUE_SIZE, RX_QUEUE_IDX, TX_QUEUE_IDX},
},
hal::fake::FakeHal,
@@ -557,7 +586,7 @@ mod tests {
dst_port: host_port.into(),
len: 0.into(),
socket_type: SocketType::Stream.into(),
- flags: 0.into(),
+ flags: (StreamShutdown::SEND | StreamShutdown::RECEIVE).into(),
buf_alloc: 1024.into(),
fwd_cnt: (hello_from_host.len() as u32).into(),
}
diff --git a/src/device/socket/mod.rs b/src/device/socket/mod.rs
index 8d2de2b..3b59d65 100644
--- a/src/device/socket/mod.rs
+++ b/src/device/socket/mod.rs
@@ -20,3 +20,7 @@ pub use error::SocketError;
pub use protocol::{VsockAddr, VMADDR_CID_HOST};
#[cfg(feature = "alloc")]
pub use vsock::{DisconnectReason, VirtIOSocket, VsockEvent, VsockEventType};
+
+/// The size in bytes of each buffer used in the RX virtqueue. This must be bigger than
+/// `size_of::<VirtioVsockHdr>()`.
+const DEFAULT_RX_BUFFER_SIZE: usize = 512;
diff --git a/src/device/socket/protocol.rs b/src/device/socket/protocol.rs
index 00ca7c0..4eac324 100644
--- a/src/device/socket/protocol.rs
+++ b/src/device/socket/protocol.rs
@@ -45,7 +45,7 @@ pub struct VirtioVsockConfig {
}
/// The message header for data packets sent on the tx/rx queues
-#[repr(packed)]
+#[repr(C, packed)]
#[derive(AsBytes, Clone, Copy, Debug, Eq, FromBytes, FromZeroes, PartialEq)]
pub struct VirtioVsockHdr {
pub src_cid: U64<LittleEndian>,
@@ -214,3 +214,20 @@ bitflags! {
const NOTIFICATION_DATA = 1 << 38;
}
}
+
+bitflags! {
+ /// Flags sent with a shutdown request to hint that the peer won't send or receive more data.
+ #[derive(Copy, Clone, Debug, Default, Eq, PartialEq)]
+ pub struct StreamShutdown: u32 {
+ /// The peer will not receive any more data.
+ const RECEIVE = 1 << 0;
+ /// The peer will not send any more data.
+ const SEND = 1 << 1;
+ }
+}
+
+impl From<StreamShutdown> for U32<LittleEndian> {
+ fn from(flags: StreamShutdown) -> Self {
+ flags.bits().into()
+ }
+}
diff --git a/src/device/socket/vsock.rs b/src/device/socket/vsock.rs
index 4578056..2103753 100644
--- a/src/device/socket/vsock.rs
+++ b/src/device/socket/vsock.rs
@@ -2,7 +2,10 @@
#![deny(unsafe_op_in_unsafe_fn)]
use super::error::SocketError;
-use super::protocol::{Feature, VirtioVsockConfig, VirtioVsockHdr, VirtioVsockOp, VsockAddr};
+use super::protocol::{
+ Feature, StreamShutdown, VirtioVsockConfig, VirtioVsockHdr, VirtioVsockOp, VsockAddr,
+};
+use super::DEFAULT_RX_BUFFER_SIZE;
use crate::hal::Hal;
use crate::queue::VirtQueue;
use crate::transport::Transport;
@@ -19,10 +22,7 @@ pub(crate) const TX_QUEUE_IDX: u16 = 1;
const EVENT_QUEUE_IDX: u16 = 2;
pub(crate) const QUEUE_SIZE: usize = 8;
-const SUPPORTED_FEATURES: Feature = Feature::RING_EVENT_IDX;
-
-/// The size in bytes of each buffer used in the RX virtqueue. This must be bigger than size_of::<VirtioVsockHdr>().
-const RX_BUFFER_SIZE: usize = 512;
+const SUPPORTED_FEATURES: Feature = Feature::RING_EVENT_IDX.union(Feature::RING_INDIRECT_DESC);
#[derive(Clone, Debug, Default, PartialEq, Eq)]
pub struct ConnectionInfo {
@@ -210,7 +210,11 @@ pub enum VsockEventType {
///
/// You probably want to use [`VsockConnectionManager`](super::VsockConnectionManager) rather than
/// using this directly.
-pub struct VirtIOSocket<H: Hal, T: Transport> {
+///
+/// `RX_BUFFER_SIZE` is the size in bytes of each buffer used in the RX virtqueue. This must be
+/// bigger than `size_of::<VirtioVsockHdr>()`.
+pub struct VirtIOSocket<H: Hal, T: Transport, const RX_BUFFER_SIZE: usize = DEFAULT_RX_BUFFER_SIZE>
+{
transport: T,
/// Virtqueue to receive packets.
rx: VirtQueue<H, { QUEUE_SIZE }>,
@@ -223,7 +227,21 @@ pub struct VirtIOSocket<H: Hal, T: Transport> {
rx_queue_buffers: [NonNull<[u8; RX_BUFFER_SIZE]>; QUEUE_SIZE],
}
-impl<H: Hal, T: Transport> Drop for VirtIOSocket<H, T> {
+// SAFETY: The `rx_queue_buffers` can be accessed from any thread.
+unsafe impl<H: Hal, T: Transport + Send> Send for VirtIOSocket<H, T> where
+ VirtQueue<H, QUEUE_SIZE>: Send
+{
+}
+
+// SAFETY: A `&VirtIOSocket` only allows reading the guest CID from a field.
+unsafe impl<H: Hal, T: Transport + Sync> Sync for VirtIOSocket<H, T> where
+ VirtQueue<H, QUEUE_SIZE>: Sync
+{
+}
+
+impl<H: Hal, T: Transport, const RX_BUFFER_SIZE: usize> Drop
+ for VirtIOSocket<H, T, RX_BUFFER_SIZE>
+{
fn drop(&mut self) {
// Clear any pointers pointing to DMA regions, so the device doesn't try to access them
// after they have been freed.
@@ -239,9 +257,11 @@ impl<H: Hal, T: Transport> Drop for VirtIOSocket<H, T> {
}
}
-impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
+impl<H: Hal, T: Transport, const RX_BUFFER_SIZE: usize> VirtIOSocket<H, T, RX_BUFFER_SIZE> {
/// Create a new VirtIO Vsock driver.
pub fn new(mut transport: T) -> Result<Self> {
+ assert!(RX_BUFFER_SIZE > size_of::<VirtioVsockHdr>());
+
let negotiated_features = transport.begin_init(SUPPORTED_FEATURES);
let config = transport.config_space::<VirtioVsockConfig>()?;
@@ -255,19 +275,19 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
let mut rx = VirtQueue::new(
&mut transport,
RX_QUEUE_IDX,
- false,
+ negotiated_features.contains(Feature::RING_INDIRECT_DESC),
negotiated_features.contains(Feature::RING_EVENT_IDX),
)?;
let tx = VirtQueue::new(
&mut transport,
TX_QUEUE_IDX,
- false,
+ negotiated_features.contains(Feature::RING_INDIRECT_DESC),
negotiated_features.contains(Feature::RING_EVENT_IDX),
)?;
let event = VirtQueue::new(
&mut transport,
EVENT_QUEUE_IDX,
- false,
+ negotiated_features.contains(Feature::RING_INDIRECT_DESC),
negotiated_features.contains(Feature::RING_EVENT_IDX),
)?;
@@ -397,19 +417,38 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
result
}
- /// Requests to shut down the connection cleanly.
+ /// Requests to shut down the connection cleanly, sending hints about whether we will send or
+ /// receive more data.
///
/// This returns as soon as the request is sent; you should wait until `poll` returns a
/// `VsockEventType::Disconnected` event if you want to know that the peer has acknowledged the
/// shutdown.
- pub fn shutdown(&mut self, connection_info: &ConnectionInfo) -> Result {
+ pub fn shutdown_with_hints(
+ &mut self,
+ connection_info: &ConnectionInfo,
+ hints: StreamShutdown,
+ ) -> Result {
let header = VirtioVsockHdr {
op: VirtioVsockOp::Shutdown.into(),
+ flags: hints.into(),
..connection_info.new_header(self.guest_cid)
};
self.send_packet_to_tx_queue(&header, &[])
}
+ /// Requests to shut down the connection cleanly, telling the peer that we won't send or receive
+ /// any more data.
+ ///
+ /// This returns as soon as the request is sent; you should wait until `poll` returns a
+ /// `VsockEventType::Disconnected` event if you want to know that the peer has acknowledged the
+ /// shutdown.
+ pub fn shutdown(&mut self, connection_info: &ConnectionInfo) -> Result {
+ self.shutdown_with_hints(
+ connection_info,
+ StreamShutdown::SEND | StreamShutdown::RECEIVE,
+ )
+ }
+
/// Forcibly closes the connection without waiting for the peer.
pub fn force_close(&mut self, connection_info: &ConnectionInfo) -> Result {
let header = VirtioVsockHdr {
diff --git a/src/hal.rs b/src/hal.rs
index 04b87fc..64e12e7 100644
--- a/src/hal.rs
+++ b/src/hal.rs
@@ -16,6 +16,13 @@ pub struct Dma<H: Hal> {
_hal: PhantomData<H>,
}
+// SAFETY: DMA memory can be accessed from any thread.
+unsafe impl<H: Hal> Send for Dma<H> {}
+
+// SAFETY: `&Dma` only allows pointers and physical addresses to be returned. Any actual access to
+// the memory requires unsafe code, which is responsible for avoiding data races.
+unsafe impl<H: Hal> Sync for Dma<H> {}
+
impl<H: Hal> Dma<H> {
/// Allocates the given number of pages of physically contiguous memory to be used for DMA in
/// the given direction.
diff --git a/src/queue.rs b/src/queue.rs
index cc10325..3573a39 100644
--- a/src/queue.rs
+++ b/src/queue.rs
@@ -8,12 +8,13 @@ use alloc::boxed::Box;
use bitflags::bitflags;
#[cfg(test)]
use core::cmp::min;
+use core::convert::TryInto;
use core::hint::spin_loop;
use core::mem::{size_of, take};
#[cfg(test)]
use core::ptr;
use core::ptr::NonNull;
-use core::sync::atomic::{fence, Ordering};
+use core::sync::atomic::{fence, AtomicU16, Ordering};
use zerocopy::{AsBytes, FromBytes, FromZeroes};
/// The mechanism for bulk data transport on virtio devices.
@@ -21,7 +22,7 @@ use zerocopy::{AsBytes, FromBytes, FromZeroes};
/// Each device can have zero or more virtqueues.
///
/// * `SIZE`: The size of the queue. This is both the number of descriptors, and the number of slots
-/// in the available and used rings.
+/// in the available and used rings. It must be a power of 2 and fit in a [`u16`].
#[derive(Debug)]
pub struct VirtQueue<H: Hal, const SIZE: usize> {
/// DMA guard
@@ -61,6 +62,8 @@ pub struct VirtQueue<H: Hal, const SIZE: usize> {
}
impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> {
+ const SIZE_OK: () = assert!(SIZE.is_power_of_two() && SIZE <= u16::MAX as usize);
+
/// Creates a new VirtQueue.
///
/// * `indirect`: Whether to use indirect descriptors. This should be set if the
@@ -74,13 +77,13 @@ impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> {
indirect: bool,
event_idx: bool,
) -> Result<Self> {
+ #[allow(clippy::let_unit_value)]
+ let _ = Self::SIZE_OK;
+
if transport.queue_used(idx) {
return Err(Error::AlreadyUsed);
}
- if !SIZE.is_power_of_two()
- || SIZE > u16::MAX.into()
- || transport.max_queue_size(idx) < SIZE as u32
- {
+ if transport.max_queue_size(idx) < SIZE as u32 {
return Err(Error::InvalidParam);
}
let size = SIZE as u16;
@@ -192,12 +195,11 @@ impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> {
self.avail_idx = self.avail_idx.wrapping_add(1);
// Safe because self.avail is properly aligned, dereferenceable and initialised.
unsafe {
- (*self.avail.as_ptr()).idx = self.avail_idx;
+ (*self.avail.as_ptr())
+ .idx
+ .store(self.avail_idx, Ordering::Release);
}
- // Write barrier so that device can see change to available index after this method returns.
- fence(Ordering::SeqCst);
-
Ok(head)
}
@@ -316,23 +318,36 @@ impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> {
unsafe { self.pop_used(token, inputs, outputs) }
}
+ /// Advise the device whether used buffer notifications are needed.
+ ///
+ /// See Virtio v1.1 2.6.7 Used Buffer Notification Suppression
+ pub fn set_dev_notify(&mut self, enable: bool) {
+ let avail_ring_flags = if enable { 0x0000 } else { 0x0001 };
+ if !self.event_idx {
+ // Safe because self.avail points to a valid, aligned, initialised, dereferenceable, readable
+ // instance of AvailRing.
+ unsafe {
+ (*self.avail.as_ptr())
+ .flags
+ .store(avail_ring_flags, Ordering::Release)
+ }
+ }
+ }
+
/// Returns whether the driver should notify the device after adding a new buffer to the
/// virtqueue.
///
/// This will be false if the device has supressed notifications.
pub fn should_notify(&self) -> bool {
- // Read barrier, so we read a fresh value from the device.
- fence(Ordering::SeqCst);
-
if self.event_idx {
// Safe because self.used points to a valid, aligned, initialised, dereferenceable, readable
// instance of UsedRing.
- let avail_event = unsafe { (*self.used.as_ptr()).avail_event };
+ let avail_event = unsafe { (*self.used.as_ptr()).avail_event.load(Ordering::Acquire) };
self.avail_idx >= avail_event.wrapping_add(1)
} else {
// Safe because self.used points to a valid, aligned, initialised, dereferenceable, readable
// instance of UsedRing.
- unsafe { (*self.used.as_ptr()).flags & 0x0001 == 0 }
+ unsafe { (*self.used.as_ptr()).flags.load(Ordering::Acquire) & 0x0001 == 0 }
}
}
@@ -349,12 +364,9 @@ impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> {
/// Returns whether there is a used element that can be popped.
pub fn can_pop(&self) -> bool {
- // Read barrier, so we read a fresh value from the device.
- fence(Ordering::SeqCst);
-
// Safe because self.used points to a valid, aligned, initialised, dereferenceable, readable
// instance of UsedRing.
- self.last_used_idx != unsafe { (*self.used.as_ptr()).idx }
+ self.last_used_idx != unsafe { (*self.used.as_ptr()).idx.load(Ordering::Acquire) }
}
/// Returns the descriptor index (a.k.a. token) of the next used element without popping it, or
@@ -492,7 +504,6 @@ impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> {
if !self.can_pop() {
return Err(Error::NotReady);
}
- // Read barrier not necessary, as can_pop already has one.
// Get the index of the start of the descriptor chain for the next element in the used ring.
let last_used_slot = self.last_used_idx & (SIZE as u16 - 1);
@@ -516,10 +527,25 @@ impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> {
}
self.last_used_idx = self.last_used_idx.wrapping_add(1);
+ if self.event_idx {
+ unsafe {
+ (*self.avail.as_ptr())
+ .used_event
+ .store(self.last_used_idx, Ordering::Release);
+ }
+ }
+
Ok(len)
}
}
+// SAFETY: None of the virt queue resources are tied to a particular thread.
+unsafe impl<H: Hal, const SIZE: usize> Send for VirtQueue<H, SIZE> {}
+
+// SAFETY: A `&VirtQueue` only allows reading from the various pointers it contains, so there is no
+// data race.
+unsafe impl<H: Hal, const SIZE: usize> Sync for VirtQueue<H, SIZE> {}
+
/// The inner layout of a VirtQueue.
///
/// Ref: 2.6 Split Virtqueues
@@ -692,7 +718,7 @@ impl Descriptor {
unsafe {
self.addr = H::share(buf, direction) as u64;
}
- self.len = buf.len() as u32;
+ self.len = buf.len().try_into().unwrap();
self.flags = extra_flags
| match direction {
BufferDirection::DeviceToDriver => DescFlags::WRITE,
@@ -741,11 +767,12 @@ bitflags! {
#[repr(C)]
#[derive(Debug)]
struct AvailRing<const SIZE: usize> {
- flags: u16,
+ flags: AtomicU16,
/// A driver MUST NOT decrement the idx.
- idx: u16,
+ idx: AtomicU16,
ring: [u16; SIZE],
- used_event: u16, // unused
+ /// Only used if `VIRTIO_F_EVENT_IDX` is negotiated.
+ used_event: AtomicU16,
}
/// The used ring is where the device returns buffers once it is done with them:
@@ -753,11 +780,11 @@ struct AvailRing<const SIZE: usize> {
#[repr(C)]
#[derive(Debug)]
struct UsedRing<const SIZE: usize> {
- flags: u16,
- idx: u16,
+ flags: AtomicU16,
+ idx: AtomicU16,
ring: [UsedElem; SIZE],
/// Only used if `VIRTIO_F_EVENT_IDX` is negotiated.
- avail_event: u16,
+ avail_event: AtomicU16,
}
#[repr(C)]
@@ -826,10 +853,13 @@ pub(crate) fn fake_read_write_queue<const QUEUE_SIZE: usize>(
// nothing else accesses them during this block.
unsafe {
// Make sure there is actually at least one descriptor available to read from.
- assert_ne!((*available_ring).idx, (*used_ring).idx);
+ assert_ne!(
+ (*available_ring).idx.load(Ordering::Acquire),
+ (*used_ring).idx.load(Ordering::Acquire)
+ );
// The fake device always uses descriptors in order, like VIRTIO_F_IN_ORDER, so
// `used_ring.idx` marks the next descriptor we should take from the available ring.
- let next_slot = (*used_ring).idx & (QUEUE_SIZE as u16 - 1);
+ let next_slot = (*used_ring).idx.load(Ordering::Acquire) & (QUEUE_SIZE as u16 - 1);
let head_descriptor_index = (*available_ring).ring[next_slot as usize];
let mut descriptor = &(*descriptors)[head_descriptor_index as usize];
@@ -928,9 +958,9 @@ pub(crate) fn fake_read_write_queue<const QUEUE_SIZE: usize>(
}
// Mark the buffer as used.
- (*used_ring).ring[next_slot as usize].id = head_descriptor_index as u32;
+ (*used_ring).ring[next_slot as usize].id = head_descriptor_index.into();
(*used_ring).ring[next_slot as usize].len = (input_length + output.len()) as u32;
- (*used_ring).idx += 1;
+ (*used_ring).idx.fetch_add(1, Ordering::AcqRel);
}
}
@@ -950,17 +980,6 @@ mod tests {
use std::sync::{Arc, Mutex};
#[test]
- fn invalid_queue_size() {
- let mut header = VirtIOHeader::make_fake_header(MODERN_VERSION, 1, 0, 0, 4);
- let mut transport = unsafe { MmioTransport::new(NonNull::from(&mut header)) }.unwrap();
- // Size not a power of 2.
- assert_eq!(
- VirtQueue::<FakeHal, 3>::new(&mut transport, 0, false, false).unwrap_err(),
- Error::InvalidParam
- );
- }
-
- #[test]
fn queue_too_big() {
let mut header = VirtIOHeader::make_fake_header(MODERN_VERSION, 1, 0, 0, 4);
let mut transport = unsafe { MmioTransport::new(NonNull::from(&mut header)) }.unwrap();
@@ -1117,6 +1136,46 @@ mod tests {
}
}
+ /// Tests that the queue advises the device that notifications are needed.
+ #[test]
+ fn set_dev_notify() {
+ let mut config_space = ();
+ let state = Arc::new(Mutex::new(State {
+ queues: vec![QueueStatus::default()],
+ ..Default::default()
+ }));
+ let mut transport = FakeTransport {
+ device_type: DeviceType::Block,
+ max_queue_size: 4,
+ device_features: 0,
+ config_space: NonNull::from(&mut config_space),
+ state: state.clone(),
+ };
+ let mut queue = VirtQueue::<FakeHal, 4>::new(&mut transport, 0, false, false).unwrap();
+
+ // Check that the avail ring's flag is zero by default.
+ assert_eq!(
+ unsafe { (*queue.avail.as_ptr()).flags.load(Ordering::Acquire) },
+ 0x0
+ );
+
+ queue.set_dev_notify(false);
+
+ // Check that the avail ring's flag is 1 after `disable_dev_notify`.
+ assert_eq!(
+ unsafe { (*queue.avail.as_ptr()).flags.load(Ordering::Acquire) },
+ 0x1
+ );
+
+ queue.set_dev_notify(true);
+
+ // Check that the avail ring's flag is 0 after `enable_dev_notify`.
+ assert_eq!(
+ unsafe { (*queue.avail.as_ptr()).flags.load(Ordering::Acquire) },
+ 0x0
+ );
+ }
+
/// Tests that the queue notifies the device about added buffers, if it hasn't suppressed
/// notifications.
#[test]
@@ -1145,7 +1204,7 @@ mod tests {
// initialised, and nothing else is accessing them at the same time.
unsafe {
// Suppress notifications.
- (*queue.used.as_ptr()).flags = 0x01;
+ (*queue.used.as_ptr()).flags.store(0x01, Ordering::Release);
}
// Check that the transport would not be notified.
@@ -1180,7 +1239,9 @@ mod tests {
// initialised, and nothing else is accessing them at the same time.
unsafe {
// Suppress notifications.
- (*queue.used.as_ptr()).avail_event = 1;
+ (*queue.used.as_ptr())
+ .avail_event
+ .store(1, Ordering::Release);
}
// Check that the transport would not be notified.
diff --git a/src/transport/mmio.rs b/src/transport/mmio.rs
index d938a97..9c5bb4d 100644
--- a/src/transport/mmio.rs
+++ b/src/transport/mmio.rs
@@ -310,6 +310,13 @@ impl MmioTransport {
}
}
+// SAFETY: `header` is only used for MMIO, which can happen from any thread or CPU core.
+unsafe impl Send for MmioTransport {}
+
+// SAFETY: `&MmioTransport` only allows MMIO reads or getting the config space, both of which are
+// fine to happen concurrently on different CPU cores.
+unsafe impl Sync for MmioTransport {}
+
impl Transport for MmioTransport {
fn device_type(&self) -> DeviceType {
// Safe because self.header points to a valid VirtIO MMIO region.
diff --git a/src/transport/pci.rs b/src/transport/pci.rs
index 27401fe..a987a9c 100644
--- a/src/transport/pci.rs
+++ b/src/transport/pci.rs
@@ -341,6 +341,13 @@ impl Transport for PciTransport {
}
}
+// SAFETY: MMIO can be done from any thread or CPU core.
+unsafe impl Send for PciTransport {}
+
+// SAFETY: `&PciTransport` only allows MMIO reads or getting the config space, both of which are
+// fine to happen concurrently on different CPU cores.
+unsafe impl Sync for PciTransport {}
+
impl Drop for PciTransport {
fn drop(&mut self) {
// Reset the device when the transport is dropped.
@@ -499,6 +506,12 @@ impl From<PciError> for VirtioPciError {
}
}
+// SAFETY: The `vaddr` field of `VirtioPciError::Misaligned` is only used for debug output.
+unsafe impl Send for VirtioPciError {}
+
+// SAFETY: The `vaddr` field of `VirtioPciError::Misaligned` is only used for debug output.
+unsafe impl Sync for VirtioPciError {}
+
#[cfg(test)]
mod tests {
use super::*;
diff --git a/src/transport/pci/bus.rs b/src/transport/pci/bus.rs
index 52f861e..146a7d6 100644
--- a/src/transport/pci/bus.rs
+++ b/src/transport/pci/bus.rs
@@ -328,6 +328,13 @@ impl PciRoot {
}
}
+// SAFETY: `mmio_base` is only used for MMIO, which can happen from any thread or CPU core.
+unsafe impl Send for PciRoot {}
+
+// SAFETY: `&PciRoot` only allows MMIO reads, which are fine to happen concurrently on different CPU
+// cores.
+unsafe impl Sync for PciRoot {}
+
/// Information about a PCI Base Address Register.
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum BarInfo {