summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLiu Jiang <gerry@linux.alibaba.com>2021-02-19 22:51:08 +0800
committerSergio Lopez <slp@sinrega.org>2021-03-01 12:50:56 +0100
commit74c353bb5fc003d95f8099a84bfdb8aa1895e2b4 (patch)
treec40c6416553e7cc070070738ff6123a96017debe
parent21b89b2ff5c418144760d08c2000776d0f0792f0 (diff)
downloadvmm_vhost-74c353bb5fc003d95f8099a84bfdb8aa1895e2b4.tar.gz
vhost_user: add VhostUserSlaveReqHandlerMut trait
Rename the original VhostUserSlaveReqHandler trait as VhostUserSlaveReqHandlerMut, and add another VhostUserSlaveReqHandler trait with interior mutability. This also help to simplify caller implementations. Signed-off-by: Liu Jiang <gerry@linux.alibaba.com>
-rw-r--r--src/vhost_user/mod.rs19
-rw-r--r--src/vhost_user/slave.rs6
-rw-r--r--src/vhost_user/slave_req_handler.rs234
3 files changed, 195 insertions, 64 deletions
diff --git a/src/vhost_user/mod.rs b/src/vhost_user/mod.rs
index 148a00e..91e4203 100644
--- a/src/vhost_user/mod.rs
+++ b/src/vhost_user/mod.rs
@@ -18,20 +18,23 @@
//! Most messages that can be sent via the Unix domain socket implementing vhost-user have an
//! equivalent ioctl to the kernel implementation.
-use libc;
use std::io::Error as IOError;
-mod connection;
pub mod message;
+
+mod connection;
pub use self::connection::Listener;
+
#[cfg(feature = "vhost-user-master")]
mod master;
#[cfg(feature = "vhost-user-master")]
pub use self::master::{Master, VhostUserMaster};
-#[cfg(any(feature = "vhost-user-master", feature = "vhost-user-slave"))]
+#[cfg(feature = "vhost-user")]
mod master_req_handler;
-#[cfg(any(feature = "vhost-user-master", feature = "vhost-user-slave"))]
-pub use self::master_req_handler::{MasterReqHandler, VhostUserMasterReqHandler};
+#[cfg(feature = "vhost-user")]
+pub use self::master_req_handler::{
+ MasterReqHandler, VhostUserMasterReqHandler, VhostUserMasterReqHandlerMut,
+};
#[cfg(feature = "vhost-user-slave")]
mod slave;
@@ -40,7 +43,9 @@ pub use self::slave::SlaveListener;
#[cfg(feature = "vhost-user-slave")]
mod slave_req_handler;
#[cfg(feature = "vhost-user-slave")]
-pub use self::slave_req_handler::{SlaveReqHandler, VhostUserSlaveReqHandler};
+pub use self::slave_req_handler::{
+ SlaveReqHandler, VhostUserSlaveReqHandler, VhostUserSlaveReqHandlerMut,
+};
#[cfg(feature = "vhost-user-slave")]
mod slave_fs_cache;
#[cfg(feature = "vhost-user-slave")]
@@ -100,6 +105,8 @@ impl std::fmt::Display for Error {
}
}
+impl std::error::Error for Error {}
+
impl Error {
/// Determine whether to rebuild the underline communication channel.
pub fn should_reconnect(&self) -> bool {
diff --git a/src/vhost_user/slave.rs b/src/vhost_user/slave.rs
index 5ac99af..c167dce 100644
--- a/src/vhost_user/slave.rs
+++ b/src/vhost_user/slave.rs
@@ -3,7 +3,7 @@
//! Traits and Structs for vhost-user slave.
-use std::sync::{Arc, Mutex};
+use std::sync::Arc;
use super::connection::{Endpoint, Listener};
use super::message::*;
@@ -12,14 +12,14 @@ use super::{Result, SlaveReqHandler, VhostUserSlaveReqHandler};
/// Vhost-user slave side connection listener.
pub struct SlaveListener<S: VhostUserSlaveReqHandler> {
listener: Listener,
- backend: Option<Arc<Mutex<S>>>,
+ backend: Option<Arc<S>>,
}
/// Sets up a listener for incoming master connections, and handles construction
/// of a Slave on success.
impl<S: VhostUserSlaveReqHandler> SlaveListener<S> {
/// Create a unix domain socket for incoming master connections.
- pub fn new(listener: Listener, backend: Arc<Mutex<S>>) -> Result<Self> {
+ pub fn new(listener: Listener, backend: Arc<S>) -> Result<Self> {
Ok(SlaveListener {
listener,
backend: Some(backend),
diff --git a/src/vhost_user/slave_req_handler.rs b/src/vhost_user/slave_req_handler.rs
index 985693f..190501b 100644
--- a/src/vhost_user/slave_req_handler.rs
+++ b/src/vhost_user/slave_req_handler.rs
@@ -1,8 +1,6 @@
// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
-//! Traits and Structs to handle vhost-user requests from the master to the slave.
-
use std::mem;
use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
use std::os::unix::net::UnixStream;
@@ -14,9 +12,63 @@ use super::message::*;
use super::slave_fs_cache::SlaveFsCacheReq;
use super::{Error, Result};
-/// Trait to handle vhost-user requests from the master to the slave.
+/// Services provided to the master by the slave with interior mutability.
+///
+/// The [VhostUserSlaveReqHandler] trait defines the services provided to the master by the slave.
+/// And the [VhostUserSlaveReqHandlerMut] trait is a helper mirroring [VhostUserSlaveReqHandler],
+/// but without interior mutability.
+/// The vhost-user specification defines a master communication channel, by which masters could
+/// request services from slaves. The [VhostUserSlaveReqHandler] trait defines services provided by
+/// slaves, and it's used both on the master side and slave side.
+///
+/// - on the master side, a stub forwarder implementing [VhostUserSlaveReqHandler] will proxy
+/// service requests to slaves.
+/// - on the slave side, the [SlaveReqHandler] will forward service requests to a handler
+/// implementing [VhostUserSlaveReqHandler].
+///
+/// The [VhostUserSlaveReqHandler] trait is design with interior mutability to improve performance
+/// for multi-threading.
+///
+/// [VhostUserSlaveReqHandler]: trait.VhostUserSlaveReqHandler.html
+/// [VhostUserSlaveReqHandlerMut]: trait.VhostUserSlaveReqHandlerMut.html
+/// [SlaveReqHandler]: struct.SlaveReqHandler.html
#[allow(missing_docs)]
pub trait VhostUserSlaveReqHandler {
+ fn set_owner(&self) -> Result<()>;
+ fn reset_owner(&self) -> Result<()>;
+ fn get_features(&self) -> Result<u64>;
+ fn set_features(&self, features: u64) -> Result<()>;
+ fn set_mem_table(&self, ctx: &[VhostUserMemoryRegion], fds: &[RawFd]) -> Result<()>;
+ fn set_vring_num(&self, index: u32, num: u32) -> Result<()>;
+ fn set_vring_addr(
+ &self,
+ index: u32,
+ flags: VhostUserVringAddrFlags,
+ descriptor: u64,
+ used: u64,
+ available: u64,
+ log: u64,
+ ) -> Result<()>;
+ fn set_vring_base(&self, index: u32, base: u32) -> Result<()>;
+ fn get_vring_base(&self, index: u32) -> Result<VhostUserVringState>;
+ fn set_vring_kick(&self, index: u8, fd: Option<RawFd>) -> Result<()>;
+ fn set_vring_call(&self, index: u8, fd: Option<RawFd>) -> Result<()>;
+ fn set_vring_err(&self, index: u8, fd: Option<RawFd>) -> Result<()>;
+
+ fn get_protocol_features(&self) -> Result<VhostUserProtocolFeatures>;
+ fn set_protocol_features(&self, features: u64) -> Result<()>;
+ fn get_queue_num(&self) -> Result<u64>;
+ fn set_vring_enable(&self, index: u32, enable: bool) -> Result<()>;
+ fn get_config(&self, offset: u32, size: u32, flags: VhostUserConfigFlags) -> Result<Vec<u8>>;
+ fn set_config(&self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>;
+ fn set_slave_req_fd(&self, _vu_req: SlaveFsCacheReq) {}
+}
+
+/// Services provided to the master by the slave without interior mutability.
+///
+/// This is a helper trait mirroring the [VhostUserSlaveReqHandler] trait.
+#[allow(missing_docs)]
+pub trait VhostUserSlaveReqHandlerMut {
fn set_owner(&mut self) -> Result<()>;
fn reset_owner(&mut self) -> Result<()>;
fn get_features(&mut self) -> Result<u64>;
@@ -52,16 +104,110 @@ pub trait VhostUserSlaveReqHandler {
fn set_slave_req_fd(&mut self, _vu_req: SlaveFsCacheReq) {}
}
-/// A vhost-user slave endpoint which relays all received requests from the
-/// master to the virtio backend device object.
+impl<T: VhostUserSlaveReqHandlerMut> VhostUserSlaveReqHandler for Mutex<T> {
+ fn set_owner(&self) -> Result<()> {
+ self.lock().unwrap().set_owner()
+ }
+
+ fn reset_owner(&self) -> Result<()> {
+ self.lock().unwrap().reset_owner()
+ }
+
+ fn get_features(&self) -> Result<u64> {
+ self.lock().unwrap().get_features()
+ }
+
+ fn set_features(&self, features: u64) -> Result<()> {
+ self.lock().unwrap().set_features(features)
+ }
+
+ fn set_mem_table(&self, ctx: &[VhostUserMemoryRegion], fds: &[RawFd]) -> Result<()> {
+ self.lock().unwrap().set_mem_table(ctx, fds)
+ }
+
+ fn set_vring_num(&self, index: u32, num: u32) -> Result<()> {
+ self.lock().unwrap().set_vring_num(index, num)
+ }
+
+ fn set_vring_addr(
+ &self,
+ index: u32,
+ flags: VhostUserVringAddrFlags,
+ descriptor: u64,
+ used: u64,
+ available: u64,
+ log: u64,
+ ) -> Result<()> {
+ self.lock()
+ .unwrap()
+ .set_vring_addr(index, flags, descriptor, used, available, log)
+ }
+
+ fn set_vring_base(&self, index: u32, base: u32) -> Result<()> {
+ self.lock().unwrap().set_vring_base(index, base)
+ }
+
+ fn get_vring_base(&self, index: u32) -> Result<VhostUserVringState> {
+ self.lock().unwrap().get_vring_base(index)
+ }
+
+ fn set_vring_kick(&self, index: u8, fd: Option<RawFd>) -> Result<()> {
+ self.lock().unwrap().set_vring_kick(index, fd)
+ }
+
+ fn set_vring_call(&self, index: u8, fd: Option<RawFd>) -> Result<()> {
+ self.lock().unwrap().set_vring_call(index, fd)
+ }
+
+ fn set_vring_err(&self, index: u8, fd: Option<RawFd>) -> Result<()> {
+ self.lock().unwrap().set_vring_err(index, fd)
+ }
+
+ fn get_protocol_features(&self) -> Result<VhostUserProtocolFeatures> {
+ self.lock().unwrap().get_protocol_features()
+ }
+
+ fn set_protocol_features(&self, features: u64) -> Result<()> {
+ self.lock().unwrap().set_protocol_features(features)
+ }
+
+ fn get_queue_num(&self) -> Result<u64> {
+ self.lock().unwrap().get_queue_num()
+ }
+
+ fn set_vring_enable(&self, index: u32, enable: bool) -> Result<()> {
+ self.lock().unwrap().set_vring_enable(index, enable)
+ }
+
+ fn get_config(&self, offset: u32, size: u32, flags: VhostUserConfigFlags) -> Result<Vec<u8>> {
+ self.lock().unwrap().get_config(offset, size, flags)
+ }
+
+ fn set_config(&self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()> {
+ self.lock().unwrap().set_config(offset, buf, flags)
+ }
+
+ fn set_slave_req_fd(&self, vu_req: SlaveFsCacheReq) {
+ self.lock().unwrap().set_slave_req_fd(vu_req)
+ }
+}
+
+/// Server to handle service requests from masters from the master communication channel.
+///
+/// The [SlaveReqHandler] acts as a server on the slave side, to handle service requests from
+/// masters on the master communication channel. It's actually a proxy invoking the registered
+/// handler implementing [VhostUserSlaveReqHandler] to do the real work.
///
/// The lifetime of the SlaveReqHandler object should be the same as the underline Unix Domain
/// Socket, so it gets simpler to recover from disconnect.
+///
+/// [VhostUserSlaveReqHandler]: trait.VhostUserSlaveReqHandler.html
+/// [SlaveReqHandler]: struct.SlaveReqHandler.html
pub struct SlaveReqHandler<S: VhostUserSlaveReqHandler> {
// underlying Unix domain socket for communication
main_sock: Endpoint<MasterReq>,
// the vhost-user backend device object
- backend: Arc<Mutex<S>>,
+ backend: Arc<S>,
virtio_features: u64,
acked_virtio_features: u64,
@@ -76,7 +222,7 @@ pub struct SlaveReqHandler<S: VhostUserSlaveReqHandler> {
impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
/// Create a vhost-user slave endpoint.
- pub(super) fn new(main_sock: Endpoint<MasterReq>, backend: Arc<Mutex<S>>) -> Self {
+ pub(super) fn new(main_sock: Endpoint<MasterReq>, backend: Arc<S>) -> Self {
SlaveReqHandler {
main_sock,
backend,
@@ -94,7 +240,7 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
/// # Arguments
/// * - `path` - path of Unix domain socket listener to connect to
/// * - `backend` - handler for requests from the master to the slave
- pub fn connect(path: &str, backend: Arc<Mutex<S>>) -> Result<Self> {
+ pub fn connect(path: &str, backend: Arc<S>) -> Result<Self> {
Ok(Self::new(Endpoint::<MasterReq>::connect(path)?, backend))
}
@@ -103,11 +249,12 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
self.error = Some(error);
}
- /// Receive and handle one incoming request message from the master.
- /// The caller needs to:
- /// . serialize calls to this function
- /// . decide what to do when error happens
- /// . optional recover from failure
+ /// Main entrance to server slave request from the slave communication channel.
+ ///
+ /// Receive and handle one incoming request message from the master. The caller needs to:
+ /// - serialize calls to this function
+ /// - decide what to do when error happens
+ /// - optional recover from failure
pub fn handle_request(&mut self) -> Result<()> {
// Return error if the endpoint is already in failed state.
self.check_state()?;
@@ -137,15 +284,15 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
match hdr.get_code() {
MasterReq::SET_OWNER => {
self.check_request_size(&hdr, size, 0)?;
- self.backend.lock().unwrap().set_owner()?;
+ self.backend.set_owner()?;
}
MasterReq::RESET_OWNER => {
self.check_request_size(&hdr, size, 0)?;
- self.backend.lock().unwrap().reset_owner()?;
+ self.backend.reset_owner()?;
}
MasterReq::GET_FEATURES => {
self.check_request_size(&hdr, size, 0)?;
- let features = self.backend.lock().unwrap().get_features()?;
+ let features = self.backend.get_features()?;
let msg = VhostUserU64::new(features);
self.send_reply_message(&hdr, &msg)?;
self.virtio_features = features;
@@ -153,7 +300,7 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
}
MasterReq::SET_FEATURES => {
let msg = self.extract_request_body::<VhostUserU64>(&hdr, size, &buf)?;
- self.backend.lock().unwrap().set_features(msg.value)?;
+ self.backend.set_features(msg.value)?;
self.acked_virtio_features = msg.value;
self.update_reply_ack_flag();
}
@@ -163,11 +310,7 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
}
MasterReq::SET_VRING_NUM => {
let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
- let res = self
- .backend
- .lock()
- .unwrap()
- .set_vring_num(msg.index, msg.num);
+ let res = self.backend.set_vring_num(msg.index, msg.num);
self.send_ack_message(&hdr, res)?;
}
MasterReq::SET_VRING_ADDR => {
@@ -176,7 +319,7 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
Some(val) => val,
None => return Err(Error::InvalidMessage),
};
- let res = self.backend.lock().unwrap().set_vring_addr(
+ let res = self.backend.set_vring_addr(
msg.index,
flags,
msg.descriptor,
@@ -188,39 +331,35 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
}
MasterReq::SET_VRING_BASE => {
let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
- let res = self
- .backend
- .lock()
- .unwrap()
- .set_vring_base(msg.index, msg.num);
+ let res = self.backend.set_vring_base(msg.index, msg.num);
self.send_ack_message(&hdr, res)?;
}
MasterReq::GET_VRING_BASE => {
let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
- let reply = self.backend.lock().unwrap().get_vring_base(msg.index)?;
+ let reply = self.backend.get_vring_base(msg.index)?;
self.send_reply_message(&hdr, &reply)?;
}
MasterReq::SET_VRING_CALL => {
self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?;
let (index, rfds) = self.handle_vring_fd_request(&buf, rfds)?;
- let res = self.backend.lock().unwrap().set_vring_call(index, rfds);
+ let res = self.backend.set_vring_call(index, rfds);
self.send_ack_message(&hdr, res)?;
}
MasterReq::SET_VRING_KICK => {
self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?;
let (index, rfds) = self.handle_vring_fd_request(&buf, rfds)?;
- let res = self.backend.lock().unwrap().set_vring_kick(index, rfds);
+ let res = self.backend.set_vring_kick(index, rfds);
self.send_ack_message(&hdr, res)?;
}
MasterReq::SET_VRING_ERR => {
self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?;
let (index, rfds) = self.handle_vring_fd_request(&buf, rfds)?;
- let res = self.backend.lock().unwrap().set_vring_err(index, rfds);
+ let res = self.backend.set_vring_err(index, rfds);
self.send_ack_message(&hdr, res)?;
}
MasterReq::GET_PROTOCOL_FEATURES => {
self.check_request_size(&hdr, size, 0)?;
- let features = self.backend.lock().unwrap().get_protocol_features()?;
+ let features = self.backend.get_protocol_features()?;
let msg = VhostUserU64::new(features.bits());
self.send_reply_message(&hdr, &msg)?;
self.protocol_features = features;
@@ -228,10 +367,7 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
}
MasterReq::SET_PROTOCOL_FEATURES => {
let msg = self.extract_request_body::<VhostUserU64>(&hdr, size, &buf)?;
- self.backend
- .lock()
- .unwrap()
- .set_protocol_features(msg.value)?;
+ self.backend.set_protocol_features(msg.value)?;
self.acked_protocol_features = msg.value;
self.update_reply_ack_flag();
}
@@ -240,7 +376,7 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
return Err(Error::InvalidOperation);
}
self.check_request_size(&hdr, size, 0)?;
- let num = self.backend.lock().unwrap().get_queue_num()?;
+ let num = self.backend.get_queue_num()?;
let msg = VhostUserU64::new(num);
self.send_reply_message(&hdr, &msg)?;
}
@@ -257,11 +393,7 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
_ => return Err(Error::InvalidParam),
};
- let res = self
- .backend
- .lock()
- .unwrap()
- .set_vring_enable(msg.index, enable);
+ let res = self.backend.set_vring_enable(msg.index, enable);
self.send_ack_message(&hdr, res)?;
}
MasterReq::GET_CONFIG => {
@@ -341,7 +473,7 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
}
}
- self.backend.lock().unwrap().set_mem_table(&regions, &fds)
+ self.backend.set_mem_table(&regions, &fds)
}
fn get_config(&mut self, hdr: &VhostUserMsgHeader<MasterReq>, buf: &[u8]) -> Result<()> {
@@ -357,11 +489,7 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
Some(val) => val,
None => return Err(Error::InvalidMessage),
};
- let res = self
- .backend
- .lock()
- .unwrap()
- .get_config(msg.offset, msg.size, flags);
+ let res = self.backend.get_config(msg.offset, msg.size, flags);
// vhost-user slave's payload size MUST match master's request
// on success, uses zero length of payload to indicate an error
@@ -405,11 +533,7 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
None => return Err(Error::InvalidMessage),
}
- let res = self
- .backend
- .lock()
- .unwrap()
- .set_config(msg.offset, buf, flags);
+ let res = self.backend.set_config(msg.offset, buf, flags);
self.send_ack_message(&hdr, res)?;
Ok(())
}
@@ -423,7 +547,7 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
if fds.len() == 1 {
let sock = unsafe { UnixStream::from_raw_fd(fds[0]) };
let vu_req = SlaveFsCacheReq::from_stream(sock);
- self.backend.lock().unwrap().set_slave_req_fd(vu_req);
+ self.backend.set_slave_req_fd(vu_req);
self.send_ack_message(&hdr, Ok(()))
} else {
Err(Error::InvalidMessage)