summaryrefslogtreecommitdiff
path: root/src/vhost_user
diff options
context:
space:
mode:
Diffstat (limited to 'src/vhost_user')
-rw-r--r--src/vhost_user/connection.rs56
-rw-r--r--src/vhost_user/dummy_slave.rs58
-rw-r--r--src/vhost_user/master.rs352
-rw-r--r--src/vhost_user/master_req_handler.rs285
-rw-r--r--src/vhost_user/message.rs162
-rw-r--r--src/vhost_user/mod.rs200
-rw-r--r--src/vhost_user/slave.rs46
-rw-r--r--src/vhost_user/slave_fs_cache.rs210
-rw-r--r--src/vhost_user/slave_req_handler.rs303
-rw-r--r--src/vhost_user/sock_ctrl_msg.rs499
10 files changed, 1363 insertions, 808 deletions
diff --git a/src/vhost_user/connection.rs b/src/vhost_user/connection.rs
index deafdeb..01bf124 100644
--- a/src/vhost_user/connection.rs
+++ b/src/vhost_user/connection.rs
@@ -5,15 +5,16 @@
#![allow(dead_code)]
-use libc::{c_void, iovec};
use std::io::ErrorKind;
use std::marker::PhantomData;
use std::os::unix::io::{AsRawFd, RawFd};
use std::os::unix::net::{UnixListener, UnixStream};
use std::{mem, slice};
+use libc::{c_void, iovec};
+use vmm_sys_util::sock_ctrl_msg::ScmSocket;
+
use super::message::*;
-use super::sock_ctrl_msg::ScmSocket;
use super::{Error, Result};
/// Unix domain socket listener for accepting incoming connections.
@@ -215,6 +216,9 @@ impl<R: Req> Endpoint<R> {
body: &T,
fds: Option<&[RawFd]>,
) -> Result<()> {
+ if mem::size_of::<T>() > MAX_MSG_SIZE {
+ return Err(Error::OversizedMsg);
+ }
// Safe because there can't be other mutable referance to hdr and body.
let iovs = unsafe {
[
@@ -243,14 +247,17 @@ impl<R: Req> Endpoint<R> {
/// * - OversizedMsg: message size is too big.
/// * - PartialMessage: received a partial message.
/// * - IncorrectFds: wrong number of attached fds.
- pub fn send_message_with_payload<T: Sized, P: Sized>(
+ pub fn send_message_with_payload<T: Sized>(
&mut self,
hdr: &VhostUserMsgHeader<R>,
body: &T,
- payload: &[P],
+ payload: &[u8],
fds: Option<&[RawFd]>,
) -> Result<()> {
- let len = payload.len() * mem::size_of::<P>();
+ let len = payload.len();
+ if mem::size_of::<T>() > MAX_MSG_SIZE {
+ return Err(Error::OversizedMsg);
+ }
if len > MAX_MSG_SIZE - mem::size_of::<T>() {
return Err(Error::OversizedMsg);
}
@@ -599,27 +606,32 @@ fn get_sub_iovs_offset(iov_lens: &[usize], skip_size: usize) -> (usize, usize) {
#[cfg(test)]
mod tests {
-
use super::*;
use std::fs::File;
use std::io::{Read, Seek, SeekFrom, Write};
use std::os::unix::io::FromRawFd;
+ use vmm_sys_util::rand::rand_alphanumerics;
use vmm_sys_util::tempfile::TempFile;
- const UNIX_SOCKET_LISTENER: &'static str = "/tmp/vhost_user_test_rust_listener";
- const UNIX_SOCKET_CONNECTION: &'static str = "/tmp/vhost_user_test_rust_connection";
- const UNIX_SOCKET_DATA: &'static str = "/tmp/vhost_user_test_rust_data";
- const UNIX_SOCKET_FD: &'static str = "/tmp/vhost_user_test_rust_fd";
- const UNIX_SOCKET_SEND: &'static str = "/tmp/vhost_user_test_rust_send";
+ fn temp_path() -> String {
+ format!(
+ "/tmp/vhost_test_{}",
+ rand_alphanumerics(8).to_str().unwrap()
+ )
+ }
#[test]
fn create_listener() {
- let _ = Listener::new(UNIX_SOCKET_LISTENER, true).unwrap();
+ let path = temp_path();
+ let listener = Listener::new(&path, true).unwrap();
+
+ assert!(listener.as_raw_fd() > 0);
}
#[test]
fn accept_connection() {
- let listener = Listener::new(UNIX_SOCKET_CONNECTION, true).unwrap();
+ let path = temp_path();
+ let listener = Listener::new(&path, true).unwrap();
listener.set_nonblocking(true).unwrap();
// accept on a fd without incoming connection
@@ -628,11 +640,11 @@ mod tests {
}
#[test]
- #[ignore]
fn send_data() {
- let listener = Listener::new(UNIX_SOCKET_DATA, true).unwrap();
+ let path = temp_path();
+ let listener = Listener::new(&path, true).unwrap();
listener.set_nonblocking(true).unwrap();
- let mut master = Endpoint::<MasterReq>::connect(UNIX_SOCKET_DATA).unwrap();
+ let mut master = Endpoint::<MasterReq>::connect(&path).unwrap();
let sock = listener.accept().unwrap().unwrap();
let mut slave = Endpoint::<MasterReq>::from_stream(sock);
@@ -654,11 +666,11 @@ mod tests {
}
#[test]
- #[ignore]
fn send_fd() {
- let listener = Listener::new(UNIX_SOCKET_FD, true).unwrap();
+ let path = temp_path();
+ let listener = Listener::new(&path, true).unwrap();
listener.set_nonblocking(true).unwrap();
- let mut master = Endpoint::<MasterReq>::connect(UNIX_SOCKET_FD).unwrap();
+ let mut master = Endpoint::<MasterReq>::connect(&path).unwrap();
let sock = listener.accept().unwrap().unwrap();
let mut slave = Endpoint::<MasterReq>::from_stream(sock);
@@ -808,11 +820,11 @@ mod tests {
}
#[test]
- #[ignore]
fn send_recv() {
- let listener = Listener::new(UNIX_SOCKET_SEND, true).unwrap();
+ let path = temp_path();
+ let listener = Listener::new(&path, true).unwrap();
listener.set_nonblocking(true).unwrap();
- let mut master = Endpoint::<MasterReq>::connect(UNIX_SOCKET_SEND).unwrap();
+ let mut master = Endpoint::<MasterReq>::connect(&path).unwrap();
let sock = listener.accept().unwrap().unwrap();
let mut slave = Endpoint::<MasterReq>::from_stream(sock);
diff --git a/src/vhost_user/dummy_slave.rs b/src/vhost_user/dummy_slave.rs
index 53887e2..9eedcbb 100644
--- a/src/vhost_user/dummy_slave.rs
+++ b/src/vhost_user/dummy_slave.rs
@@ -1,9 +1,10 @@
// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
+use std::os::unix::io::RawFd;
+
use super::message::*;
use super::*;
-use std::os::unix::io::RawFd;
pub const MAX_QUEUE_NUM: usize = 2;
pub const MAX_VRING_NUM: usize = 256;
@@ -34,7 +35,7 @@ impl DummySlaveReqHandler {
}
}
-impl VhostUserSlaveReqHandler for DummySlaveReqHandler {
+impl VhostUserSlaveReqHandlerMut for DummySlaveReqHandler {
fn set_owner(&mut self) -> Result<()> {
if self.owned {
return Err(Error::InvalidOperation);
@@ -56,9 +57,7 @@ impl VhostUserSlaveReqHandler for DummySlaveReqHandler {
}
fn set_features(&mut self, features: u64) -> Result<()> {
- if !self.owned {
- return Err(Error::InvalidOperation);
- } else if self.features_acked {
+ if !self.owned || self.features_acked {
return Err(Error::InvalidOperation);
} else if (features & !VIRTIO_FEATURES) != 0 {
return Err(Error::InvalidParam);
@@ -83,30 +82,10 @@ impl VhostUserSlaveReqHandler for DummySlaveReqHandler {
Ok(())
}
- fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures> {
- Ok(VhostUserProtocolFeatures::all())
- }
-
- fn set_protocol_features(&mut self, features: u64) -> Result<()> {
- // Note: slave that reported VHOST_USER_F_PROTOCOL_FEATURES must
- // support this message even before VHOST_USER_SET_FEATURES was
- // called.
- // What happens if the master calls set_features() with
- // VHOST_USER_F_PROTOCOL_FEATURES cleared after calling this
- // interface?
- self.acked_protocol_features = features;
- Ok(())
- }
-
fn set_mem_table(&mut self, _ctx: &[VhostUserMemoryRegion], _fds: &[RawFd]) -> Result<()> {
- // TODO
Ok(())
}
- fn get_queue_num(&mut self) -> Result<u64> {
- Ok(MAX_QUEUE_NUM as u64)
- }
-
fn set_vring_num(&mut self, index: u32, num: u32) -> Result<()> {
if index as usize >= self.queue_num || num == 0 || num as usize > MAX_VRING_NUM {
return Err(Error::InvalidParam);
@@ -199,6 +178,25 @@ impl VhostUserSlaveReqHandler for DummySlaveReqHandler {
Ok(())
}
+ fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures> {
+ Ok(VhostUserProtocolFeatures::all())
+ }
+
+ fn set_protocol_features(&mut self, features: u64) -> Result<()> {
+ // Note: slave that reported VHOST_USER_F_PROTOCOL_FEATURES must
+ // support this message even before VHOST_USER_SET_FEATURES was
+ // called.
+ // What happens if the master calls set_features() with
+ // VHOST_USER_F_PROTOCOL_FEATURES cleared after calling this
+ // interface?
+ self.acked_protocol_features = features;
+ Ok(())
+ }
+
+ fn get_queue_num(&mut self) -> Result<u64> {
+ Ok(MAX_QUEUE_NUM as u64)
+ }
+
fn set_vring_enable(&mut self, index: u32, enable: bool) -> Result<()> {
// This request should be handled only when VHOST_USER_F_PROTOCOL_FEATURES
// has been negotiated.
@@ -222,10 +220,9 @@ impl VhostUserSlaveReqHandler for DummySlaveReqHandler {
size: u32,
_flags: VhostUserConfigFlags,
) -> Result<Vec<u8>> {
- if self.acked_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 {
+ if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 {
return Err(Error::InvalidOperation);
- } else if offset < VHOST_USER_CONFIG_OFFSET
- || offset >= VHOST_USER_CONFIG_SIZE
+ } else if !(VHOST_USER_CONFIG_OFFSET..VHOST_USER_CONFIG_SIZE).contains(&offset)
|| size > VHOST_USER_CONFIG_SIZE - VHOST_USER_CONFIG_OFFSET
|| size + offset > VHOST_USER_CONFIG_SIZE
{
@@ -236,10 +233,9 @@ impl VhostUserSlaveReqHandler for DummySlaveReqHandler {
fn set_config(&mut self, offset: u32, buf: &[u8], _flags: VhostUserConfigFlags) -> Result<()> {
let size = buf.len() as u32;
- if self.acked_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 {
+ if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 {
return Err(Error::InvalidOperation);
- } else if offset < VHOST_USER_CONFIG_OFFSET
- || offset >= VHOST_USER_CONFIG_SIZE
+ } else if !(VHOST_USER_CONFIG_OFFSET..VHOST_USER_CONFIG_SIZE).contains(&offset)
|| size > VHOST_USER_CONFIG_SIZE - VHOST_USER_CONFIG_OFFSET
|| size + offset > VHOST_USER_CONFIG_SIZE
{
diff --git a/src/vhost_user/master.rs b/src/vhost_user/master.rs
index ffed909..35ca471 100644
--- a/src/vhost_user/master.rs
+++ b/src/vhost_user/master.rs
@@ -6,7 +6,7 @@
use std::mem;
use std::os::unix::io::{AsRawFd, RawFd};
use std::os::unix::net::UnixStream;
-use std::sync::{Arc, Mutex};
+use std::sync::{Arc, Mutex, MutexGuard};
use vmm_sys_util::eventfd::EventFd;
@@ -78,6 +78,10 @@ impl Master {
}
}
+ fn node(&self) -> MutexGuard<MasterInternal> {
+ self.node.lock().unwrap()
+ }
+
/// Create a new instance from a Unix stream socket.
pub fn from_stream(sock: UnixStream, max_queue_num: u64) -> Self {
Self::new(Endpoint::<MasterReq>::from_stream(sock), max_queue_num)
@@ -115,8 +119,8 @@ impl Master {
impl VhostBackend for Master {
/// Get from the underlying vhost implementation the feature bitmask.
- fn get_features(&mut self) -> Result<u64> {
- let mut node = self.node.lock().unwrap();
+ fn get_features(&self) -> Result<u64> {
+ let mut node = self.node();
let hdr = node.send_request_header(MasterReq::GET_FEATURES, None)?;
let val = node.recv_reply::<VhostUserU64>(&hdr)?;
node.virtio_features = val.value;
@@ -124,8 +128,8 @@ impl VhostBackend for Master {
}
/// Enable features in the underlying vhost implementation using a bitmask.
- fn set_features(&mut self, features: u64) -> Result<()> {
- let mut node = self.node.lock().unwrap();
+ fn set_features(&self, features: u64) -> Result<()> {
+ let mut node = self.node();
let val = VhostUserU64::new(features);
let _ = node.send_request_with_body(MasterReq::SET_FEATURES, &val, None)?;
// Don't wait for ACK here because the protocol feature negotiation process hasn't been
@@ -135,18 +139,18 @@ impl VhostBackend for Master {
}
/// Set the current Master as an owner of the session.
- fn set_owner(&mut self) -> Result<()> {
+ fn set_owner(&self) -> Result<()> {
// We unwrap() the return value to assert that we are not expecting threads to ever fail
// while holding the lock.
- let mut node = self.node.lock().unwrap();
+ let mut node = self.node();
let _ = node.send_request_header(MasterReq::SET_OWNER, None)?;
// Don't wait for ACK here because the protocol feature negotiation process hasn't been
// completed yet.
Ok(())
}
- fn reset_owner(&mut self) -> Result<()> {
- let mut node = self.node.lock().unwrap();
+ fn reset_owner(&self) -> Result<()> {
+ let mut node = self.node();
let _ = node.send_request_header(MasterReq::RESET_OWNER, None)?;
// Don't wait for ACK here because the protocol feature negotiation process hasn't been
// completed yet.
@@ -155,7 +159,7 @@ impl VhostBackend for Master {
/// Set the memory map regions on the slave so it can translate the vring
/// addresses. In the ancillary data there is an array of file descriptors
- fn set_mem_table(&mut self, regions: &[VhostUserMemoryRegionInfo]) -> Result<()> {
+ fn set_mem_table(&self, regions: &[VhostUserMemoryRegionInfo]) -> Result<()> {
if regions.is_empty() || regions.len() > MAX_ATTACHED_FD_ENTRIES {
return error_code(VhostUserError::InvalidParam);
}
@@ -174,12 +178,13 @@ impl VhostBackend for Master {
ctx.append(&reg, region.mmap_handle);
}
- let mut node = self.node.lock().unwrap();
+ let mut node = self.node();
let body = VhostUserMemory::new(ctx.regions.len() as u32);
+ let (_, payload, _) = unsafe { ctx.regions.align_to::<u8>() };
let hdr = node.send_request_with_payload(
MasterReq::SET_MEM_TABLE,
&body,
- ctx.regions.as_slice(),
+ payload,
Some(ctx.fds.as_slice()),
)?;
node.wait_for_ack(&hdr).map_err(|e| e.into())
@@ -187,8 +192,8 @@ impl VhostBackend for Master {
// Clippy doesn't seem to know that if let with && is still experimental
#[allow(clippy::unnecessary_unwrap)]
- fn set_log_base(&mut self, base: u64, fd: Option<RawFd>) -> Result<()> {
- let mut node = self.node.lock().unwrap();
+ fn set_log_base(&self, base: u64, fd: Option<RawFd>) -> Result<()> {
+ let mut node = self.node();
let val = VhostUserU64::new(base);
if node.acked_protocol_features & VhostUserProtocolFeatures::LOG_SHMFD.bits() != 0
@@ -202,16 +207,16 @@ impl VhostBackend for Master {
Ok(())
}
- fn set_log_fd(&mut self, fd: RawFd) -> Result<()> {
- let mut node = self.node.lock().unwrap();
+ fn set_log_fd(&self, fd: RawFd) -> Result<()> {
+ let mut node = self.node();
let fds = [fd];
node.send_request_header(MasterReq::SET_LOG_FD, Some(&fds))?;
Ok(())
}
/// Set the size of the queue.
- fn set_vring_num(&mut self, queue_index: usize, num: u16) -> Result<()> {
- let mut node = self.node.lock().unwrap();
+ fn set_vring_num(&self, queue_index: usize, num: u16) -> Result<()> {
+ let mut node = self.node();
if queue_index as u64 >= node.max_queue_num {
return error_code(VhostUserError::InvalidParam);
}
@@ -222,8 +227,8 @@ impl VhostBackend for Master {
}
/// Sets the addresses of the different aspects of the vring.
- fn set_vring_addr(&mut self, queue_index: usize, config_data: &VringConfigData) -> Result<()> {
- let mut node = self.node.lock().unwrap();
+ fn set_vring_addr(&self, queue_index: usize, config_data: &VringConfigData) -> Result<()> {
+ let mut node = self.node();
if queue_index as u64 >= node.max_queue_num
|| config_data.flags & !(VhostUserVringAddrFlags::all().bits()) != 0
{
@@ -236,8 +241,8 @@ impl VhostBackend for Master {
}
/// Sets the base offset in the available vring.
- fn set_vring_base(&mut self, queue_index: usize, base: u16) -> Result<()> {
- let mut node = self.node.lock().unwrap();
+ fn set_vring_base(&self, queue_index: usize, base: u16) -> Result<()> {
+ let mut node = self.node();
if queue_index as u64 >= node.max_queue_num {
return error_code(VhostUserError::InvalidParam);
}
@@ -247,8 +252,8 @@ impl VhostBackend for Master {
node.wait_for_ack(&hdr).map_err(|e| e.into())
}
- fn get_vring_base(&mut self, queue_index: usize) -> Result<u32> {
- let mut node = self.node.lock().unwrap();
+ fn get_vring_base(&self, queue_index: usize) -> Result<u32> {
+ let mut node = self.node();
if queue_index as u64 >= node.max_queue_num {
return error_code(VhostUserError::InvalidParam);
}
@@ -263,8 +268,8 @@ impl VhostBackend for Master {
/// Bits (0-7) of the payload contain the vring index. Bit 8 is the invalid FD flag. This flag
/// is set when there is no file descriptor in the ancillary data. This signals that polling
/// will be used instead of waiting for the call.
- fn set_vring_call(&mut self, queue_index: usize, fd: &EventFd) -> Result<()> {
- let mut node = self.node.lock().unwrap();
+ fn set_vring_call(&self, queue_index: usize, fd: &EventFd) -> Result<()> {
+ let mut node = self.node();
if queue_index as u64 >= node.max_queue_num {
return error_code(VhostUserError::InvalidParam);
}
@@ -276,8 +281,8 @@ impl VhostBackend for Master {
/// Bits (0-7) of the payload contain the vring index. Bit 8 is the invalid FD flag. This flag
/// is set when there is no file descriptor in the ancillary data. This signals that polling
/// should be used instead of waiting for a kick.
- fn set_vring_kick(&mut self, queue_index: usize, fd: &EventFd) -> Result<()> {
- let mut node = self.node.lock().unwrap();
+ fn set_vring_kick(&self, queue_index: usize, fd: &EventFd) -> Result<()> {
+ let mut node = self.node();
if queue_index as u64 >= node.max_queue_num {
return error_code(VhostUserError::InvalidParam);
}
@@ -288,8 +293,8 @@ impl VhostBackend for Master {
/// Set the event file descriptor to signal when error occurs.
/// Bits (0-7) of the payload contain the vring index. Bit 8 is the invalid FD flag. This flag
/// is set when there is no file descriptor in the ancillary data.
- fn set_vring_err(&mut self, queue_index: usize, fd: &EventFd) -> Result<()> {
- let mut node = self.node.lock().unwrap();
+ fn set_vring_err(&self, queue_index: usize, fd: &EventFd) -> Result<()> {
+ let mut node = self.node();
if queue_index as u64 >= node.max_queue_num {
return error_code(VhostUserError::InvalidParam);
}
@@ -300,7 +305,7 @@ impl VhostBackend for Master {
impl VhostUserMaster for Master {
fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures> {
- let mut node = self.node.lock().unwrap();
+ let mut node = self.node();
let flag = VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits();
if node.virtio_features & flag == 0 || node.acked_virtio_features & flag == 0 {
return error_code(VhostUserError::InvalidOperation);
@@ -317,7 +322,7 @@ impl VhostUserMaster for Master {
}
fn set_protocol_features(&mut self, features: VhostUserProtocolFeatures) -> Result<()> {
- let mut node = self.node.lock().unwrap();
+ let mut node = self.node();
let flag = VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits();
if node.virtio_features & flag == 0 || node.acked_virtio_features & flag == 0 {
return error_code(VhostUserError::InvalidOperation);
@@ -332,7 +337,7 @@ impl VhostUserMaster for Master {
}
fn get_queue_num(&mut self) -> Result<u64> {
- let mut node = self.node.lock().unwrap();
+ let mut node = self.node();
if !node.is_feature_mq_available() {
return error_code(VhostUserError::InvalidOperation);
}
@@ -347,7 +352,7 @@ impl VhostUserMaster for Master {
}
fn set_vring_enable(&mut self, queue_index: usize, enable: bool) -> Result<()> {
- let mut node = self.node.lock().unwrap();
+ let mut node = self.node();
// set_vring_enable() is supported only when PROTOCOL_FEATURES has been enabled.
if node.acked_virtio_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() == 0 {
return error_code(VhostUserError::InvalidOperation);
@@ -373,7 +378,7 @@ impl VhostUserMaster for Master {
return error_code(VhostUserError::InvalidParam);
}
- let mut node = self.node.lock().unwrap();
+ let mut node = self.node();
// depends on VhostUserProtocolFeatures::CONFIG
if node.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 {
return error_code(VhostUserError::InvalidOperation);
@@ -390,9 +395,13 @@ impl VhostUserMaster for Master {
return error_code(VhostUserError::InvalidMessage);
} else if body_reply.size == 0 {
return error_code(VhostUserError::SlaveInternalError);
- } else if body_reply.size != body.size || body_reply.size as usize != buf.len() {
+ } else if body_reply.size != body.size
+ || body_reply.size as usize != buf.len()
+ || body_reply.offset != body.offset
+ {
return error_code(VhostUserError::InvalidMessage);
}
+
Ok((body_reply, buf_reply))
}
@@ -405,7 +414,7 @@ impl VhostUserMaster for Master {
return error_code(VhostUserError::InvalidParam);
}
- let mut node = self.node.lock().unwrap();
+ let mut node = self.node();
// depends on VhostUserProtocolFeatures::CONFIG
if node.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 {
return error_code(VhostUserError::InvalidOperation);
@@ -416,7 +425,7 @@ impl VhostUserMaster for Master {
}
fn set_slave_request_fd(&mut self, fd: RawFd) -> Result<()> {
- let mut node = self.node.lock().unwrap();
+ let mut node = self.node();
if node.acked_protocol_features & VhostUserProtocolFeatures::SLAVE_REQ.bits() == 0 {
return error_code(VhostUserError::InvalidOperation);
}
@@ -429,7 +438,7 @@ impl VhostUserMaster for Master {
impl AsRawFd for Master {
fn as_raw_fd(&self) -> RawFd {
- let node = self.node.lock().unwrap();
+ let node = self.node();
node.main_sock.as_raw_fd()
}
}
@@ -503,14 +512,14 @@ impl MasterInternal {
Ok(hdr)
}
- fn send_request_with_payload<T: Sized, P: Sized>(
+ fn send_request_with_payload<T: Sized>(
&mut self,
code: MasterReq,
msg: &T,
- payload: &[P],
+ payload: &[u8],
fds: Option<&[RawFd]>,
) -> VhostUserResult<VhostUserMsgHeader<MasterReq>> {
- let len = mem::size_of::<T>() + payload.len() * mem::size_of::<P>();
+ let len = mem::size_of::<T>() + payload.len();
if len > MAX_MSG_SIZE {
return Err(VhostUserError::InvalidParam);
}
@@ -568,7 +577,11 @@ impl MasterInternal {
&mut self,
hdr: &VhostUserMsgHeader<MasterReq>,
) -> VhostUserResult<(T, Vec<u8>, Option<Vec<RawFd>>)> {
- if mem::size_of::<T>() > MAX_MSG_SIZE || hdr.is_reply() {
+ if mem::size_of::<T>() > MAX_MSG_SIZE
+ || hdr.get_size() as usize <= mem::size_of::<T>()
+ || hdr.get_size() as usize > MAX_MSG_SIZE
+ || hdr.is_reply()
+ {
return Err(VhostUserError::InvalidParam);
}
self.check_state()?;
@@ -582,11 +595,8 @@ impl MasterInternal {
{
Endpoint::<MasterReq>::close_rfds(rfds);
return Err(VhostUserError::InvalidMessage);
- } else if bytes > MAX_MSG_SIZE - mem::size_of::<T>() {
+ } else if bytes != buf.len() {
return Err(VhostUserError::InvalidMessage);
- } else if bytes < buf.len() {
- // It's safe because we have checked the buffer size
- unsafe { buf.set_len(bytes) };
}
Ok((body, buf, rfds))
}
@@ -634,11 +644,14 @@ impl MasterInternal {
mod tests {
use super::super::connection::Listener;
use super::*;
+ use vmm_sys_util::rand::rand_alphanumerics;
- const UNIX_SOCKET_MASTER: &'static str = "/tmp/vhost_user_test_rust_master";
- const UNIX_SOCKET_MASTER2: &'static str = "/tmp/vhost_user_test_rust_master2";
- const UNIX_SOCKET_MASTER3: &'static str = "/tmp/vhost_user_test_rust_master3";
- const UNIX_SOCKET_MASTER4: &'static str = "/tmp/vhost_user_test_rust_master4";
+ fn temp_path() -> String {
+ format!(
+ "/tmp/vhost_test_{}",
+ rand_alphanumerics(8).to_str().unwrap()
+ )
+ }
fn create_pair(path: &str) -> (Master, Endpoint<MasterReq>) {
let listener = Listener::new(path, true).unwrap();
@@ -649,14 +662,15 @@ mod tests {
}
#[test]
- #[ignore]
fn create_master() {
- let listener = Listener::new(UNIX_SOCKET_MASTER, true).unwrap();
+ let path = temp_path();
+ let listener = Listener::new(&path, true).unwrap();
listener.set_nonblocking(true).unwrap();
- let mut master = Master::connect(UNIX_SOCKET_MASTER, 1).unwrap();
+ let master = Master::connect(&path, 1).unwrap();
let mut slave = Endpoint::<MasterReq>::from_stream(listener.accept().unwrap().unwrap());
+ assert!(master.as_raw_fd() > 0);
// Send two messages continuously
master.set_owner().unwrap();
master.reset_owner().unwrap();
@@ -675,24 +689,24 @@ mod tests {
}
#[test]
- #[ignore]
fn test_create_failure() {
- let _ = Listener::new(UNIX_SOCKET_MASTER2, true).unwrap();
- let _ = Listener::new(UNIX_SOCKET_MASTER2, false).is_err();
- assert!(Master::connect(UNIX_SOCKET_MASTER2, 1).is_err());
+ let path = temp_path();
+ let _ = Listener::new(&path, true).unwrap();
+ let _ = Listener::new(&path, false).is_err();
+ assert!(Master::connect(&path, 1).is_err());
- let listener = Listener::new(UNIX_SOCKET_MASTER2, true).unwrap();
- assert!(Listener::new(UNIX_SOCKET_MASTER2, false).is_err());
+ let listener = Listener::new(&path, true).unwrap();
+ assert!(Listener::new(&path, false).is_err());
listener.set_nonblocking(true).unwrap();
- let _master = Master::connect(UNIX_SOCKET_MASTER2, 1).unwrap();
+ let _master = Master::connect(&path, 1).unwrap();
let _slave = listener.accept().unwrap().unwrap();
}
#[test]
- #[ignore]
fn test_features() {
- let (mut master, mut peer) = create_pair(UNIX_SOCKET_MASTER3);
+ let path = temp_path();
+ let (master, mut peer) = create_pair(&path);
master.set_owner().unwrap();
let (hdr, rfds) = peer.recv_header().unwrap();
@@ -709,6 +723,9 @@ mod tests {
let (_hdr, rfds) = peer.recv_header().unwrap();
assert!(rfds.is_none());
+ let hdr = VhostUserMsgHeader::new(MasterReq::SET_FEATURES, 0x4, 8);
+ let msg = VhostUserU64::new(0x15);
+ peer.send_message(&hdr, &msg, None).unwrap();
master.set_features(0x15).unwrap();
let (_hdr, msg, rfds) = peer.recv_body::<VhostUserU64>().unwrap();
assert!(rfds.is_none());
@@ -722,9 +739,9 @@ mod tests {
}
#[test]
- #[ignore]
fn test_protocol_features() {
- let (mut master, mut peer) = create_pair(UNIX_SOCKET_MASTER4);
+ let path = temp_path();
+ let (mut master, mut peer) = create_pair(&path);
master.set_owner().unwrap();
let (hdr, rfds) = peer.recv_header().unwrap();
@@ -773,12 +790,209 @@ mod tests {
}
#[test]
- fn test_set_mem_table() {
- // TODO
+ fn test_master_set_config_negative() {
+ let path = temp_path();
+ let (mut master, _peer) = create_pair(&path);
+ let buf = vec![0x0; MAX_MSG_SIZE + 1];
+
+ master
+ .set_config(0x100, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .unwrap_err();
+
+ {
+ let mut node = master.node();
+ node.virtio_features = 0xffff_ffff;
+ node.acked_virtio_features = 0xffff_ffff;
+ node.protocol_features = 0xffff_ffff;
+ node.acked_protocol_features = 0xffff_ffff;
+ }
+
+ master
+ .set_config(0x100, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .unwrap();
+ master
+ .set_config(0x0, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .unwrap_err();
+ master
+ .set_config(0x1000, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .unwrap_err();
+ master
+ .set_config(
+ 0x100,
+ unsafe { VhostUserConfigFlags::from_bits_unchecked(0xffff_ffff) },
+ &buf[0..4],
+ )
+ .unwrap_err();
+ master
+ .set_config(0x100, VhostUserConfigFlags::WRITABLE, &buf)
+ .unwrap_err();
+ master
+ .set_config(0x100, VhostUserConfigFlags::WRITABLE, &[])
+ .unwrap_err();
+ }
+
+ fn create_pair2() -> (Master, Endpoint<MasterReq>) {
+ let path = temp_path();
+ let (master, peer) = create_pair(&path);
+
+ {
+ let mut node = master.node();
+ node.virtio_features = 0xffff_ffff;
+ node.acked_virtio_features = 0xffff_ffff;
+ node.protocol_features = 0xffff_ffff;
+ node.acked_protocol_features = 0xffff_ffff;
+ }
+
+ (master, peer)
+ }
+
+ #[test]
+ fn test_master_get_config_negative0() {
+ let (mut master, mut peer) = create_pair2();
+ let buf = vec![0x0; MAX_MSG_SIZE + 1];
+
+ let mut hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16);
+ let msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty());
+ peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
+ .unwrap();
+ assert!(master
+ .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .is_ok());
+
+ hdr.set_code(MasterReq::GET_FEATURES);
+ peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
+ .unwrap();
+ assert!(master
+ .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .is_err());
+ hdr.set_code(MasterReq::GET_CONFIG);
+ }
+
+ #[test]
+ fn test_master_get_config_negative1() {
+ let (mut master, mut peer) = create_pair2();
+ let buf = vec![0x0; MAX_MSG_SIZE + 1];
+
+ let mut hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16);
+ let msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty());
+ peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
+ .unwrap();
+ assert!(master
+ .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .is_ok());
+
+ hdr.set_reply(false);
+ peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
+ .unwrap();
+ assert!(master
+ .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .is_err());
}
#[test]
- fn test_get_ring_num() {
- // TODO
+ fn test_master_get_config_negative2() {
+ let (mut master, mut peer) = create_pair2();
+ let buf = vec![0x0; MAX_MSG_SIZE + 1];
+
+ let hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16);
+ let msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty());
+ peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
+ .unwrap();
+ assert!(master
+ .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .is_ok());
+ }
+
+ #[test]
+ fn test_master_get_config_negative3() {
+ let (mut master, mut peer) = create_pair2();
+ let buf = vec![0x0; MAX_MSG_SIZE + 1];
+
+ let hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16);
+ let mut msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty());
+ peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
+ .unwrap();
+ assert!(master
+ .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .is_ok());
+
+ msg.offset = 0;
+ peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
+ .unwrap();
+ assert!(master
+ .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .is_err());
+ }
+
+ #[test]
+ fn test_master_get_config_negative4() {
+ let (mut master, mut peer) = create_pair2();
+ let buf = vec![0x0; MAX_MSG_SIZE + 1];
+
+ let hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16);
+ let mut msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty());
+ peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
+ .unwrap();
+ assert!(master
+ .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .is_ok());
+
+ msg.offset = 0x101;
+ peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
+ .unwrap();
+ assert!(master
+ .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .is_err());
+ }
+
+ #[test]
+ fn test_master_get_config_negative5() {
+ let (mut master, mut peer) = create_pair2();
+ let buf = vec![0x0; MAX_MSG_SIZE + 1];
+
+ let hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16);
+ let mut msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty());
+ peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
+ .unwrap();
+ assert!(master
+ .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .is_ok());
+
+ msg.offset = (MAX_MSG_SIZE + 1) as u32;
+ peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
+ .unwrap();
+ assert!(master
+ .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .is_err());
+ }
+
+ #[test]
+ fn test_master_get_config_negative6() {
+ let (mut master, mut peer) = create_pair2();
+ let buf = vec![0x0; MAX_MSG_SIZE + 1];
+
+ let hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16);
+ let mut msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty());
+ peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
+ .unwrap();
+ assert!(master
+ .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .is_ok());
+
+ msg.size = 6;
+ peer.send_message_with_payload(&hdr, &msg, &buf[0..6], None)
+ .unwrap();
+ assert!(master
+ .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .is_err());
+ }
+
+ #[test]
+ fn test_maset_set_mem_table_failure() {
+ let (master, _peer) = create_pair2();
+
+ master.set_mem_table(&[]).unwrap_err();
+ let tables = vec![VhostUserMemoryRegionInfo::default(); MAX_ATTACHED_FD_ENTRIES + 1];
+ master.set_mem_table(&tables).unwrap_err();
}
}
diff --git a/src/vhost_user/master_req_handler.rs b/src/vhost_user/master_req_handler.rs
index aadfeee..8cba188 100644
--- a/src/vhost_user/master_req_handler.rs
+++ b/src/vhost_user/master_req_handler.rs
@@ -1,9 +1,6 @@
-// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved.
+// Copyright (C) 2019-2021 Alibaba Cloud. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
-//! Traits and Structs to handle vhost-user requests from the slave to the master.
-
-use libc;
use std::mem;
use std::os::unix::io::{AsRawFd, RawFd};
use std::os::unix::net::UnixStream;
@@ -13,83 +10,189 @@ use super::connection::Endpoint;
use super::message::*;
use super::{Error, HandlerResult, Result};
-/// Trait to handle vhost-user requests from the slave to the master.
+/// Define services provided by masters for the slave communication channel.
+///
+/// The vhost-user specification defines a slave communication channel, by which slaves could
+/// request services from masters. The [VhostUserMasterReqHandler] trait defines services provided
+/// by masters, and it's used both on the master side and slave side.
+/// - on the slave side, a stub forwarder implementing [VhostUserMasterReqHandler] will proxy
+/// service requests to masters. The [SlaveFsCacheReq] is an example stub forwarder.
+/// - on the master side, the [MasterReqHandler] will forward service requests to a handler
+/// implementing [VhostUserMasterReqHandler].
+///
+/// The [VhostUserMasterReqHandler] trait is design with interior mutability to improve performance
+/// for multi-threading.
+///
+/// [VhostUserMasterReqHandler]: trait.VhostUserMasterReqHandler.html
+/// [MasterReqHandler]: struct.MasterReqHandler.html
+/// [SlaveFsCacheReq]: struct.SlaveFsCacheReq.html
pub trait VhostUserMasterReqHandler {
+ /// Handle device configuration change notifications.
+ fn handle_config_change(&self) -> HandlerResult<u64> {
+ Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
+ }
+
+ /// Handle virtio-fs map file requests.
+ fn fs_slave_map(&self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> {
+ // Safe because we have just received the rawfd from kernel.
+ unsafe { libc::close(fd) };
+ Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
+ }
+
+ /// Handle virtio-fs unmap file requests.
+ fn fs_slave_unmap(&self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
+ Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
+ }
+
+ /// Handle virtio-fs sync file requests.
+ fn fs_slave_sync(&self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
+ Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
+ }
+
+ /// Handle virtio-fs file IO requests.
+ fn fs_slave_io(&self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> {
+ // Safe because we have just received the rawfd from kernel.
+ unsafe { libc::close(fd) };
+ Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
+ }
+
// fn handle_iotlb_msg(&mut self, iotlb: VhostUserIotlb);
// fn handle_vring_host_notifier(&mut self, area: VhostUserVringArea, fd: RawFd);
+}
- /// Handle device configuration change notifications from the slave.
+/// A helper trait mirroring [VhostUserMasterReqHandler] but without interior mutability.
+///
+/// [VhostUserMasterReqHandler]: trait.VhostUserMasterReqHandler.html
+pub trait VhostUserMasterReqHandlerMut {
+ /// Handle device configuration change notifications.
fn handle_config_change(&mut self) -> HandlerResult<u64> {
Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
}
- /// Handle virtio-fs map file requests from the slave.
+ /// Handle virtio-fs map file requests.
fn fs_slave_map(&mut self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> {
// Safe because we have just received the rawfd from kernel.
unsafe { libc::close(fd) };
Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
}
- /// Handle virtio-fs unmap file requests from the slave.
+ /// Handle virtio-fs unmap file requests.
fn fs_slave_unmap(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
}
- /// Handle virtio-fs sync file requests from the slave.
+ /// Handle virtio-fs sync file requests.
fn fs_slave_sync(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
}
- /// Handle virtio-fs file IO requests from the slave.
+ /// Handle virtio-fs file IO requests.
fn fs_slave_io(&mut self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> {
// Safe because we have just received the rawfd from kernel.
unsafe { libc::close(fd) };
Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
}
+
+ // fn handle_iotlb_msg(&mut self, iotlb: VhostUserIotlb);
+ // fn handle_vring_host_notifier(&mut self, area: VhostUserVringArea, fd: RawFd);
+}
+
+impl<S: VhostUserMasterReqHandlerMut> VhostUserMasterReqHandler for Mutex<S> {
+ fn handle_config_change(&self) -> HandlerResult<u64> {
+ self.lock().unwrap().handle_config_change()
+ }
+
+ fn fs_slave_map(&self, fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> {
+ self.lock().unwrap().fs_slave_map(fs, fd)
+ }
+
+ fn fs_slave_unmap(&self, fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
+ self.lock().unwrap().fs_slave_unmap(fs)
+ }
+
+ fn fs_slave_sync(&self, fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
+ self.lock().unwrap().fs_slave_sync(fs)
+ }
+
+ fn fs_slave_io(&self, fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> {
+ self.lock().unwrap().fs_slave_io(fs, fd)
+ }
}
-/// A vhost-user master request endpoint which relays all received requests from the slave to the
-/// provided request handler.
+/// Server to handle service requests from slaves from the slave communication channel.
+///
+/// The [MasterReqHandler] acts as a server on the master side, to handle service requests from
+/// slaves on the slave communication channel. It's actually a proxy invoking the registered
+/// handler implementing [VhostUserMasterReqHandler] to do the real work.
+///
+/// [MasterReqHandler]: struct.MasterReqHandler.html
+/// [VhostUserMasterReqHandler]: trait.VhostUserMasterReqHandler.html
pub struct MasterReqHandler<S: VhostUserMasterReqHandler> {
// underlying Unix domain socket for communication
sub_sock: Endpoint<SlaveReq>,
tx_sock: UnixStream,
+ // Protocol feature VHOST_USER_PROTOCOL_F_REPLY_ACK has been negotiated.
+ reply_ack_negotiated: bool,
// the VirtIO backend device object
- backend: Arc<Mutex<S>>,
+ backend: Arc<S>,
// whether the endpoint has encountered any failure
error: Option<i32>,
}
impl<S: VhostUserMasterReqHandler> MasterReqHandler<S> {
- /// Create a vhost-user slave request handler.
- /// This opens a pair of connected anonymous sockets.
- /// Returns Self and the socket that must be sent to the slave via SET_SLAVE_REQ_FD.
- pub fn new(backend: Arc<Mutex<S>>) -> Result<Self> {
+ /// Create a server to handle service requests from slaves on the slave communication channel.
+ ///
+ /// This opens a pair of connected anonymous sockets to form the slave communication channel.
+ /// The socket fd returned by [Self::get_tx_raw_fd()] should be sent to the slave by
+ /// [VhostUserMaster::set_slave_request_fd()].
+ ///
+ /// [Self::get_tx_raw_fd()]: struct.MasterReqHandler.html#method.get_tx_raw_fd
+ /// [VhostUserMaster::set_slave_request_fd()]: trait.VhostUserMaster.html#tymethod.set_slave_request_fd
+ pub fn new(backend: Arc<S>) -> Result<Self> {
let (tx, rx) = UnixStream::pair().map_err(Error::SocketError)?;
Ok(MasterReqHandler {
sub_sock: Endpoint::<SlaveReq>::from_stream(rx),
tx_sock: tx,
+ reply_ack_negotiated: false,
backend,
error: None,
})
}
- /// Get the raw fd to send to the slave as slave communication channel.
+ /// Get the socket fd for the slave to communication with the master.
+ ///
+ /// The returned fd should be sent to the slave by [VhostUserMaster::set_slave_request_fd()].
+ ///
+ /// [VhostUserMaster::set_slave_request_fd()]: trait.VhostUserMaster.html#tymethod.set_slave_request_fd
pub fn get_tx_raw_fd(&self) -> RawFd {
self.tx_sock.as_raw_fd()
}
- /// Mark endpoint as failed or normal state.
+ /// Set the negotiation state of the `VHOST_USER_PROTOCOL_F_REPLY_ACK` protocol feature.
+ ///
+ /// When the `VHOST_USER_PROTOCOL_F_REPLY_ACK` protocol feature has been negotiated,
+ /// the "REPLY_ACK" flag will be set in the message header for every slave to master request
+ /// message.
+ pub fn set_reply_ack_flag(&mut self, enable: bool) {
+ self.reply_ack_negotiated = enable;
+ }
+
+ /// Mark endpoint as failed or in normal state.
pub fn set_failed(&mut self, error: i32) {
- self.error = Some(error);
+ if error == 0 {
+ self.error = None;
+ } else {
+ self.error = Some(error);
+ }
}
- /// Receive and handle one incoming request message from the slave.
+ /// Main entrance to server slave request from the slave communication channel.
+ ///
/// The caller needs to:
- /// . serialize calls to this function
- /// . decide what to do when errer happens
- /// . optional recover from failure
+ /// - serialize calls to this function
+ /// - decide what to do when errer happens
+ /// - optional recover from failure
pub fn handle_request(&mut self) -> Result<u64> {
// Return error if the endpoint is already in failed state.
self.check_state()?;
@@ -108,6 +211,9 @@ impl<S: VhostUserMasterReqHandler> MasterReqHandler<S> {
let (size, buf) = match hdr.get_size() {
0 => (0, vec![0u8; 0]),
len => {
+ if len as usize > MAX_MSG_SIZE {
+ return Err(Error::InvalidMessage);
+ }
let (size2, rbuf) = self.sub_sock.recv_data(len as usize)?;
if size2 != len as usize {
return Err(Error::InvalidMessage);
@@ -120,41 +226,33 @@ impl<S: VhostUserMasterReqHandler> MasterReqHandler<S> {
SlaveReq::CONFIG_CHANGE_MSG => {
self.check_msg_size(&hdr, size, 0)?;
self.backend
- .lock()
- .unwrap()
.handle_config_change()
.map_err(Error::ReqHandlerError)
}
SlaveReq::FS_MAP => {
let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?;
+ // check_attached_rfds() has validated rfds
self.backend
- .lock()
- .unwrap()
- .fs_slave_map(msg, rfds.unwrap()[0])
+ .fs_slave_map(&msg, rfds.unwrap()[0])
.map_err(Error::ReqHandlerError)
}
SlaveReq::FS_UNMAP => {
let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?;
self.backend
- .lock()
- .unwrap()
- .fs_slave_unmap(msg)
+ .fs_slave_unmap(&msg)
.map_err(Error::ReqHandlerError)
}
SlaveReq::FS_SYNC => {
let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?;
self.backend
- .lock()
- .unwrap()
- .fs_slave_sync(msg)
+ .fs_slave_sync(&msg)
.map_err(Error::ReqHandlerError)
}
SlaveReq::FS_IO => {
let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?;
+ // check_attached_rfds() has validated rfds
self.backend
- .lock()
- .unwrap()
- .fs_slave_io(msg, rfds.unwrap()[0])
+ .fs_slave_io(&msg, rfds.unwrap()[0])
.map_err(Error::ReqHandlerError)
}
_ => Err(Error::InvalidMessage),
@@ -211,7 +309,7 @@ impl<S: VhostUserMasterReqHandler> MasterReqHandler<S> {
_ => {
if rfds.is_some() {
Endpoint::<SlaveReq>::close_rfds(rfds);
- return Err(Error::InvalidMessage);
+ Err(Error::InvalidMessage)
} else {
Ok(rfds)
}
@@ -219,14 +317,14 @@ impl<S: VhostUserMasterReqHandler> MasterReqHandler<S> {
}
}
- fn extract_msg_body<'a, T: Sized + VhostUserMsgValidator>(
+ fn extract_msg_body<T: Sized + VhostUserMsgValidator>(
&self,
hdr: &VhostUserMsgHeader<SlaveReq>,
size: usize,
- buf: &'a [u8],
- ) -> Result<&'a T> {
+ buf: &[u8],
+ ) -> Result<T> {
self.check_msg_size(hdr, size, mem::size_of::<T>())?;
- let msg = unsafe { &*(buf.as_ptr() as *const T) };
+ let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const T) };
if !msg.is_valid() {
return Err(Error::InvalidMessage);
}
@@ -253,7 +351,7 @@ impl<S: VhostUserMasterReqHandler> MasterReqHandler<S> {
req: &VhostUserMsgHeader<SlaveReq>,
res: &Result<u64>,
) -> Result<()> {
- if req.is_need_reply() {
+ if self.reply_ack_negotiated && req.is_need_reply() {
let hdr = self.new_reply_header::<VhostUserU64>(req)?;
let def_err = libc::EINVAL;
let val = match res {
@@ -278,3 +376,102 @@ impl<S: VhostUserMasterReqHandler> AsRawFd for MasterReqHandler<S> {
self.sub_sock.as_raw_fd()
}
}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[cfg(feature = "vhost-user-slave")]
+ use crate::vhost_user::SlaveFsCacheReq;
+ #[cfg(feature = "vhost-user-slave")]
+ use std::os::unix::io::FromRawFd;
+
+ struct MockMasterReqHandler {}
+
+ impl VhostUserMasterReqHandlerMut for MockMasterReqHandler {
+ /// Handle virtio-fs map file requests from the slave.
+ fn fs_slave_map(&mut self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> {
+ // Safe because we have just received the rawfd from kernel.
+ unsafe { libc::close(fd) };
+ Ok(0)
+ }
+
+ /// Handle virtio-fs unmap file requests from the slave.
+ fn fs_slave_unmap(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
+ Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
+ }
+ }
+
+ #[test]
+ fn test_new_master_req_handler() {
+ let backend = Arc::new(Mutex::new(MockMasterReqHandler {}));
+ let mut handler = MasterReqHandler::new(backend).unwrap();
+
+ assert!(handler.get_tx_raw_fd() >= 0);
+ assert!(handler.as_raw_fd() >= 0);
+ handler.check_state().unwrap();
+
+ assert_eq!(handler.error, None);
+ handler.set_failed(libc::EAGAIN);
+ assert_eq!(handler.error, Some(libc::EAGAIN));
+ handler.check_state().unwrap_err();
+ }
+
+ #[cfg(feature = "vhost-user-slave")]
+ #[test]
+ fn test_master_slave_req_handler() {
+ let backend = Arc::new(Mutex::new(MockMasterReqHandler {}));
+ let mut handler = MasterReqHandler::new(backend).unwrap();
+
+ let fd = unsafe { libc::dup(handler.get_tx_raw_fd()) };
+ if fd < 0 {
+ panic!("failed to duplicated tx fd!");
+ }
+ let stream = unsafe { UnixStream::from_raw_fd(fd) };
+ let fs_cache = SlaveFsCacheReq::from_stream(stream);
+
+ std::thread::spawn(move || {
+ let res = handler.handle_request().unwrap();
+ assert_eq!(res, 0);
+ handler.handle_request().unwrap_err();
+ });
+
+ fs_cache
+ .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd)
+ .unwrap();
+ // When REPLY_ACK has not been negotiated, the master has no way to detect failure from
+ // slave side.
+ fs_cache
+ .fs_slave_unmap(&VhostUserFSSlaveMsg::default())
+ .unwrap();
+ }
+
+ #[cfg(feature = "vhost-user-slave")]
+ #[test]
+ fn test_master_slave_req_handler_with_ack() {
+ let backend = Arc::new(Mutex::new(MockMasterReqHandler {}));
+ let mut handler = MasterReqHandler::new(backend).unwrap();
+ handler.set_reply_ack_flag(true);
+
+ let fd = unsafe { libc::dup(handler.get_tx_raw_fd()) };
+ if fd < 0 {
+ panic!("failed to duplicated tx fd!");
+ }
+ let stream = unsafe { UnixStream::from_raw_fd(fd) };
+ let fs_cache = SlaveFsCacheReq::from_stream(stream);
+
+ std::thread::spawn(move || {
+ let res = handler.handle_request().unwrap();
+ assert_eq!(res, 0);
+ handler.handle_request().unwrap_err();
+ });
+
+ fs_cache.set_reply_ack_flag(true);
+ fs_cache
+ .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd)
+ .unwrap();
+ fs_cache
+ .fs_slave_unmap(&VhostUserFSSlaveMsg::default())
+ .unwrap_err();
+ }
+}
diff --git a/src/vhost_user/message.rs b/src/vhost_user/message.rs
index 4109b61..8600410 100644
--- a/src/vhost_user/message.rs
+++ b/src/vhost_user/message.rs
@@ -562,9 +562,9 @@ bitflags! {
/// Flags for the device configuration message.
pub struct VhostUserConfigFlags: u32 {
/// Vhost master messages used for writeable fields.
- const WRITABLE = 0x0;
+ const WRITABLE = 0x1;
/// Vhost master messages used for live migration.
- const LIVE_MIGRATION = 0x1;
+ const LIVE_MIGRATION = 0x2;
}
}
@@ -596,9 +596,11 @@ impl VhostUserMsgValidator for VhostUserConfig {
fn is_valid(&self) -> bool {
if (self.flags & !VhostUserConfigFlags::all().bits()) != 0 {
return false;
+ } else if self.offset < 0x100 {
+ return false;
} else if self.size == 0
|| self.size > VHOST_USER_CONFIG_SIZE
- || self.size + self.offset >= VHOST_USER_CONFIG_SIZE
+ || self.size + self.offset > VHOST_USER_CONFIG_SIZE
{
return false;
}
@@ -656,9 +658,9 @@ pub const VHOST_USER_FS_SLAVE_ENTRIES: usize = 8;
#[repr(packed)]
#[derive(Default)]
pub struct VhostUserFSSlaveMsg {
- /// TODO:
+ /// File offset.
pub fd_offset: [u64; VHOST_USER_FS_SLAVE_ENTRIES],
- /// TODO:
+ /// Offset into the DAX window.
pub cache_offset: [u64; VHOST_USER_FS_SLAVE_ENTRIES],
/// Size of region to map.
pub len: [u64; VHOST_USER_FS_SLAVE_ENTRIES],
@@ -686,13 +688,31 @@ mod tests {
use std::mem;
#[test]
- fn check_request_code() {
+ fn check_master_request_code() {
let code = MasterReq::NOOP;
assert!(!code.is_valid());
let code = MasterReq::MAX_CMD;
assert!(!code.is_valid());
+ assert!(code > MasterReq::NOOP);
let code = MasterReq::GET_FEATURES;
assert!(code.is_valid());
+ assert_eq!(code, code.clone());
+ let code: MasterReq = unsafe { std::mem::transmute::<u32, MasterReq>(10000u32) };
+ assert!(!code.is_valid());
+ }
+
+ #[test]
+ fn check_slave_request_code() {
+ let code = SlaveReq::NOOP;
+ assert!(!code.is_valid());
+ let code = SlaveReq::MAX_CMD;
+ assert!(!code.is_valid());
+ assert!(code > SlaveReq::NOOP);
+ let code = SlaveReq::CONFIG_CHANGE_MSG;
+ assert!(code.is_valid());
+ assert_eq!(code, code.clone());
+ let code: SlaveReq = unsafe { std::mem::transmute::<u32, SlaveReq>(10000u32) };
+ assert!(!code.is_valid());
}
#[test]
@@ -741,6 +761,20 @@ mod tests {
assert!(!hdr.is_valid());
hdr.set_version(0x1);
assert!(hdr.is_valid());
+
+ assert_eq!(hdr, hdr.clone());
+ }
+
+ #[test]
+ fn test_vhost_user_message_u64() {
+ let val = VhostUserU64::default();
+ let val1 = VhostUserU64::new(0);
+
+ let a = val.value;
+ let b = val1.value;
+ assert_eq!(a, b);
+ let a = VhostUserU64::new(1).value;
+ assert_eq!(a, 1);
}
#[test]
@@ -775,6 +809,104 @@ mod tests {
msg.guest_phys_addr = 0xFFFFFFFFFFFF0000;
msg.memory_size = 0;
assert!(!msg.is_valid());
+ let a = msg.guest_phys_addr;
+ let b = msg.guest_phys_addr;
+ assert_eq!(a, b);
+
+ let msg = VhostUserMemoryRegion::default();
+ let a = msg.guest_phys_addr;
+ assert_eq!(a, 0);
+ let a = msg.memory_size;
+ assert_eq!(a, 0);
+ let a = msg.user_addr;
+ assert_eq!(a, 0);
+ let a = msg.mmap_offset;
+ assert_eq!(a, 0);
+ }
+
+ #[test]
+ fn test_vhost_user_state() {
+ let state = VhostUserVringState::new(5, 8);
+
+ let a = state.index;
+ assert_eq!(a, 5);
+ let a = state.num;
+ assert_eq!(a, 8);
+ assert_eq!(state.is_valid(), true);
+
+ let state = VhostUserVringState::default();
+ let a = state.index;
+ assert_eq!(a, 0);
+ let a = state.num;
+ assert_eq!(a, 0);
+ assert_eq!(state.is_valid(), true);
+ }
+
+ #[test]
+ fn test_vhost_user_addr() {
+ let mut addr = VhostUserVringAddr::new(
+ 2,
+ VhostUserVringAddrFlags::VHOST_VRING_F_LOG,
+ 0x1000,
+ 0x2000,
+ 0x3000,
+ 0x4000,
+ );
+
+ let a = addr.index;
+ assert_eq!(a, 2);
+ let a = addr.flags;
+ assert_eq!(a, VhostUserVringAddrFlags::VHOST_VRING_F_LOG.bits());
+ let a = addr.descriptor;
+ assert_eq!(a, 0x1000);
+ let a = addr.used;
+ assert_eq!(a, 0x2000);
+ let a = addr.available;
+ assert_eq!(a, 0x3000);
+ let a = addr.log;
+ assert_eq!(a, 0x4000);
+ assert_eq!(addr.is_valid(), true);
+
+ addr.descriptor = 0x1001;
+ assert_eq!(addr.is_valid(), false);
+ addr.descriptor = 0x1000;
+
+ addr.available = 0x3001;
+ assert_eq!(addr.is_valid(), false);
+ addr.available = 0x3000;
+
+ addr.used = 0x2001;
+ assert_eq!(addr.is_valid(), false);
+ addr.used = 0x2000;
+ assert_eq!(addr.is_valid(), true);
+ }
+
+ #[test]
+ fn test_vhost_user_state_from_config() {
+ let config = VringConfigData {
+ queue_max_size: 256,
+ queue_size: 128,
+ flags: VhostUserVringAddrFlags::VHOST_VRING_F_LOG.bits,
+ desc_table_addr: 0x1000,
+ used_ring_addr: 0x2000,
+ avail_ring_addr: 0x3000,
+ log_addr: Some(0x4000),
+ };
+ let addr = VhostUserVringAddr::from_config_data(2, &config);
+
+ let a = addr.index;
+ assert_eq!(a, 2);
+ let a = addr.flags;
+ assert_eq!(a, VhostUserVringAddrFlags::VHOST_VRING_F_LOG.bits());
+ let a = addr.descriptor;
+ assert_eq!(a, 0x1000);
+ let a = addr.used;
+ assert_eq!(a, 0x2000);
+ let a = addr.available;
+ assert_eq!(a, 0x3000);
+ let a = addr.log;
+ assert_eq!(a, 0x4000);
+ assert_eq!(addr.is_valid(), true);
}
#[test]
@@ -801,7 +933,6 @@ mod tests {
}
#[test]
- #[ignore]
fn check_user_config_msg() {
let mut msg = VhostUserConfig::new(
VHOST_USER_CONFIG_OFFSET,
@@ -828,4 +959,21 @@ mod tests {
msg.flags |= 0x4;
assert!(!msg.is_valid());
}
+
+ #[test]
+ fn test_vhost_user_fs_slave() {
+ let mut fs_slave = VhostUserFSSlaveMsg::default();
+
+ assert_eq!(fs_slave.is_valid(), true);
+
+ fs_slave.fd_offset[0] = 0xffff_ffff_ffff_ffff;
+ fs_slave.len[0] = 0x1;
+ assert_eq!(fs_slave.is_valid(), false);
+
+ assert_ne!(
+ VhostUserFSSlaveMsgFlags::MAP_R,
+ VhostUserFSSlaveMsgFlags::MAP_W
+ );
+ assert_eq!(VhostUserFSSlaveMsgFlags::EMPTY.bits(), 0);
+ }
}
diff --git a/src/vhost_user/mod.rs b/src/vhost_user/mod.rs
index 48a93ff..6a5b6a1 100644
--- a/src/vhost_user/mod.rs
+++ b/src/vhost_user/mod.rs
@@ -18,20 +18,23 @@
//! Most messages that can be sent via the Unix domain socket implementing vhost-user have an
//! equivalent ioctl to the kernel implementation.
-use libc;
use std::io::Error as IOError;
-mod connection;
pub mod message;
+
+mod connection;
pub use self::connection::Listener;
+
#[cfg(feature = "vhost-user-master")]
mod master;
#[cfg(feature = "vhost-user-master")]
pub use self::master::{Master, VhostUserMaster};
-#[cfg(any(feature = "vhost-user-master", feature = "vhost-user-slave"))]
+#[cfg(feature = "vhost-user")]
mod master_req_handler;
-#[cfg(any(feature = "vhost-user-master", feature = "vhost-user-slave"))]
-pub use self::master_req_handler::{MasterReqHandler, VhostUserMasterReqHandler};
+#[cfg(feature = "vhost-user")]
+pub use self::master_req_handler::{
+ MasterReqHandler, VhostUserMasterReqHandler, VhostUserMasterReqHandlerMut,
+};
#[cfg(feature = "vhost-user-slave")]
mod slave;
@@ -40,14 +43,14 @@ pub use self::slave::SlaveListener;
#[cfg(feature = "vhost-user-slave")]
mod slave_req_handler;
#[cfg(feature = "vhost-user-slave")]
-pub use self::slave_req_handler::{SlaveReqHandler, VhostUserSlaveReqHandler};
+pub use self::slave_req_handler::{
+ SlaveReqHandler, VhostUserSlaveReqHandler, VhostUserSlaveReqHandlerMut,
+};
#[cfg(feature = "vhost-user-slave")]
mod slave_fs_cache;
#[cfg(feature = "vhost-user-slave")]
pub use self::slave_fs_cache::SlaveFsCacheReq;
-pub mod sock_ctrl_msg;
-
/// Errors for vhost-user operations
#[derive(Debug)]
pub enum Error {
@@ -102,6 +105,8 @@ impl std::fmt::Display for Error {
}
}
+impl std::error::Error for Error {}
+
impl Error {
/// Determine whether to rebuild the underline communication channel.
pub fn should_reconnect(&self) -> bool {
@@ -170,21 +175,32 @@ pub type Result<T> = std::result::Result<T, Error>;
/// Result of request handler.
pub type HandlerResult<T> = std::result::Result<T, IOError>;
-#[cfg(all(test, feature = "vhost-user-master", feature = "vhost-user-slave"))]
+#[cfg(all(test, feature = "vhost-user-slave"))]
mod dummy_slave;
#[cfg(all(test, feature = "vhost-user-master", feature = "vhost-user-slave"))]
mod tests {
+ use std::os::unix::io::AsRawFd;
+ use std::sync::{Arc, Barrier, Mutex};
+ use std::thread;
+ use vmm_sys_util::rand::rand_alphanumerics;
+
use super::dummy_slave::{DummySlaveReqHandler, VIRTIO_FEATURES};
use super::message::*;
use super::*;
use crate::backend::VhostBackend;
- use std::sync::{Arc, Barrier, Mutex};
- use std::thread;
+ use crate::{VhostUserMemoryRegionInfo, VringConfigData};
+
+ fn temp_path() -> String {
+ format!(
+ "/tmp/vhost_test_{}",
+ rand_alphanumerics(8).to_str().unwrap()
+ )
+ }
fn create_slave<S: VhostUserSlaveReqHandler>(
path: &str,
- backend: Arc<Mutex<S>>,
+ backend: Arc<S>,
) -> (Master, SlaveReqHandler<S>) {
let listener = Listener::new(path, true).unwrap();
let mut slave_listener = SlaveListener::new(listener, backend).unwrap();
@@ -194,7 +210,7 @@ mod tests {
#[test]
fn create_dummy_slave() {
- let mut slave = DummySlaveReqHandler::new();
+ let slave = Arc::new(Mutex::new(DummySlaveReqHandler::new()));
slave.set_owner().unwrap();
assert!(slave.set_owner().is_err());
@@ -203,8 +219,8 @@ mod tests {
#[test]
fn test_set_owner() {
let slave_be = Arc::new(Mutex::new(DummySlaveReqHandler::new()));
- let (mut master, mut slave) =
- create_slave("/tmp/vhost_user_lib_unit_test_owner", slave_be.clone());
+ let path = temp_path();
+ let (master, mut slave) = create_slave(&path, slave_be.clone());
assert_eq!(slave_be.lock().unwrap().owned, false);
master.set_owner().unwrap();
@@ -219,14 +235,60 @@ mod tests {
fn test_set_features() {
let mbar = Arc::new(Barrier::new(2));
let sbar = mbar.clone();
+ let path = temp_path();
+ let slave_be = Arc::new(Mutex::new(DummySlaveReqHandler::new()));
+ let (mut master, mut slave) = create_slave(&path, slave_be.clone());
+
+ thread::spawn(move || {
+ slave.handle_request().unwrap();
+ assert_eq!(slave_be.lock().unwrap().owned, true);
+
+ slave.handle_request().unwrap();
+ slave.handle_request().unwrap();
+ assert_eq!(
+ slave_be.lock().unwrap().acked_features,
+ VIRTIO_FEATURES & !0x1
+ );
+
+ slave.handle_request().unwrap();
+ slave.handle_request().unwrap();
+ assert_eq!(
+ slave_be.lock().unwrap().acked_protocol_features,
+ VhostUserProtocolFeatures::all().bits()
+ );
+
+ sbar.wait();
+ });
+
+ master.set_owner().unwrap();
+
+ // set virtio features
+ let features = master.get_features().unwrap();
+ assert_eq!(features, VIRTIO_FEATURES);
+ master.set_features(VIRTIO_FEATURES & !0x1).unwrap();
+
+ // set vhost protocol features
+ let features = master.get_protocol_features().unwrap();
+ assert_eq!(features.bits(), VhostUserProtocolFeatures::all().bits());
+ master.set_protocol_features(features).unwrap();
+
+ mbar.wait();
+ }
+
+ #[test]
+ fn test_master_slave_process() {
+ let mbar = Arc::new(Barrier::new(2));
+ let sbar = mbar.clone();
+ let path = temp_path();
let slave_be = Arc::new(Mutex::new(DummySlaveReqHandler::new()));
- let (mut master, mut slave) =
- create_slave("/tmp/vhost_user_lib_unit_test_feature", slave_be.clone());
+ let (mut master, mut slave) = create_slave(&path, slave_be.clone());
thread::spawn(move || {
+ // set_own()
slave.handle_request().unwrap();
assert_eq!(slave_be.lock().unwrap().owned, true);
+ // get/set_features()
slave.handle_request().unwrap();
slave.handle_request().unwrap();
assert_eq!(
@@ -241,6 +303,34 @@ mod tests {
VhostUserProtocolFeatures::all().bits()
);
+ // get_queue_num()
+ slave.handle_request().unwrap();
+
+ // set_mem_table()
+ slave.handle_request().unwrap();
+
+ // get/set_config()
+ slave.handle_request().unwrap();
+ slave.handle_request().unwrap();
+
+ // set_slave_request_fd
+ slave.handle_request().unwrap();
+
+ // set_vring_enable
+ slave.handle_request().unwrap();
+
+ // set_log_base,set_log_fd()
+ slave.handle_request().unwrap_err();
+ slave.handle_request().unwrap_err();
+
+ // set_vring_xxx
+ slave.handle_request().unwrap();
+ slave.handle_request().unwrap();
+ slave.handle_request().unwrap();
+ slave.handle_request().unwrap();
+ slave.handle_request().unwrap();
+ slave.handle_request().unwrap();
+
sbar.wait();
});
@@ -256,6 +346,82 @@ mod tests {
assert_eq!(features.bits(), VhostUserProtocolFeatures::all().bits());
master.set_protocol_features(features).unwrap();
+ let num = master.get_queue_num().unwrap();
+ assert_eq!(num, 2);
+
+ let eventfd = vmm_sys_util::eventfd::EventFd::new(0).unwrap();
+ let mem = [VhostUserMemoryRegionInfo {
+ guest_phys_addr: 0,
+ memory_size: 0x10_0000,
+ userspace_addr: 0,
+ mmap_offset: 0,
+ mmap_handle: eventfd.as_raw_fd(),
+ }];
+ master.set_mem_table(&mem).unwrap();
+
+ master
+ .set_config(0x100, VhostUserConfigFlags::WRITABLE, &[0xa5u8])
+ .unwrap();
+ let buf = [0x0u8; 4];
+ let (reply_body, reply_payload) = master
+ .get_config(0x100, 4, VhostUserConfigFlags::empty(), &buf)
+ .unwrap();
+ let offset = reply_body.offset;
+ assert_eq!(offset, 0x100);
+ assert_eq!(reply_payload[0], 0xa5);
+
+ master.set_slave_request_fd(eventfd.as_raw_fd()).unwrap();
+ master.set_vring_enable(0, true).unwrap();
+
+ // unimplemented yet
+ master.set_log_base(0, Some(eventfd.as_raw_fd())).unwrap();
+ master.set_log_fd(eventfd.as_raw_fd()).unwrap();
+
+ master.set_vring_num(0, 256).unwrap();
+ master.set_vring_base(0, 0).unwrap();
+ let config = VringConfigData {
+ queue_max_size: 256,
+ queue_size: 128,
+ flags: VhostUserVringAddrFlags::VHOST_VRING_F_LOG.bits(),
+ desc_table_addr: 0x1000,
+ used_ring_addr: 0x2000,
+ avail_ring_addr: 0x3000,
+ log_addr: Some(0x4000),
+ };
+ master.set_vring_addr(0, &config).unwrap();
+ master.set_vring_call(0, &eventfd).unwrap();
+ master.set_vring_kick(0, &eventfd).unwrap();
+ master.set_vring_err(0, &eventfd).unwrap();
+
mbar.wait();
}
+
+ #[test]
+ fn test_error_display() {
+ assert_eq!(format!("{}", Error::InvalidParam), "invalid parameters");
+ assert_eq!(format!("{}", Error::InvalidOperation), "invalid operation");
+ }
+
+ #[test]
+ fn test_should_reconnect() {
+ assert_eq!(Error::PartialMessage.should_reconnect(), true);
+ assert_eq!(Error::SlaveInternalError.should_reconnect(), true);
+ assert_eq!(Error::MasterInternalError.should_reconnect(), true);
+ assert_eq!(Error::InvalidParam.should_reconnect(), false);
+ assert_eq!(Error::InvalidOperation.should_reconnect(), false);
+ assert_eq!(Error::InvalidMessage.should_reconnect(), false);
+ assert_eq!(Error::IncorrectFds.should_reconnect(), false);
+ assert_eq!(Error::OversizedMsg.should_reconnect(), false);
+ assert_eq!(Error::FeatureMismatch.should_reconnect(), false);
+ }
+
+ #[test]
+ fn test_error_from_sys_util_error() {
+ let e: Error = vmm_sys_util::errno::Error::new(libc::EAGAIN).into();
+ if let Error::SocketRetry(e1) = e {
+ assert_eq!(e1.raw_os_error().unwrap(), libc::EAGAIN);
+ } else {
+ panic!("invalid error code conversion!");
+ }
+ }
}
diff --git a/src/vhost_user/slave.rs b/src/vhost_user/slave.rs
index 5ac99af..fb65c41 100644
--- a/src/vhost_user/slave.rs
+++ b/src/vhost_user/slave.rs
@@ -3,7 +3,7 @@
//! Traits and Structs for vhost-user slave.
-use std::sync::{Arc, Mutex};
+use std::sync::Arc;
use super::connection::{Endpoint, Listener};
use super::message::*;
@@ -12,14 +12,14 @@ use super::{Result, SlaveReqHandler, VhostUserSlaveReqHandler};
/// Vhost-user slave side connection listener.
pub struct SlaveListener<S: VhostUserSlaveReqHandler> {
listener: Listener,
- backend: Option<Arc<Mutex<S>>>,
+ backend: Option<Arc<S>>,
}
/// Sets up a listener for incoming master connections, and handles construction
/// of a Slave on success.
impl<S: VhostUserSlaveReqHandler> SlaveListener<S> {
/// Create a unix domain socket for incoming master connections.
- pub fn new(listener: Listener, backend: Arc<Mutex<S>>) -> Result<Self> {
+ pub fn new(listener: Listener, backend: Arc<S>) -> Result<Self> {
Ok(SlaveListener {
listener,
backend: Some(backend),
@@ -44,3 +44,43 @@ impl<S: VhostUserSlaveReqHandler> SlaveListener<S> {
self.listener.set_nonblocking(block)
}
}
+
+#[cfg(test)]
+mod tests {
+ use std::sync::Mutex;
+
+ use super::*;
+ use crate::vhost_user::dummy_slave::DummySlaveReqHandler;
+
+ #[test]
+ fn test_slave_listener_set_nonblocking() {
+ let backend = Arc::new(Mutex::new(DummySlaveReqHandler::new()));
+ let listener =
+ Listener::new("/tmp/vhost_user_lib_unit_test_slave_nonblocking", true).unwrap();
+ let slave_listener = SlaveListener::new(listener, backend).unwrap();
+
+ slave_listener.set_nonblocking(true).unwrap();
+ slave_listener.set_nonblocking(false).unwrap();
+ slave_listener.set_nonblocking(false).unwrap();
+ slave_listener.set_nonblocking(true).unwrap();
+ slave_listener.set_nonblocking(true).unwrap();
+ }
+
+ #[cfg(feature = "vhost-user-master")]
+ #[test]
+ fn test_slave_listener_accept() {
+ use super::super::Master;
+
+ let path = "/tmp/vhost_user_lib_unit_test_slave_accept";
+ let backend = Arc::new(Mutex::new(DummySlaveReqHandler::new()));
+ let listener = Listener::new(path, true).unwrap();
+ let mut slave_listener = SlaveListener::new(listener, backend).unwrap();
+
+ slave_listener.set_nonblocking(true).unwrap();
+ assert!(slave_listener.accept().unwrap().is_none());
+ assert!(slave_listener.accept().unwrap().is_none());
+
+ let _master = Master::connect(path, 1).unwrap();
+ let _slave = slave_listener.accept().unwrap().unwrap();
+ }
+}
diff --git a/src/vhost_user/slave_fs_cache.rs b/src/vhost_user/slave_fs_cache.rs
index 1804c7a..a9c4ed2 100644
--- a/src/vhost_user/slave_fs_cache.rs
+++ b/src/vhost_user/slave_fs_cache.rs
@@ -1,61 +1,59 @@
-// Copyright (C) 2020 Alibaba Cloud Computing. All rights reserved.
+// Copyright (C) 2020 Alibaba Cloud. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
-use super::connection::Endpoint;
-use super::message::*;
-use super::{Error, HandlerResult, Result, VhostUserMasterReqHandler};
use std::io;
use std::mem;
use std::os::unix::io::RawFd;
use std::os::unix::net::UnixStream;
-use std::sync::{Arc, Mutex};
+use std::sync::{Arc, Mutex, MutexGuard};
+
+use super::connection::Endpoint;
+use super::message::*;
+use super::{Error, HandlerResult, Result, VhostUserMasterReqHandler};
struct SlaveFsCacheReqInternal {
sock: Endpoint<SlaveReq>,
-}
-/// A vhost-user slave endpoint which sends fs cache requests to the master
-#[derive(Clone)]
-pub struct SlaveFsCacheReq {
- // underlying Unix domain socket for communication
- node: Arc<Mutex<SlaveFsCacheReqInternal>>,
+ // Protocol feature VHOST_USER_PROTOCOL_F_REPLY_ACK has been negotiated.
+ reply_ack_negotiated: bool,
// whether the endpoint has encountered any failure
error: Option<i32>,
}
-impl SlaveFsCacheReq {
- fn new(ep: Endpoint<SlaveReq>) -> Self {
- SlaveFsCacheReq {
- node: Arc::new(Mutex::new(SlaveFsCacheReqInternal { sock: ep })),
- error: None,
+impl SlaveFsCacheReqInternal {
+ fn check_state(&self) -> Result<u64> {
+ match self.error {
+ Some(e) => Err(Error::SocketBroken(std::io::Error::from_raw_os_error(e))),
+ None => Ok(0),
}
}
- /// Create a new instance.
- pub fn from_stream(sock: UnixStream) -> Self {
- Self::new(Endpoint::<SlaveReq>::from_stream(sock))
- }
-
fn send_message(
&mut self,
- flags: SlaveReq,
+ request: SlaveReq,
fs: &VhostUserFSSlaveMsg,
fds: Option<&[RawFd]>,
) -> Result<u64> {
self.check_state()?;
let len = mem::size_of::<VhostUserFSSlaveMsg>();
- let mut hdr = VhostUserMsgHeader::new(flags, 0, len as u32);
- hdr.set_need_reply(true);
- self.node.lock().unwrap().sock.send_message(&hdr, fs, fds)?;
+ let mut hdr = VhostUserMsgHeader::new(request, 0, len as u32);
+ if self.reply_ack_negotiated {
+ hdr.set_need_reply(true);
+ }
+ self.sock.send_message(&hdr, fs, fds)?;
self.wait_for_ack(&hdr)
}
fn wait_for_ack(&mut self, hdr: &VhostUserMsgHeader<SlaveReq>) -> Result<u64> {
self.check_state()?;
- let (reply, body, rfds) = self.node.lock().unwrap().sock.recv_body::<VhostUserU64>()?;
+ if !self.reply_ack_negotiated {
+ return Ok(0);
+ }
+
+ let (reply, body, rfds) = self.sock.recv_body::<VhostUserU64>()?;
if !reply.is_reply_for(&hdr) || rfds.is_some() || !body.is_valid() {
Endpoint::<SlaveReq>::close_rfds(rfds);
return Err(Error::InvalidMessage);
@@ -63,32 +61,166 @@ impl SlaveFsCacheReq {
if body.value != 0 {
return Err(Error::MasterInternalError);
}
- Ok(0)
+
+ Ok(body.value)
}
+}
- fn check_state(&self) -> Result<u64> {
- match self.error {
- Some(e) => Err(Error::SocketBroken(std::io::Error::from_raw_os_error(e))),
- None => Ok(0),
+/// Request proxy to send vhost-user-fs slave requests to the master through the slave
+/// communication channel.
+///
+/// The [SlaveFsCacheReq] acts as a message proxy to forward vhost-user-fs slave requests to the
+/// master through the vhost-user slave communication channel. The forwarded messages will be
+/// handled by the [MasterReqHandler] server.
+///
+/// [SlaveFsCacheReq]: struct.SlaveFsCacheReq.html
+/// [MasterReqHandler]: struct.MasterReqHandler.html
+#[derive(Clone)]
+pub struct SlaveFsCacheReq {
+ // underlying Unix domain socket for communication
+ node: Arc<Mutex<SlaveFsCacheReqInternal>>,
+}
+
+impl SlaveFsCacheReq {
+ fn new(ep: Endpoint<SlaveReq>) -> Self {
+ SlaveFsCacheReq {
+ node: Arc::new(Mutex::new(SlaveFsCacheReqInternal {
+ sock: ep,
+ reply_ack_negotiated: false,
+ error: None,
+ })),
}
}
+ fn node(&self) -> MutexGuard<SlaveFsCacheReqInternal> {
+ self.node.lock().unwrap()
+ }
+
+ fn send_message(
+ &self,
+ request: SlaveReq,
+ fs: &VhostUserFSSlaveMsg,
+ fds: Option<&[RawFd]>,
+ ) -> io::Result<u64> {
+ self.node()
+ .send_message(request, fs, fds)
+ .map_err(|e| io::Error::new(io::ErrorKind::Other, format!("{}", e)))
+ }
+
+ /// Create a new instance from a `UnixStream` object.
+ pub fn from_stream(sock: UnixStream) -> Self {
+ Self::new(Endpoint::<SlaveReq>::from_stream(sock))
+ }
+
+ /// Set the negotiation state of the `VHOST_USER_PROTOCOL_F_REPLY_ACK` protocol feature.
+ ///
+ /// When the `VHOST_USER_PROTOCOL_F_REPLY_ACK` protocol feature has been negotiated,
+ /// the "REPLY_ACK" flag will be set in the message header for every slave to master request
+ /// message.
+ pub fn set_reply_ack_flag(&self, enable: bool) {
+ self.node().reply_ack_negotiated = enable;
+ }
+
/// Mark endpoint as failed with specified error code.
- pub fn set_failed(&mut self, error: i32) {
- self.error = Some(error);
+ pub fn set_failed(&self, error: i32) {
+ self.node().error = Some(error);
}
}
impl VhostUserMasterReqHandler for SlaveFsCacheReq {
- /// Handle virtio-fs map file requests from the slave.
- fn fs_slave_map(&mut self, fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> {
+ /// Forward vhost-user-fs map file requests to the slave.
+ fn fs_slave_map(&self, fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> {
self.send_message(SlaveReq::FS_MAP, fs, Some(&[fd]))
- .or_else(|e| Err(io::Error::new(io::ErrorKind::Other, format!("{}", e))))
}
- /// Handle virtio-fs unmap file requests from the slave.
- fn fs_slave_unmap(&mut self, fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
+ /// Forward vhost-user-fs unmap file requests to the master.
+ fn fs_slave_unmap(&self, fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
self.send_message(SlaveReq::FS_UNMAP, fs, None)
- .or_else(|e| Err(io::Error::new(io::ErrorKind::Other, format!("{}", e))))
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use std::os::unix::io::AsRawFd;
+
+ use super::*;
+
+ #[test]
+ fn test_slave_fs_cache_req_set_failed() {
+ let (p1, _p2) = UnixStream::pair().unwrap();
+ let fs_cache = SlaveFsCacheReq::from_stream(p1);
+
+ assert!(fs_cache.node().error.is_none());
+ fs_cache.set_failed(libc::EAGAIN);
+ assert_eq!(fs_cache.node().error, Some(libc::EAGAIN));
+ }
+
+ #[test]
+ fn test_slave_fs_cache_send_failure() {
+ let (p1, p2) = UnixStream::pair().unwrap();
+ let fd = p2.as_raw_fd();
+ let fs_cache = SlaveFsCacheReq::from_stream(p1);
+
+ fs_cache.set_failed(libc::ECONNRESET);
+ fs_cache
+ .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd)
+ .unwrap_err();
+ fs_cache
+ .fs_slave_unmap(&VhostUserFSSlaveMsg::default())
+ .unwrap_err();
+ fs_cache.node().error = None;
+
+ drop(p2);
+ fs_cache
+ .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd)
+ .unwrap_err();
+ fs_cache
+ .fs_slave_unmap(&VhostUserFSSlaveMsg::default())
+ .unwrap_err();
+ }
+
+ #[test]
+ fn test_slave_fs_cache_recv_negative() {
+ let (p1, p2) = UnixStream::pair().unwrap();
+ let fd = p2.as_raw_fd();
+ let fs_cache = SlaveFsCacheReq::from_stream(p1);
+ let mut master = Endpoint::<SlaveReq>::from_stream(p2);
+
+ let len = mem::size_of::<VhostUserFSSlaveMsg>();
+ let mut hdr = VhostUserMsgHeader::new(
+ SlaveReq::FS_MAP,
+ VhostUserHeaderFlag::REPLY.bits(),
+ len as u32,
+ );
+ let body = VhostUserU64::new(0);
+
+ master.send_message(&hdr, &body, Some(&[fd])).unwrap();
+ fs_cache
+ .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd)
+ .unwrap();
+
+ fs_cache.set_reply_ack_flag(true);
+ fs_cache
+ .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd)
+ .unwrap_err();
+
+ hdr.set_code(SlaveReq::FS_UNMAP);
+ master.send_message(&hdr, &body, None).unwrap();
+ fs_cache
+ .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd)
+ .unwrap_err();
+ hdr.set_code(SlaveReq::FS_MAP);
+
+ let body = VhostUserU64::new(1);
+ master.send_message(&hdr, &body, None).unwrap();
+ fs_cache
+ .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd)
+ .unwrap_err();
+
+ let body = VhostUserU64::new(0);
+ master.send_message(&hdr, &body, None).unwrap();
+ fs_cache
+ .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd)
+ .unwrap();
}
}
diff --git a/src/vhost_user/slave_req_handler.rs b/src/vhost_user/slave_req_handler.rs
index f3b0770..3b44e4c 100644
--- a/src/vhost_user/slave_req_handler.rs
+++ b/src/vhost_user/slave_req_handler.rs
@@ -1,8 +1,6 @@
// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
-//! Traits and Structs to handle vhost-user requests from the master to the slave.
-
use std::mem;
use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
use std::os::unix::net::UnixStream;
@@ -14,9 +12,63 @@ use super::message::*;
use super::slave_fs_cache::SlaveFsCacheReq;
use super::{Error, Result};
-/// Trait to handle vhost-user requests from the master to the slave.
+/// Services provided to the master by the slave with interior mutability.
+///
+/// The [VhostUserSlaveReqHandler] trait defines the services provided to the master by the slave.
+/// And the [VhostUserSlaveReqHandlerMut] trait is a helper mirroring [VhostUserSlaveReqHandler],
+/// but without interior mutability.
+/// The vhost-user specification defines a master communication channel, by which masters could
+/// request services from slaves. The [VhostUserSlaveReqHandler] trait defines services provided by
+/// slaves, and it's used both on the master side and slave side.
+///
+/// - on the master side, a stub forwarder implementing [VhostUserSlaveReqHandler] will proxy
+/// service requests to slaves.
+/// - on the slave side, the [SlaveReqHandler] will forward service requests to a handler
+/// implementing [VhostUserSlaveReqHandler].
+///
+/// The [VhostUserSlaveReqHandler] trait is design with interior mutability to improve performance
+/// for multi-threading.
+///
+/// [VhostUserSlaveReqHandler]: trait.VhostUserSlaveReqHandler.html
+/// [VhostUserSlaveReqHandlerMut]: trait.VhostUserSlaveReqHandlerMut.html
+/// [SlaveReqHandler]: struct.SlaveReqHandler.html
#[allow(missing_docs)]
pub trait VhostUserSlaveReqHandler {
+ fn set_owner(&self) -> Result<()>;
+ fn reset_owner(&self) -> Result<()>;
+ fn get_features(&self) -> Result<u64>;
+ fn set_features(&self, features: u64) -> Result<()>;
+ fn set_mem_table(&self, ctx: &[VhostUserMemoryRegion], fds: &[RawFd]) -> Result<()>;
+ fn set_vring_num(&self, index: u32, num: u32) -> Result<()>;
+ fn set_vring_addr(
+ &self,
+ index: u32,
+ flags: VhostUserVringAddrFlags,
+ descriptor: u64,
+ used: u64,
+ available: u64,
+ log: u64,
+ ) -> Result<()>;
+ fn set_vring_base(&self, index: u32, base: u32) -> Result<()>;
+ fn get_vring_base(&self, index: u32) -> Result<VhostUserVringState>;
+ fn set_vring_kick(&self, index: u8, fd: Option<RawFd>) -> Result<()>;
+ fn set_vring_call(&self, index: u8, fd: Option<RawFd>) -> Result<()>;
+ fn set_vring_err(&self, index: u8, fd: Option<RawFd>) -> Result<()>;
+
+ fn get_protocol_features(&self) -> Result<VhostUserProtocolFeatures>;
+ fn set_protocol_features(&self, features: u64) -> Result<()>;
+ fn get_queue_num(&self) -> Result<u64>;
+ fn set_vring_enable(&self, index: u32, enable: bool) -> Result<()>;
+ fn get_config(&self, offset: u32, size: u32, flags: VhostUserConfigFlags) -> Result<Vec<u8>>;
+ fn set_config(&self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>;
+ fn set_slave_req_fd(&self, _vu_req: SlaveFsCacheReq) {}
+}
+
+/// Services provided to the master by the slave without interior mutability.
+///
+/// This is a helper trait mirroring the [VhostUserSlaveReqHandler] trait.
+#[allow(missing_docs)]
+pub trait VhostUserSlaveReqHandlerMut {
fn set_owner(&mut self) -> Result<()>;
fn reset_owner(&mut self) -> Result<()>;
fn get_features(&mut self) -> Result<u64>;
@@ -52,16 +104,110 @@ pub trait VhostUserSlaveReqHandler {
fn set_slave_req_fd(&mut self, _vu_req: SlaveFsCacheReq) {}
}
-/// A vhost-user slave endpoint which relays all received requests from the
-/// master to the virtio backend device object.
+impl<T: VhostUserSlaveReqHandlerMut> VhostUserSlaveReqHandler for Mutex<T> {
+ fn set_owner(&self) -> Result<()> {
+ self.lock().unwrap().set_owner()
+ }
+
+ fn reset_owner(&self) -> Result<()> {
+ self.lock().unwrap().reset_owner()
+ }
+
+ fn get_features(&self) -> Result<u64> {
+ self.lock().unwrap().get_features()
+ }
+
+ fn set_features(&self, features: u64) -> Result<()> {
+ self.lock().unwrap().set_features(features)
+ }
+
+ fn set_mem_table(&self, ctx: &[VhostUserMemoryRegion], fds: &[RawFd]) -> Result<()> {
+ self.lock().unwrap().set_mem_table(ctx, fds)
+ }
+
+ fn set_vring_num(&self, index: u32, num: u32) -> Result<()> {
+ self.lock().unwrap().set_vring_num(index, num)
+ }
+
+ fn set_vring_addr(
+ &self,
+ index: u32,
+ flags: VhostUserVringAddrFlags,
+ descriptor: u64,
+ used: u64,
+ available: u64,
+ log: u64,
+ ) -> Result<()> {
+ self.lock()
+ .unwrap()
+ .set_vring_addr(index, flags, descriptor, used, available, log)
+ }
+
+ fn set_vring_base(&self, index: u32, base: u32) -> Result<()> {
+ self.lock().unwrap().set_vring_base(index, base)
+ }
+
+ fn get_vring_base(&self, index: u32) -> Result<VhostUserVringState> {
+ self.lock().unwrap().get_vring_base(index)
+ }
+
+ fn set_vring_kick(&self, index: u8, fd: Option<RawFd>) -> Result<()> {
+ self.lock().unwrap().set_vring_kick(index, fd)
+ }
+
+ fn set_vring_call(&self, index: u8, fd: Option<RawFd>) -> Result<()> {
+ self.lock().unwrap().set_vring_call(index, fd)
+ }
+
+ fn set_vring_err(&self, index: u8, fd: Option<RawFd>) -> Result<()> {
+ self.lock().unwrap().set_vring_err(index, fd)
+ }
+
+ fn get_protocol_features(&self) -> Result<VhostUserProtocolFeatures> {
+ self.lock().unwrap().get_protocol_features()
+ }
+
+ fn set_protocol_features(&self, features: u64) -> Result<()> {
+ self.lock().unwrap().set_protocol_features(features)
+ }
+
+ fn get_queue_num(&self) -> Result<u64> {
+ self.lock().unwrap().get_queue_num()
+ }
+
+ fn set_vring_enable(&self, index: u32, enable: bool) -> Result<()> {
+ self.lock().unwrap().set_vring_enable(index, enable)
+ }
+
+ fn get_config(&self, offset: u32, size: u32, flags: VhostUserConfigFlags) -> Result<Vec<u8>> {
+ self.lock().unwrap().get_config(offset, size, flags)
+ }
+
+ fn set_config(&self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()> {
+ self.lock().unwrap().set_config(offset, buf, flags)
+ }
+
+ fn set_slave_req_fd(&self, vu_req: SlaveFsCacheReq) {
+ self.lock().unwrap().set_slave_req_fd(vu_req)
+ }
+}
+
+/// Server to handle service requests from masters from the master communication channel.
+///
+/// The [SlaveReqHandler] acts as a server on the slave side, to handle service requests from
+/// masters on the master communication channel. It's actually a proxy invoking the registered
+/// handler implementing [VhostUserSlaveReqHandler] to do the real work.
///
/// The lifetime of the SlaveReqHandler object should be the same as the underline Unix Domain
/// Socket, so it gets simpler to recover from disconnect.
+///
+/// [VhostUserSlaveReqHandler]: trait.VhostUserSlaveReqHandler.html
+/// [SlaveReqHandler]: struct.SlaveReqHandler.html
pub struct SlaveReqHandler<S: VhostUserSlaveReqHandler> {
// underlying Unix domain socket for communication
main_sock: Endpoint<MasterReq>,
// the vhost-user backend device object
- backend: Arc<Mutex<S>>,
+ backend: Arc<S>,
virtio_features: u64,
acked_virtio_features: u64,
@@ -76,7 +222,7 @@ pub struct SlaveReqHandler<S: VhostUserSlaveReqHandler> {
impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
/// Create a vhost-user slave endpoint.
- pub(super) fn new(main_sock: Endpoint<MasterReq>, backend: Arc<Mutex<S>>) -> Self {
+ pub(super) fn new(main_sock: Endpoint<MasterReq>, backend: Arc<S>) -> Self {
SlaveReqHandler {
main_sock,
backend,
@@ -94,7 +240,7 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
/// # Arguments
/// * - `path` - path of Unix domain socket listener to connect to
/// * - `backend` - handler for requests from the master to the slave
- pub fn connect(path: &str, backend: Arc<Mutex<S>>) -> Result<Self> {
+ pub fn connect(path: &str, backend: Arc<S>) -> Result<Self> {
Ok(Self::new(Endpoint::<MasterReq>::connect(path)?, backend))
}
@@ -103,11 +249,12 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
self.error = Some(error);
}
- /// Receive and handle one incoming request message from the master.
- /// The caller needs to:
- /// . serialize calls to this function
- /// . decide what to do when error happens
- /// . optional recover from failure
+ /// Main entrance to server slave request from the slave communication channel.
+ ///
+ /// Receive and handle one incoming request message from the master. The caller needs to:
+ /// - serialize calls to this function
+ /// - decide what to do when error happens
+ /// - optional recover from failure
pub fn handle_request(&mut self) -> Result<()> {
// Return error if the endpoint is already in failed state.
self.check_state()?;
@@ -137,15 +284,15 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
match hdr.get_code() {
MasterReq::SET_OWNER => {
self.check_request_size(&hdr, size, 0)?;
- self.backend.lock().unwrap().set_owner()?;
+ self.backend.set_owner()?;
}
MasterReq::RESET_OWNER => {
self.check_request_size(&hdr, size, 0)?;
- self.backend.lock().unwrap().reset_owner()?;
+ self.backend.reset_owner()?;
}
MasterReq::GET_FEATURES => {
self.check_request_size(&hdr, size, 0)?;
- let features = self.backend.lock().unwrap().get_features()?;
+ let features = self.backend.get_features()?;
let msg = VhostUserU64::new(features);
self.send_reply_message(&hdr, &msg)?;
self.virtio_features = features;
@@ -153,7 +300,7 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
}
MasterReq::SET_FEATURES => {
let msg = self.extract_request_body::<VhostUserU64>(&hdr, size, &buf)?;
- self.backend.lock().unwrap().set_features(msg.value)?;
+ self.backend.set_features(msg.value)?;
self.acked_virtio_features = msg.value;
self.update_reply_ack_flag();
}
@@ -163,11 +310,7 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
}
MasterReq::SET_VRING_NUM => {
let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
- let res = self
- .backend
- .lock()
- .unwrap()
- .set_vring_num(msg.index, msg.num);
+ let res = self.backend.set_vring_num(msg.index, msg.num);
self.send_ack_message(&hdr, res)?;
}
MasterReq::SET_VRING_ADDR => {
@@ -176,7 +319,7 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
Some(val) => val,
None => return Err(Error::InvalidMessage),
};
- let res = self.backend.lock().unwrap().set_vring_addr(
+ let res = self.backend.set_vring_addr(
msg.index,
flags,
msg.descriptor,
@@ -188,39 +331,35 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
}
MasterReq::SET_VRING_BASE => {
let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
- let res = self
- .backend
- .lock()
- .unwrap()
- .set_vring_base(msg.index, msg.num);
+ let res = self.backend.set_vring_base(msg.index, msg.num);
self.send_ack_message(&hdr, res)?;
}
MasterReq::GET_VRING_BASE => {
let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
- let reply = self.backend.lock().unwrap().get_vring_base(msg.index)?;
+ let reply = self.backend.get_vring_base(msg.index)?;
self.send_reply_message(&hdr, &reply)?;
}
MasterReq::SET_VRING_CALL => {
self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?;
let (index, rfds) = self.handle_vring_fd_request(&buf, rfds)?;
- let res = self.backend.lock().unwrap().set_vring_call(index, rfds);
+ let res = self.backend.set_vring_call(index, rfds);
self.send_ack_message(&hdr, res)?;
}
MasterReq::SET_VRING_KICK => {
self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?;
let (index, rfds) = self.handle_vring_fd_request(&buf, rfds)?;
- let res = self.backend.lock().unwrap().set_vring_kick(index, rfds);
+ let res = self.backend.set_vring_kick(index, rfds);
self.send_ack_message(&hdr, res)?;
}
MasterReq::SET_VRING_ERR => {
self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?;
let (index, rfds) = self.handle_vring_fd_request(&buf, rfds)?;
- let res = self.backend.lock().unwrap().set_vring_err(index, rfds);
+ let res = self.backend.set_vring_err(index, rfds);
self.send_ack_message(&hdr, res)?;
}
MasterReq::GET_PROTOCOL_FEATURES => {
self.check_request_size(&hdr, size, 0)?;
- let features = self.backend.lock().unwrap().get_protocol_features()?;
+ let features = self.backend.get_protocol_features()?;
let msg = VhostUserU64::new(features.bits());
self.send_reply_message(&hdr, &msg)?;
self.protocol_features = features;
@@ -228,10 +367,7 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
}
MasterReq::SET_PROTOCOL_FEATURES => {
let msg = self.extract_request_body::<VhostUserU64>(&hdr, size, &buf)?;
- self.backend
- .lock()
- .unwrap()
- .set_protocol_features(msg.value)?;
+ self.backend.set_protocol_features(msg.value)?;
self.acked_protocol_features = msg.value;
self.update_reply_ack_flag();
}
@@ -240,7 +376,7 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
return Err(Error::InvalidOperation);
}
self.check_request_size(&hdr, size, 0)?;
- let num = self.backend.lock().unwrap().get_queue_num()?;
+ let num = self.backend.get_queue_num()?;
let msg = VhostUserU64::new(num);
self.send_reply_message(&hdr, &msg)?;
}
@@ -257,17 +393,14 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
_ => return Err(Error::InvalidParam),
};
- let res = self
- .backend
- .lock()
- .unwrap()
- .set_vring_enable(msg.index, enable);
+ let res = self.backend.set_vring_enable(msg.index, enable);
self.send_ack_message(&hdr, res)?;
}
MasterReq::GET_CONFIG => {
if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 {
return Err(Error::InvalidOperation);
}
+ self.check_request_size(&hdr, size, hdr.get_size() as usize)?;
self.get_config(&hdr, &buf)?;
}
MasterReq::SET_CONFIG => {
@@ -281,6 +414,7 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
if self.acked_protocol_features & VhostUserProtocolFeatures::SLAVE_REQ.bits() == 0 {
return Err(Error::InvalidOperation);
}
+ self.check_request_size(&hdr, size, hdr.get_size() as usize)?;
self.set_slave_req_fd(&hdr, rfds)?;
}
_ => {
@@ -341,15 +475,18 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
}
}
- self.backend.lock().unwrap().set_mem_table(&regions, &fds)
+ self.backend.set_mem_table(&regions, &fds)
}
fn get_config(&mut self, hdr: &VhostUserMsgHeader<MasterReq>, buf: &[u8]) -> Result<()> {
- let msg = unsafe { &*(buf.as_ptr() as *const VhostUserConfig) };
+ let payload_offset = mem::size_of::<VhostUserConfig>();
+ if buf.len() > MAX_MSG_SIZE || buf.len() < payload_offset {
+ return Err(Error::InvalidMessage);
+ }
+ let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const VhostUserConfig) };
if !msg.is_valid() {
return Err(Error::InvalidMessage);
}
- let payload_offset = mem::size_of::<VhostUserConfig>();
if buf.len() - payload_offset != msg.size as usize {
return Err(Error::InvalidMessage);
}
@@ -357,11 +494,7 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
Some(val) => val,
None => return Err(Error::InvalidMessage),
};
- let res = self
- .backend
- .lock()
- .unwrap()
- .get_config(msg.offset, msg.size, flags);
+ let res = self.backend.get_config(msg.offset, msg.size, flags);
// vhost-user slave's payload size MUST match master's request
// on success, uses zero length of payload to indicate an error
@@ -389,10 +522,10 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
size: usize,
buf: &[u8],
) -> Result<()> {
- if size < mem::size_of::<VhostUserConfig>() {
+ if size > MAX_MSG_SIZE || size < mem::size_of::<VhostUserConfig>() {
return Err(Error::InvalidMessage);
}
- let msg = unsafe { &*(buf.as_ptr() as *const VhostUserConfig) };
+ let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const VhostUserConfig) };
if !msg.is_valid() {
return Err(Error::InvalidMessage);
}
@@ -405,11 +538,7 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
None => return Err(Error::InvalidMessage),
}
- let res = self
- .backend
- .lock()
- .unwrap()
- .set_config(msg.offset, buf, flags);
+ let res = self.backend.set_config(msg.offset, buf, flags);
self.send_ack_message(&hdr, res)?;
Ok(())
}
@@ -423,7 +552,7 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
if fds.len() == 1 {
let sock = unsafe { UnixStream::from_raw_fd(fds[0]) };
let vu_req = SlaveFsCacheReq::from_stream(sock);
- self.backend.lock().unwrap().set_slave_req_fd(vu_req);
+ self.backend.set_slave_req_fd(vu_req);
self.send_ack_message(&hdr, Ok(()))
} else {
Err(Error::InvalidMessage)
@@ -438,7 +567,10 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
buf: &[u8],
rfds: Option<Vec<RawFd>>,
) -> Result<(u8, Option<RawFd>)> {
- let msg = unsafe { &*(buf.as_ptr() as *const VhostUserU64) };
+ if buf.len() > MAX_MSG_SIZE || buf.len() < mem::size_of::<VhostUserU64>() {
+ return Err(Error::InvalidMessage);
+ }
+ let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const VhostUserU64) };
if !msg.is_valid() {
return Err(Error::InvalidMessage);
}
@@ -447,10 +579,7 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
// invalid FD flag. This flag is set when there is no file descriptor
// in the ancillary data. This signals that polling will be used
// instead of waiting for the call.
- let nofd = match msg.value & 0x100u64 {
- 0x100u64 => true,
- _ => false,
- };
+ let nofd = (msg.value & 0x100u64) == 0x100u64;
let mut rfd = None;
match rfds {
@@ -519,14 +648,14 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
}
}
- fn extract_request_body<'a, T: Sized + VhostUserMsgValidator>(
+ fn extract_request_body<T: Sized + VhostUserMsgValidator>(
&self,
hdr: &VhostUserMsgHeader<MasterReq>,
size: usize,
- buf: &'a [u8],
- ) -> Result<&'a T> {
+ buf: &[u8],
+ ) -> Result<T> {
self.check_request_size(hdr, size, mem::size_of::<T>())?;
- let msg = unsafe { &*(buf.as_ptr() as *const T) };
+ let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const T) };
if !msg.is_valid() {
return Err(Error::InvalidMessage);
}
@@ -552,7 +681,10 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
req: &VhostUserMsgHeader<MasterReq>,
payload_size: usize,
) -> Result<VhostUserMsgHeader<MasterReq>> {
- if mem::size_of::<T>() > MAX_MSG_SIZE {
+ if mem::size_of::<T>() > MAX_MSG_SIZE
+ || payload_size > MAX_MSG_SIZE
+ || mem::size_of::<T>() + payload_size > MAX_MSG_SIZE
+ {
return Err(Error::InvalidParam);
}
self.check_state()?;
@@ -568,7 +700,7 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
req: &VhostUserMsgHeader<MasterReq>,
res: Result<()>,
) -> Result<()> {
- if self.reply_ack_enabled {
+ if self.reply_ack_enabled && req.is_need_reply() {
let hdr = self.new_reply_header::<VhostUserU64>(req, 0)?;
let val = match res {
Ok(_) => 0,
@@ -590,16 +722,12 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
Ok(())
}
- fn send_reply_with_payload<T, P>(
+ fn send_reply_with_payload<T: Sized>(
&mut self,
req: &VhostUserMsgHeader<MasterReq>,
msg: &T,
- payload: &[P],
- ) -> Result<()>
- where
- T: Sized,
- P: Sized,
- {
+ payload: &[u8],
+ ) -> Result<()> {
let hdr = self.new_reply_header::<T>(req, payload.len())?;
self.main_sock
.send_message_with_payload(&hdr, msg, payload, None)?;
@@ -612,3 +740,24 @@ impl<S: VhostUserSlaveReqHandler> AsRawFd for SlaveReqHandler<S> {
self.main_sock.as_raw_fd()
}
}
+
+#[cfg(test)]
+mod tests {
+ use std::os::unix::io::AsRawFd;
+
+ use super::*;
+ use crate::vhost_user::dummy_slave::DummySlaveReqHandler;
+
+ #[test]
+ fn test_slave_req_handler_new() {
+ let (p1, _p2) = UnixStream::pair().unwrap();
+ let endpoint = Endpoint::<MasterReq>::from_stream(p1);
+ let backend = Arc::new(Mutex::new(DummySlaveReqHandler::new()));
+ let mut handler = SlaveReqHandler::new(endpoint, backend);
+
+ handler.check_state().unwrap();
+ handler.set_failed(libc::EAGAIN);
+ handler.check_state().unwrap_err();
+ assert!(handler.as_raw_fd() >= 0);
+ }
+}
diff --git a/src/vhost_user/sock_ctrl_msg.rs b/src/vhost_user/sock_ctrl_msg.rs
deleted file mode 100644
index db3ec2e..0000000
--- a/src/vhost_user/sock_ctrl_msg.rs
+++ /dev/null
@@ -1,499 +0,0 @@
-// Copyright 2017 The Chromium OS Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style license that can be
-// found in the LICENSE file.
-
-//! Used to send and receive messages with file descriptors on sockets that accept control messages
-//! (e.g. Unix domain sockets).
-
-// TODO: move this file into the vmm-sys-util crate
-
-use std::fs::File;
-use std::mem::size_of;
-use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
-use std::os::unix::net::{UnixDatagram, UnixStream};
-use std::ptr::{copy_nonoverlapping, null_mut, write_unaligned};
-
-use libc::{
- c_long, c_void, cmsghdr, iovec, msghdr, recvmsg, sendmsg, MSG_NOSIGNAL, SCM_RIGHTS, SOL_SOCKET,
-};
-use vmm_sys_util::errno::{Error, Result};
-
-// Each of the following macros performs the same function as their C counterparts. They are each
-// macros because they are used to size statically allocated arrays.
-
-macro_rules! CMSG_ALIGN {
- ($len:expr) => {
- (($len) + size_of::<c_long>() - 1) & !(size_of::<c_long>() - 1)
- };
-}
-
-macro_rules! CMSG_SPACE {
- ($len:expr) => {
- size_of::<cmsghdr>() + CMSG_ALIGN!($len)
- };
-}
-
-#[cfg(not(target_env = "musl"))]
-macro_rules! CMSG_LEN {
- ($len:expr) => {
- size_of::<cmsghdr>() + ($len)
- };
-}
-
-#[cfg(target_env = "musl")]
-macro_rules! CMSG_LEN {
- ($len:expr) => {{
- let sz = size_of::<cmsghdr>() + ($len);
- assert!(sz <= (std::u32::MAX as usize));
- sz as u32
- }};
-}
-
-#[cfg(not(target_env = "musl"))]
-fn new_msghdr(iovecs: &mut [iovec]) -> msghdr {
- msghdr {
- msg_name: null_mut(),
- msg_namelen: 0,
- msg_iov: iovecs.as_mut_ptr(),
- msg_iovlen: iovecs.len(),
- msg_control: null_mut(),
- msg_controllen: 0,
- msg_flags: 0,
- }
-}
-
-#[cfg(target_env = "musl")]
-fn new_msghdr(iovecs: &mut [iovec]) -> msghdr {
- assert!(iovecs.len() <= (std::i32::MAX as usize));
- let mut msg: msghdr = unsafe { std::mem::zeroed() };
- msg.msg_name = null_mut();
- msg.msg_iov = iovecs.as_mut_ptr();
- msg.msg_iovlen = iovecs.len() as i32;
- msg.msg_control = null_mut();
- msg
-}
-
-#[cfg(not(target_env = "musl"))]
-fn set_msg_controllen(msg: &mut msghdr, cmsg_capacity: usize) {
- msg.msg_controllen = cmsg_capacity;
-}
-
-#[cfg(target_env = "musl")]
-fn set_msg_controllen(msg: &mut msghdr, cmsg_capacity: usize) {
- assert!(cmsg_capacity <= (std::u32::MAX as usize));
- msg.msg_controllen = cmsg_capacity as u32;
-}
-
-// This function (macro in the C version) is not used in any compile time constant slots, so is just
-// an ordinary function. The returned pointer is hard coded to be RawFd because that's all that this
-// module supports.
-#[allow(non_snake_case)]
-#[inline(always)]
-fn CMSG_DATA(cmsg_buffer: *mut cmsghdr) -> *mut RawFd {
- // Essentially returns a pointer to just past the header.
- cmsg_buffer.wrapping_offset(1) as *mut RawFd
-}
-
-// This function is like CMSG_NEXT, but safer because it reads only from references, although it
-// does some pointer arithmetic on cmsg_ptr.
-#[cfg_attr(feature = "cargo-clippy", allow(clippy::cast_ptr_alignment))]
-fn get_next_cmsg(msghdr: &msghdr, cmsg: &cmsghdr, cmsg_ptr: *mut cmsghdr) -> *mut cmsghdr {
- let next_cmsg =
- (cmsg_ptr as *mut u8).wrapping_add(CMSG_ALIGN!(cmsg.cmsg_len as usize)) as *mut cmsghdr;
- if next_cmsg
- .wrapping_offset(1)
- .wrapping_sub(msghdr.msg_control as usize) as usize
- > msghdr.msg_controllen as usize
- {
- null_mut()
- } else {
- next_cmsg
- }
-}
-
-const CMSG_BUFFER_INLINE_CAPACITY: usize = CMSG_SPACE!(size_of::<RawFd>() * 32);
-
-enum CmsgBuffer {
- Inline([u64; (CMSG_BUFFER_INLINE_CAPACITY + 7) / 8]),
- Heap(Box<[cmsghdr]>),
-}
-
-impl CmsgBuffer {
- fn with_capacity(capacity: usize) -> CmsgBuffer {
- let cap_in_cmsghdr_units =
- (capacity.checked_add(size_of::<cmsghdr>()).unwrap() - 1) / size_of::<cmsghdr>();
- if capacity <= CMSG_BUFFER_INLINE_CAPACITY {
- CmsgBuffer::Inline([0u64; (CMSG_BUFFER_INLINE_CAPACITY + 7) / 8])
- } else {
- CmsgBuffer::Heap(
- vec![
- cmsghdr {
- cmsg_len: 0,
- cmsg_level: 0,
- cmsg_type: 0,
- #[cfg(target_env = "musl")]
- __pad1: 0,
- };
- cap_in_cmsghdr_units
- ]
- .into_boxed_slice(),
- )
- }
- }
-
- fn as_mut_ptr(&mut self) -> *mut cmsghdr {
- match self {
- CmsgBuffer::Inline(a) => a.as_mut_ptr() as *mut cmsghdr,
- CmsgBuffer::Heap(a) => a.as_mut_ptr(),
- }
- }
-}
-
-fn raw_sendmsg<D: IntoIovec>(fd: RawFd, out_data: &[D], out_fds: &[RawFd]) -> Result<usize> {
- let cmsg_capacity = CMSG_SPACE!(size_of::<RawFd>() * out_fds.len());
- let mut cmsg_buffer = CmsgBuffer::with_capacity(cmsg_capacity);
-
- let mut iovecs = Vec::with_capacity(out_data.len());
- for data in out_data {
- iovecs.push(iovec {
- iov_base: data.as_ptr() as *mut c_void,
- iov_len: data.size(),
- });
- }
-
- let mut msg = new_msghdr(&mut iovecs);
-
- if !out_fds.is_empty() {
- let cmsg = cmsghdr {
- cmsg_len: CMSG_LEN!(size_of::<RawFd>() * out_fds.len()),
- cmsg_level: SOL_SOCKET,
- cmsg_type: SCM_RIGHTS,
- #[cfg(target_env = "musl")]
- __pad1: 0,
- };
- unsafe {
- // Safe because cmsg_buffer was allocated to be large enough to contain cmsghdr.
- write_unaligned(cmsg_buffer.as_mut_ptr() as *mut cmsghdr, cmsg);
- // Safe because the cmsg_buffer was allocated to be large enough to hold out_fds.len()
- // file descriptors.
- copy_nonoverlapping(
- out_fds.as_ptr(),
- CMSG_DATA(cmsg_buffer.as_mut_ptr()),
- out_fds.len(),
- );
- }
-
- msg.msg_control = cmsg_buffer.as_mut_ptr() as *mut c_void;
- set_msg_controllen(&mut msg, cmsg_capacity);
- }
-
- // Safe because the msghdr was properly constructed from valid (or null) pointers of the
- // indicated length and we check the return value.
- let write_count = unsafe { sendmsg(fd, &msg, MSG_NOSIGNAL) };
-
- if write_count == -1 {
- Err(Error::last())
- } else {
- Ok(write_count as usize)
- }
-}
-
-fn raw_recvmsg(fd: RawFd, iovecs: &mut [iovec], in_fds: &mut [RawFd]) -> Result<(usize, usize)> {
- let cmsg_capacity = CMSG_SPACE!(size_of::<RawFd>() * in_fds.len());
- let mut cmsg_buffer = CmsgBuffer::with_capacity(cmsg_capacity);
- let mut msg = new_msghdr(iovecs);
-
- if !in_fds.is_empty() {
- msg.msg_control = cmsg_buffer.as_mut_ptr() as *mut c_void;
- set_msg_controllen(&mut msg, cmsg_capacity);
- }
-
- // Safe because the msghdr was properly constructed from valid (or null) pointers of the
- // indicated length and we check the return value.
- let total_read = unsafe { recvmsg(fd, &mut msg, libc::MSG_WAITALL) };
-
- if total_read == -1 {
- return Err(Error::last());
- }
-
- // When the connection is closed recvmsg() doesn't give an explicit error
- if total_read == 0 && (msg.msg_controllen as usize) < size_of::<cmsghdr>() {
- return Err(Error::new(libc::ECONNRESET));
- }
-
- let mut cmsg_ptr = msg.msg_control as *mut cmsghdr;
- let mut in_fds_count = 0;
- while !cmsg_ptr.is_null() {
- // Safe because we checked that cmsg_ptr was non-null, and the loop is constructed such that
- // that only happens when there is at least sizeof(cmsghdr) space after the pointer to read.
- let cmsg = unsafe { (cmsg_ptr as *mut cmsghdr).read_unaligned() };
-
- if cmsg.cmsg_level == SOL_SOCKET && cmsg.cmsg_type == SCM_RIGHTS {
- let fd_count = (cmsg.cmsg_len - CMSG_LEN!(0)) as usize / size_of::<RawFd>();
- unsafe {
- copy_nonoverlapping(
- CMSG_DATA(cmsg_ptr),
- in_fds[in_fds_count..(in_fds_count + fd_count)].as_mut_ptr(),
- fd_count,
- );
- }
- in_fds_count += fd_count;
- }
-
- cmsg_ptr = get_next_cmsg(&msg, &cmsg, cmsg_ptr);
- }
-
- Ok((total_read as usize, in_fds_count))
-}
-
-/// Trait for file descriptors can send and receive socket control messages via `sendmsg` and
-/// `recvmsg`.
-pub trait ScmSocket {
- /// Gets the file descriptor of this socket.
- fn socket_fd(&self) -> RawFd;
-
- /// Sends the given data and file descriptor over the socket.
- ///
- /// On success, returns the number of bytes sent.
- ///
- /// # Arguments
- ///
- /// * `buf` - A buffer of data to send on the `socket`.
- /// * `fd` - A file descriptors to be sent.
- fn send_with_fd<D: IntoIovec>(&self, buf: D, fd: RawFd) -> Result<usize> {
- self.send_with_fds(&[buf], &[fd])
- }
-
- /// Sends the given data and file descriptors over the socket.
- ///
- /// On success, returns the number of bytes sent.
- ///
- /// # Arguments
- ///
- /// * `bufs` - A list of data buffer to send on the `socket`.
- /// * `fds` - A list of file descriptors to be sent.
- fn send_with_fds<D: IntoIovec>(&self, bufs: &[D], fds: &[RawFd]) -> Result<usize> {
- raw_sendmsg(self.socket_fd(), bufs, fds)
- }
-
- /// Receives data and potentially a file descriptor from the socket.
- ///
- /// On success, returns the number of bytes and an optional file descriptor.
- ///
- /// # Arguments
- ///
- /// * `buf` - A buffer to receive data from the socket.
- fn recv_with_fd(&self, buf: &mut [u8]) -> Result<(usize, Option<File>)> {
- let mut fd = [0];
- let mut iovecs = [iovec {
- iov_base: buf.as_mut_ptr() as *mut c_void,
- iov_len: buf.len(),
- }];
-
- let (read_count, fd_count) = self.recv_with_fds(&mut iovecs[..], &mut fd)?;
- let file = if fd_count == 0 {
- None
- } else {
- // Safe because the first fd from recv_with_fds is owned by us and valid because this
- // branch was taken.
- Some(unsafe { File::from_raw_fd(fd[0]) })
- };
- Ok((read_count, file))
- }
-
- /// Receives data and file descriptors from the socket.
- ///
- /// On success, returns the number of bytes and file descriptors received as a tuple
- /// `(bytes count, files count)`.
- ///
- /// # Arguments
- ///
- /// * `iovecs` - A list of iovec to receive data from the socket.
- /// * `fds` - A slice of `RawFd`s to put the received file descriptors into. On success, the
- /// number of valid file descriptors is indicated by the second element of the
- /// returned tuple. The caller owns these file descriptors, but they will not be
- /// closed on drop like a `File`-like type would be. It is recommended that each valid
- /// file descriptor gets wrapped in a drop type that closes it after this returns.
- fn recv_with_fds(&self, iovecs: &mut [iovec], fds: &mut [RawFd]) -> Result<(usize, usize)> {
- raw_recvmsg(self.socket_fd(), iovecs, fds)
- }
-}
-
-impl ScmSocket for UnixDatagram {
- fn socket_fd(&self) -> RawFd {
- self.as_raw_fd()
- }
-}
-
-impl ScmSocket for UnixStream {
- fn socket_fd(&self) -> RawFd {
- self.as_raw_fd()
- }
-}
-
-/// Trait for types that can be converted into an `iovec` that can be referenced by a syscall for
-/// the lifetime of this object.
-///
-/// This trait is unsafe because interfaces that use this trait depend on the base pointer and size
-/// being accurate.
-pub unsafe trait IntoIovec {
- /// Gets the base pointer of this `iovec`.
- fn as_ptr(&self) -> *const c_void;
-
- /// Gets the size in bytes of this `iovec`.
- fn size(&self) -> usize;
-}
-
-// Safe because this slice can not have another mutable reference and it's pointer and size are
-// guaranteed to be valid.
-unsafe impl<'a> IntoIovec for &'a [u8] {
- // Clippy false positive: https://github.com/rust-lang/rust-clippy/issues/3480
- #[cfg_attr(feature = "cargo-clippy", allow(clippy::useless_asref))]
- fn as_ptr(&self) -> *const c_void {
- self.as_ref().as_ptr() as *const c_void
- }
-
- fn size(&self) -> usize {
- self.len()
- }
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
-
- use std::io::Write;
- use std::mem::size_of;
- use std::os::raw::c_long;
- use std::os::unix::net::UnixDatagram;
- use std::slice::from_raw_parts;
-
- use libc::cmsghdr;
-
- use vmm_sys_util::eventfd::EventFd;
-
- #[test]
- fn buffer_len() {
- assert_eq!(CMSG_SPACE!(0 * size_of::<RawFd>()), size_of::<cmsghdr>());
- assert_eq!(
- CMSG_SPACE!(1 * size_of::<RawFd>()),
- size_of::<cmsghdr>() + size_of::<c_long>()
- );
- if size_of::<RawFd>() == 4 {
- assert_eq!(
- CMSG_SPACE!(2 * size_of::<RawFd>()),
- size_of::<cmsghdr>() + size_of::<c_long>()
- );
- assert_eq!(
- CMSG_SPACE!(3 * size_of::<RawFd>()),
- size_of::<cmsghdr>() + size_of::<c_long>() * 2
- );
- assert_eq!(
- CMSG_SPACE!(4 * size_of::<RawFd>()),
- size_of::<cmsghdr>() + size_of::<c_long>() * 2
- );
- } else if size_of::<RawFd>() == 8 {
- assert_eq!(
- CMSG_SPACE!(2 * size_of::<RawFd>()),
- size_of::<cmsghdr>() + size_of::<c_long>() * 2
- );
- assert_eq!(
- CMSG_SPACE!(3 * size_of::<RawFd>()),
- size_of::<cmsghdr>() + size_of::<c_long>() * 3
- );
- assert_eq!(
- CMSG_SPACE!(4 * size_of::<RawFd>()),
- size_of::<cmsghdr>() + size_of::<c_long>() * 4
- );
- }
- }
-
- #[test]
- fn send_recv_no_fd() {
- let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair");
-
- let write_count = s1
- .send_with_fds(&[[1u8, 1, 2].as_ref(), [21u8, 34, 55].as_ref()], &[])
- .expect("failed to send data");
-
- assert_eq!(write_count, 6);
-
- let mut buf = [0u8; 6];
- let mut files = [0; 1];
- let mut iovecs = [iovec {
- iov_base: buf.as_mut_ptr() as *mut c_void,
- iov_len: buf.len(),
- }];
- let (read_count, file_count) = s2
- .recv_with_fds(&mut iovecs[..], &mut files)
- .expect("failed to recv data");
-
- assert_eq!(read_count, 6);
- assert_eq!(file_count, 0);
- assert_eq!(buf, [1, 1, 2, 21, 34, 55]);
- }
-
- #[test]
- fn send_recv_only_fd() {
- let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair");
-
- let evt = EventFd::new(0).expect("failed to create eventfd");
- let write_count = s1
- .send_with_fd([].as_ref(), evt.as_raw_fd())
- .expect("failed to send fd");
-
- assert_eq!(write_count, 0);
-
- let (read_count, file_opt) = s2.recv_with_fd(&mut []).expect("failed to recv fd");
-
- let mut file = file_opt.unwrap();
-
- assert_eq!(read_count, 0);
- assert!(file.as_raw_fd() >= 0);
- assert_ne!(file.as_raw_fd(), s1.as_raw_fd());
- assert_ne!(file.as_raw_fd(), s2.as_raw_fd());
- assert_ne!(file.as_raw_fd(), evt.as_raw_fd());
-
- file.write(unsafe { from_raw_parts(&1203u64 as *const u64 as *const u8, 8) })
- .expect("failed to write to sent fd");
-
- assert_eq!(evt.read().expect("failed to read from eventfd"), 1203);
- }
-
- #[test]
- fn send_recv_with_fd() {
- let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair");
-
- let evt = EventFd::new(0).expect("failed to create eventfd");
- let write_count = s1
- .send_with_fds(&[[237].as_ref()], &[evt.as_raw_fd()])
- .expect("failed to send fd");
-
- assert_eq!(write_count, 1);
-
- let mut files = [0; 2];
- let mut buf = [0u8];
- let mut iovecs = [iovec {
- iov_base: buf.as_mut_ptr() as *mut c_void,
- iov_len: buf.len(),
- }];
- let (read_count, file_count) = s2
- .recv_with_fds(&mut iovecs[..], &mut files)
- .expect("failed to recv fd");
-
- assert_eq!(read_count, 1);
- assert_eq!(buf[0], 237);
- assert_eq!(file_count, 1);
- assert!(files[0] >= 0);
- assert_ne!(files[0], s1.as_raw_fd());
- assert_ne!(files[0], s2.as_raw_fd());
- assert_ne!(files[0], evt.as_raw_fd());
-
- let mut file = unsafe { File::from_raw_fd(files[0]) };
-
- file.write(unsafe { from_raw_parts(&1203u64 as *const u64 as *const u8, 8) })
- .expect("failed to write to sent fd");
-
- assert_eq!(evt.read().expect("failed to read from eventfd"), 1203);
- }
-}