summaryrefslogtreecommitdiff
path: root/src/vhost_user
diff options
context:
space:
mode:
Diffstat (limited to 'src/vhost_user')
-rw-r--r--src/vhost_user/connection.rs858
-rw-r--r--src/vhost_user/dummy_slave.rs259
-rw-r--r--src/vhost_user/master.rs1071
-rw-r--r--src/vhost_user/master_req_handler.rs477
-rw-r--r--src/vhost_user/message.rs1042
-rw-r--r--src/vhost_user/mod.rs456
-rw-r--r--src/vhost_user/slave.rs86
-rw-r--r--src/vhost_user/slave_fs_cache.rs226
-rw-r--r--src/vhost_user/slave_req_handler.rs828
9 files changed, 5303 insertions, 0 deletions
diff --git a/src/vhost_user/connection.rs b/src/vhost_user/connection.rs
new file mode 100644
index 0000000..f92db45
--- /dev/null
+++ b/src/vhost_user/connection.rs
@@ -0,0 +1,858 @@
+// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved.
+// SPDX-License-Identifier: Apache-2.0
+
+//! Structs for Unix Domain Socket listener and endpoint.
+
+#![allow(dead_code)]
+
+use std::io::ErrorKind;
+use std::marker::PhantomData;
+use std::os::unix::io::{AsRawFd, RawFd};
+use std::os::unix::net::{UnixListener, UnixStream};
+use std::path::{Path, PathBuf};
+use std::{mem, slice};
+
+use libc::{c_void, iovec};
+use sys_util::ScmSocket;
+
+use super::message::*;
+use super::{Error, Result};
+
+/// Unix domain socket listener for accepting incoming connections.
+pub struct Listener {
+ fd: UnixListener,
+ path: PathBuf,
+}
+
+impl Listener {
+ /// Create a unix domain socket listener.
+ ///
+ /// # Return:
+ /// * - the new Listener object on success.
+ /// * - SocketError: failed to create listener socket.
+ pub fn new<P: AsRef<Path>>(path: P, unlink: bool) -> Result<Self> {
+ if unlink {
+ let _ = std::fs::remove_file(&path);
+ }
+ let fd = UnixListener::bind(&path).map_err(Error::SocketError)?;
+ Ok(Listener {
+ fd,
+ path: path.as_ref().to_owned(),
+ })
+ }
+
+ /// Accept an incoming connection.
+ ///
+ /// # Return:
+ /// * - Some(UnixStream): new UnixStream object if new incoming connection is available.
+ /// * - None: no incoming connection available.
+ /// * - SocketError: errors from accept().
+ pub fn accept(&self) -> Result<Option<UnixStream>> {
+ loop {
+ match self.fd.accept() {
+ Ok((socket, _addr)) => return Ok(Some(socket)),
+ Err(e) => {
+ match e.kind() {
+ // No incoming connection available.
+ ErrorKind::WouldBlock => return Ok(None),
+ // New connection closed by peer.
+ ErrorKind::ConnectionAborted => return Ok(None),
+ // Interrupted by signals, retry
+ ErrorKind::Interrupted => continue,
+ _ => return Err(Error::SocketError(e)),
+ }
+ }
+ }
+ }
+ }
+
+ /// Change blocking status on the listener.
+ ///
+ /// # Return:
+ /// * - () on success.
+ /// * - SocketError: failure from set_nonblocking().
+ pub fn set_nonblocking(&self, block: bool) -> Result<()> {
+ self.fd.set_nonblocking(block).map_err(Error::SocketError)
+ }
+}
+
+impl AsRawFd for Listener {
+ fn as_raw_fd(&self) -> RawFd {
+ self.fd.as_raw_fd()
+ }
+}
+
+impl Drop for Listener {
+ fn drop(&mut self) {
+ let _ = std::fs::remove_file(&self.path);
+ }
+}
+
+/// Unix domain socket endpoint for vhost-user connection.
+pub(super) struct Endpoint<R: Req> {
+ sock: UnixStream,
+ _r: PhantomData<R>,
+}
+
+impl<R: Req> Endpoint<R> {
+ /// Create a new stream by connecting to server at `str`.
+ ///
+ /// # Return:
+ /// * - the new Endpoint object on success.
+ /// * - SocketConnect: failed to connect to peer.
+ pub fn connect<P: AsRef<Path>>(path: P) -> Result<Self> {
+ let sock = UnixStream::connect(path).map_err(Error::SocketConnect)?;
+ Ok(Self::from_stream(sock))
+ }
+
+ /// Create an endpoint from a stream object.
+ pub fn from_stream(sock: UnixStream) -> Self {
+ Endpoint {
+ sock,
+ _r: PhantomData,
+ }
+ }
+
+ /// Sends bytes from scatter-gather vectors over the socket with optional attached file
+ /// descriptors.
+ ///
+ /// # Return:
+ /// * - number of bytes sent on success
+ /// * - SocketRetry: temporary error caused by signals or short of resources.
+ /// * - SocketBroken: the underline socket is broken.
+ /// * - SocketError: other socket related errors.
+ pub fn send_iovec(&mut self, iovs: &[&[u8]], fds: Option<&[RawFd]>) -> Result<usize> {
+ let rfds = match fds {
+ Some(rfds) => rfds,
+ _ => &[],
+ };
+ self.sock.send_bufs_with_fds(iovs, rfds).map_err(Into::into)
+ }
+
+ /// Sends all bytes from scatter-gather vectors over the socket with optional attached file
+ /// descriptors. Will loop until all data has been transfered.
+ ///
+ /// # Return:
+ /// * - number of bytes sent on success
+ /// * - SocketBroken: the underline socket is broken.
+ /// * - SocketError: other socket related errors.
+ pub fn send_iovec_all(&mut self, iovs: &[&[u8]], fds: Option<&[RawFd]>) -> Result<usize> {
+ let mut data_sent = 0;
+ let mut data_total = 0;
+ let iov_lens: Vec<usize> = iovs.iter().map(|iov| iov.len()).collect();
+ for len in &iov_lens {
+ data_total += len;
+ }
+
+ while (data_total - data_sent) > 0 {
+ let (nr_skip, offset) = get_sub_iovs_offset(&iov_lens, data_sent);
+ let iov = &iovs[nr_skip][offset..];
+
+ let data = &[&[iov], &iovs[(nr_skip + 1)..]].concat();
+ let sfds = if data_sent == 0 { fds } else { None };
+
+ let sent = self.send_iovec(data, sfds);
+ match sent {
+ Ok(0) => return Ok(data_sent),
+ Ok(n) => data_sent += n,
+ Err(e) => match e {
+ Error::SocketRetry(_) => {}
+ _ => return Err(e),
+ },
+ }
+ }
+ Ok(data_sent)
+ }
+
+ /// Sends bytes from a slice over the socket with optional attached file descriptors.
+ ///
+ /// # Return:
+ /// * - number of bytes sent on success
+ /// * - SocketRetry: temporary error caused by signals or short of resources.
+ /// * - SocketBroken: the underline socket is broken.
+ /// * - SocketError: other socket related errors.
+ pub fn send_slice(&mut self, data: &[u8], fds: Option<&[RawFd]>) -> Result<usize> {
+ self.send_iovec(&[data], fds)
+ }
+
+ /// Sends a header-only message with optional attached file descriptors.
+ ///
+ /// # Return:
+ /// * - number of bytes sent on success
+ /// * - SocketRetry: temporary error caused by signals or short of resources.
+ /// * - SocketBroken: the underline socket is broken.
+ /// * - SocketError: other socket related errors.
+ /// * - PartialMessage: received a partial message.
+ pub fn send_header(
+ &mut self,
+ hdr: &VhostUserMsgHeader<R>,
+ fds: Option<&[RawFd]>,
+ ) -> Result<()> {
+ // Safe because there can't be other mutable referance to hdr.
+ let iovs = unsafe {
+ [slice::from_raw_parts(
+ hdr as *const VhostUserMsgHeader<R> as *const u8,
+ mem::size_of::<VhostUserMsgHeader<R>>(),
+ )]
+ };
+ let bytes = self.send_iovec_all(&iovs[..], fds)?;
+ if bytes != mem::size_of::<VhostUserMsgHeader<R>>() {
+ return Err(Error::PartialMessage);
+ }
+ Ok(())
+ }
+
+ /// Send a message with header and body. Optional file descriptors may be attached to
+ /// the message.
+ ///
+ /// # Return:
+ /// * - number of bytes sent on success
+ /// * - SocketRetry: temporary error caused by signals or short of resources.
+ /// * - SocketBroken: the underline socket is broken.
+ /// * - SocketError: other socket related errors.
+ /// * - PartialMessage: received a partial message.
+ pub fn send_message<T: Sized>(
+ &mut self,
+ hdr: &VhostUserMsgHeader<R>,
+ body: &T,
+ fds: Option<&[RawFd]>,
+ ) -> Result<()> {
+ if mem::size_of::<T>() > MAX_MSG_SIZE {
+ return Err(Error::OversizedMsg);
+ }
+ // Safe because there can't be other mutable referance to hdr and body.
+ let iovs = unsafe {
+ [
+ slice::from_raw_parts(
+ hdr as *const VhostUserMsgHeader<R> as *const u8,
+ mem::size_of::<VhostUserMsgHeader<R>>(),
+ ),
+ slice::from_raw_parts(body as *const T as *const u8, mem::size_of::<T>()),
+ ]
+ };
+ let bytes = self.send_iovec_all(&iovs[..], fds)?;
+ if bytes != mem::size_of::<VhostUserMsgHeader<R>>() + mem::size_of::<T>() {
+ return Err(Error::PartialMessage);
+ }
+ Ok(())
+ }
+
+ /// Send a message with header, body and payload. Optional file descriptors
+ /// may also be attached to the message.
+ ///
+ /// # Return:
+ /// * - number of bytes sent on success
+ /// * - SocketRetry: temporary error caused by signals or short of resources.
+ /// * - SocketBroken: the underline socket is broken.
+ /// * - SocketError: other socket related errors.
+ /// * - OversizedMsg: message size is too big.
+ /// * - PartialMessage: received a partial message.
+ /// * - IncorrectFds: wrong number of attached fds.
+ pub fn send_message_with_payload<T: Sized>(
+ &mut self,
+ hdr: &VhostUserMsgHeader<R>,
+ body: &T,
+ payload: &[u8],
+ fds: Option<&[RawFd]>,
+ ) -> Result<()> {
+ let len = payload.len();
+ if mem::size_of::<T>() > MAX_MSG_SIZE {
+ return Err(Error::OversizedMsg);
+ }
+ if len > MAX_MSG_SIZE - mem::size_of::<T>() {
+ return Err(Error::OversizedMsg);
+ }
+ if let Some(fd_arr) = fds {
+ if fd_arr.len() > MAX_ATTACHED_FD_ENTRIES {
+ return Err(Error::IncorrectFds);
+ }
+ }
+
+ // Safe because there can't be other mutable reference to hdr, body and payload.
+ let iovs = unsafe {
+ [
+ slice::from_raw_parts(
+ hdr as *const VhostUserMsgHeader<R> as *const u8,
+ mem::size_of::<VhostUserMsgHeader<R>>(),
+ ),
+ slice::from_raw_parts(body as *const T as *const u8, mem::size_of::<T>()),
+ slice::from_raw_parts(payload.as_ptr() as *const u8, len),
+ ]
+ };
+ let total = mem::size_of::<VhostUserMsgHeader<R>>() + mem::size_of::<T>() + len;
+ let len = self.send_iovec_all(&iovs, fds)?;
+ if len != total {
+ return Err(Error::PartialMessage);
+ }
+ Ok(())
+ }
+
+ /// Reads bytes from the socket into the given scatter/gather vectors.
+ ///
+ /// # Return:
+ /// * - (number of bytes received, buf) on success
+ /// * - SocketRetry: temporary error caused by signals or short of resources.
+ /// * - SocketBroken: the underline socket is broken.
+ /// * - SocketError: other socket related errors.
+ pub fn recv_data(&mut self, len: usize) -> Result<(usize, Vec<u8>)> {
+ let mut rbuf = vec![0u8; len];
+ let (bytes, _) = self.sock.recv_with_fds(&mut rbuf[..], &mut [])?;
+ Ok((bytes, rbuf))
+ }
+
+ /// Reads bytes from the socket into the given scatter/gather vectors with optional attached
+ /// file descriptors.
+ ///
+ /// The underlying communication channel is a Unix domain socket in STREAM mode. It's a little
+ /// tricky to pass file descriptors through such a communication channel. Let's assume that a
+ /// sender sending a message with some file descriptors attached. To successfully receive those
+ /// attached file descriptors, the receiver must obey following rules:
+ /// 1) file descriptors are attached to a message.
+ /// 2) message(packet) boundaries must be respected on the receive side.
+ /// In other words, recvmsg() operations must not cross the packet boundary, otherwise the
+ /// attached file descriptors will get lost.
+ ///
+ /// # Return:
+ /// * - (number of bytes received, [received fds]) on success
+ /// * - SocketRetry: temporary error caused by signals or short of resources.
+ /// * - SocketBroken: the underline socket is broken.
+ /// * - SocketError: other socket related errors.
+ pub fn recv_into_iovec(&mut self, iovs: &mut [iovec]) -> Result<(usize, Option<Vec<RawFd>>)> {
+ let mut fd_array = vec![0; MAX_ATTACHED_FD_ENTRIES];
+ let (bytes, fds) = self.sock.recv_iovecs_with_fds(iovs, &mut fd_array)?;
+ let rfds = match fds {
+ 0 => None,
+ n => {
+ let mut fds = Vec::with_capacity(n);
+ fds.extend_from_slice(&fd_array[0..n]);
+ Some(fds)
+ }
+ };
+
+ Ok((bytes, rfds))
+ }
+
+ /// Reads all bytes from the socket into the given scatter/gather vectors with optional
+ /// attached file descriptors. Will loop until all data has been transfered.
+ ///
+ /// The underlying communication channel is a Unix domain socket in STREAM mode. It's a little
+ /// tricky to pass file descriptors through such a communication channel. Let's assume that a
+ /// sender sending a message with some file descriptors attached. To successfully receive those
+ /// attached file descriptors, the receiver must obey following rules:
+ /// 1) file descriptors are attached to a message.
+ /// 2) message(packet) boundaries must be respected on the receive side.
+ /// In other words, recvmsg() operations must not cross the packet boundary, otherwise the
+ /// attached file descriptors will get lost.
+ ///
+ /// # Return:
+ /// * - (number of bytes received, [received fds]) on success
+ /// * - SocketBroken: the underline socket is broken.
+ /// * - SocketError: other socket related errors.
+ pub fn recv_into_iovec_all(
+ &mut self,
+ iovs: &mut [iovec],
+ ) -> Result<(usize, Option<Vec<RawFd>>)> {
+ let mut data_read = 0;
+ let mut data_total = 0;
+ let mut rfds = None;
+ let iov_lens: Vec<usize> = iovs.iter().map(|iov| iov.iov_len).collect();
+ for len in &iov_lens {
+ data_total += len;
+ }
+
+ while (data_total - data_read) > 0 {
+ let (nr_skip, offset) = get_sub_iovs_offset(&iov_lens, data_read);
+ let iov = &mut iovs[nr_skip];
+
+ let mut data = [
+ &[iovec {
+ iov_base: (iov.iov_base as usize + offset) as *mut c_void,
+ iov_len: iov.iov_len - offset,
+ }],
+ &iovs[(nr_skip + 1)..],
+ ]
+ .concat();
+
+ let res = self.recv_into_iovec(&mut data);
+ match res {
+ Ok((0, _)) => return Ok((data_read, rfds)),
+ Ok((n, fds)) => {
+ if data_read == 0 {
+ rfds = fds;
+ }
+ data_read += n;
+ }
+ Err(e) => match e {
+ Error::SocketRetry(_) => {}
+ _ => return Err(e),
+ },
+ }
+ }
+ Ok((data_read, rfds))
+ }
+
+ /// Reads bytes from the socket into a new buffer with optional attached
+ /// file descriptors. Received file descriptors are set close-on-exec.
+ ///
+ /// # Return:
+ /// * - (number of bytes received, buf, [received fds]) on success.
+ /// * - SocketRetry: temporary error caused by signals or short of resources.
+ /// * - SocketBroken: the underline socket is broken.
+ /// * - SocketError: other socket related errors.
+ pub fn recv_into_buf(
+ &mut self,
+ buf_size: usize,
+ ) -> Result<(usize, Vec<u8>, Option<Vec<RawFd>>)> {
+ let mut buf = vec![0u8; buf_size];
+ let (bytes, rfds) = {
+ let mut iovs = [iovec {
+ iov_base: buf.as_mut_ptr() as *mut c_void,
+ iov_len: buf_size,
+ }];
+ self.recv_into_iovec(&mut iovs)?
+ };
+ Ok((bytes, buf, rfds))
+ }
+
+ /// Receive a header-only message with optional attached file descriptors.
+ /// Note, only the first MAX_ATTACHED_FD_ENTRIES file descriptors will be
+ /// accepted and all other file descriptor will be discard silently.
+ ///
+ /// # Return:
+ /// * - (message header, [received fds]) on success.
+ /// * - SocketRetry: temporary error caused by signals or short of resources.
+ /// * - SocketBroken: the underline socket is broken.
+ /// * - SocketError: other socket related errors.
+ /// * - PartialMessage: received a partial message.
+ /// * - InvalidMessage: received a invalid message.
+ pub fn recv_header(&mut self) -> Result<(VhostUserMsgHeader<R>, Option<Vec<RawFd>>)> {
+ let mut hdr = VhostUserMsgHeader::default();
+ let mut iovs = [iovec {
+ iov_base: (&mut hdr as *mut VhostUserMsgHeader<R>) as *mut c_void,
+ iov_len: mem::size_of::<VhostUserMsgHeader<R>>(),
+ }];
+ let (bytes, rfds) = self.recv_into_iovec_all(&mut iovs[..])?;
+
+ if bytes != mem::size_of::<VhostUserMsgHeader<R>>() {
+ return Err(Error::PartialMessage);
+ } else if !hdr.is_valid() {
+ return Err(Error::InvalidMessage);
+ }
+
+ Ok((hdr, rfds))
+ }
+
+ /// Receive a message with optional attached file descriptors.
+ /// Note, only the first MAX_ATTACHED_FD_ENTRIES file descriptors will be
+ /// accepted and all other file descriptor will be discard silently.
+ ///
+ /// # Return:
+ /// * - (message header, message body, [received fds]) on success.
+ /// * - SocketRetry: temporary error caused by signals or short of resources.
+ /// * - SocketBroken: the underline socket is broken.
+ /// * - SocketError: other socket related errors.
+ /// * - PartialMessage: received a partial message.
+ /// * - InvalidMessage: received a invalid message.
+ pub fn recv_body<T: Sized + Default + VhostUserMsgValidator>(
+ &mut self,
+ ) -> Result<(VhostUserMsgHeader<R>, T, Option<Vec<RawFd>>)> {
+ let mut hdr = VhostUserMsgHeader::default();
+ let mut body: T = Default::default();
+ let mut iovs = [
+ iovec {
+ iov_base: (&mut hdr as *mut VhostUserMsgHeader<R>) as *mut c_void,
+ iov_len: mem::size_of::<VhostUserMsgHeader<R>>(),
+ },
+ iovec {
+ iov_base: (&mut body as *mut T) as *mut c_void,
+ iov_len: mem::size_of::<T>(),
+ },
+ ];
+ let (bytes, rfds) = self.recv_into_iovec_all(&mut iovs[..])?;
+
+ let total = mem::size_of::<VhostUserMsgHeader<R>>() + mem::size_of::<T>();
+ if bytes != total {
+ return Err(Error::PartialMessage);
+ } else if !hdr.is_valid() || !body.is_valid() {
+ return Err(Error::InvalidMessage);
+ }
+
+ Ok((hdr, body, rfds))
+ }
+
+ /// Receive a message with header and optional content. Callers need to
+ /// pre-allocate a big enough buffer to receive the message body and
+ /// optional payload. If there are attached file descriptor associated
+ /// with the message, the first MAX_ATTACHED_FD_ENTRIES file descriptors
+ /// will be accepted and all other file descriptor will be discard
+ /// silently.
+ ///
+ /// # Return:
+ /// * - (message header, message size, [received fds]) on success.
+ /// * - SocketRetry: temporary error caused by signals or short of resources.
+ /// * - SocketBroken: the underline socket is broken.
+ /// * - SocketError: other socket related errors.
+ /// * - PartialMessage: received a partial message.
+ /// * - InvalidMessage: received a invalid message.
+ pub fn recv_body_into_buf(
+ &mut self,
+ buf: &mut [u8],
+ ) -> Result<(VhostUserMsgHeader<R>, usize, Option<Vec<RawFd>>)> {
+ let mut hdr = VhostUserMsgHeader::default();
+ let mut iovs = [
+ iovec {
+ iov_base: (&mut hdr as *mut VhostUserMsgHeader<R>) as *mut c_void,
+ iov_len: mem::size_of::<VhostUserMsgHeader<R>>(),
+ },
+ iovec {
+ iov_base: buf.as_mut_ptr() as *mut c_void,
+ iov_len: buf.len(),
+ },
+ ];
+ let (bytes, rfds) = self.recv_into_iovec_all(&mut iovs[..])?;
+
+ if bytes < mem::size_of::<VhostUserMsgHeader<R>>() {
+ return Err(Error::PartialMessage);
+ } else if !hdr.is_valid() {
+ return Err(Error::InvalidMessage);
+ }
+
+ Ok((hdr, bytes - mem::size_of::<VhostUserMsgHeader<R>>(), rfds))
+ }
+
+ /// Receive a message with optional payload and attached file descriptors.
+ /// Note, only the first MAX_ATTACHED_FD_ENTRIES file descriptors will be
+ /// accepted and all other file descriptor will be discard silently.
+ ///
+ /// # Return:
+ /// * - (message header, message body, size of payload, [received fds]) on success.
+ /// * - SocketRetry: temporary error caused by signals or short of resources.
+ /// * - SocketBroken: the underline socket is broken.
+ /// * - SocketError: other socket related errors.
+ /// * - PartialMessage: received a partial message.
+ /// * - InvalidMessage: received a invalid message.
+ #[cfg_attr(feature = "cargo-clippy", allow(clippy::type_complexity))]
+ pub fn recv_payload_into_buf<T: Sized + Default + VhostUserMsgValidator>(
+ &mut self,
+ buf: &mut [u8],
+ ) -> Result<(VhostUserMsgHeader<R>, T, usize, Option<Vec<RawFd>>)> {
+ let mut hdr = VhostUserMsgHeader::default();
+ let mut body: T = Default::default();
+ let mut iovs = [
+ iovec {
+ iov_base: (&mut hdr as *mut VhostUserMsgHeader<R>) as *mut c_void,
+ iov_len: mem::size_of::<VhostUserMsgHeader<R>>(),
+ },
+ iovec {
+ iov_base: (&mut body as *mut T) as *mut c_void,
+ iov_len: mem::size_of::<T>(),
+ },
+ iovec {
+ iov_base: buf.as_mut_ptr() as *mut c_void,
+ iov_len: buf.len(),
+ },
+ ];
+ let (bytes, rfds) = self.recv_into_iovec_all(&mut iovs[..])?;
+
+ let total = mem::size_of::<VhostUserMsgHeader<R>>() + mem::size_of::<T>();
+ if bytes < total {
+ return Err(Error::PartialMessage);
+ } else if !hdr.is_valid() || !body.is_valid() {
+ return Err(Error::InvalidMessage);
+ }
+
+ Ok((hdr, body, bytes - total, rfds))
+ }
+
+ /// Close all raw file descriptors.
+ pub fn close_rfds(rfds: Option<Vec<RawFd>>) {
+ if let Some(fds) = rfds {
+ for fd in fds {
+ // safe because the rawfds are valid and we don't care about the result.
+ let _ = unsafe { libc::close(fd) };
+ }
+ }
+ }
+}
+
+impl<T: Req> AsRawFd for Endpoint<T> {
+ fn as_raw_fd(&self) -> RawFd {
+ self.sock.as_raw_fd()
+ }
+}
+
+// Given a slice of sizes and the `skip_size`, return the offset of `skip_size` in the slice.
+// For example:
+// let iov_lens = vec![4, 4, 5];
+// let size = 6;
+// assert_eq!(get_sub_iovs_offset(&iov_len, size), (1, 2));
+fn get_sub_iovs_offset(iov_lens: &[usize], skip_size: usize) -> (usize, usize) {
+ let mut size = skip_size;
+ let mut nr_skip = 0;
+
+ for len in iov_lens {
+ if size >= *len {
+ size -= *len;
+ nr_skip += 1;
+ } else {
+ break;
+ }
+ }
+ (nr_skip, size)
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use std::fs::File;
+ use std::io::{Read, Seek, SeekFrom, Write};
+ use std::os::unix::io::FromRawFd;
+ use tempfile::{tempfile, Builder, TempDir};
+
+ fn temp_dir() -> TempDir {
+ Builder::new().prefix("/tmp/vhost_test").tempdir().unwrap()
+ }
+
+ #[test]
+ fn create_listener() {
+ let dir = temp_dir();
+ let mut path = dir.path().to_owned();
+ path.push("sock");
+ let listener = Listener::new(&path, true).unwrap();
+
+ assert!(listener.as_raw_fd() > 0);
+ }
+
+ #[test]
+ fn accept_connection() {
+ let dir = temp_dir();
+ let mut path = dir.path().to_owned();
+ path.push("sock");
+ let listener = Listener::new(&path, true).unwrap();
+ listener.set_nonblocking(true).unwrap();
+
+ // accept on a fd without incoming connection
+ let conn = listener.accept().unwrap();
+ assert!(conn.is_none());
+ }
+
+ #[test]
+ fn send_data() {
+ let dir = temp_dir();
+ let mut path = dir.path().to_owned();
+ path.push("sock");
+ let listener = Listener::new(&path, true).unwrap();
+ listener.set_nonblocking(true).unwrap();
+ let mut master = Endpoint::<MasterReq>::connect(&path).unwrap();
+ let sock = listener.accept().unwrap().unwrap();
+ let mut slave = Endpoint::<MasterReq>::from_stream(sock);
+
+ let buf1 = vec![0x1, 0x2, 0x3, 0x4];
+ let mut len = master.send_slice(&buf1[..], None).unwrap();
+ assert_eq!(len, 4);
+ let (bytes, buf2, _) = slave.recv_into_buf(0x1000).unwrap();
+ assert_eq!(bytes, 4);
+ assert_eq!(&buf1[..], &buf2[..bytes]);
+
+ len = master.send_slice(&buf1[..], None).unwrap();
+ assert_eq!(len, 4);
+ let (bytes, buf2, _) = slave.recv_into_buf(0x2).unwrap();
+ assert_eq!(bytes, 2);
+ assert_eq!(&buf1[..2], &buf2[..]);
+ let (bytes, buf2, _) = slave.recv_into_buf(0x2).unwrap();
+ assert_eq!(bytes, 2);
+ assert_eq!(&buf1[2..], &buf2[..]);
+ }
+
+ #[test]
+ fn send_fd() {
+ let dir = temp_dir();
+ let mut path = dir.path().to_owned();
+ path.push("sock");
+ let listener = Listener::new(&path, true).unwrap();
+ listener.set_nonblocking(true).unwrap();
+ let mut master = Endpoint::<MasterReq>::connect(&path).unwrap();
+ let sock = listener.accept().unwrap().unwrap();
+ let mut slave = Endpoint::<MasterReq>::from_stream(sock);
+
+ let mut fd = tempfile().unwrap();
+ write!(fd, "test").unwrap();
+
+ // Normal case for sending/receiving file descriptors
+ let buf1 = vec![0x1, 0x2, 0x3, 0x4];
+ let len = master
+ .send_slice(&buf1[..], Some(&[fd.as_raw_fd()]))
+ .unwrap();
+ assert_eq!(len, 4);
+
+ let (bytes, buf2, rfds) = slave.recv_into_buf(4).unwrap();
+ assert_eq!(bytes, 4);
+ assert_eq!(&buf1[..], &buf2[..]);
+ assert!(rfds.is_some());
+ let fds = rfds.unwrap();
+ {
+ assert_eq!(fds.len(), 1);
+ let mut file = unsafe { File::from_raw_fd(fds[0]) };
+ let mut content = String::new();
+ file.seek(SeekFrom::Start(0)).unwrap();
+ file.read_to_string(&mut content).unwrap();
+ assert_eq!(content, "test");
+ }
+
+ // Following communication pattern should work:
+ // Sending side: data(header, body) with fds
+ // Receiving side: data(header) with fds, data(body)
+ let len = master
+ .send_slice(
+ &buf1[..],
+ Some(&[fd.as_raw_fd(), fd.as_raw_fd(), fd.as_raw_fd()]),
+ )
+ .unwrap();
+ assert_eq!(len, 4);
+
+ let (bytes, buf2, rfds) = slave.recv_into_buf(0x2).unwrap();
+ assert_eq!(bytes, 2);
+ assert_eq!(&buf1[..2], &buf2[..]);
+ assert!(rfds.is_some());
+ let fds = rfds.unwrap();
+ {
+ assert_eq!(fds.len(), 3);
+ let mut file = unsafe { File::from_raw_fd(fds[1]) };
+ let mut content = String::new();
+ file.seek(SeekFrom::Start(0)).unwrap();
+ file.read_to_string(&mut content).unwrap();
+ assert_eq!(content, "test");
+ }
+ let (bytes, buf2, rfds) = slave.recv_into_buf(0x2).unwrap();
+ assert_eq!(bytes, 2);
+ assert_eq!(&buf1[2..], &buf2[..]);
+ assert!(rfds.is_none());
+
+ // Following communication pattern should not work:
+ // Sending side: data(header, body) with fds
+ // Receiving side: data(header), data(body) with fds
+ let len = master
+ .send_slice(
+ &buf1[..],
+ Some(&[fd.as_raw_fd(), fd.as_raw_fd(), fd.as_raw_fd()]),
+ )
+ .unwrap();
+ assert_eq!(len, 4);
+
+ let (bytes, buf4) = slave.recv_data(2).unwrap();
+ assert_eq!(bytes, 2);
+ assert_eq!(&buf1[..2], &buf4[..]);
+ let (bytes, buf2, rfds) = slave.recv_into_buf(0x2).unwrap();
+ assert_eq!(bytes, 2);
+ assert_eq!(&buf1[2..], &buf2[..]);
+ assert!(rfds.is_none());
+
+ // Following communication pattern should work:
+ // Sending side: data, data with fds
+ // Receiving side: data, data with fds
+ let len = master.send_slice(&buf1[..], None).unwrap();
+ assert_eq!(len, 4);
+ let len = master
+ .send_slice(
+ &buf1[..],
+ Some(&[fd.as_raw_fd(), fd.as_raw_fd(), fd.as_raw_fd()]),
+ )
+ .unwrap();
+ assert_eq!(len, 4);
+
+ let (bytes, buf2, rfds) = slave.recv_into_buf(0x4).unwrap();
+ assert_eq!(bytes, 4);
+ assert_eq!(&buf1[..], &buf2[..]);
+ assert!(rfds.is_none());
+
+ let (bytes, buf2, rfds) = slave.recv_into_buf(0x2).unwrap();
+ assert_eq!(bytes, 2);
+ assert_eq!(&buf1[..2], &buf2[..]);
+ assert!(rfds.is_some());
+ let fds = rfds.unwrap();
+ {
+ assert_eq!(fds.len(), 3);
+ let mut file = unsafe { File::from_raw_fd(fds[1]) };
+ let mut content = String::new();
+ file.seek(SeekFrom::Start(0)).unwrap();
+ file.read_to_string(&mut content).unwrap();
+ assert_eq!(content, "test");
+ }
+ let (bytes, buf2, rfds) = slave.recv_into_buf(0x2).unwrap();
+ assert_eq!(bytes, 2);
+ assert_eq!(&buf1[2..], &buf2[..]);
+ assert!(rfds.is_none());
+
+ // Following communication pattern should not work:
+ // Sending side: data1, data2 with fds
+ // Receiving side: data + partial of data2, left of data2 with fds
+ let len = master.send_slice(&buf1[..], None).unwrap();
+ assert_eq!(len, 4);
+ let len = master
+ .send_slice(
+ &buf1[..],
+ Some(&[fd.as_raw_fd(), fd.as_raw_fd(), fd.as_raw_fd()]),
+ )
+ .unwrap();
+ assert_eq!(len, 4);
+
+ let (bytes, _) = slave.recv_data(5).unwrap();
+ assert_eq!(bytes, 5);
+
+ let (bytes, _, rfds) = slave.recv_into_buf(0x4).unwrap();
+ assert_eq!(bytes, 3);
+ assert!(rfds.is_none());
+
+ // If the target fd array is too small, extra file descriptors will get lost.
+ let len = master
+ .send_slice(
+ &buf1[..],
+ Some(&[fd.as_raw_fd(), fd.as_raw_fd(), fd.as_raw_fd()]),
+ )
+ .unwrap();
+ assert_eq!(len, 4);
+
+ let (bytes, _, rfds) = slave.recv_into_buf(0x4).unwrap();
+ assert_eq!(bytes, 4);
+ assert!(rfds.is_some());
+
+ Endpoint::<MasterReq>::close_rfds(rfds);
+ Endpoint::<MasterReq>::close_rfds(None);
+ }
+
+ #[test]
+ fn send_recv() {
+ let dir = temp_dir();
+ let mut path = dir.path().to_owned();
+ path.push("sock");
+ let listener = Listener::new(&path, true).unwrap();
+ listener.set_nonblocking(true).unwrap();
+ let mut master = Endpoint::<MasterReq>::connect(&path).unwrap();
+ let sock = listener.accept().unwrap().unwrap();
+ let mut slave = Endpoint::<MasterReq>::from_stream(sock);
+
+ let mut hdr1 =
+ VhostUserMsgHeader::new(MasterReq::GET_FEATURES, 0, mem::size_of::<u64>() as u32);
+ hdr1.set_need_reply(true);
+ let features1 = 0x1u64;
+ master.send_message(&hdr1, &features1, None).unwrap();
+
+ let mut features2 = 0u64;
+ let slice = unsafe {
+ slice::from_raw_parts_mut(
+ (&mut features2 as *mut u64) as *mut u8,
+ mem::size_of::<u64>(),
+ )
+ };
+ let (hdr2, bytes, rfds) = slave.recv_body_into_buf(slice).unwrap();
+ assert_eq!(hdr1, hdr2);
+ assert_eq!(bytes, 8);
+ assert_eq!(features1, features2);
+ assert!(rfds.is_none());
+
+ master.send_header(&hdr1, None).unwrap();
+ let (hdr2, rfds) = slave.recv_header().unwrap();
+ assert_eq!(hdr1, hdr2);
+ assert!(rfds.is_none());
+ }
+}
diff --git a/src/vhost_user/dummy_slave.rs b/src/vhost_user/dummy_slave.rs
new file mode 100644
index 0000000..b2b83d2
--- /dev/null
+++ b/src/vhost_user/dummy_slave.rs
@@ -0,0 +1,259 @@
+// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved.
+// SPDX-License-Identifier: Apache-2.0
+
+use std::os::unix::io::RawFd;
+
+use super::message::*;
+use super::*;
+
+pub const MAX_QUEUE_NUM: usize = 2;
+pub const MAX_VRING_NUM: usize = 256;
+pub const MAX_MEM_SLOTS: usize = 32;
+pub const VIRTIO_FEATURES: u64 = 0x40000003;
+
+#[derive(Default)]
+pub struct DummySlaveReqHandler {
+ pub owned: bool,
+ pub features_acked: bool,
+ pub acked_features: u64,
+ pub acked_protocol_features: u64,
+ pub queue_num: usize,
+ pub vring_num: [u32; MAX_QUEUE_NUM],
+ pub vring_base: [u32; MAX_QUEUE_NUM],
+ pub call_fd: [Option<RawFd>; MAX_QUEUE_NUM],
+ pub kick_fd: [Option<RawFd>; MAX_QUEUE_NUM],
+ pub err_fd: [Option<RawFd>; MAX_QUEUE_NUM],
+ pub vring_started: [bool; MAX_QUEUE_NUM],
+ pub vring_enabled: [bool; MAX_QUEUE_NUM],
+}
+
+impl DummySlaveReqHandler {
+ pub fn new() -> Self {
+ DummySlaveReqHandler {
+ queue_num: MAX_QUEUE_NUM,
+ ..Default::default()
+ }
+ }
+}
+
+impl VhostUserSlaveReqHandlerMut for DummySlaveReqHandler {
+ fn set_owner(&mut self) -> Result<()> {
+ if self.owned {
+ return Err(Error::InvalidOperation);
+ }
+ self.owned = true;
+ Ok(())
+ }
+
+ fn reset_owner(&mut self) -> Result<()> {
+ self.owned = false;
+ self.features_acked = false;
+ self.acked_features = 0;
+ self.acked_protocol_features = 0;
+ Ok(())
+ }
+
+ fn get_features(&mut self) -> Result<u64> {
+ Ok(VIRTIO_FEATURES)
+ }
+
+ fn set_features(&mut self, features: u64) -> Result<()> {
+ if !self.owned || self.features_acked {
+ return Err(Error::InvalidOperation);
+ } else if (features & !VIRTIO_FEATURES) != 0 {
+ return Err(Error::InvalidParam);
+ }
+
+ self.acked_features = features;
+ self.features_acked = true;
+
+ // If VHOST_USER_F_PROTOCOL_FEATURES has not been negotiated,
+ // the ring is initialized in an enabled state.
+ // If VHOST_USER_F_PROTOCOL_FEATURES has been negotiated,
+ // the ring is initialized in a disabled state. Client must not
+ // pass data to/from the backend until ring is enabled by
+ // VHOST_USER_SET_VRING_ENABLE with parameter 1, or after it has
+ // been disabled by VHOST_USER_SET_VRING_ENABLE with parameter 0.
+ let vring_enabled =
+ self.acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() == 0;
+ for enabled in &mut self.vring_enabled {
+ *enabled = vring_enabled;
+ }
+
+ Ok(())
+ }
+
+ fn set_mem_table(&mut self, _ctx: &[VhostUserMemoryRegion], _fds: &[RawFd]) -> Result<()> {
+ Ok(())
+ }
+
+ fn set_vring_num(&mut self, index: u32, num: u32) -> Result<()> {
+ if index as usize >= self.queue_num || num == 0 || num as usize > MAX_VRING_NUM {
+ return Err(Error::InvalidParam);
+ }
+ self.vring_num[index as usize] = num;
+ Ok(())
+ }
+
+ fn set_vring_addr(
+ &mut self,
+ index: u32,
+ _flags: VhostUserVringAddrFlags,
+ _descriptor: u64,
+ _used: u64,
+ _available: u64,
+ _log: u64,
+ ) -> Result<()> {
+ if index as usize >= self.queue_num {
+ return Err(Error::InvalidParam);
+ }
+ Ok(())
+ }
+
+ fn set_vring_base(&mut self, index: u32, base: u32) -> Result<()> {
+ if index as usize >= self.queue_num || base as usize >= MAX_VRING_NUM {
+ return Err(Error::InvalidParam);
+ }
+ self.vring_base[index as usize] = base;
+ Ok(())
+ }
+
+ fn get_vring_base(&mut self, index: u32) -> Result<VhostUserVringState> {
+ if index as usize >= self.queue_num {
+ return Err(Error::InvalidParam);
+ }
+ // Quotation from vhost-user spec:
+ // Client must start ring upon receiving a kick (that is, detecting
+ // that file descriptor is readable) on the descriptor specified by
+ // VHOST_USER_SET_VRING_KICK, and stop ring upon receiving
+ // VHOST_USER_GET_VRING_BASE.
+ self.vring_started[index as usize] = false;
+ Ok(VhostUserVringState::new(
+ index,
+ self.vring_base[index as usize],
+ ))
+ }
+
+ fn set_vring_kick(&mut self, index: u8, fd: Option<RawFd>) -> Result<()> {
+ if index as usize >= self.queue_num || index as usize > self.queue_num {
+ return Err(Error::InvalidParam);
+ }
+ if self.kick_fd[index as usize].is_some() {
+ // Close file descriptor set by previous operations.
+ let _ = unsafe { libc::close(self.kick_fd[index as usize].unwrap()) };
+ }
+ self.kick_fd[index as usize] = fd;
+
+ // Quotation from vhost-user spec:
+ // Client must start ring upon receiving a kick (that is, detecting
+ // that file descriptor is readable) on the descriptor specified by
+ // VHOST_USER_SET_VRING_KICK, and stop ring upon receiving
+ // VHOST_USER_GET_VRING_BASE.
+ //
+ // So we should add fd to event monitor(select, poll, epoll) here.
+ self.vring_started[index as usize] = true;
+ Ok(())
+ }
+
+ fn set_vring_call(&mut self, index: u8, fd: Option<RawFd>) -> Result<()> {
+ if index as usize >= self.queue_num || index as usize > self.queue_num {
+ return Err(Error::InvalidParam);
+ }
+ if self.call_fd[index as usize].is_some() {
+ // Close file descriptor set by previous operations.
+ let _ = unsafe { libc::close(self.call_fd[index as usize].unwrap()) };
+ }
+ self.call_fd[index as usize] = fd;
+ Ok(())
+ }
+
+ fn set_vring_err(&mut self, index: u8, fd: Option<RawFd>) -> Result<()> {
+ if index as usize >= self.queue_num || index as usize > self.queue_num {
+ return Err(Error::InvalidParam);
+ }
+ if self.err_fd[index as usize].is_some() {
+ // Close file descriptor set by previous operations.
+ let _ = unsafe { libc::close(self.err_fd[index as usize].unwrap()) };
+ }
+ self.err_fd[index as usize] = fd;
+ Ok(())
+ }
+
+ fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures> {
+ Ok(VhostUserProtocolFeatures::all())
+ }
+
+ fn set_protocol_features(&mut self, features: u64) -> Result<()> {
+ // Note: slave that reported VHOST_USER_F_PROTOCOL_FEATURES must
+ // support this message even before VHOST_USER_SET_FEATURES was
+ // called.
+ // What happens if the master calls set_features() with
+ // VHOST_USER_F_PROTOCOL_FEATURES cleared after calling this
+ // interface?
+ self.acked_protocol_features = features;
+ Ok(())
+ }
+
+ fn get_queue_num(&mut self) -> Result<u64> {
+ Ok(MAX_QUEUE_NUM as u64)
+ }
+
+ fn set_vring_enable(&mut self, index: u32, enable: bool) -> Result<()> {
+ // This request should be handled only when VHOST_USER_F_PROTOCOL_FEATURES
+ // has been negotiated.
+ if self.acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() == 0 {
+ return Err(Error::InvalidOperation);
+ } else if index as usize >= self.queue_num || index as usize > self.queue_num {
+ return Err(Error::InvalidParam);
+ }
+
+ // Slave must not pass data to/from the backend until ring is
+ // enabled by VHOST_USER_SET_VRING_ENABLE with parameter 1,
+ // or after it has been disabled by VHOST_USER_SET_VRING_ENABLE
+ // with parameter 0.
+ self.vring_enabled[index as usize] = enable;
+ Ok(())
+ }
+
+ fn get_config(
+ &mut self,
+ offset: u32,
+ size: u32,
+ _flags: VhostUserConfigFlags,
+ ) -> Result<Vec<u8>> {
+ if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 {
+ return Err(Error::InvalidOperation);
+ } else if !(VHOST_USER_CONFIG_OFFSET..VHOST_USER_CONFIG_SIZE).contains(&offset)
+ || size > VHOST_USER_CONFIG_SIZE - VHOST_USER_CONFIG_OFFSET
+ || size + offset > VHOST_USER_CONFIG_SIZE
+ {
+ return Err(Error::InvalidParam);
+ }
+ Ok(vec![0xa5; size as usize])
+ }
+
+ fn set_config(&mut self, offset: u32, buf: &[u8], _flags: VhostUserConfigFlags) -> Result<()> {
+ let size = buf.len() as u32;
+ if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 {
+ return Err(Error::InvalidOperation);
+ } else if !(VHOST_USER_CONFIG_OFFSET..VHOST_USER_CONFIG_SIZE).contains(&offset)
+ || size > VHOST_USER_CONFIG_SIZE - VHOST_USER_CONFIG_OFFSET
+ || size + offset > VHOST_USER_CONFIG_SIZE
+ {
+ return Err(Error::InvalidParam);
+ }
+ Ok(())
+ }
+
+ fn get_max_mem_slots(&mut self) -> Result<u64> {
+ Ok(MAX_MEM_SLOTS as u64)
+ }
+
+ fn add_mem_region(&mut self, _region: &VhostUserSingleMemoryRegion, _fd: RawFd) -> Result<()> {
+ Ok(())
+ }
+
+ fn remove_mem_region(&mut self, _region: &VhostUserSingleMemoryRegion) -> Result<()> {
+ Ok(())
+ }
+}
diff --git a/src/vhost_user/master.rs b/src/vhost_user/master.rs
new file mode 100644
index 0000000..cc79871
--- /dev/null
+++ b/src/vhost_user/master.rs
@@ -0,0 +1,1071 @@
+// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved.
+// SPDX-License-Identifier: Apache-2.0
+
+//! Traits and Struct for vhost-user master.
+
+use std::mem;
+use std::os::unix::io::{AsRawFd, RawFd};
+use std::os::unix::net::UnixStream;
+use std::path::Path;
+use std::sync::{Arc, Mutex, MutexGuard};
+
+use sys_util::EventFd;
+
+use super::connection::Endpoint;
+use super::message::*;
+use super::{Error as VhostUserError, Result as VhostUserResult};
+use crate::backend::{VhostBackend, VhostUserMemoryRegionInfo, VringConfigData};
+use crate::{Error, Result};
+
+/// Trait for vhost-user master to provide extra methods not covered by the VhostBackend yet.
+pub trait VhostUserMaster: VhostBackend {
+ /// Get the protocol feature bitmask from the underlying vhost implementation.
+ fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures>;
+
+ /// Enable protocol features in the underlying vhost implementation.
+ fn set_protocol_features(&mut self, features: VhostUserProtocolFeatures) -> Result<()>;
+
+ /// Query how many queues the backend supports.
+ fn get_queue_num(&mut self) -> Result<u64>;
+
+ /// Signal slave to enable or disable corresponding vring.
+ ///
+ /// Slave must not pass data to/from the backend until ring is enabled by
+ /// VHOST_USER_SET_VRING_ENABLE with parameter 1, or after it has been
+ /// disabled by VHOST_USER_SET_VRING_ENABLE with parameter 0.
+ fn set_vring_enable(&mut self, queue_index: usize, enable: bool) -> Result<()>;
+
+ /// Fetch the contents of the virtio device configuration space.
+ fn get_config(
+ &mut self,
+ offset: u32,
+ size: u32,
+ flags: VhostUserConfigFlags,
+ buf: &[u8],
+ ) -> Result<(VhostUserConfig, VhostUserConfigPayload)>;
+
+ /// Change the virtio device configuration space. It also can be used for live migration on the
+ /// destination host to set readonly configuration space fields.
+ fn set_config(&mut self, offset: u32, flags: VhostUserConfigFlags, buf: &[u8]) -> Result<()>;
+
+ /// Setup slave communication channel.
+ fn set_slave_request_fd(&mut self, fd: RawFd) -> Result<()>;
+
+ /// Query the maximum amount of memory slots supported by the backend.
+ fn get_max_mem_slots(&mut self) -> Result<u64>;
+
+ /// Add a new guest memory mapping for vhost to use.
+ fn add_mem_region(&mut self, region: &VhostUserMemoryRegionInfo) -> Result<()>;
+
+ /// Remove a guest memory mapping from vhost.
+ fn remove_mem_region(&mut self, region: &VhostUserMemoryRegionInfo) -> Result<()>;
+}
+
+fn error_code<T>(err: VhostUserError) -> Result<T> {
+ Err(Error::VhostUserProtocol(err))
+}
+
+/// Struct for the vhost-user master endpoint.
+#[derive(Clone)]
+pub struct Master {
+ node: Arc<Mutex<MasterInternal>>,
+}
+
+impl Master {
+ /// Create a new instance.
+ fn new(ep: Endpoint<MasterReq>, max_queue_num: u64) -> Self {
+ Master {
+ node: Arc::new(Mutex::new(MasterInternal {
+ main_sock: ep,
+ virtio_features: 0,
+ acked_virtio_features: 0,
+ protocol_features: 0,
+ acked_protocol_features: 0,
+ protocol_features_ready: false,
+ max_queue_num,
+ error: None,
+ })),
+ }
+ }
+
+ fn node(&self) -> MutexGuard<MasterInternal> {
+ self.node.lock().unwrap()
+ }
+
+ /// Create a new instance from a Unix stream socket.
+ pub fn from_stream(sock: UnixStream, max_queue_num: u64) -> Self {
+ Self::new(Endpoint::<MasterReq>::from_stream(sock), max_queue_num)
+ }
+
+ /// Create a new vhost-user master endpoint.
+ ///
+ /// Will retry as the backend may not be ready to accept the connection.
+ ///
+ /// # Arguments
+ /// * `path` - path of Unix domain socket listener to connect to
+ pub fn connect<P: AsRef<Path>>(path: P, max_queue_num: u64) -> Result<Self> {
+ let mut retry_count = 5;
+ let endpoint = loop {
+ match Endpoint::<MasterReq>::connect(&path) {
+ Ok(endpoint) => break Ok(endpoint),
+ Err(e) => match &e {
+ VhostUserError::SocketConnect(why) => {
+ if why.kind() == std::io::ErrorKind::ConnectionRefused && retry_count > 0 {
+ std::thread::sleep(std::time::Duration::from_millis(100));
+ retry_count -= 1;
+ continue;
+ } else {
+ break Err(e);
+ }
+ }
+ _ => break Err(e),
+ },
+ }
+ }?;
+
+ Ok(Self::new(endpoint, max_queue_num))
+ }
+}
+
+impl VhostBackend for Master {
+ /// Get from the underlying vhost implementation the feature bitmask.
+ fn get_features(&self) -> Result<u64> {
+ let mut node = self.node();
+ let hdr = node.send_request_header(MasterReq::GET_FEATURES, None)?;
+ let val = node.recv_reply::<VhostUserU64>(&hdr)?;
+ node.virtio_features = val.value;
+ Ok(node.virtio_features)
+ }
+
+ /// Enable features in the underlying vhost implementation using a bitmask.
+ fn set_features(&self, features: u64) -> Result<()> {
+ let mut node = self.node();
+ let val = VhostUserU64::new(features);
+ let _ = node.send_request_with_body(MasterReq::SET_FEATURES, &val, None)?;
+ // Don't wait for ACK here because the protocol feature negotiation process hasn't been
+ // completed yet.
+ node.acked_virtio_features = features & node.virtio_features;
+ Ok(())
+ }
+
+ /// Set the current Master as an owner of the session.
+ fn set_owner(&self) -> Result<()> {
+ // We unwrap() the return value to assert that we are not expecting threads to ever fail
+ // while holding the lock.
+ let mut node = self.node();
+ let _ = node.send_request_header(MasterReq::SET_OWNER, None)?;
+ // Don't wait for ACK here because the protocol feature negotiation process hasn't been
+ // completed yet.
+ Ok(())
+ }
+
+ fn reset_owner(&self) -> Result<()> {
+ let mut node = self.node();
+ let _ = node.send_request_header(MasterReq::RESET_OWNER, None)?;
+ // Don't wait for ACK here because the protocol feature negotiation process hasn't been
+ // completed yet.
+ Ok(())
+ }
+
+ /// Set the memory map regions on the slave so it can translate the vring
+ /// addresses. In the ancillary data there is an array of file descriptors
+ fn set_mem_table(&self, regions: &[VhostUserMemoryRegionInfo]) -> Result<()> {
+ if regions.is_empty() || regions.len() > MAX_ATTACHED_FD_ENTRIES {
+ return error_code(VhostUserError::InvalidParam);
+ }
+
+ let mut ctx = VhostUserMemoryContext::new();
+ for region in regions.iter() {
+ if region.memory_size == 0 || region.mmap_handle < 0 {
+ return error_code(VhostUserError::InvalidParam);
+ }
+ let reg = VhostUserMemoryRegion {
+ guest_phys_addr: region.guest_phys_addr,
+ memory_size: region.memory_size,
+ user_addr: region.userspace_addr,
+ mmap_offset: region.mmap_offset,
+ };
+ ctx.append(&reg, region.mmap_handle);
+ }
+
+ let mut node = self.node();
+ let body = VhostUserMemory::new(ctx.regions.len() as u32);
+ let (_, payload, _) = unsafe { ctx.regions.align_to::<u8>() };
+ let hdr = node.send_request_with_payload(
+ MasterReq::SET_MEM_TABLE,
+ &body,
+ payload,
+ Some(ctx.fds.as_slice()),
+ )?;
+ node.wait_for_ack(&hdr).map_err(|e| e.into())
+ }
+
+ // Clippy doesn't seem to know that if let with && is still experimental
+ #[allow(clippy::unnecessary_unwrap)]
+ fn set_log_base(&self, base: u64, fd: Option<RawFd>) -> Result<()> {
+ let mut node = self.node();
+ let val = VhostUserU64::new(base);
+
+ if node.acked_protocol_features & VhostUserProtocolFeatures::LOG_SHMFD.bits() != 0
+ && fd.is_some()
+ {
+ let fds = [fd.unwrap()];
+ let _ = node.send_request_with_body(MasterReq::SET_LOG_BASE, &val, Some(&fds))?;
+ } else {
+ let _ = node.send_request_with_body(MasterReq::SET_LOG_BASE, &val, None)?;
+ }
+ Ok(())
+ }
+
+ fn set_log_fd(&self, fd: RawFd) -> Result<()> {
+ let mut node = self.node();
+ let fds = [fd];
+ node.send_request_header(MasterReq::SET_LOG_FD, Some(&fds))?;
+ Ok(())
+ }
+
+ /// Set the size of the queue.
+ fn set_vring_num(&self, queue_index: usize, num: u16) -> Result<()> {
+ let mut node = self.node();
+ if queue_index as u64 >= node.max_queue_num {
+ return error_code(VhostUserError::InvalidParam);
+ }
+
+ let val = VhostUserVringState::new(queue_index as u32, num.into());
+ let hdr = node.send_request_with_body(MasterReq::SET_VRING_NUM, &val, None)?;
+ node.wait_for_ack(&hdr).map_err(|e| e.into())
+ }
+
+ /// Sets the addresses of the different aspects of the vring.
+ fn set_vring_addr(&self, queue_index: usize, config_data: &VringConfigData) -> Result<()> {
+ let mut node = self.node();
+ if queue_index as u64 >= node.max_queue_num
+ || config_data.flags & !(VhostUserVringAddrFlags::all().bits()) != 0
+ {
+ return error_code(VhostUserError::InvalidParam);
+ }
+
+ let val = VhostUserVringAddr::from_config_data(queue_index as u32, config_data);
+ let hdr = node.send_request_with_body(MasterReq::SET_VRING_ADDR, &val, None)?;
+ node.wait_for_ack(&hdr).map_err(|e| e.into())
+ }
+
+ /// Sets the base offset in the available vring.
+ fn set_vring_base(&self, queue_index: usize, base: u16) -> Result<()> {
+ let mut node = self.node();
+ if queue_index as u64 >= node.max_queue_num {
+ return error_code(VhostUserError::InvalidParam);
+ }
+
+ let val = VhostUserVringState::new(queue_index as u32, base.into());
+ let hdr = node.send_request_with_body(MasterReq::SET_VRING_BASE, &val, None)?;
+ node.wait_for_ack(&hdr).map_err(|e| e.into())
+ }
+
+ fn get_vring_base(&self, queue_index: usize) -> Result<u32> {
+ let mut node = self.node();
+ if queue_index as u64 >= node.max_queue_num {
+ return error_code(VhostUserError::InvalidParam);
+ }
+
+ let req = VhostUserVringState::new(queue_index as u32, 0);
+ let hdr = node.send_request_with_body(MasterReq::GET_VRING_BASE, &req, None)?;
+ let reply = node.recv_reply::<VhostUserVringState>(&hdr)?;
+ Ok(reply.num)
+ }
+
+ /// Set the event file descriptor to signal when buffers are used.
+ /// Bits (0-7) of the payload contain the vring index. Bit 8 is the invalid FD flag. This flag
+ /// is set when there is no file descriptor in the ancillary data. This signals that polling
+ /// will be used instead of waiting for the call.
+ fn set_vring_call(&self, queue_index: usize, fd: &EventFd) -> Result<()> {
+ let mut node = self.node();
+ if queue_index as u64 >= node.max_queue_num {
+ return error_code(VhostUserError::InvalidParam);
+ }
+ node.send_fd_for_vring(MasterReq::SET_VRING_CALL, queue_index, fd.as_raw_fd())?;
+ Ok(())
+ }
+
+ /// Set the event file descriptor for adding buffers to the vring.
+ /// Bits (0-7) of the payload contain the vring index. Bit 8 is the invalid FD flag. This flag
+ /// is set when there is no file descriptor in the ancillary data. This signals that polling
+ /// should be used instead of waiting for a kick.
+ fn set_vring_kick(&self, queue_index: usize, fd: &EventFd) -> Result<()> {
+ let mut node = self.node();
+ if queue_index as u64 >= node.max_queue_num {
+ return error_code(VhostUserError::InvalidParam);
+ }
+ node.send_fd_for_vring(MasterReq::SET_VRING_KICK, queue_index, fd.as_raw_fd())?;
+ Ok(())
+ }
+
+ /// Set the event file descriptor to signal when error occurs.
+ /// Bits (0-7) of the payload contain the vring index. Bit 8 is the invalid FD flag. This flag
+ /// is set when there is no file descriptor in the ancillary data.
+ fn set_vring_err(&self, queue_index: usize, fd: &EventFd) -> Result<()> {
+ let mut node = self.node();
+ if queue_index as u64 >= node.max_queue_num {
+ return error_code(VhostUserError::InvalidParam);
+ }
+ node.send_fd_for_vring(MasterReq::SET_VRING_ERR, queue_index, fd.as_raw_fd())?;
+ Ok(())
+ }
+}
+
+impl VhostUserMaster for Master {
+ fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures> {
+ let mut node = self.node();
+ let flag = VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits();
+ if node.virtio_features & flag == 0 || node.acked_virtio_features & flag == 0 {
+ return error_code(VhostUserError::InvalidOperation);
+ }
+ let hdr = node.send_request_header(MasterReq::GET_PROTOCOL_FEATURES, None)?;
+ let val = node.recv_reply::<VhostUserU64>(&hdr)?;
+ node.protocol_features = val.value;
+ // Should we support forward compatibility?
+ // If so just mask out unrecognized flags instead of return errors.
+ match VhostUserProtocolFeatures::from_bits(node.protocol_features) {
+ Some(val) => Ok(val),
+ None => error_code(VhostUserError::InvalidMessage),
+ }
+ }
+
+ fn set_protocol_features(&mut self, features: VhostUserProtocolFeatures) -> Result<()> {
+ let mut node = self.node();
+ let flag = VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits();
+ if node.virtio_features & flag == 0 || node.acked_virtio_features & flag == 0 {
+ return error_code(VhostUserError::InvalidOperation);
+ }
+ let val = VhostUserU64::new(features.bits());
+ let _ = node.send_request_with_body(MasterReq::SET_PROTOCOL_FEATURES, &val, None)?;
+ // Don't wait for ACK here because the protocol feature negotiation process hasn't been
+ // completed yet.
+ node.acked_protocol_features = features.bits();
+ node.protocol_features_ready = true;
+ Ok(())
+ }
+
+ fn get_queue_num(&mut self) -> Result<u64> {
+ let mut node = self.node();
+ if !node.is_feature_mq_available() {
+ return error_code(VhostUserError::InvalidOperation);
+ }
+
+ let hdr = node.send_request_header(MasterReq::GET_QUEUE_NUM, None)?;
+ let val = node.recv_reply::<VhostUserU64>(&hdr)?;
+ if val.value > VHOST_USER_MAX_VRINGS {
+ return error_code(VhostUserError::InvalidMessage);
+ }
+ node.max_queue_num = val.value;
+ Ok(node.max_queue_num)
+ }
+
+ fn set_vring_enable(&mut self, queue_index: usize, enable: bool) -> Result<()> {
+ let mut node = self.node();
+ // set_vring_enable() is supported only when PROTOCOL_FEATURES has been enabled.
+ if node.acked_virtio_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() == 0 {
+ return error_code(VhostUserError::InvalidOperation);
+ } else if queue_index as u64 >= node.max_queue_num {
+ return error_code(VhostUserError::InvalidParam);
+ }
+
+ let flag = if enable { 1 } else { 0 };
+ let val = VhostUserVringState::new(queue_index as u32, flag);
+ let hdr = node.send_request_with_body(MasterReq::SET_VRING_ENABLE, &val, None)?;
+ node.wait_for_ack(&hdr).map_err(|e| e.into())
+ }
+
+ fn get_config(
+ &mut self,
+ offset: u32,
+ size: u32,
+ flags: VhostUserConfigFlags,
+ buf: &[u8],
+ ) -> Result<(VhostUserConfig, VhostUserConfigPayload)> {
+ let body = VhostUserConfig::new(offset, size, flags);
+ if !body.is_valid() {
+ return error_code(VhostUserError::InvalidParam);
+ }
+
+ let mut node = self.node();
+ // depends on VhostUserProtocolFeatures::CONFIG
+ if node.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 {
+ return error_code(VhostUserError::InvalidOperation);
+ }
+
+ // vhost-user spec states that:
+ // "Master payload: virtio device config space"
+ // "Slave payload: virtio device config space"
+ let hdr = node.send_request_with_payload(MasterReq::GET_CONFIG, &body, buf, None)?;
+ let (body_reply, buf_reply, rfds) =
+ node.recv_reply_with_payload::<VhostUserConfig>(&hdr)?;
+ if rfds.is_some() {
+ Endpoint::<MasterReq>::close_rfds(rfds);
+ return error_code(VhostUserError::InvalidMessage);
+ } else if body_reply.size == 0 {
+ return error_code(VhostUserError::SlaveInternalError);
+ } else if body_reply.size != body.size
+ || body_reply.size as usize != buf.len()
+ || body_reply.offset != body.offset
+ {
+ return error_code(VhostUserError::InvalidMessage);
+ }
+
+ Ok((body_reply, buf_reply))
+ }
+
+ fn set_config(&mut self, offset: u32, flags: VhostUserConfigFlags, buf: &[u8]) -> Result<()> {
+ if buf.len() > MAX_MSG_SIZE {
+ return error_code(VhostUserError::InvalidParam);
+ }
+ let body = VhostUserConfig::new(offset, buf.len() as u32, flags);
+ if !body.is_valid() {
+ return error_code(VhostUserError::InvalidParam);
+ }
+
+ let mut node = self.node();
+ // depends on VhostUserProtocolFeatures::CONFIG
+ if node.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 {
+ return error_code(VhostUserError::InvalidOperation);
+ }
+
+ let hdr = node.send_request_with_payload(MasterReq::SET_CONFIG, &body, buf, None)?;
+ node.wait_for_ack(&hdr).map_err(|e| e.into())
+ }
+
+ fn set_slave_request_fd(&mut self, fd: RawFd) -> Result<()> {
+ let mut node = self.node();
+ if node.acked_protocol_features & VhostUserProtocolFeatures::SLAVE_REQ.bits() == 0 {
+ return error_code(VhostUserError::InvalidOperation);
+ }
+
+ let fds = [fd];
+ node.send_request_header(MasterReq::SET_SLAVE_REQ_FD, Some(&fds))?;
+ Ok(())
+ }
+
+ fn get_max_mem_slots(&mut self) -> Result<u64> {
+ let mut node = self.node();
+ if node.acked_protocol_features & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits() == 0
+ {
+ return error_code(VhostUserError::InvalidOperation);
+ }
+
+ let hdr = node.send_request_header(MasterReq::GET_MAX_MEM_SLOTS, None)?;
+ let val = node.recv_reply::<VhostUserU64>(&hdr)?;
+
+ Ok(val.value)
+ }
+
+ fn add_mem_region(&mut self, region: &VhostUserMemoryRegionInfo) -> Result<()> {
+ let mut node = self.node();
+ if node.acked_protocol_features & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits() == 0
+ {
+ return error_code(VhostUserError::InvalidOperation);
+ }
+ if region.memory_size == 0 || region.mmap_handle < 0 {
+ return error_code(VhostUserError::InvalidParam);
+ }
+
+ let body = VhostUserSingleMemoryRegion::new(
+ region.guest_phys_addr,
+ region.memory_size,
+ region.userspace_addr,
+ region.mmap_offset,
+ );
+ let fds = [region.mmap_handle];
+ let hdr = node.send_request_with_body(MasterReq::ADD_MEM_REG, &body, Some(&fds))?;
+ node.wait_for_ack(&hdr).map_err(|e| e.into())
+ }
+
+ fn remove_mem_region(&mut self, region: &VhostUserMemoryRegionInfo) -> Result<()> {
+ let mut node = self.node();
+ if node.acked_protocol_features & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits() == 0
+ {
+ return error_code(VhostUserError::InvalidOperation);
+ }
+ if region.memory_size == 0 {
+ return error_code(VhostUserError::InvalidParam);
+ }
+
+ let body = VhostUserSingleMemoryRegion::new(
+ region.guest_phys_addr,
+ region.memory_size,
+ region.userspace_addr,
+ region.mmap_offset,
+ );
+ let hdr = node.send_request_with_body(MasterReq::REM_MEM_REG, &body, None)?;
+ node.wait_for_ack(&hdr).map_err(|e| e.into())
+ }
+}
+
+impl AsRawFd for Master {
+ fn as_raw_fd(&self) -> RawFd {
+ let node = self.node();
+ node.main_sock.as_raw_fd()
+ }
+}
+
+/// Context object to pass guest memory configuration to VhostUserMaster::set_mem_table().
+struct VhostUserMemoryContext {
+ regions: VhostUserMemoryPayload,
+ fds: Vec<RawFd>,
+}
+
+impl VhostUserMemoryContext {
+ /// Create a context object.
+ pub fn new() -> Self {
+ VhostUserMemoryContext {
+ regions: VhostUserMemoryPayload::new(),
+ fds: Vec::new(),
+ }
+ }
+
+ /// Append a user memory region and corresponding RawFd into the context object.
+ pub fn append(&mut self, region: &VhostUserMemoryRegion, fd: RawFd) {
+ self.regions.push(*region);
+ self.fds.push(fd);
+ }
+}
+
+struct MasterInternal {
+ // Used to send requests to the slave.
+ main_sock: Endpoint<MasterReq>,
+ // Cached virtio features from the slave.
+ virtio_features: u64,
+ // Cached acked virtio features from the driver.
+ acked_virtio_features: u64,
+ // Cached vhost-user protocol features from the slave.
+ protocol_features: u64,
+ // Cached vhost-user protocol features.
+ acked_protocol_features: u64,
+ // Cached vhost-user protocol features are ready to use.
+ protocol_features_ready: bool,
+ // Cached maxinum number of queues supported from the slave.
+ max_queue_num: u64,
+ // Internal flag to mark failure state.
+ error: Option<i32>,
+}
+
+impl MasterInternal {
+ fn send_request_header(
+ &mut self,
+ code: MasterReq,
+ fds: Option<&[RawFd]>,
+ ) -> VhostUserResult<VhostUserMsgHeader<MasterReq>> {
+ self.check_state()?;
+ let hdr = Self::new_request_header(code, 0);
+ self.main_sock.send_header(&hdr, fds)?;
+ Ok(hdr)
+ }
+
+ fn send_request_with_body<T: Sized>(
+ &mut self,
+ code: MasterReq,
+ msg: &T,
+ fds: Option<&[RawFd]>,
+ ) -> VhostUserResult<VhostUserMsgHeader<MasterReq>> {
+ if mem::size_of::<T>() > MAX_MSG_SIZE {
+ return Err(VhostUserError::InvalidParam);
+ }
+ self.check_state()?;
+
+ let hdr = Self::new_request_header(code, mem::size_of::<T>() as u32);
+ self.main_sock.send_message(&hdr, msg, fds)?;
+ Ok(hdr)
+ }
+
+ fn send_request_with_payload<T: Sized>(
+ &mut self,
+ code: MasterReq,
+ msg: &T,
+ payload: &[u8],
+ fds: Option<&[RawFd]>,
+ ) -> VhostUserResult<VhostUserMsgHeader<MasterReq>> {
+ let len = mem::size_of::<T>() + payload.len();
+ if len > MAX_MSG_SIZE {
+ return Err(VhostUserError::InvalidParam);
+ }
+ if let Some(ref fd_arr) = fds {
+ if fd_arr.len() > MAX_ATTACHED_FD_ENTRIES {
+ return Err(VhostUserError::InvalidParam);
+ }
+ }
+ self.check_state()?;
+
+ let hdr = Self::new_request_header(code, len as u32);
+ self.main_sock
+ .send_message_with_payload(&hdr, msg, payload, fds)?;
+ Ok(hdr)
+ }
+
+ fn send_fd_for_vring(
+ &mut self,
+ code: MasterReq,
+ queue_index: usize,
+ fd: RawFd,
+ ) -> VhostUserResult<VhostUserMsgHeader<MasterReq>> {
+ if queue_index as u64 >= self.max_queue_num {
+ return Err(VhostUserError::InvalidParam);
+ }
+ self.check_state()?;
+
+ // Bits (0-7) of the payload contain the vring index. Bit 8 is the invalid FD flag.
+ // This flag is set when there is no file descriptor in the ancillary data. This signals
+ // that polling will be used instead of waiting for the call.
+ let msg = VhostUserU64::new(queue_index as u64);
+ let hdr = Self::new_request_header(code, mem::size_of::<VhostUserU64>() as u32);
+ self.main_sock.send_message(&hdr, &msg, Some(&[fd]))?;
+ Ok(hdr)
+ }
+
+ fn recv_reply<T: Sized + Default + VhostUserMsgValidator>(
+ &mut self,
+ hdr: &VhostUserMsgHeader<MasterReq>,
+ ) -> VhostUserResult<T> {
+ if mem::size_of::<T>() > MAX_MSG_SIZE || hdr.is_reply() {
+ return Err(VhostUserError::InvalidParam);
+ }
+ self.check_state()?;
+
+ let (reply, body, rfds) = self.main_sock.recv_body::<T>()?;
+ if !reply.is_reply_for(&hdr) || rfds.is_some() || !body.is_valid() {
+ Endpoint::<MasterReq>::close_rfds(rfds);
+ return Err(VhostUserError::InvalidMessage);
+ }
+ Ok(body)
+ }
+
+ fn recv_reply_with_payload<T: Sized + Default + VhostUserMsgValidator>(
+ &mut self,
+ hdr: &VhostUserMsgHeader<MasterReq>,
+ ) -> VhostUserResult<(T, Vec<u8>, Option<Vec<RawFd>>)> {
+ if mem::size_of::<T>() > MAX_MSG_SIZE
+ || hdr.get_size() as usize <= mem::size_of::<T>()
+ || hdr.get_size() as usize > MAX_MSG_SIZE
+ || hdr.is_reply()
+ {
+ return Err(VhostUserError::InvalidParam);
+ }
+ self.check_state()?;
+
+ let mut buf: Vec<u8> = vec![0; hdr.get_size() as usize - mem::size_of::<T>()];
+ let (reply, body, bytes, rfds) = self.main_sock.recv_payload_into_buf::<T>(&mut buf)?;
+ if !reply.is_reply_for(hdr)
+ || reply.get_size() as usize != mem::size_of::<T>() + bytes
+ || rfds.is_some()
+ || !body.is_valid()
+ {
+ Endpoint::<MasterReq>::close_rfds(rfds);
+ return Err(VhostUserError::InvalidMessage);
+ } else if bytes != buf.len() {
+ return Err(VhostUserError::InvalidMessage);
+ }
+ Ok((body, buf, rfds))
+ }
+
+ fn wait_for_ack(&mut self, hdr: &VhostUserMsgHeader<MasterReq>) -> VhostUserResult<()> {
+ if self.acked_protocol_features & VhostUserProtocolFeatures::REPLY_ACK.bits() == 0
+ || !hdr.is_need_reply()
+ {
+ return Ok(());
+ }
+ self.check_state()?;
+
+ let (reply, body, rfds) = self.main_sock.recv_body::<VhostUserU64>()?;
+ if !reply.is_reply_for(&hdr) || rfds.is_some() || !body.is_valid() {
+ Endpoint::<MasterReq>::close_rfds(rfds);
+ return Err(VhostUserError::InvalidMessage);
+ }
+ if body.value != 0 {
+ return Err(VhostUserError::SlaveInternalError);
+ }
+ Ok(())
+ }
+
+ fn is_feature_mq_available(&self) -> bool {
+ self.acked_protocol_features & VhostUserProtocolFeatures::MQ.bits() != 0
+ }
+
+ fn check_state(&self) -> VhostUserResult<()> {
+ match self.error {
+ Some(e) => Err(VhostUserError::SocketBroken(
+ std::io::Error::from_raw_os_error(e),
+ )),
+ None => Ok(()),
+ }
+ }
+
+ #[inline]
+ fn new_request_header(request: MasterReq, size: u32) -> VhostUserMsgHeader<MasterReq> {
+ // TODO: handle NEED_REPLY flag
+ VhostUserMsgHeader::new(request, 0x1, size)
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::super::connection::Listener;
+ use super::*;
+ use tempfile::{Builder, TempDir};
+
+ fn temp_dir() -> TempDir {
+ Builder::new().prefix("/tmp/vhost_test").tempdir().unwrap()
+ }
+
+ fn create_pair<P: AsRef<Path>>(path: P) -> (Master, Endpoint<MasterReq>) {
+ let listener = Listener::new(&path, true).unwrap();
+ listener.set_nonblocking(true).unwrap();
+ let master = Master::connect(path, 2).unwrap();
+ let slave = listener.accept().unwrap().unwrap();
+ (master, Endpoint::from_stream(slave))
+ }
+
+ #[test]
+ fn create_master() {
+ let dir = temp_dir();
+ let mut path = dir.path().to_owned();
+ path.push("sock");
+ let listener = Listener::new(&path, true).unwrap();
+ listener.set_nonblocking(true).unwrap();
+
+ let master = Master::connect(&path, 1).unwrap();
+ let mut slave = Endpoint::<MasterReq>::from_stream(listener.accept().unwrap().unwrap());
+
+ assert!(master.as_raw_fd() > 0);
+ // Send two messages continuously
+ master.set_owner().unwrap();
+ master.reset_owner().unwrap();
+
+ let (hdr, rfds) = slave.recv_header().unwrap();
+ assert_eq!(hdr.get_code(), MasterReq::SET_OWNER);
+ assert_eq!(hdr.get_size(), 0);
+ assert_eq!(hdr.get_version(), 0x1);
+ assert!(rfds.is_none());
+
+ let (hdr, rfds) = slave.recv_header().unwrap();
+ assert_eq!(hdr.get_code(), MasterReq::RESET_OWNER);
+ assert_eq!(hdr.get_size(), 0);
+ assert_eq!(hdr.get_version(), 0x1);
+ assert!(rfds.is_none());
+ }
+
+ #[test]
+ fn test_create_failure() {
+ let dir = temp_dir();
+ let mut path = dir.path().to_owned();
+ path.push("sock");
+ let _ = Listener::new(&path, true).unwrap();
+ let _ = Listener::new(&path, false).is_err();
+ assert!(Master::connect(&path, 1).is_err());
+
+ let listener = Listener::new(&path, true).unwrap();
+ assert!(Listener::new(&path, false).is_err());
+ listener.set_nonblocking(true).unwrap();
+
+ let _master = Master::connect(&path, 1).unwrap();
+ let _slave = listener.accept().unwrap().unwrap();
+ }
+
+ #[test]
+ fn test_features() {
+ let dir = temp_dir();
+ let mut path = dir.path().to_owned();
+ path.push("sock");
+ let (master, mut peer) = create_pair(&path);
+
+ master.set_owner().unwrap();
+ let (hdr, rfds) = peer.recv_header().unwrap();
+ assert_eq!(hdr.get_code(), MasterReq::SET_OWNER);
+ assert_eq!(hdr.get_size(), 0);
+ assert_eq!(hdr.get_version(), 0x1);
+ assert!(rfds.is_none());
+
+ let hdr = VhostUserMsgHeader::new(MasterReq::GET_FEATURES, 0x4, 8);
+ let msg = VhostUserU64::new(0x15);
+ peer.send_message(&hdr, &msg, None).unwrap();
+ let features = master.get_features().unwrap();
+ assert_eq!(features, 0x15u64);
+ let (_hdr, rfds) = peer.recv_header().unwrap();
+ assert!(rfds.is_none());
+
+ let hdr = VhostUserMsgHeader::new(MasterReq::SET_FEATURES, 0x4, 8);
+ let msg = VhostUserU64::new(0x15);
+ peer.send_message(&hdr, &msg, None).unwrap();
+ master.set_features(0x15).unwrap();
+ let (_hdr, msg, rfds) = peer.recv_body::<VhostUserU64>().unwrap();
+ assert!(rfds.is_none());
+ let val = msg.value;
+ assert_eq!(val, 0x15);
+
+ let hdr = VhostUserMsgHeader::new(MasterReq::GET_FEATURES, 0x4, 8);
+ let msg = 0x15u32;
+ peer.send_message(&hdr, &msg, None).unwrap();
+ assert!(master.get_features().is_err());
+ }
+
+ #[test]
+ fn test_protocol_features() {
+ let dir = temp_dir();
+ let mut path = dir.path().to_owned();
+ path.push("sock");
+ let (mut master, mut peer) = create_pair(&path);
+
+ master.set_owner().unwrap();
+ let (hdr, rfds) = peer.recv_header().unwrap();
+ assert_eq!(hdr.get_code(), MasterReq::SET_OWNER);
+ assert!(rfds.is_none());
+
+ assert!(master.get_protocol_features().is_err());
+ assert!(master
+ .set_protocol_features(VhostUserProtocolFeatures::all())
+ .is_err());
+
+ let vfeatures = 0x15 | VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits();
+ let hdr = VhostUserMsgHeader::new(MasterReq::GET_FEATURES, 0x4, 8);
+ let msg = VhostUserU64::new(vfeatures);
+ peer.send_message(&hdr, &msg, None).unwrap();
+ let features = master.get_features().unwrap();
+ assert_eq!(features, vfeatures);
+ let (_hdr, rfds) = peer.recv_header().unwrap();
+ assert!(rfds.is_none());
+
+ master.set_features(vfeatures).unwrap();
+ let (_hdr, msg, rfds) = peer.recv_body::<VhostUserU64>().unwrap();
+ assert!(rfds.is_none());
+ let val = msg.value;
+ assert_eq!(val, vfeatures);
+
+ let pfeatures = VhostUserProtocolFeatures::all();
+ let hdr = VhostUserMsgHeader::new(MasterReq::GET_PROTOCOL_FEATURES, 0x4, 8);
+ let msg = VhostUserU64::new(pfeatures.bits());
+ peer.send_message(&hdr, &msg, None).unwrap();
+ let features = master.get_protocol_features().unwrap();
+ assert_eq!(features, pfeatures);
+ let (_hdr, rfds) = peer.recv_header().unwrap();
+ assert!(rfds.is_none());
+
+ master.set_protocol_features(pfeatures).unwrap();
+ let (_hdr, msg, rfds) = peer.recv_body::<VhostUserU64>().unwrap();
+ assert!(rfds.is_none());
+ let val = msg.value;
+ assert_eq!(val, pfeatures.bits());
+
+ let hdr = VhostUserMsgHeader::new(MasterReq::SET_PROTOCOL_FEATURES, 0x4, 8);
+ let msg = VhostUserU64::new(pfeatures.bits());
+ peer.send_message(&hdr, &msg, None).unwrap();
+ assert!(master.get_protocol_features().is_err());
+ }
+
+ #[test]
+ fn test_master_set_config_negative() {
+ let dir = temp_dir();
+ let mut path = dir.path().to_owned();
+ path.push("sock");
+ let (mut master, _peer) = create_pair(&path);
+ let buf = vec![0x0; MAX_MSG_SIZE + 1];
+
+ master
+ .set_config(0x100, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .unwrap_err();
+
+ {
+ let mut node = master.node();
+ node.virtio_features = 0xffff_ffff;
+ node.acked_virtio_features = 0xffff_ffff;
+ node.protocol_features = 0xffff_ffff;
+ node.acked_protocol_features = 0xffff_ffff;
+ }
+
+ master
+ .set_config(0x100, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .unwrap();
+ master
+ .set_config(0x0, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .unwrap_err();
+ master
+ .set_config(0x1000, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .unwrap_err();
+ master
+ .set_config(
+ 0x100,
+ unsafe { VhostUserConfigFlags::from_bits_unchecked(0xffff_ffff) },
+ &buf[0..4],
+ )
+ .unwrap_err();
+ master
+ .set_config(0x100, VhostUserConfigFlags::WRITABLE, &buf)
+ .unwrap_err();
+ master
+ .set_config(0x100, VhostUserConfigFlags::WRITABLE, &[])
+ .unwrap_err();
+ }
+
+ fn create_pair2() -> (Master, Endpoint<MasterReq>) {
+ let dir = temp_dir();
+ let mut path = dir.path().to_owned();
+ path.push("sock");
+ let (master, peer) = create_pair(&path);
+
+ {
+ let mut node = master.node();
+ node.virtio_features = 0xffff_ffff;
+ node.acked_virtio_features = 0xffff_ffff;
+ node.protocol_features = 0xffff_ffff;
+ node.acked_protocol_features = 0xffff_ffff;
+ }
+
+ (master, peer)
+ }
+
+ #[test]
+ fn test_master_get_config_negative0() {
+ let (mut master, mut peer) = create_pair2();
+ let buf = vec![0x0; MAX_MSG_SIZE + 1];
+
+ let mut hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16);
+ let msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty());
+ peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
+ .unwrap();
+ assert!(master
+ .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .is_ok());
+
+ hdr.set_code(MasterReq::GET_FEATURES);
+ peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
+ .unwrap();
+ assert!(master
+ .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .is_err());
+ hdr.set_code(MasterReq::GET_CONFIG);
+ }
+
+ #[test]
+ fn test_master_get_config_negative1() {
+ let (mut master, mut peer) = create_pair2();
+ let buf = vec![0x0; MAX_MSG_SIZE + 1];
+
+ let mut hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16);
+ let msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty());
+ peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
+ .unwrap();
+ assert!(master
+ .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .is_ok());
+
+ hdr.set_reply(false);
+ peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
+ .unwrap();
+ assert!(master
+ .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .is_err());
+ }
+
+ #[test]
+ fn test_master_get_config_negative2() {
+ let (mut master, mut peer) = create_pair2();
+ let buf = vec![0x0; MAX_MSG_SIZE + 1];
+
+ let hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16);
+ let msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty());
+ peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
+ .unwrap();
+ assert!(master
+ .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .is_ok());
+ }
+
+ #[test]
+ fn test_master_get_config_negative3() {
+ let (mut master, mut peer) = create_pair2();
+ let buf = vec![0x0; MAX_MSG_SIZE + 1];
+
+ let hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16);
+ let mut msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty());
+ peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
+ .unwrap();
+ assert!(master
+ .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .is_ok());
+
+ msg.offset = 0;
+ peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
+ .unwrap();
+ assert!(master
+ .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .is_err());
+ }
+
+ #[test]
+ fn test_master_get_config_negative4() {
+ let (mut master, mut peer) = create_pair2();
+ let buf = vec![0x0; MAX_MSG_SIZE + 1];
+
+ let hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16);
+ let mut msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty());
+ peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
+ .unwrap();
+ assert!(master
+ .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .is_ok());
+
+ msg.offset = 0x101;
+ peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
+ .unwrap();
+ assert!(master
+ .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .is_err());
+ }
+
+ #[test]
+ fn test_master_get_config_negative5() {
+ let (mut master, mut peer) = create_pair2();
+ let buf = vec![0x0; MAX_MSG_SIZE + 1];
+
+ let hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16);
+ let mut msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty());
+ peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
+ .unwrap();
+ assert!(master
+ .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .is_ok());
+
+ msg.offset = (MAX_MSG_SIZE + 1) as u32;
+ peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
+ .unwrap();
+ assert!(master
+ .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .is_err());
+ }
+
+ #[test]
+ fn test_master_get_config_negative6() {
+ let (mut master, mut peer) = create_pair2();
+ let buf = vec![0x0; MAX_MSG_SIZE + 1];
+
+ let hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16);
+ let mut msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty());
+ peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
+ .unwrap();
+ assert!(master
+ .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .is_ok());
+
+ msg.size = 6;
+ peer.send_message_with_payload(&hdr, &msg, &buf[0..6], None)
+ .unwrap();
+ assert!(master
+ .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .is_err());
+ }
+
+ #[test]
+ fn test_maset_set_mem_table_failure() {
+ let (master, _peer) = create_pair2();
+
+ master.set_mem_table(&[]).unwrap_err();
+ let tables = vec![VhostUserMemoryRegionInfo::default(); MAX_ATTACHED_FD_ENTRIES + 1];
+ master.set_mem_table(&tables).unwrap_err();
+ }
+}
diff --git a/src/vhost_user/master_req_handler.rs b/src/vhost_user/master_req_handler.rs
new file mode 100644
index 0000000..8cba188
--- /dev/null
+++ b/src/vhost_user/master_req_handler.rs
@@ -0,0 +1,477 @@
+// Copyright (C) 2019-2021 Alibaba Cloud. All rights reserved.
+// SPDX-License-Identifier: Apache-2.0
+
+use std::mem;
+use std::os::unix::io::{AsRawFd, RawFd};
+use std::os::unix::net::UnixStream;
+use std::sync::{Arc, Mutex};
+
+use super::connection::Endpoint;
+use super::message::*;
+use super::{Error, HandlerResult, Result};
+
+/// Define services provided by masters for the slave communication channel.
+///
+/// The vhost-user specification defines a slave communication channel, by which slaves could
+/// request services from masters. The [VhostUserMasterReqHandler] trait defines services provided
+/// by masters, and it's used both on the master side and slave side.
+/// - on the slave side, a stub forwarder implementing [VhostUserMasterReqHandler] will proxy
+/// service requests to masters. The [SlaveFsCacheReq] is an example stub forwarder.
+/// - on the master side, the [MasterReqHandler] will forward service requests to a handler
+/// implementing [VhostUserMasterReqHandler].
+///
+/// The [VhostUserMasterReqHandler] trait is design with interior mutability to improve performance
+/// for multi-threading.
+///
+/// [VhostUserMasterReqHandler]: trait.VhostUserMasterReqHandler.html
+/// [MasterReqHandler]: struct.MasterReqHandler.html
+/// [SlaveFsCacheReq]: struct.SlaveFsCacheReq.html
+pub trait VhostUserMasterReqHandler {
+ /// Handle device configuration change notifications.
+ fn handle_config_change(&self) -> HandlerResult<u64> {
+ Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
+ }
+
+ /// Handle virtio-fs map file requests.
+ fn fs_slave_map(&self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> {
+ // Safe because we have just received the rawfd from kernel.
+ unsafe { libc::close(fd) };
+ Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
+ }
+
+ /// Handle virtio-fs unmap file requests.
+ fn fs_slave_unmap(&self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
+ Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
+ }
+
+ /// Handle virtio-fs sync file requests.
+ fn fs_slave_sync(&self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
+ Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
+ }
+
+ /// Handle virtio-fs file IO requests.
+ fn fs_slave_io(&self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> {
+ // Safe because we have just received the rawfd from kernel.
+ unsafe { libc::close(fd) };
+ Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
+ }
+
+ // fn handle_iotlb_msg(&mut self, iotlb: VhostUserIotlb);
+ // fn handle_vring_host_notifier(&mut self, area: VhostUserVringArea, fd: RawFd);
+}
+
+/// A helper trait mirroring [VhostUserMasterReqHandler] but without interior mutability.
+///
+/// [VhostUserMasterReqHandler]: trait.VhostUserMasterReqHandler.html
+pub trait VhostUserMasterReqHandlerMut {
+ /// Handle device configuration change notifications.
+ fn handle_config_change(&mut self) -> HandlerResult<u64> {
+ Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
+ }
+
+ /// Handle virtio-fs map file requests.
+ fn fs_slave_map(&mut self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> {
+ // Safe because we have just received the rawfd from kernel.
+ unsafe { libc::close(fd) };
+ Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
+ }
+
+ /// Handle virtio-fs unmap file requests.
+ fn fs_slave_unmap(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
+ Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
+ }
+
+ /// Handle virtio-fs sync file requests.
+ fn fs_slave_sync(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
+ Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
+ }
+
+ /// Handle virtio-fs file IO requests.
+ fn fs_slave_io(&mut self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> {
+ // Safe because we have just received the rawfd from kernel.
+ unsafe { libc::close(fd) };
+ Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
+ }
+
+ // fn handle_iotlb_msg(&mut self, iotlb: VhostUserIotlb);
+ // fn handle_vring_host_notifier(&mut self, area: VhostUserVringArea, fd: RawFd);
+}
+
+impl<S: VhostUserMasterReqHandlerMut> VhostUserMasterReqHandler for Mutex<S> {
+ fn handle_config_change(&self) -> HandlerResult<u64> {
+ self.lock().unwrap().handle_config_change()
+ }
+
+ fn fs_slave_map(&self, fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> {
+ self.lock().unwrap().fs_slave_map(fs, fd)
+ }
+
+ fn fs_slave_unmap(&self, fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
+ self.lock().unwrap().fs_slave_unmap(fs)
+ }
+
+ fn fs_slave_sync(&self, fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
+ self.lock().unwrap().fs_slave_sync(fs)
+ }
+
+ fn fs_slave_io(&self, fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> {
+ self.lock().unwrap().fs_slave_io(fs, fd)
+ }
+}
+
+/// Server to handle service requests from slaves from the slave communication channel.
+///
+/// The [MasterReqHandler] acts as a server on the master side, to handle service requests from
+/// slaves on the slave communication channel. It's actually a proxy invoking the registered
+/// handler implementing [VhostUserMasterReqHandler] to do the real work.
+///
+/// [MasterReqHandler]: struct.MasterReqHandler.html
+/// [VhostUserMasterReqHandler]: trait.VhostUserMasterReqHandler.html
+pub struct MasterReqHandler<S: VhostUserMasterReqHandler> {
+ // underlying Unix domain socket for communication
+ sub_sock: Endpoint<SlaveReq>,
+ tx_sock: UnixStream,
+ // Protocol feature VHOST_USER_PROTOCOL_F_REPLY_ACK has been negotiated.
+ reply_ack_negotiated: bool,
+ // the VirtIO backend device object
+ backend: Arc<S>,
+ // whether the endpoint has encountered any failure
+ error: Option<i32>,
+}
+
+impl<S: VhostUserMasterReqHandler> MasterReqHandler<S> {
+ /// Create a server to handle service requests from slaves on the slave communication channel.
+ ///
+ /// This opens a pair of connected anonymous sockets to form the slave communication channel.
+ /// The socket fd returned by [Self::get_tx_raw_fd()] should be sent to the slave by
+ /// [VhostUserMaster::set_slave_request_fd()].
+ ///
+ /// [Self::get_tx_raw_fd()]: struct.MasterReqHandler.html#method.get_tx_raw_fd
+ /// [VhostUserMaster::set_slave_request_fd()]: trait.VhostUserMaster.html#tymethod.set_slave_request_fd
+ pub fn new(backend: Arc<S>) -> Result<Self> {
+ let (tx, rx) = UnixStream::pair().map_err(Error::SocketError)?;
+
+ Ok(MasterReqHandler {
+ sub_sock: Endpoint::<SlaveReq>::from_stream(rx),
+ tx_sock: tx,
+ reply_ack_negotiated: false,
+ backend,
+ error: None,
+ })
+ }
+
+ /// Get the socket fd for the slave to communication with the master.
+ ///
+ /// The returned fd should be sent to the slave by [VhostUserMaster::set_slave_request_fd()].
+ ///
+ /// [VhostUserMaster::set_slave_request_fd()]: trait.VhostUserMaster.html#tymethod.set_slave_request_fd
+ pub fn get_tx_raw_fd(&self) -> RawFd {
+ self.tx_sock.as_raw_fd()
+ }
+
+ /// Set the negotiation state of the `VHOST_USER_PROTOCOL_F_REPLY_ACK` protocol feature.
+ ///
+ /// When the `VHOST_USER_PROTOCOL_F_REPLY_ACK` protocol feature has been negotiated,
+ /// the "REPLY_ACK" flag will be set in the message header for every slave to master request
+ /// message.
+ pub fn set_reply_ack_flag(&mut self, enable: bool) {
+ self.reply_ack_negotiated = enable;
+ }
+
+ /// Mark endpoint as failed or in normal state.
+ pub fn set_failed(&mut self, error: i32) {
+ if error == 0 {
+ self.error = None;
+ } else {
+ self.error = Some(error);
+ }
+ }
+
+ /// Main entrance to server slave request from the slave communication channel.
+ ///
+ /// The caller needs to:
+ /// - serialize calls to this function
+ /// - decide what to do when errer happens
+ /// - optional recover from failure
+ pub fn handle_request(&mut self) -> Result<u64> {
+ // Return error if the endpoint is already in failed state.
+ self.check_state()?;
+
+ // The underlying communication channel is a Unix domain socket in
+ // stream mode, and recvmsg() is a little tricky here. To successfully
+ // receive attached file descriptors, we need to receive messages and
+ // corresponding attached file descriptors in this way:
+ // . recv messsage header and optional attached file
+ // . validate message header
+ // . recv optional message body and payload according size field in
+ // message header
+ // . validate message body and optional payload
+ let (hdr, rfds) = self.sub_sock.recv_header()?;
+ let rfds = self.check_attached_rfds(&hdr, rfds)?;
+ let (size, buf) = match hdr.get_size() {
+ 0 => (0, vec![0u8; 0]),
+ len => {
+ if len as usize > MAX_MSG_SIZE {
+ return Err(Error::InvalidMessage);
+ }
+ let (size2, rbuf) = self.sub_sock.recv_data(len as usize)?;
+ if size2 != len as usize {
+ return Err(Error::InvalidMessage);
+ }
+ (size2, rbuf)
+ }
+ };
+
+ let res = match hdr.get_code() {
+ SlaveReq::CONFIG_CHANGE_MSG => {
+ self.check_msg_size(&hdr, size, 0)?;
+ self.backend
+ .handle_config_change()
+ .map_err(Error::ReqHandlerError)
+ }
+ SlaveReq::FS_MAP => {
+ let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?;
+ // check_attached_rfds() has validated rfds
+ self.backend
+ .fs_slave_map(&msg, rfds.unwrap()[0])
+ .map_err(Error::ReqHandlerError)
+ }
+ SlaveReq::FS_UNMAP => {
+ let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?;
+ self.backend
+ .fs_slave_unmap(&msg)
+ .map_err(Error::ReqHandlerError)
+ }
+ SlaveReq::FS_SYNC => {
+ let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?;
+ self.backend
+ .fs_slave_sync(&msg)
+ .map_err(Error::ReqHandlerError)
+ }
+ SlaveReq::FS_IO => {
+ let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?;
+ // check_attached_rfds() has validated rfds
+ self.backend
+ .fs_slave_io(&msg, rfds.unwrap()[0])
+ .map_err(Error::ReqHandlerError)
+ }
+ _ => Err(Error::InvalidMessage),
+ };
+
+ self.send_ack_message(&hdr, &res)?;
+
+ res
+ }
+
+ fn check_state(&self) -> Result<()> {
+ match self.error {
+ Some(e) => Err(Error::SocketBroken(std::io::Error::from_raw_os_error(e))),
+ None => Ok(()),
+ }
+ }
+
+ fn check_msg_size(
+ &self,
+ hdr: &VhostUserMsgHeader<SlaveReq>,
+ size: usize,
+ expected: usize,
+ ) -> Result<()> {
+ if hdr.get_size() as usize != expected
+ || hdr.is_reply()
+ || hdr.get_version() != 0x1
+ || size != expected
+ {
+ return Err(Error::InvalidMessage);
+ }
+ Ok(())
+ }
+
+ fn check_attached_rfds(
+ &self,
+ hdr: &VhostUserMsgHeader<SlaveReq>,
+ rfds: Option<Vec<RawFd>>,
+ ) -> Result<Option<Vec<RawFd>>> {
+ match hdr.get_code() {
+ SlaveReq::FS_MAP | SlaveReq::FS_IO => {
+ // Expect an fd set with a single fd.
+ match rfds {
+ None => Err(Error::InvalidMessage),
+ Some(fds) => {
+ if fds.len() != 1 {
+ Endpoint::<SlaveReq>::close_rfds(Some(fds));
+ Err(Error::InvalidMessage)
+ } else {
+ Ok(Some(fds))
+ }
+ }
+ }
+ }
+ _ => {
+ if rfds.is_some() {
+ Endpoint::<SlaveReq>::close_rfds(rfds);
+ Err(Error::InvalidMessage)
+ } else {
+ Ok(rfds)
+ }
+ }
+ }
+ }
+
+ fn extract_msg_body<T: Sized + VhostUserMsgValidator>(
+ &self,
+ hdr: &VhostUserMsgHeader<SlaveReq>,
+ size: usize,
+ buf: &[u8],
+ ) -> Result<T> {
+ self.check_msg_size(hdr, size, mem::size_of::<T>())?;
+ let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const T) };
+ if !msg.is_valid() {
+ return Err(Error::InvalidMessage);
+ }
+ Ok(msg)
+ }
+
+ fn new_reply_header<T: Sized>(
+ &self,
+ req: &VhostUserMsgHeader<SlaveReq>,
+ ) -> Result<VhostUserMsgHeader<SlaveReq>> {
+ if mem::size_of::<T>() > MAX_MSG_SIZE {
+ return Err(Error::InvalidParam);
+ }
+ self.check_state()?;
+ Ok(VhostUserMsgHeader::new(
+ req.get_code(),
+ VhostUserHeaderFlag::REPLY.bits(),
+ mem::size_of::<T>() as u32,
+ ))
+ }
+
+ fn send_ack_message(
+ &mut self,
+ req: &VhostUserMsgHeader<SlaveReq>,
+ res: &Result<u64>,
+ ) -> Result<()> {
+ if self.reply_ack_negotiated && req.is_need_reply() {
+ let hdr = self.new_reply_header::<VhostUserU64>(req)?;
+ let def_err = libc::EINVAL;
+ let val = match res {
+ Ok(n) => *n,
+ Err(e) => match &*e {
+ Error::ReqHandlerError(ioerr) => match ioerr.raw_os_error() {
+ Some(rawerr) => -rawerr as u64,
+ None => -def_err as u64,
+ },
+ _ => -def_err as u64,
+ },
+ };
+ let msg = VhostUserU64::new(val);
+ self.sub_sock.send_message(&hdr, &msg, None)?;
+ }
+ Ok(())
+ }
+}
+
+impl<S: VhostUserMasterReqHandler> AsRawFd for MasterReqHandler<S> {
+ fn as_raw_fd(&self) -> RawFd {
+ self.sub_sock.as_raw_fd()
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[cfg(feature = "vhost-user-slave")]
+ use crate::vhost_user::SlaveFsCacheReq;
+ #[cfg(feature = "vhost-user-slave")]
+ use std::os::unix::io::FromRawFd;
+
+ struct MockMasterReqHandler {}
+
+ impl VhostUserMasterReqHandlerMut for MockMasterReqHandler {
+ /// Handle virtio-fs map file requests from the slave.
+ fn fs_slave_map(&mut self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> {
+ // Safe because we have just received the rawfd from kernel.
+ unsafe { libc::close(fd) };
+ Ok(0)
+ }
+
+ /// Handle virtio-fs unmap file requests from the slave.
+ fn fs_slave_unmap(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
+ Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
+ }
+ }
+
+ #[test]
+ fn test_new_master_req_handler() {
+ let backend = Arc::new(Mutex::new(MockMasterReqHandler {}));
+ let mut handler = MasterReqHandler::new(backend).unwrap();
+
+ assert!(handler.get_tx_raw_fd() >= 0);
+ assert!(handler.as_raw_fd() >= 0);
+ handler.check_state().unwrap();
+
+ assert_eq!(handler.error, None);
+ handler.set_failed(libc::EAGAIN);
+ assert_eq!(handler.error, Some(libc::EAGAIN));
+ handler.check_state().unwrap_err();
+ }
+
+ #[cfg(feature = "vhost-user-slave")]
+ #[test]
+ fn test_master_slave_req_handler() {
+ let backend = Arc::new(Mutex::new(MockMasterReqHandler {}));
+ let mut handler = MasterReqHandler::new(backend).unwrap();
+
+ let fd = unsafe { libc::dup(handler.get_tx_raw_fd()) };
+ if fd < 0 {
+ panic!("failed to duplicated tx fd!");
+ }
+ let stream = unsafe { UnixStream::from_raw_fd(fd) };
+ let fs_cache = SlaveFsCacheReq::from_stream(stream);
+
+ std::thread::spawn(move || {
+ let res = handler.handle_request().unwrap();
+ assert_eq!(res, 0);
+ handler.handle_request().unwrap_err();
+ });
+
+ fs_cache
+ .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd)
+ .unwrap();
+ // When REPLY_ACK has not been negotiated, the master has no way to detect failure from
+ // slave side.
+ fs_cache
+ .fs_slave_unmap(&VhostUserFSSlaveMsg::default())
+ .unwrap();
+ }
+
+ #[cfg(feature = "vhost-user-slave")]
+ #[test]
+ fn test_master_slave_req_handler_with_ack() {
+ let backend = Arc::new(Mutex::new(MockMasterReqHandler {}));
+ let mut handler = MasterReqHandler::new(backend).unwrap();
+ handler.set_reply_ack_flag(true);
+
+ let fd = unsafe { libc::dup(handler.get_tx_raw_fd()) };
+ if fd < 0 {
+ panic!("failed to duplicated tx fd!");
+ }
+ let stream = unsafe { UnixStream::from_raw_fd(fd) };
+ let fs_cache = SlaveFsCacheReq::from_stream(stream);
+
+ std::thread::spawn(move || {
+ let res = handler.handle_request().unwrap();
+ assert_eq!(res, 0);
+ handler.handle_request().unwrap_err();
+ });
+
+ fs_cache.set_reply_ack_flag(true);
+ fs_cache
+ .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd)
+ .unwrap();
+ fs_cache
+ .fs_slave_unmap(&VhostUserFSSlaveMsg::default())
+ .unwrap_err();
+ }
+}
diff --git a/src/vhost_user/message.rs b/src/vhost_user/message.rs
new file mode 100644
index 0000000..ea2df4e
--- /dev/null
+++ b/src/vhost_user/message.rs
@@ -0,0 +1,1042 @@
+// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved.
+// SPDX-License-Identifier: Apache-2.0
+
+//! Define communication messages for the vhost-user protocol.
+//!
+//! For message definition, please refer to the [vhost-user spec](https://github.com/qemu/qemu/blob/f7526eece29cd2e36a63b6703508b24453095eb8/docs/interop/vhost-user.txt).
+
+#![allow(dead_code)]
+#![allow(non_camel_case_types)]
+
+use std::fmt::Debug;
+use std::marker::PhantomData;
+
+use crate::VringConfigData;
+
+/// The vhost-user specification uses a field of u32 to store message length.
+/// On the other hand, preallocated buffers are needed to receive messages from the Unix domain
+/// socket. To preallocating a 4GB buffer for each vhost-user message is really just an overhead.
+/// Among all defined vhost-user messages, only the VhostUserConfig and VhostUserMemory has variable
+/// message size. For the VhostUserConfig, a maximum size of 4K is enough because the user
+/// configuration space for virtio devices is (4K - 0x100) bytes at most. For the VhostUserMemory,
+/// 4K should be enough too because it can support 255 memory regions at most.
+pub const MAX_MSG_SIZE: usize = 0x1000;
+
+/// The VhostUserMemory message has variable message size and variable number of attached file
+/// descriptors. Each user memory region entry in the message payload occupies 32 bytes,
+/// so setting maximum number of attached file descriptors based on the maximum message size.
+/// But rust only implements Default and AsMut traits for arrays with 0 - 32 entries, so further
+/// reduce the maximum number...
+// pub const MAX_ATTACHED_FD_ENTRIES: usize = (MAX_MSG_SIZE - 8) / 32;
+pub const MAX_ATTACHED_FD_ENTRIES: usize = 32;
+
+/// Starting position (inclusion) of the device configuration space in virtio devices.
+pub const VHOST_USER_CONFIG_OFFSET: u32 = 0x100;
+
+/// Ending position (exclusion) of the device configuration space in virtio devices.
+pub const VHOST_USER_CONFIG_SIZE: u32 = 0x1000;
+
+/// Maximum number of vrings supported.
+pub const VHOST_USER_MAX_VRINGS: u64 = 0x8000u64;
+
+pub(super) trait Req:
+ Clone + Copy + Debug + PartialEq + Eq + PartialOrd + Ord + Into<u32>
+{
+ fn is_valid(&self) -> bool;
+}
+
+/// Type of requests sending from masters to slaves.
+#[repr(u32)]
+#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
+pub enum MasterReq {
+ /// Null operation.
+ NOOP = 0,
+ /// Get from the underlying vhost implementation the features bit mask.
+ GET_FEATURES = 1,
+ /// Enable features in the underlying vhost implementation using a bit mask.
+ SET_FEATURES = 2,
+ /// Set the current Master as an owner of the session.
+ SET_OWNER = 3,
+ /// No longer used.
+ RESET_OWNER = 4,
+ /// Set the memory map regions on the slave so it can translate the vring addresses.
+ SET_MEM_TABLE = 5,
+ /// Set logging shared memory space.
+ SET_LOG_BASE = 6,
+ /// Set the logging file descriptor, which is passed as ancillary data.
+ SET_LOG_FD = 7,
+ /// Set the size of the queue.
+ SET_VRING_NUM = 8,
+ /// Set the addresses of the different aspects of the vring.
+ SET_VRING_ADDR = 9,
+ /// Set the base offset in the available vring.
+ SET_VRING_BASE = 10,
+ /// Get the available vring base offset.
+ GET_VRING_BASE = 11,
+ /// Set the event file descriptor for adding buffers to the vring.
+ SET_VRING_KICK = 12,
+ /// Set the event file descriptor to signal when buffers are used.
+ SET_VRING_CALL = 13,
+ /// Set the event file descriptor to signal when error occurs.
+ SET_VRING_ERR = 14,
+ /// Get the protocol feature bit mask from the underlying vhost implementation.
+ GET_PROTOCOL_FEATURES = 15,
+ /// Enable protocol features in the underlying vhost implementation.
+ SET_PROTOCOL_FEATURES = 16,
+ /// Query how many queues the backend supports.
+ GET_QUEUE_NUM = 17,
+ /// Signal slave to enable or disable corresponding vring.
+ SET_VRING_ENABLE = 18,
+ /// Ask vhost user backend to broadcast a fake RARP to notify the migration is terminated
+ /// for guest that does not support GUEST_ANNOUNCE.
+ SEND_RARP = 19,
+ /// Set host MTU value exposed to the guest.
+ NET_SET_MTU = 20,
+ /// Set the socket file descriptor for slave initiated requests.
+ SET_SLAVE_REQ_FD = 21,
+ /// Send IOTLB messages with struct vhost_iotlb_msg as payload.
+ IOTLB_MSG = 22,
+ /// Set the endianness of a VQ for legacy devices.
+ SET_VRING_ENDIAN = 23,
+ /// Fetch the contents of the virtio device configuration space.
+ GET_CONFIG = 24,
+ /// Change the contents of the virtio device configuration space.
+ SET_CONFIG = 25,
+ /// Create a session for crypto operation.
+ CREATE_CRYPTO_SESSION = 26,
+ /// Close a session for crypto operation.
+ CLOSE_CRYPTO_SESSION = 27,
+ /// Advise slave that a migration with postcopy enabled is underway.
+ POSTCOPY_ADVISE = 28,
+ /// Advise slave that a transition to postcopy mode has happened.
+ POSTCOPY_LISTEN = 29,
+ /// Advise that postcopy migration has now completed.
+ POSTCOPY_END = 30,
+ /// Get a shared buffer from slave.
+ GET_INFLIGHT_FD = 31,
+ /// Send the shared inflight buffer back to slave.
+ SET_INFLIGHT_FD = 32,
+ /// Sets the GPU protocol socket file descriptor.
+ GPU_SET_SOCKET = 33,
+ /// Ask the vhost user backend to disable all rings and reset all internal
+ /// device state to the initial state.
+ RESET_DEVICE = 34,
+ /// Indicate that a buffer was added to the vring instead of signalling it
+ /// using the vring’s kick file descriptor.
+ VRING_KICK = 35,
+ /// Return a u64 payload containing the maximum number of memory slots.
+ GET_MAX_MEM_SLOTS = 36,
+ /// Update the memory tables by adding the region described.
+ ADD_MEM_REG = 37,
+ /// Update the memory tables by removing the region described.
+ REM_MEM_REG = 38,
+ /// Notify the backend with updated device status as defined in the VIRTIO
+ /// specification.
+ SET_STATUS = 39,
+ /// Query the backend for its device status as defined in the VIRTIO
+ /// specification.
+ GET_STATUS = 40,
+ /// Upper bound of valid commands.
+ MAX_CMD = 41,
+}
+
+impl Into<u32> for MasterReq {
+ fn into(self) -> u32 {
+ self as u32
+ }
+}
+
+impl Req for MasterReq {
+ fn is_valid(&self) -> bool {
+ (*self > MasterReq::NOOP) && (*self < MasterReq::MAX_CMD)
+ }
+}
+
+/// Type of requests sending from slaves to masters.
+#[repr(u32)]
+#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
+pub enum SlaveReq {
+ /// Null operation.
+ NOOP = 0,
+ /// Send IOTLB messages with struct vhost_iotlb_msg as payload.
+ IOTLB_MSG = 1,
+ /// Notify that the virtio device's configuration space has changed.
+ CONFIG_CHANGE_MSG = 2,
+ /// Set host notifier for a specified queue.
+ VRING_HOST_NOTIFIER_MSG = 3,
+ /// Indicate that a buffer was used from the vring.
+ VRING_CALL = 4,
+ /// Indicate that an error occurred on the specific vring.
+ VRING_ERR = 5,
+ /// Virtio-fs draft: map file content into the window.
+ FS_MAP = 6,
+ /// Virtio-fs draft: unmap file content from the window.
+ FS_UNMAP = 7,
+ /// Virtio-fs draft: sync file content.
+ FS_SYNC = 8,
+ /// Virtio-fs draft: perform a read/write from an fd directly to GPA.
+ FS_IO = 9,
+ /// Upper bound of valid commands.
+ MAX_CMD = 10,
+}
+
+impl Into<u32> for SlaveReq {
+ fn into(self) -> u32 {
+ self as u32
+ }
+}
+
+impl Req for SlaveReq {
+ fn is_valid(&self) -> bool {
+ (*self > SlaveReq::NOOP) && (*self < SlaveReq::MAX_CMD)
+ }
+}
+
+/// Vhost message Validator.
+pub trait VhostUserMsgValidator {
+ /// Validate message syntax only.
+ /// It doesn't validate message semantics such as protocol version number and dependency
+ /// on feature flags etc.
+ fn is_valid(&self) -> bool {
+ true
+ }
+}
+
+// Bit mask for common message flags.
+bitflags! {
+ /// Common message flags for vhost-user requests and replies.
+ pub struct VhostUserHeaderFlag: u32 {
+ /// Bits[0..2] is message version number.
+ const VERSION = 0x3;
+ /// Mark message as reply.
+ const REPLY = 0x4;
+ /// Sender anticipates a reply message from the peer.
+ const NEED_REPLY = 0x8;
+ /// All valid bits.
+ const ALL_FLAGS = 0xc;
+ /// All reserved bits.
+ const RESERVED_BITS = !0xf;
+ }
+}
+
+/// Common message header for vhost-user requests and replies.
+/// A vhost-user message consists of 3 header fields and an optional payload. All numbers are in the
+/// machine native byte order.
+#[allow(safe_packed_borrows)]
+#[repr(packed)]
+#[derive(Debug, Clone, Copy, PartialEq)]
+pub(super) struct VhostUserMsgHeader<R: Req> {
+ request: u32,
+ flags: u32,
+ size: u32,
+ _r: PhantomData<R>,
+}
+
+impl<R: Req> VhostUserMsgHeader<R> {
+ /// Create a new instance of `VhostUserMsgHeader`.
+ pub fn new(request: R, flags: u32, size: u32) -> Self {
+ // Default to protocol version 1
+ let fl = (flags & VhostUserHeaderFlag::ALL_FLAGS.bits()) | 0x1;
+ VhostUserMsgHeader {
+ request: request.into(),
+ flags: fl,
+ size,
+ _r: PhantomData,
+ }
+ }
+
+ /// Get message type.
+ pub fn get_code(&self) -> R {
+ // It's safe because R is marked as repr(u32).
+ unsafe { std::mem::transmute_copy::<u32, R>(&self.request) }
+ }
+
+ /// Set message type.
+ pub fn set_code(&mut self, request: R) {
+ self.request = request.into();
+ }
+
+ /// Get message version number.
+ pub fn get_version(&self) -> u32 {
+ self.flags & 0x3
+ }
+
+ /// Set message version number.
+ pub fn set_version(&mut self, ver: u32) {
+ self.flags &= !0x3;
+ self.flags |= ver & 0x3;
+ }
+
+ /// Check whether it's a reply message.
+ pub fn is_reply(&self) -> bool {
+ (self.flags & VhostUserHeaderFlag::REPLY.bits()) != 0
+ }
+
+ /// Mark message as reply.
+ pub fn set_reply(&mut self, is_reply: bool) {
+ if is_reply {
+ self.flags |= VhostUserHeaderFlag::REPLY.bits();
+ } else {
+ self.flags &= !VhostUserHeaderFlag::REPLY.bits();
+ }
+ }
+
+ /// Check whether reply for this message is requested.
+ pub fn is_need_reply(&self) -> bool {
+ (self.flags & VhostUserHeaderFlag::NEED_REPLY.bits()) != 0
+ }
+
+ /// Mark that reply for this message is needed.
+ pub fn set_need_reply(&mut self, need_reply: bool) {
+ if need_reply {
+ self.flags |= VhostUserHeaderFlag::NEED_REPLY.bits();
+ } else {
+ self.flags &= !VhostUserHeaderFlag::NEED_REPLY.bits();
+ }
+ }
+
+ /// Check whether it's the reply message for the request `req`.
+ pub fn is_reply_for(&self, req: &VhostUserMsgHeader<R>) -> bool {
+ self.is_reply() && !req.is_reply() && self.get_code() == req.get_code()
+ }
+
+ /// Get message size.
+ pub fn get_size(&self) -> u32 {
+ self.size
+ }
+
+ /// Set message size.
+ pub fn set_size(&mut self, size: u32) {
+ self.size = size;
+ }
+}
+
+impl<R: Req> Default for VhostUserMsgHeader<R> {
+ fn default() -> Self {
+ VhostUserMsgHeader {
+ request: 0,
+ flags: 0x1,
+ size: 0,
+ _r: PhantomData,
+ }
+ }
+}
+
+impl<T: Req> VhostUserMsgValidator for VhostUserMsgHeader<T> {
+ #[allow(clippy::if_same_then_else)]
+ fn is_valid(&self) -> bool {
+ if !self.get_code().is_valid() {
+ return false;
+ } else if self.size as usize > MAX_MSG_SIZE {
+ return false;
+ } else if self.get_version() != 0x1 {
+ return false;
+ } else if (self.flags & VhostUserHeaderFlag::RESERVED_BITS.bits()) != 0 {
+ return false;
+ }
+ true
+ }
+}
+
+// Bit mask for transport specific flags in VirtIO feature set defined by vhost-user.
+bitflags! {
+ /// Transport specific flags in VirtIO feature set defined by vhost-user.
+ pub struct VhostUserVirtioFeatures: u64 {
+ /// Feature flag for the protocol feature.
+ const PROTOCOL_FEATURES = 0x4000_0000;
+ }
+}
+
+// Bit mask for vhost-user protocol feature flags.
+bitflags! {
+ /// Vhost-user protocol feature flags.
+ pub struct VhostUserProtocolFeatures: u64 {
+ /// Support multiple queues.
+ const MQ = 0x0000_0001;
+ /// Support logging through shared memory fd.
+ const LOG_SHMFD = 0x0000_0002;
+ /// Support broadcasting fake RARP packet.
+ const RARP = 0x0000_0004;
+ /// Support sending reply messages for requests with NEED_REPLY flag set.
+ const REPLY_ACK = 0x0000_0008;
+ /// Support setting MTU for virtio-net devices.
+ const MTU = 0x0000_0010;
+ /// Allow the slave to send requests to the master by an optional communication channel.
+ const SLAVE_REQ = 0x0000_0020;
+ /// Support setting slave endian by SET_VRING_ENDIAN.
+ const CROSS_ENDIAN = 0x0000_0040;
+ /// Support crypto operations.
+ const CRYPTO_SESSION = 0x0000_0080;
+ /// Support sending userfault_fd from slaves to masters.
+ const PAGEFAULT = 0x0000_0100;
+ /// Support Virtio device configuration.
+ const CONFIG = 0x0000_0200;
+ /// Allow the slave to send fds (at most 8 descriptors in each message) to the master.
+ const SLAVE_SEND_FD = 0x0000_0400;
+ /// Allow the slave to register a host notifier.
+ const HOST_NOTIFIER = 0x0000_0800;
+ /// Support inflight shmfd.
+ const INFLIGHT_SHMFD = 0x0000_1000;
+ /// Support resetting the device.
+ const RESET_DEVICE = 0x0000_2000;
+ /// Support inband notifications.
+ const INBAND_NOTIFICATIONS = 0x0000_4000;
+ /// Support configuring memory slots.
+ const CONFIGURE_MEM_SLOTS = 0x0000_8000;
+ /// Support reporting status.
+ const STATUS = 0x0001_0000;
+ }
+}
+
+/// A generic message to encapsulate a 64-bit value.
+#[repr(packed)]
+#[derive(Default)]
+pub struct VhostUserU64 {
+ /// The encapsulated 64-bit common value.
+ pub value: u64,
+}
+
+impl VhostUserU64 {
+ /// Create a new instance.
+ pub fn new(value: u64) -> Self {
+ VhostUserU64 { value }
+ }
+}
+
+impl VhostUserMsgValidator for VhostUserU64 {}
+
+/// Memory region descriptor for the SET_MEM_TABLE request.
+#[repr(packed)]
+#[derive(Default)]
+pub struct VhostUserMemory {
+ /// Number of memory regions in the payload.
+ pub num_regions: u32,
+ /// Padding for alignment.
+ pub padding1: u32,
+}
+
+impl VhostUserMemory {
+ /// Create a new instance.
+ pub fn new(cnt: u32) -> Self {
+ VhostUserMemory {
+ num_regions: cnt,
+ padding1: 0,
+ }
+ }
+}
+
+impl VhostUserMsgValidator for VhostUserMemory {
+ #[allow(clippy::if_same_then_else)]
+ fn is_valid(&self) -> bool {
+ if self.padding1 != 0 {
+ return false;
+ } else if self.num_regions == 0 || self.num_regions > MAX_ATTACHED_FD_ENTRIES as u32 {
+ return false;
+ }
+ true
+ }
+}
+
+/// Memory region descriptors as payload for the SET_MEM_TABLE request.
+#[repr(packed)]
+#[derive(Default, Clone, Copy)]
+pub struct VhostUserMemoryRegion {
+ /// Guest physical address of the memory region.
+ pub guest_phys_addr: u64,
+ /// Size of the memory region.
+ pub memory_size: u64,
+ /// Virtual address in the current process.
+ pub user_addr: u64,
+ /// Offset where region starts in the mapped memory.
+ pub mmap_offset: u64,
+}
+
+impl VhostUserMemoryRegion {
+ /// Create a new instance.
+ pub fn new(guest_phys_addr: u64, memory_size: u64, user_addr: u64, mmap_offset: u64) -> Self {
+ VhostUserMemoryRegion {
+ guest_phys_addr,
+ memory_size,
+ user_addr,
+ mmap_offset,
+ }
+ }
+}
+
+impl VhostUserMsgValidator for VhostUserMemoryRegion {
+ fn is_valid(&self) -> bool {
+ if self.memory_size == 0
+ || self.guest_phys_addr.checked_add(self.memory_size).is_none()
+ || self.user_addr.checked_add(self.memory_size).is_none()
+ || self.mmap_offset.checked_add(self.memory_size).is_none()
+ {
+ return false;
+ }
+ true
+ }
+}
+
+/// Payload of the VhostUserMemory message.
+pub type VhostUserMemoryPayload = Vec<VhostUserMemoryRegion>;
+
+/// Single memory region descriptor as payload for ADD_MEM_REG and REM_MEM_REG
+/// requests.
+#[repr(C)]
+#[derive(Default, Clone, Copy)]
+pub struct VhostUserSingleMemoryRegion {
+ /// Padding for correct alignment
+ padding: u64,
+ /// Guest physical address of the memory region.
+ pub guest_phys_addr: u64,
+ /// Size of the memory region.
+ pub memory_size: u64,
+ /// Virtual address in the current process.
+ pub user_addr: u64,
+ /// Offset where region starts in the mapped memory.
+ pub mmap_offset: u64,
+}
+
+impl VhostUserSingleMemoryRegion {
+ /// Create a new instance.
+ pub fn new(guest_phys_addr: u64, memory_size: u64, user_addr: u64, mmap_offset: u64) -> Self {
+ VhostUserSingleMemoryRegion {
+ padding: 0,
+ guest_phys_addr,
+ memory_size,
+ user_addr,
+ mmap_offset,
+ }
+ }
+}
+
+impl VhostUserMsgValidator for VhostUserSingleMemoryRegion {
+ fn is_valid(&self) -> bool {
+ if self.memory_size == 0
+ || self.guest_phys_addr.checked_add(self.memory_size).is_none()
+ || self.user_addr.checked_add(self.memory_size).is_none()
+ || self.mmap_offset.checked_add(self.memory_size).is_none()
+ {
+ return false;
+ }
+ true
+ }
+}
+
+/// Vring state descriptor.
+#[repr(packed)]
+#[derive(Default)]
+pub struct VhostUserVringState {
+ /// Vring index.
+ pub index: u32,
+ /// A common 32bit value to encapsulate vring state etc.
+ pub num: u32,
+}
+
+impl VhostUserVringState {
+ /// Create a new instance.
+ pub fn new(index: u32, num: u32) -> Self {
+ VhostUserVringState { index, num }
+ }
+}
+
+impl VhostUserMsgValidator for VhostUserVringState {}
+
+// Bit mask for vring address flags.
+bitflags! {
+ /// Flags for vring address.
+ pub struct VhostUserVringAddrFlags: u32 {
+ /// Support log of vring operations.
+ /// Modifications to "used" vring should be logged.
+ const VHOST_VRING_F_LOG = 0x1;
+ }
+}
+
+/// Vring address descriptor.
+#[repr(packed)]
+#[derive(Default)]
+pub struct VhostUserVringAddr {
+ /// Vring index.
+ pub index: u32,
+ /// Vring flags defined by VhostUserVringAddrFlags.
+ pub flags: u32,
+ /// Ring address of the vring descriptor table.
+ pub descriptor: u64,
+ /// Ring address of the vring used ring.
+ pub used: u64,
+ /// Ring address of the vring available ring.
+ pub available: u64,
+ /// Guest address for logging.
+ pub log: u64,
+}
+
+impl VhostUserVringAddr {
+ /// Create a new instance.
+ pub fn new(
+ index: u32,
+ flags: VhostUserVringAddrFlags,
+ descriptor: u64,
+ used: u64,
+ available: u64,
+ log: u64,
+ ) -> Self {
+ VhostUserVringAddr {
+ index,
+ flags: flags.bits(),
+ descriptor,
+ used,
+ available,
+ log,
+ }
+ }
+
+ /// Create a new instance from `VringConfigData`.
+ #[cfg_attr(feature = "cargo-clippy", allow(clippy::identity_conversion))]
+ pub fn from_config_data(index: u32, config_data: &VringConfigData) -> Self {
+ let log_addr = config_data.log_addr.unwrap_or(0);
+ VhostUserVringAddr {
+ index,
+ flags: config_data.flags,
+ descriptor: config_data.desc_table_addr,
+ used: config_data.used_ring_addr,
+ available: config_data.avail_ring_addr,
+ log: log_addr,
+ }
+ }
+}
+
+impl VhostUserMsgValidator for VhostUserVringAddr {
+ #[allow(clippy::if_same_then_else)]
+ fn is_valid(&self) -> bool {
+ if (self.flags & !VhostUserVringAddrFlags::all().bits()) != 0 {
+ return false;
+ } else if self.descriptor & 0xf != 0 {
+ return false;
+ } else if self.available & 0x1 != 0 {
+ return false;
+ } else if self.used & 0x3 != 0 {
+ return false;
+ }
+ true
+ }
+}
+
+// Bit mask for the vhost-user device configuration message.
+bitflags! {
+ /// Flags for the device configuration message.
+ pub struct VhostUserConfigFlags: u32 {
+ /// Vhost master messages used for writeable fields.
+ const WRITABLE = 0x1;
+ /// Vhost master messages used for live migration.
+ const LIVE_MIGRATION = 0x2;
+ }
+}
+
+/// Message to read/write device configuration space.
+#[repr(packed)]
+#[derive(Default)]
+pub struct VhostUserConfig {
+ /// Offset of virtio device's configuration space.
+ pub offset: u32,
+ /// Configuration space access size in bytes.
+ pub size: u32,
+ /// Flags for the device configuration operation.
+ pub flags: u32,
+}
+
+impl VhostUserConfig {
+ /// Create a new instance.
+ pub fn new(offset: u32, size: u32, flags: VhostUserConfigFlags) -> Self {
+ VhostUserConfig {
+ offset,
+ size,
+ flags: flags.bits(),
+ }
+ }
+}
+
+impl VhostUserMsgValidator for VhostUserConfig {
+ #[allow(clippy::if_same_then_else)]
+ fn is_valid(&self) -> bool {
+ if (self.flags & !VhostUserConfigFlags::all().bits()) != 0 {
+ return false;
+ } else if self.offset < 0x100 {
+ return false;
+ } else if self.size == 0
+ || self.size > VHOST_USER_CONFIG_SIZE
+ || self.size + self.offset > VHOST_USER_CONFIG_SIZE
+ {
+ return false;
+ }
+ true
+ }
+}
+
+/// Payload for the VhostUserConfig message.
+pub type VhostUserConfigPayload = Vec<u8>;
+
+/*
+ * TODO: support dirty log, live migration and IOTLB operations.
+#[repr(packed)]
+pub struct VhostUserVringArea {
+ pub index: u32,
+ pub flags: u32,
+ pub size: u64,
+ pub offset: u64,
+}
+
+#[repr(packed)]
+pub struct VhostUserLog {
+ pub size: u64,
+ pub offset: u64,
+}
+
+#[repr(packed)]
+pub struct VhostUserIotlb {
+ pub iova: u64,
+ pub size: u64,
+ pub user_addr: u64,
+ pub permission: u8,
+ pub optype: u8,
+}
+*/
+
+// Bit mask for flags in virtio-fs slave messages
+bitflags! {
+ #[derive(Default)]
+ /// Flags for virtio-fs slave messages.
+ pub struct VhostUserFSSlaveMsgFlags: u64 {
+ /// Empty permission.
+ const EMPTY = 0x0;
+ /// Read permission.
+ const MAP_R = 0x1;
+ /// Write permission.
+ const MAP_W = 0x2;
+ }
+}
+
+/// Max entries in one virtio-fs slave request.
+pub const VHOST_USER_FS_SLAVE_ENTRIES: usize = 8;
+
+/// Slave request message to update the MMIO window.
+#[repr(packed)]
+#[derive(Default)]
+pub struct VhostUserFSSlaveMsg {
+ /// File offset.
+ pub fd_offset: [u64; VHOST_USER_FS_SLAVE_ENTRIES],
+ /// Offset into the DAX window.
+ pub cache_offset: [u64; VHOST_USER_FS_SLAVE_ENTRIES],
+ /// Size of region to map.
+ pub len: [u64; VHOST_USER_FS_SLAVE_ENTRIES],
+ /// Flags for the mmap operation
+ pub flags: [VhostUserFSSlaveMsgFlags; VHOST_USER_FS_SLAVE_ENTRIES],
+}
+
+impl VhostUserMsgValidator for VhostUserFSSlaveMsg {
+ fn is_valid(&self) -> bool {
+ for i in 0..VHOST_USER_FS_SLAVE_ENTRIES {
+ if ({ self.flags[i] }.bits() & !VhostUserFSSlaveMsgFlags::all().bits()) != 0
+ || self.fd_offset[i].checked_add(self.len[i]).is_none()
+ || self.cache_offset[i].checked_add(self.len[i]).is_none()
+ {
+ return false;
+ }
+ }
+ true
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use std::mem;
+
+ #[test]
+ fn check_master_request_code() {
+ let code = MasterReq::NOOP;
+ assert!(!code.is_valid());
+ let code = MasterReq::MAX_CMD;
+ assert!(!code.is_valid());
+ assert!(code > MasterReq::NOOP);
+ let code = MasterReq::GET_FEATURES;
+ assert!(code.is_valid());
+ assert_eq!(code, code.clone());
+ let code: MasterReq = unsafe { std::mem::transmute::<u32, MasterReq>(10000u32) };
+ assert!(!code.is_valid());
+ }
+
+ #[test]
+ fn check_slave_request_code() {
+ let code = SlaveReq::NOOP;
+ assert!(!code.is_valid());
+ let code = SlaveReq::MAX_CMD;
+ assert!(!code.is_valid());
+ assert!(code > SlaveReq::NOOP);
+ let code = SlaveReq::CONFIG_CHANGE_MSG;
+ assert!(code.is_valid());
+ assert_eq!(code, code.clone());
+ let code: SlaveReq = unsafe { std::mem::transmute::<u32, SlaveReq>(10000u32) };
+ assert!(!code.is_valid());
+ }
+
+ #[test]
+ fn msg_header_ops() {
+ let mut hdr = VhostUserMsgHeader::new(MasterReq::GET_FEATURES, 0, 0x100);
+ assert_eq!(hdr.get_code(), MasterReq::GET_FEATURES);
+ hdr.set_code(MasterReq::SET_FEATURES);
+ assert_eq!(hdr.get_code(), MasterReq::SET_FEATURES);
+
+ assert_eq!(hdr.get_version(), 0x1);
+
+ assert_eq!(hdr.is_reply(), false);
+ hdr.set_reply(true);
+ assert_eq!(hdr.is_reply(), true);
+ hdr.set_reply(false);
+
+ assert_eq!(hdr.is_need_reply(), false);
+ hdr.set_need_reply(true);
+ assert_eq!(hdr.is_need_reply(), true);
+ hdr.set_need_reply(false);
+
+ assert_eq!(hdr.get_size(), 0x100);
+ hdr.set_size(0x200);
+ assert_eq!(hdr.get_size(), 0x200);
+
+ assert_eq!(hdr.is_need_reply(), false);
+ assert_eq!(hdr.is_reply(), false);
+ assert_eq!(hdr.get_version(), 0x1);
+
+ // Check message length
+ assert!(hdr.is_valid());
+ hdr.set_size(0x2000);
+ assert!(!hdr.is_valid());
+ hdr.set_size(0x100);
+ assert_eq!(hdr.get_size(), 0x100);
+ assert!(hdr.is_valid());
+ hdr.set_size((MAX_MSG_SIZE - mem::size_of::<VhostUserMsgHeader<MasterReq>>()) as u32);
+ assert!(hdr.is_valid());
+ hdr.set_size(0x0);
+ assert!(hdr.is_valid());
+
+ // Check version
+ hdr.set_version(0x0);
+ assert!(!hdr.is_valid());
+ hdr.set_version(0x2);
+ assert!(!hdr.is_valid());
+ hdr.set_version(0x1);
+ assert!(hdr.is_valid());
+
+ assert_eq!(hdr, hdr.clone());
+ }
+
+ #[test]
+ fn test_vhost_user_message_u64() {
+ let val = VhostUserU64::default();
+ let val1 = VhostUserU64::new(0);
+
+ let a = val.value;
+ let b = val1.value;
+ assert_eq!(a, b);
+ let a = VhostUserU64::new(1).value;
+ assert_eq!(a, 1);
+ }
+
+ #[test]
+ fn check_user_memory() {
+ let mut msg = VhostUserMemory::new(1);
+ assert!(msg.is_valid());
+ msg.num_regions = MAX_ATTACHED_FD_ENTRIES as u32;
+ assert!(msg.is_valid());
+
+ msg.num_regions += 1;
+ assert!(!msg.is_valid());
+ msg.num_regions = 0xFFFFFFFF;
+ assert!(!msg.is_valid());
+ msg.num_regions = MAX_ATTACHED_FD_ENTRIES as u32;
+ msg.padding1 = 1;
+ assert!(!msg.is_valid());
+ }
+
+ #[test]
+ fn check_user_memory_region() {
+ let mut msg = VhostUserMemoryRegion {
+ guest_phys_addr: 0,
+ memory_size: 0x1000,
+ user_addr: 0,
+ mmap_offset: 0,
+ };
+ assert!(msg.is_valid());
+ msg.guest_phys_addr = 0xFFFFFFFFFFFFEFFF;
+ assert!(msg.is_valid());
+ msg.guest_phys_addr = 0xFFFFFFFFFFFFF000;
+ assert!(!msg.is_valid());
+ msg.guest_phys_addr = 0xFFFFFFFFFFFF0000;
+ msg.memory_size = 0;
+ assert!(!msg.is_valid());
+ let a = msg.guest_phys_addr;
+ let b = msg.guest_phys_addr;
+ assert_eq!(a, b);
+
+ let msg = VhostUserMemoryRegion::default();
+ let a = msg.guest_phys_addr;
+ assert_eq!(a, 0);
+ let a = msg.memory_size;
+ assert_eq!(a, 0);
+ let a = msg.user_addr;
+ assert_eq!(a, 0);
+ let a = msg.mmap_offset;
+ assert_eq!(a, 0);
+ }
+
+ #[test]
+ fn test_vhost_user_state() {
+ let state = VhostUserVringState::new(5, 8);
+
+ let a = state.index;
+ assert_eq!(a, 5);
+ let a = state.num;
+ assert_eq!(a, 8);
+ assert_eq!(state.is_valid(), true);
+
+ let state = VhostUserVringState::default();
+ let a = state.index;
+ assert_eq!(a, 0);
+ let a = state.num;
+ assert_eq!(a, 0);
+ assert_eq!(state.is_valid(), true);
+ }
+
+ #[test]
+ fn test_vhost_user_addr() {
+ let mut addr = VhostUserVringAddr::new(
+ 2,
+ VhostUserVringAddrFlags::VHOST_VRING_F_LOG,
+ 0x1000,
+ 0x2000,
+ 0x3000,
+ 0x4000,
+ );
+
+ let a = addr.index;
+ assert_eq!(a, 2);
+ let a = addr.flags;
+ assert_eq!(a, VhostUserVringAddrFlags::VHOST_VRING_F_LOG.bits());
+ let a = addr.descriptor;
+ assert_eq!(a, 0x1000);
+ let a = addr.used;
+ assert_eq!(a, 0x2000);
+ let a = addr.available;
+ assert_eq!(a, 0x3000);
+ let a = addr.log;
+ assert_eq!(a, 0x4000);
+ assert_eq!(addr.is_valid(), true);
+
+ addr.descriptor = 0x1001;
+ assert_eq!(addr.is_valid(), false);
+ addr.descriptor = 0x1000;
+
+ addr.available = 0x3001;
+ assert_eq!(addr.is_valid(), false);
+ addr.available = 0x3000;
+
+ addr.used = 0x2001;
+ assert_eq!(addr.is_valid(), false);
+ addr.used = 0x2000;
+ assert_eq!(addr.is_valid(), true);
+ }
+
+ #[test]
+ fn test_vhost_user_state_from_config() {
+ let config = VringConfigData {
+ queue_max_size: 256,
+ queue_size: 128,
+ flags: VhostUserVringAddrFlags::VHOST_VRING_F_LOG.bits,
+ desc_table_addr: 0x1000,
+ used_ring_addr: 0x2000,
+ avail_ring_addr: 0x3000,
+ log_addr: Some(0x4000),
+ };
+ let addr = VhostUserVringAddr::from_config_data(2, &config);
+
+ let a = addr.index;
+ assert_eq!(a, 2);
+ let a = addr.flags;
+ assert_eq!(a, VhostUserVringAddrFlags::VHOST_VRING_F_LOG.bits());
+ let a = addr.descriptor;
+ assert_eq!(a, 0x1000);
+ let a = addr.used;
+ assert_eq!(a, 0x2000);
+ let a = addr.available;
+ assert_eq!(a, 0x3000);
+ let a = addr.log;
+ assert_eq!(a, 0x4000);
+ assert_eq!(addr.is_valid(), true);
+ }
+
+ #[test]
+ fn check_user_vring_addr() {
+ let mut msg =
+ VhostUserVringAddr::new(0, VhostUserVringAddrFlags::all(), 0x0, 0x0, 0x0, 0x0);
+ assert!(msg.is_valid());
+
+ msg.descriptor = 1;
+ assert!(!msg.is_valid());
+ msg.descriptor = 0;
+
+ msg.available = 1;
+ assert!(!msg.is_valid());
+ msg.available = 0;
+
+ msg.used = 1;
+ assert!(!msg.is_valid());
+ msg.used = 0;
+
+ msg.flags |= 0x80000000;
+ assert!(!msg.is_valid());
+ msg.flags &= !0x80000000;
+ }
+
+ #[test]
+ fn check_user_config_msg() {
+ let mut msg = VhostUserConfig::new(
+ VHOST_USER_CONFIG_OFFSET,
+ VHOST_USER_CONFIG_SIZE - VHOST_USER_CONFIG_OFFSET,
+ VhostUserConfigFlags::WRITABLE,
+ );
+
+ assert!(msg.is_valid());
+ msg.size = 0;
+ assert!(!msg.is_valid());
+ msg.size = 1;
+ assert!(msg.is_valid());
+ msg.offset = 0;
+ assert!(!msg.is_valid());
+ msg.offset = VHOST_USER_CONFIG_SIZE;
+ assert!(!msg.is_valid());
+ msg.offset = VHOST_USER_CONFIG_SIZE - 1;
+ assert!(msg.is_valid());
+ msg.size = 2;
+ assert!(!msg.is_valid());
+ msg.size = 1;
+ msg.flags |= VhostUserConfigFlags::LIVE_MIGRATION.bits();
+ assert!(msg.is_valid());
+ msg.flags |= 0x4;
+ assert!(!msg.is_valid());
+ }
+
+ #[test]
+ fn test_vhost_user_fs_slave() {
+ let mut fs_slave = VhostUserFSSlaveMsg::default();
+
+ assert_eq!(fs_slave.is_valid(), true);
+
+ fs_slave.fd_offset[0] = 0xffff_ffff_ffff_ffff;
+ fs_slave.len[0] = 0x1;
+ assert_eq!(fs_slave.is_valid(), false);
+
+ assert_ne!(
+ VhostUserFSSlaveMsgFlags::MAP_R,
+ VhostUserFSSlaveMsgFlags::MAP_W
+ );
+ assert_eq!(VhostUserFSSlaveMsgFlags::EMPTY.bits(), 0);
+ }
+}
diff --git a/src/vhost_user/mod.rs b/src/vhost_user/mod.rs
new file mode 100644
index 0000000..9ef6453
--- /dev/null
+++ b/src/vhost_user/mod.rs
@@ -0,0 +1,456 @@
+// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved.
+// SPDX-License-Identifier: Apache-2.0
+
+//! The protocol for vhost-user is based on the existing implementation of vhost for the Linux
+//! Kernel. The protocol defines two sides of the communication, master and slave. Master is
+//! the application that shares its virtqueues. Slave is the consumer of the virtqueues.
+//!
+//! The communication channel between the master and the slave includes two sub channels. One is
+//! used to send requests from the master to the slave and optional replies from the slave to the
+//! master. This sub channel is created on master startup by connecting to the slave service
+//! endpoint. The other is used to send requests from the slave to the master and optional replies
+//! from the master to the slave. This sub channel is created by the master issuing a
+//! VHOST_USER_SET_SLAVE_REQ_FD request to the slave with an auxiliary file descriptor.
+//!
+//! Unix domain socket is used as the underlying communication channel because the master needs to
+//! send file descriptors to the slave.
+//!
+//! Most messages that can be sent via the Unix domain socket implementing vhost-user have an
+//! equivalent ioctl to the kernel implementation.
+
+use std::io::Error as IOError;
+
+pub mod message;
+
+mod connection;
+pub use self::connection::Listener;
+
+#[cfg(feature = "vhost-user-master")]
+mod master;
+#[cfg(feature = "vhost-user-master")]
+pub use self::master::{Master, VhostUserMaster};
+#[cfg(feature = "vhost-user")]
+mod master_req_handler;
+#[cfg(feature = "vhost-user")]
+pub use self::master_req_handler::{
+ MasterReqHandler, VhostUserMasterReqHandler, VhostUserMasterReqHandlerMut,
+};
+
+#[cfg(feature = "vhost-user-slave")]
+mod slave;
+#[cfg(feature = "vhost-user-slave")]
+pub use self::slave::SlaveListener;
+#[cfg(feature = "vhost-user-slave")]
+mod slave_req_handler;
+#[cfg(feature = "vhost-user-slave")]
+pub use self::slave_req_handler::{
+ SlaveReqHandler, VhostUserSlaveReqHandler, VhostUserSlaveReqHandlerMut,
+};
+#[cfg(feature = "vhost-user-slave")]
+mod slave_fs_cache;
+#[cfg(feature = "vhost-user-slave")]
+pub use self::slave_fs_cache::SlaveFsCacheReq;
+
+/// Errors for vhost-user operations
+#[derive(Debug)]
+pub enum Error {
+ /// Invalid parameters.
+ InvalidParam,
+ /// Unsupported operations due to that the protocol feature hasn't been negotiated.
+ InvalidOperation,
+ /// Invalid message format, flag or content.
+ InvalidMessage,
+ /// Only part of a message have been sent or received successfully
+ PartialMessage,
+ /// Message is too large
+ OversizedMsg,
+ /// Fd array in question is too big or too small
+ IncorrectFds,
+ /// Can't connect to peer.
+ SocketConnect(std::io::Error),
+ /// Generic socket errors.
+ SocketError(std::io::Error),
+ /// The socket is broken or has been closed.
+ SocketBroken(std::io::Error),
+ /// Should retry the socket operation again.
+ SocketRetry(std::io::Error),
+ /// Failure from the slave side.
+ SlaveInternalError,
+ /// Failure from the master side.
+ MasterInternalError,
+ /// Virtio/protocol features mismatch.
+ FeatureMismatch,
+ /// Error from request handler
+ ReqHandlerError(IOError),
+}
+
+impl std::fmt::Display for Error {
+ fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
+ match self {
+ Error::InvalidParam => write!(f, "invalid parameters"),
+ Error::InvalidOperation => write!(f, "invalid operation"),
+ Error::InvalidMessage => write!(f, "invalid message"),
+ Error::PartialMessage => write!(f, "partial message"),
+ Error::OversizedMsg => write!(f, "oversized message"),
+ Error::IncorrectFds => write!(f, "wrong number of attached fds"),
+ Error::SocketError(e) => write!(f, "socket error: {}", e),
+ Error::SocketConnect(e) => write!(f, "can't connect to peer: {}", e),
+ Error::SocketBroken(e) => write!(f, "socket is broken: {}", e),
+ Error::SocketRetry(e) => write!(f, "temporary socket error: {}", e),
+ Error::SlaveInternalError => write!(f, "slave internal error"),
+ Error::MasterInternalError => write!(f, "Master internal error"),
+ Error::FeatureMismatch => write!(f, "virtio/protocol features mismatch"),
+ Error::ReqHandlerError(e) => write!(f, "handler failed to handle request: {}", e),
+ }
+ }
+}
+
+impl std::error::Error for Error {}
+
+impl Error {
+ /// Determine whether to rebuild the underline communication channel.
+ pub fn should_reconnect(&self) -> bool {
+ match *self {
+ // Should reconnect because it may be caused by temporary network errors.
+ Error::PartialMessage => true,
+ // Should reconnect because the underline socket is broken.
+ Error::SocketBroken(_) => true,
+ // Slave internal error, hope it recovers on reconnect.
+ Error::SlaveInternalError => true,
+ // Master internal error, hope it recovers on reconnect.
+ Error::MasterInternalError => true,
+ // Should just retry the IO operation instead of rebuilding the underline connection.
+ Error::SocketRetry(_) => false,
+ Error::InvalidParam | Error::InvalidOperation => false,
+ Error::InvalidMessage | Error::IncorrectFds | Error::OversizedMsg => false,
+ Error::SocketError(_) | Error::SocketConnect(_) => false,
+ Error::FeatureMismatch => false,
+ Error::ReqHandlerError(_) => false,
+ }
+ }
+}
+
+impl std::convert::From<sys_util::Error> for Error {
+ /// Convert raw socket errors into meaningful vhost-user errors.
+ ///
+ /// The sys_util::Error is a simple wrapper over the raw errno, which doesn't means
+ /// much to the vhost-user connection manager. So convert it into meaningful errors to simplify
+ /// the connection manager logic.
+ ///
+ /// # Return:
+ /// * - Error::SocketRetry: temporary error caused by signals or short of resources.
+ /// * - Error::SocketBroken: the underline socket is broken.
+ /// * - Error::SocketError: other socket related errors.
+ #[allow(unreachable_patterns)] // EWOULDBLOCK equals to EGAIN on linux
+ fn from(err: sys_util::Error) -> Self {
+ match err.errno() {
+ // The socket is marked nonblocking and the requested operation would block.
+ libc::EAGAIN => Error::SocketRetry(IOError::from_raw_os_error(libc::EAGAIN)),
+ // The socket is marked nonblocking and the requested operation would block.
+ libc::EWOULDBLOCK => Error::SocketRetry(IOError::from_raw_os_error(libc::EWOULDBLOCK)),
+ // A signal occurred before any data was transmitted
+ libc::EINTR => Error::SocketRetry(IOError::from_raw_os_error(libc::EINTR)),
+ // The output queue for a network interface was full. This generally indicates
+ // that the interface has stopped sending, but may be caused by transient congestion.
+ libc::ENOBUFS => Error::SocketRetry(IOError::from_raw_os_error(libc::ENOBUFS)),
+ // No memory available.
+ libc::ENOMEM => Error::SocketRetry(IOError::from_raw_os_error(libc::ENOMEM)),
+ // Connection reset by peer.
+ libc::ECONNRESET => Error::SocketBroken(IOError::from_raw_os_error(libc::ECONNRESET)),
+ // The local end has been shut down on a connection oriented socket. In this case the
+ // process will also receive a SIGPIPE unless MSG_NOSIGNAL is set.
+ libc::EPIPE => Error::SocketBroken(IOError::from_raw_os_error(libc::EPIPE)),
+ // Write permission is denied on the destination socket file, or search permission is
+ // denied for one of the directories the path prefix.
+ libc::EACCES => Error::SocketConnect(IOError::from_raw_os_error(libc::EACCES)),
+ // Catch all other errors
+ e => Error::SocketError(IOError::from_raw_os_error(e)),
+ }
+ }
+}
+
+/// Result of vhost-user operations
+pub type Result<T> = std::result::Result<T, Error>;
+
+/// Result of request handler.
+pub type HandlerResult<T> = std::result::Result<T, IOError>;
+
+#[cfg(all(test, feature = "vhost-user-slave"))]
+mod dummy_slave;
+
+#[cfg(all(test, feature = "vhost-user-master", feature = "vhost-user-slave"))]
+mod tests {
+ use std::os::unix::io::AsRawFd;
+ use std::path::Path;
+ use std::sync::{Arc, Barrier, Mutex};
+ use std::thread;
+
+ use super::dummy_slave::{DummySlaveReqHandler, VIRTIO_FEATURES};
+ use super::message::*;
+ use super::*;
+ use crate::backend::VhostBackend;
+ use crate::{VhostUserMemoryRegionInfo, VringConfigData};
+ use tempfile::{tempfile, Builder, TempDir};
+
+ fn temp_dir() -> TempDir {
+ Builder::new().prefix("/tmp/vhost_test").tempdir().unwrap()
+ }
+
+ fn create_slave<P, S>(path: P, backend: Arc<S>) -> (Master, SlaveReqHandler<S>)
+ where
+ P: AsRef<Path>,
+ S: VhostUserSlaveReqHandler,
+ {
+ let listener = Listener::new(&path, true).unwrap();
+ let mut slave_listener = SlaveListener::new(listener, backend).unwrap();
+ let master = Master::connect(&path, 1).unwrap();
+ (master, slave_listener.accept().unwrap().unwrap())
+ }
+
+ #[test]
+ fn create_dummy_slave() {
+ let slave = Arc::new(Mutex::new(DummySlaveReqHandler::new()));
+
+ slave.set_owner().unwrap();
+ assert!(slave.set_owner().is_err());
+ }
+
+ #[test]
+ fn test_set_owner() {
+ let slave_be = Arc::new(Mutex::new(DummySlaveReqHandler::new()));
+ let dir = temp_dir();
+ let mut path = dir.path().to_owned();
+ path.push("sock");
+ let (master, mut slave) = create_slave(&path, slave_be.clone());
+
+ assert_eq!(slave_be.lock().unwrap().owned, false);
+ master.set_owner().unwrap();
+ slave.handle_request().unwrap();
+ assert_eq!(slave_be.lock().unwrap().owned, true);
+ master.set_owner().unwrap();
+ assert!(slave.handle_request().is_err());
+ assert_eq!(slave_be.lock().unwrap().owned, true);
+ }
+
+ #[test]
+ fn test_set_features() {
+ let mbar = Arc::new(Barrier::new(2));
+ let sbar = mbar.clone();
+ let dir = temp_dir();
+ let mut path = dir.path().to_owned();
+ path.push("sock");
+ let slave_be = Arc::new(Mutex::new(DummySlaveReqHandler::new()));
+ let (mut master, mut slave) = create_slave(&path, slave_be.clone());
+
+ thread::spawn(move || {
+ slave.handle_request().unwrap();
+ assert_eq!(slave_be.lock().unwrap().owned, true);
+
+ slave.handle_request().unwrap();
+ slave.handle_request().unwrap();
+ assert_eq!(
+ slave_be.lock().unwrap().acked_features,
+ VIRTIO_FEATURES & !0x1
+ );
+
+ slave.handle_request().unwrap();
+ slave.handle_request().unwrap();
+ assert_eq!(
+ slave_be.lock().unwrap().acked_protocol_features,
+ VhostUserProtocolFeatures::all().bits()
+ );
+
+ sbar.wait();
+ });
+
+ master.set_owner().unwrap();
+
+ // set virtio features
+ let features = master.get_features().unwrap();
+ assert_eq!(features, VIRTIO_FEATURES);
+ master.set_features(VIRTIO_FEATURES & !0x1).unwrap();
+
+ // set vhost protocol features
+ let features = master.get_protocol_features().unwrap();
+ assert_eq!(features.bits(), VhostUserProtocolFeatures::all().bits());
+ master.set_protocol_features(features).unwrap();
+
+ mbar.wait();
+ }
+
+ #[test]
+ fn test_master_slave_process() {
+ let mbar = Arc::new(Barrier::new(2));
+ let sbar = mbar.clone();
+ let dir = temp_dir();
+ let mut path = dir.path().to_owned();
+ path.push("sock");
+ let slave_be = Arc::new(Mutex::new(DummySlaveReqHandler::new()));
+ let (mut master, mut slave) = create_slave(&path, slave_be.clone());
+
+ thread::spawn(move || {
+ // set_own()
+ slave.handle_request().unwrap();
+ assert_eq!(slave_be.lock().unwrap().owned, true);
+
+ // get/set_features()
+ slave.handle_request().unwrap();
+ slave.handle_request().unwrap();
+ assert_eq!(
+ slave_be.lock().unwrap().acked_features,
+ VIRTIO_FEATURES & !0x1
+ );
+
+ slave.handle_request().unwrap();
+ slave.handle_request().unwrap();
+ assert_eq!(
+ slave_be.lock().unwrap().acked_protocol_features,
+ VhostUserProtocolFeatures::all().bits()
+ );
+
+ // get_queue_num()
+ slave.handle_request().unwrap();
+
+ // set_mem_table()
+ slave.handle_request().unwrap();
+
+ // get/set_config()
+ slave.handle_request().unwrap();
+ slave.handle_request().unwrap();
+
+ // set_slave_request_fd
+ slave.handle_request().unwrap();
+
+ // set_vring_enable
+ slave.handle_request().unwrap();
+
+ // set_log_base,set_log_fd()
+ slave.handle_request().unwrap_err();
+ slave.handle_request().unwrap_err();
+
+ // set_vring_xxx
+ slave.handle_request().unwrap();
+ slave.handle_request().unwrap();
+ slave.handle_request().unwrap();
+ slave.handle_request().unwrap();
+ slave.handle_request().unwrap();
+ slave.handle_request().unwrap();
+
+ // get_max_mem_slots()
+ slave.handle_request().unwrap();
+
+ // add_mem_region()
+ slave.handle_request().unwrap();
+
+ // remove_mem_region()
+ slave.handle_request().unwrap();
+
+ sbar.wait();
+ });
+
+ master.set_owner().unwrap();
+
+ // set virtio features
+ let features = master.get_features().unwrap();
+ assert_eq!(features, VIRTIO_FEATURES);
+ master.set_features(VIRTIO_FEATURES & !0x1).unwrap();
+
+ // set vhost protocol features
+ let features = master.get_protocol_features().unwrap();
+ assert_eq!(features.bits(), VhostUserProtocolFeatures::all().bits());
+ master.set_protocol_features(features).unwrap();
+
+ let num = master.get_queue_num().unwrap();
+ assert_eq!(num, 2);
+
+ let eventfd = sys_util::EventFd::new().unwrap();
+ let mem = [VhostUserMemoryRegionInfo {
+ guest_phys_addr: 0,
+ memory_size: 0x10_0000,
+ userspace_addr: 0,
+ mmap_offset: 0,
+ mmap_handle: eventfd.as_raw_fd(),
+ }];
+ master.set_mem_table(&mem).unwrap();
+
+ master
+ .set_config(0x100, VhostUserConfigFlags::WRITABLE, &[0xa5u8])
+ .unwrap();
+ let buf = [0x0u8; 4];
+ let (reply_body, reply_payload) = master
+ .get_config(0x100, 4, VhostUserConfigFlags::empty(), &buf)
+ .unwrap();
+ let offset = reply_body.offset;
+ assert_eq!(offset, 0x100);
+ assert_eq!(reply_payload[0], 0xa5);
+
+ master.set_slave_request_fd(eventfd.as_raw_fd()).unwrap();
+ master.set_vring_enable(0, true).unwrap();
+
+ // unimplemented yet
+ master.set_log_base(0, Some(eventfd.as_raw_fd())).unwrap();
+ master.set_log_fd(eventfd.as_raw_fd()).unwrap();
+
+ master.set_vring_num(0, 256).unwrap();
+ master.set_vring_base(0, 0).unwrap();
+ let config = VringConfigData {
+ queue_max_size: 256,
+ queue_size: 128,
+ flags: VhostUserVringAddrFlags::VHOST_VRING_F_LOG.bits(),
+ desc_table_addr: 0x1000,
+ used_ring_addr: 0x2000,
+ avail_ring_addr: 0x3000,
+ log_addr: Some(0x4000),
+ };
+ master.set_vring_addr(0, &config).unwrap();
+ master.set_vring_call(0, &eventfd).unwrap();
+ master.set_vring_kick(0, &eventfd).unwrap();
+ master.set_vring_err(0, &eventfd).unwrap();
+
+ let max_mem_slots = master.get_max_mem_slots().unwrap();
+ assert_eq!(max_mem_slots, 32);
+
+ let region_file = tempfile().unwrap();
+ let region = VhostUserMemoryRegionInfo {
+ guest_phys_addr: 0x10_0000,
+ memory_size: 0x10_0000,
+ userspace_addr: 0,
+ mmap_offset: 0,
+ mmap_handle: region_file.as_raw_fd(),
+ };
+ master.add_mem_region(&region).unwrap();
+
+ master.remove_mem_region(&region).unwrap();
+
+ mbar.wait();
+ }
+
+ #[test]
+ fn test_error_display() {
+ assert_eq!(format!("{}", Error::InvalidParam), "invalid parameters");
+ assert_eq!(format!("{}", Error::InvalidOperation), "invalid operation");
+ }
+
+ #[test]
+ fn test_should_reconnect() {
+ assert_eq!(Error::PartialMessage.should_reconnect(), true);
+ assert_eq!(Error::SlaveInternalError.should_reconnect(), true);
+ assert_eq!(Error::MasterInternalError.should_reconnect(), true);
+ assert_eq!(Error::InvalidParam.should_reconnect(), false);
+ assert_eq!(Error::InvalidOperation.should_reconnect(), false);
+ assert_eq!(Error::InvalidMessage.should_reconnect(), false);
+ assert_eq!(Error::IncorrectFds.should_reconnect(), false);
+ assert_eq!(Error::OversizedMsg.should_reconnect(), false);
+ assert_eq!(Error::FeatureMismatch.should_reconnect(), false);
+ }
+
+ #[test]
+ fn test_error_from_sys_util_error() {
+ let e: Error = sys_util::Error::new(libc::EAGAIN).into();
+ if let Error::SocketRetry(e1) = e {
+ assert_eq!(e1.raw_os_error().unwrap(), libc::EAGAIN);
+ } else {
+ panic!("invalid error code conversion!");
+ }
+ }
+}
diff --git a/src/vhost_user/slave.rs b/src/vhost_user/slave.rs
new file mode 100644
index 0000000..fb65c41
--- /dev/null
+++ b/src/vhost_user/slave.rs
@@ -0,0 +1,86 @@
+// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved.
+// SPDX-License-Identifier: Apache-2.0
+
+//! Traits and Structs for vhost-user slave.
+
+use std::sync::Arc;
+
+use super::connection::{Endpoint, Listener};
+use super::message::*;
+use super::{Result, SlaveReqHandler, VhostUserSlaveReqHandler};
+
+/// Vhost-user slave side connection listener.
+pub struct SlaveListener<S: VhostUserSlaveReqHandler> {
+ listener: Listener,
+ backend: Option<Arc<S>>,
+}
+
+/// Sets up a listener for incoming master connections, and handles construction
+/// of a Slave on success.
+impl<S: VhostUserSlaveReqHandler> SlaveListener<S> {
+ /// Create a unix domain socket for incoming master connections.
+ pub fn new(listener: Listener, backend: Arc<S>) -> Result<Self> {
+ Ok(SlaveListener {
+ listener,
+ backend: Some(backend),
+ })
+ }
+
+ /// Accept an incoming connection from the master, returning Some(Slave) on
+ /// success, or None if the socket is nonblocking and no incoming connection
+ /// was detected
+ pub fn accept(&mut self) -> Result<Option<SlaveReqHandler<S>>> {
+ if let Some(fd) = self.listener.accept()? {
+ return Ok(Some(SlaveReqHandler::new(
+ Endpoint::<MasterReq>::from_stream(fd),
+ self.backend.take().unwrap(),
+ )));
+ }
+ Ok(None)
+ }
+
+ /// Change blocking status on the listener.
+ pub fn set_nonblocking(&self, block: bool) -> Result<()> {
+ self.listener.set_nonblocking(block)
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use std::sync::Mutex;
+
+ use super::*;
+ use crate::vhost_user::dummy_slave::DummySlaveReqHandler;
+
+ #[test]
+ fn test_slave_listener_set_nonblocking() {
+ let backend = Arc::new(Mutex::new(DummySlaveReqHandler::new()));
+ let listener =
+ Listener::new("/tmp/vhost_user_lib_unit_test_slave_nonblocking", true).unwrap();
+ let slave_listener = SlaveListener::new(listener, backend).unwrap();
+
+ slave_listener.set_nonblocking(true).unwrap();
+ slave_listener.set_nonblocking(false).unwrap();
+ slave_listener.set_nonblocking(false).unwrap();
+ slave_listener.set_nonblocking(true).unwrap();
+ slave_listener.set_nonblocking(true).unwrap();
+ }
+
+ #[cfg(feature = "vhost-user-master")]
+ #[test]
+ fn test_slave_listener_accept() {
+ use super::super::Master;
+
+ let path = "/tmp/vhost_user_lib_unit_test_slave_accept";
+ let backend = Arc::new(Mutex::new(DummySlaveReqHandler::new()));
+ let listener = Listener::new(path, true).unwrap();
+ let mut slave_listener = SlaveListener::new(listener, backend).unwrap();
+
+ slave_listener.set_nonblocking(true).unwrap();
+ assert!(slave_listener.accept().unwrap().is_none());
+ assert!(slave_listener.accept().unwrap().is_none());
+
+ let _master = Master::connect(path, 1).unwrap();
+ let _slave = slave_listener.accept().unwrap().unwrap();
+ }
+}
diff --git a/src/vhost_user/slave_fs_cache.rs b/src/vhost_user/slave_fs_cache.rs
new file mode 100644
index 0000000..a9c4ed2
--- /dev/null
+++ b/src/vhost_user/slave_fs_cache.rs
@@ -0,0 +1,226 @@
+// Copyright (C) 2020 Alibaba Cloud. All rights reserved.
+// SPDX-License-Identifier: Apache-2.0
+
+use std::io;
+use std::mem;
+use std::os::unix::io::RawFd;
+use std::os::unix::net::UnixStream;
+use std::sync::{Arc, Mutex, MutexGuard};
+
+use super::connection::Endpoint;
+use super::message::*;
+use super::{Error, HandlerResult, Result, VhostUserMasterReqHandler};
+
+struct SlaveFsCacheReqInternal {
+ sock: Endpoint<SlaveReq>,
+
+ // Protocol feature VHOST_USER_PROTOCOL_F_REPLY_ACK has been negotiated.
+ reply_ack_negotiated: bool,
+
+ // whether the endpoint has encountered any failure
+ error: Option<i32>,
+}
+
+impl SlaveFsCacheReqInternal {
+ fn check_state(&self) -> Result<u64> {
+ match self.error {
+ Some(e) => Err(Error::SocketBroken(std::io::Error::from_raw_os_error(e))),
+ None => Ok(0),
+ }
+ }
+
+ fn send_message(
+ &mut self,
+ request: SlaveReq,
+ fs: &VhostUserFSSlaveMsg,
+ fds: Option<&[RawFd]>,
+ ) -> Result<u64> {
+ self.check_state()?;
+
+ let len = mem::size_of::<VhostUserFSSlaveMsg>();
+ let mut hdr = VhostUserMsgHeader::new(request, 0, len as u32);
+ if self.reply_ack_negotiated {
+ hdr.set_need_reply(true);
+ }
+ self.sock.send_message(&hdr, fs, fds)?;
+
+ self.wait_for_ack(&hdr)
+ }
+
+ fn wait_for_ack(&mut self, hdr: &VhostUserMsgHeader<SlaveReq>) -> Result<u64> {
+ self.check_state()?;
+ if !self.reply_ack_negotiated {
+ return Ok(0);
+ }
+
+ let (reply, body, rfds) = self.sock.recv_body::<VhostUserU64>()?;
+ if !reply.is_reply_for(&hdr) || rfds.is_some() || !body.is_valid() {
+ Endpoint::<SlaveReq>::close_rfds(rfds);
+ return Err(Error::InvalidMessage);
+ }
+ if body.value != 0 {
+ return Err(Error::MasterInternalError);
+ }
+
+ Ok(body.value)
+ }
+}
+
+/// Request proxy to send vhost-user-fs slave requests to the master through the slave
+/// communication channel.
+///
+/// The [SlaveFsCacheReq] acts as a message proxy to forward vhost-user-fs slave requests to the
+/// master through the vhost-user slave communication channel. The forwarded messages will be
+/// handled by the [MasterReqHandler] server.
+///
+/// [SlaveFsCacheReq]: struct.SlaveFsCacheReq.html
+/// [MasterReqHandler]: struct.MasterReqHandler.html
+#[derive(Clone)]
+pub struct SlaveFsCacheReq {
+ // underlying Unix domain socket for communication
+ node: Arc<Mutex<SlaveFsCacheReqInternal>>,
+}
+
+impl SlaveFsCacheReq {
+ fn new(ep: Endpoint<SlaveReq>) -> Self {
+ SlaveFsCacheReq {
+ node: Arc::new(Mutex::new(SlaveFsCacheReqInternal {
+ sock: ep,
+ reply_ack_negotiated: false,
+ error: None,
+ })),
+ }
+ }
+
+ fn node(&self) -> MutexGuard<SlaveFsCacheReqInternal> {
+ self.node.lock().unwrap()
+ }
+
+ fn send_message(
+ &self,
+ request: SlaveReq,
+ fs: &VhostUserFSSlaveMsg,
+ fds: Option<&[RawFd]>,
+ ) -> io::Result<u64> {
+ self.node()
+ .send_message(request, fs, fds)
+ .map_err(|e| io::Error::new(io::ErrorKind::Other, format!("{}", e)))
+ }
+
+ /// Create a new instance from a `UnixStream` object.
+ pub fn from_stream(sock: UnixStream) -> Self {
+ Self::new(Endpoint::<SlaveReq>::from_stream(sock))
+ }
+
+ /// Set the negotiation state of the `VHOST_USER_PROTOCOL_F_REPLY_ACK` protocol feature.
+ ///
+ /// When the `VHOST_USER_PROTOCOL_F_REPLY_ACK` protocol feature has been negotiated,
+ /// the "REPLY_ACK" flag will be set in the message header for every slave to master request
+ /// message.
+ pub fn set_reply_ack_flag(&self, enable: bool) {
+ self.node().reply_ack_negotiated = enable;
+ }
+
+ /// Mark endpoint as failed with specified error code.
+ pub fn set_failed(&self, error: i32) {
+ self.node().error = Some(error);
+ }
+}
+
+impl VhostUserMasterReqHandler for SlaveFsCacheReq {
+ /// Forward vhost-user-fs map file requests to the slave.
+ fn fs_slave_map(&self, fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> {
+ self.send_message(SlaveReq::FS_MAP, fs, Some(&[fd]))
+ }
+
+ /// Forward vhost-user-fs unmap file requests to the master.
+ fn fs_slave_unmap(&self, fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
+ self.send_message(SlaveReq::FS_UNMAP, fs, None)
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use std::os::unix::io::AsRawFd;
+
+ use super::*;
+
+ #[test]
+ fn test_slave_fs_cache_req_set_failed() {
+ let (p1, _p2) = UnixStream::pair().unwrap();
+ let fs_cache = SlaveFsCacheReq::from_stream(p1);
+
+ assert!(fs_cache.node().error.is_none());
+ fs_cache.set_failed(libc::EAGAIN);
+ assert_eq!(fs_cache.node().error, Some(libc::EAGAIN));
+ }
+
+ #[test]
+ fn test_slave_fs_cache_send_failure() {
+ let (p1, p2) = UnixStream::pair().unwrap();
+ let fd = p2.as_raw_fd();
+ let fs_cache = SlaveFsCacheReq::from_stream(p1);
+
+ fs_cache.set_failed(libc::ECONNRESET);
+ fs_cache
+ .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd)
+ .unwrap_err();
+ fs_cache
+ .fs_slave_unmap(&VhostUserFSSlaveMsg::default())
+ .unwrap_err();
+ fs_cache.node().error = None;
+
+ drop(p2);
+ fs_cache
+ .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd)
+ .unwrap_err();
+ fs_cache
+ .fs_slave_unmap(&VhostUserFSSlaveMsg::default())
+ .unwrap_err();
+ }
+
+ #[test]
+ fn test_slave_fs_cache_recv_negative() {
+ let (p1, p2) = UnixStream::pair().unwrap();
+ let fd = p2.as_raw_fd();
+ let fs_cache = SlaveFsCacheReq::from_stream(p1);
+ let mut master = Endpoint::<SlaveReq>::from_stream(p2);
+
+ let len = mem::size_of::<VhostUserFSSlaveMsg>();
+ let mut hdr = VhostUserMsgHeader::new(
+ SlaveReq::FS_MAP,
+ VhostUserHeaderFlag::REPLY.bits(),
+ len as u32,
+ );
+ let body = VhostUserU64::new(0);
+
+ master.send_message(&hdr, &body, Some(&[fd])).unwrap();
+ fs_cache
+ .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd)
+ .unwrap();
+
+ fs_cache.set_reply_ack_flag(true);
+ fs_cache
+ .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd)
+ .unwrap_err();
+
+ hdr.set_code(SlaveReq::FS_UNMAP);
+ master.send_message(&hdr, &body, None).unwrap();
+ fs_cache
+ .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd)
+ .unwrap_err();
+ hdr.set_code(SlaveReq::FS_MAP);
+
+ let body = VhostUserU64::new(1);
+ master.send_message(&hdr, &body, None).unwrap();
+ fs_cache
+ .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd)
+ .unwrap_err();
+
+ let body = VhostUserU64::new(0);
+ master.send_message(&hdr, &body, None).unwrap();
+ fs_cache
+ .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd)
+ .unwrap();
+ }
+}
diff --git a/src/vhost_user/slave_req_handler.rs b/src/vhost_user/slave_req_handler.rs
new file mode 100644
index 0000000..18459a2
--- /dev/null
+++ b/src/vhost_user/slave_req_handler.rs
@@ -0,0 +1,828 @@
+// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved.
+// SPDX-License-Identifier: Apache-2.0
+
+use std::mem;
+use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
+use std::os::unix::net::UnixStream;
+use std::slice;
+use std::sync::{Arc, Mutex};
+
+use super::connection::Endpoint;
+use super::message::*;
+use super::slave_fs_cache::SlaveFsCacheReq;
+use super::{Error, Result};
+
+/// Services provided to the master by the slave with interior mutability.
+///
+/// The [VhostUserSlaveReqHandler] trait defines the services provided to the master by the slave.
+/// And the [VhostUserSlaveReqHandlerMut] trait is a helper mirroring [VhostUserSlaveReqHandler],
+/// but without interior mutability.
+/// The vhost-user specification defines a master communication channel, by which masters could
+/// request services from slaves. The [VhostUserSlaveReqHandler] trait defines services provided by
+/// slaves, and it's used both on the master side and slave side.
+///
+/// - on the master side, a stub forwarder implementing [VhostUserSlaveReqHandler] will proxy
+/// service requests to slaves.
+/// - on the slave side, the [SlaveReqHandler] will forward service requests to a handler
+/// implementing [VhostUserSlaveReqHandler].
+///
+/// The [VhostUserSlaveReqHandler] trait is design with interior mutability to improve performance
+/// for multi-threading.
+///
+/// [VhostUserSlaveReqHandler]: trait.VhostUserSlaveReqHandler.html
+/// [VhostUserSlaveReqHandlerMut]: trait.VhostUserSlaveReqHandlerMut.html
+/// [SlaveReqHandler]: struct.SlaveReqHandler.html
+#[allow(missing_docs)]
+pub trait VhostUserSlaveReqHandler {
+ fn set_owner(&self) -> Result<()>;
+ fn reset_owner(&self) -> Result<()>;
+ fn get_features(&self) -> Result<u64>;
+ fn set_features(&self, features: u64) -> Result<()>;
+ fn set_mem_table(&self, ctx: &[VhostUserMemoryRegion], fds: &[RawFd]) -> Result<()>;
+ fn set_vring_num(&self, index: u32, num: u32) -> Result<()>;
+ fn set_vring_addr(
+ &self,
+ index: u32,
+ flags: VhostUserVringAddrFlags,
+ descriptor: u64,
+ used: u64,
+ available: u64,
+ log: u64,
+ ) -> Result<()>;
+ fn set_vring_base(&self, index: u32, base: u32) -> Result<()>;
+ fn get_vring_base(&self, index: u32) -> Result<VhostUserVringState>;
+ fn set_vring_kick(&self, index: u8, fd: Option<RawFd>) -> Result<()>;
+ fn set_vring_call(&self, index: u8, fd: Option<RawFd>) -> Result<()>;
+ fn set_vring_err(&self, index: u8, fd: Option<RawFd>) -> Result<()>;
+
+ fn get_protocol_features(&self) -> Result<VhostUserProtocolFeatures>;
+ fn set_protocol_features(&self, features: u64) -> Result<()>;
+ fn get_queue_num(&self) -> Result<u64>;
+ fn set_vring_enable(&self, index: u32, enable: bool) -> Result<()>;
+ fn get_config(&self, offset: u32, size: u32, flags: VhostUserConfigFlags) -> Result<Vec<u8>>;
+ fn set_config(&self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>;
+ fn set_slave_req_fd(&self, _vu_req: SlaveFsCacheReq) {}
+ fn get_max_mem_slots(&self) -> Result<u64>;
+ fn add_mem_region(&self, region: &VhostUserSingleMemoryRegion, fd: RawFd) -> Result<()>;
+ fn remove_mem_region(&self, region: &VhostUserSingleMemoryRegion) -> Result<()>;
+}
+
+/// Services provided to the master by the slave without interior mutability.
+///
+/// This is a helper trait mirroring the [VhostUserSlaveReqHandler] trait.
+#[allow(missing_docs)]
+pub trait VhostUserSlaveReqHandlerMut {
+ fn set_owner(&mut self) -> Result<()>;
+ fn reset_owner(&mut self) -> Result<()>;
+ fn get_features(&mut self) -> Result<u64>;
+ fn set_features(&mut self, features: u64) -> Result<()>;
+ fn set_mem_table(&mut self, ctx: &[VhostUserMemoryRegion], fds: &[RawFd]) -> Result<()>;
+ fn set_vring_num(&mut self, index: u32, num: u32) -> Result<()>;
+ fn set_vring_addr(
+ &mut self,
+ index: u32,
+ flags: VhostUserVringAddrFlags,
+ descriptor: u64,
+ used: u64,
+ available: u64,
+ log: u64,
+ ) -> Result<()>;
+ fn set_vring_base(&mut self, index: u32, base: u32) -> Result<()>;
+ fn get_vring_base(&mut self, index: u32) -> Result<VhostUserVringState>;
+ fn set_vring_kick(&mut self, index: u8, fd: Option<RawFd>) -> Result<()>;
+ fn set_vring_call(&mut self, index: u8, fd: Option<RawFd>) -> Result<()>;
+ fn set_vring_err(&mut self, index: u8, fd: Option<RawFd>) -> Result<()>;
+
+ fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures>;
+ fn set_protocol_features(&mut self, features: u64) -> Result<()>;
+ fn get_queue_num(&mut self) -> Result<u64>;
+ fn set_vring_enable(&mut self, index: u32, enable: bool) -> Result<()>;
+ fn get_config(
+ &mut self,
+ offset: u32,
+ size: u32,
+ flags: VhostUserConfigFlags,
+ ) -> Result<Vec<u8>>;
+ fn set_config(&mut self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>;
+ fn set_slave_req_fd(&mut self, _vu_req: SlaveFsCacheReq) {}
+ fn get_max_mem_slots(&mut self) -> Result<u64>;
+ fn add_mem_region(&mut self, region: &VhostUserSingleMemoryRegion, fd: RawFd) -> Result<()>;
+ fn remove_mem_region(&mut self, region: &VhostUserSingleMemoryRegion) -> Result<()>;
+}
+
+impl<T: VhostUserSlaveReqHandlerMut> VhostUserSlaveReqHandler for Mutex<T> {
+ fn set_owner(&self) -> Result<()> {
+ self.lock().unwrap().set_owner()
+ }
+
+ fn reset_owner(&self) -> Result<()> {
+ self.lock().unwrap().reset_owner()
+ }
+
+ fn get_features(&self) -> Result<u64> {
+ self.lock().unwrap().get_features()
+ }
+
+ fn set_features(&self, features: u64) -> Result<()> {
+ self.lock().unwrap().set_features(features)
+ }
+
+ fn set_mem_table(&self, ctx: &[VhostUserMemoryRegion], fds: &[RawFd]) -> Result<()> {
+ self.lock().unwrap().set_mem_table(ctx, fds)
+ }
+
+ fn set_vring_num(&self, index: u32, num: u32) -> Result<()> {
+ self.lock().unwrap().set_vring_num(index, num)
+ }
+
+ fn set_vring_addr(
+ &self,
+ index: u32,
+ flags: VhostUserVringAddrFlags,
+ descriptor: u64,
+ used: u64,
+ available: u64,
+ log: u64,
+ ) -> Result<()> {
+ self.lock()
+ .unwrap()
+ .set_vring_addr(index, flags, descriptor, used, available, log)
+ }
+
+ fn set_vring_base(&self, index: u32, base: u32) -> Result<()> {
+ self.lock().unwrap().set_vring_base(index, base)
+ }
+
+ fn get_vring_base(&self, index: u32) -> Result<VhostUserVringState> {
+ self.lock().unwrap().get_vring_base(index)
+ }
+
+ fn set_vring_kick(&self, index: u8, fd: Option<RawFd>) -> Result<()> {
+ self.lock().unwrap().set_vring_kick(index, fd)
+ }
+
+ fn set_vring_call(&self, index: u8, fd: Option<RawFd>) -> Result<()> {
+ self.lock().unwrap().set_vring_call(index, fd)
+ }
+
+ fn set_vring_err(&self, index: u8, fd: Option<RawFd>) -> Result<()> {
+ self.lock().unwrap().set_vring_err(index, fd)
+ }
+
+ fn get_protocol_features(&self) -> Result<VhostUserProtocolFeatures> {
+ self.lock().unwrap().get_protocol_features()
+ }
+
+ fn set_protocol_features(&self, features: u64) -> Result<()> {
+ self.lock().unwrap().set_protocol_features(features)
+ }
+
+ fn get_queue_num(&self) -> Result<u64> {
+ self.lock().unwrap().get_queue_num()
+ }
+
+ fn set_vring_enable(&self, index: u32, enable: bool) -> Result<()> {
+ self.lock().unwrap().set_vring_enable(index, enable)
+ }
+
+ fn get_config(&self, offset: u32, size: u32, flags: VhostUserConfigFlags) -> Result<Vec<u8>> {
+ self.lock().unwrap().get_config(offset, size, flags)
+ }
+
+ fn set_config(&self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()> {
+ self.lock().unwrap().set_config(offset, buf, flags)
+ }
+
+ fn set_slave_req_fd(&self, vu_req: SlaveFsCacheReq) {
+ self.lock().unwrap().set_slave_req_fd(vu_req)
+ }
+
+ fn get_max_mem_slots(&self) -> Result<u64> {
+ self.lock().unwrap().get_max_mem_slots()
+ }
+
+ fn add_mem_region(&self, region: &VhostUserSingleMemoryRegion, fd: RawFd) -> Result<()> {
+ self.lock().unwrap().add_mem_region(region, fd)
+ }
+
+ fn remove_mem_region(&self, region: &VhostUserSingleMemoryRegion) -> Result<()> {
+ self.lock().unwrap().remove_mem_region(region)
+ }
+}
+
+/// Server to handle service requests from masters from the master communication channel.
+///
+/// The [SlaveReqHandler] acts as a server on the slave side, to handle service requests from
+/// masters on the master communication channel. It's actually a proxy invoking the registered
+/// handler implementing [VhostUserSlaveReqHandler] to do the real work.
+///
+/// The lifetime of the SlaveReqHandler object should be the same as the underline Unix Domain
+/// Socket, so it gets simpler to recover from disconnect.
+///
+/// [VhostUserSlaveReqHandler]: trait.VhostUserSlaveReqHandler.html
+/// [SlaveReqHandler]: struct.SlaveReqHandler.html
+pub struct SlaveReqHandler<S: VhostUserSlaveReqHandler> {
+ // underlying Unix domain socket for communication
+ main_sock: Endpoint<MasterReq>,
+ // the vhost-user backend device object
+ backend: Arc<S>,
+
+ virtio_features: u64,
+ acked_virtio_features: u64,
+ protocol_features: VhostUserProtocolFeatures,
+ acked_protocol_features: u64,
+
+ // sending ack for messages without payload
+ reply_ack_enabled: bool,
+ // whether the endpoint has encountered any failure
+ error: Option<i32>,
+}
+
+impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
+ /// Create a vhost-user slave endpoint.
+ pub(super) fn new(main_sock: Endpoint<MasterReq>, backend: Arc<S>) -> Self {
+ SlaveReqHandler {
+ main_sock,
+ backend,
+ virtio_features: 0,
+ acked_virtio_features: 0,
+ protocol_features: VhostUserProtocolFeatures::empty(),
+ acked_protocol_features: 0,
+ reply_ack_enabled: false,
+ error: None,
+ }
+ }
+
+ /// Create a new vhost-user slave endpoint.
+ ///
+ /// # Arguments
+ /// * - `path` - path of Unix domain socket listener to connect to
+ /// * - `backend` - handler for requests from the master to the slave
+ pub fn connect(path: &str, backend: Arc<S>) -> Result<Self> {
+ Ok(Self::new(Endpoint::<MasterReq>::connect(path)?, backend))
+ }
+
+ /// Mark endpoint as failed with specified error code.
+ pub fn set_failed(&mut self, error: i32) {
+ self.error = Some(error);
+ }
+
+ /// Main entrance to server slave request from the slave communication channel.
+ ///
+ /// Receive and handle one incoming request message from the master. The caller needs to:
+ /// - serialize calls to this function
+ /// - decide what to do when error happens
+ /// - optional recover from failure
+ pub fn handle_request(&mut self) -> Result<()> {
+ // Return error if the endpoint is already in failed state.
+ self.check_state()?;
+
+ // The underlying communication channel is a Unix domain socket in
+ // stream mode, and recvmsg() is a little tricky here. To successfully
+ // receive attached file descriptors, we need to receive messages and
+ // corresponding attached file descriptors in this way:
+ // . recv messsage header and optional attached file
+ // . validate message header
+ // . recv optional message body and payload according size field in
+ // message header
+ // . validate message body and optional payload
+ let (hdr, rfds) = self.main_sock.recv_header()?;
+ let rfds = self.check_attached_rfds(&hdr, rfds)?;
+ let (size, buf) = match hdr.get_size() {
+ 0 => (0, vec![0u8; 0]),
+ len => {
+ let (size2, rbuf) = self.main_sock.recv_data(len as usize)?;
+ if size2 != len as usize {
+ return Err(Error::InvalidMessage);
+ }
+ (size2, rbuf)
+ }
+ };
+
+ match hdr.get_code() {
+ MasterReq::SET_OWNER => {
+ self.check_request_size(&hdr, size, 0)?;
+ self.backend.set_owner()?;
+ }
+ MasterReq::RESET_OWNER => {
+ self.check_request_size(&hdr, size, 0)?;
+ self.backend.reset_owner()?;
+ }
+ MasterReq::GET_FEATURES => {
+ self.check_request_size(&hdr, size, 0)?;
+ let features = self.backend.get_features()?;
+ let msg = VhostUserU64::new(features);
+ self.send_reply_message(&hdr, &msg)?;
+ self.virtio_features = features;
+ self.update_reply_ack_flag();
+ }
+ MasterReq::SET_FEATURES => {
+ let msg = self.extract_request_body::<VhostUserU64>(&hdr, size, &buf)?;
+ self.backend.set_features(msg.value)?;
+ self.acked_virtio_features = msg.value;
+ self.update_reply_ack_flag();
+ }
+ MasterReq::SET_MEM_TABLE => {
+ let res = self.set_mem_table(&hdr, size, &buf, rfds);
+ self.send_ack_message(&hdr, res)?;
+ }
+ MasterReq::SET_VRING_NUM => {
+ let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
+ let res = self.backend.set_vring_num(msg.index, msg.num);
+ self.send_ack_message(&hdr, res)?;
+ }
+ MasterReq::SET_VRING_ADDR => {
+ let msg = self.extract_request_body::<VhostUserVringAddr>(&hdr, size, &buf)?;
+ let flags = match VhostUserVringAddrFlags::from_bits(msg.flags) {
+ Some(val) => val,
+ None => return Err(Error::InvalidMessage),
+ };
+ let res = self.backend.set_vring_addr(
+ msg.index,
+ flags,
+ msg.descriptor,
+ msg.used,
+ msg.available,
+ msg.log,
+ );
+ self.send_ack_message(&hdr, res)?;
+ }
+ MasterReq::SET_VRING_BASE => {
+ let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
+ let res = self.backend.set_vring_base(msg.index, msg.num);
+ self.send_ack_message(&hdr, res)?;
+ }
+ MasterReq::GET_VRING_BASE => {
+ let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
+ let reply = self.backend.get_vring_base(msg.index)?;
+ self.send_reply_message(&hdr, &reply)?;
+ }
+ MasterReq::SET_VRING_CALL => {
+ self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?;
+ let (index, rfds) = self.handle_vring_fd_request(&buf, rfds)?;
+ let res = self.backend.set_vring_call(index, rfds);
+ self.send_ack_message(&hdr, res)?;
+ }
+ MasterReq::SET_VRING_KICK => {
+ self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?;
+ let (index, rfds) = self.handle_vring_fd_request(&buf, rfds)?;
+ let res = self.backend.set_vring_kick(index, rfds);
+ self.send_ack_message(&hdr, res)?;
+ }
+ MasterReq::SET_VRING_ERR => {
+ self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?;
+ let (index, rfds) = self.handle_vring_fd_request(&buf, rfds)?;
+ let res = self.backend.set_vring_err(index, rfds);
+ self.send_ack_message(&hdr, res)?;
+ }
+ MasterReq::GET_PROTOCOL_FEATURES => {
+ self.check_request_size(&hdr, size, 0)?;
+ let features = self.backend.get_protocol_features()?;
+ let msg = VhostUserU64::new(features.bits());
+ self.send_reply_message(&hdr, &msg)?;
+ self.protocol_features = features;
+ self.update_reply_ack_flag();
+ }
+ MasterReq::SET_PROTOCOL_FEATURES => {
+ let msg = self.extract_request_body::<VhostUserU64>(&hdr, size, &buf)?;
+ self.backend.set_protocol_features(msg.value)?;
+ self.acked_protocol_features = msg.value;
+ self.update_reply_ack_flag();
+ }
+ MasterReq::GET_QUEUE_NUM => {
+ if self.acked_protocol_features & VhostUserProtocolFeatures::MQ.bits() == 0 {
+ return Err(Error::InvalidOperation);
+ }
+ self.check_request_size(&hdr, size, 0)?;
+ let num = self.backend.get_queue_num()?;
+ let msg = VhostUserU64::new(num);
+ self.send_reply_message(&hdr, &msg)?;
+ }
+ MasterReq::SET_VRING_ENABLE => {
+ let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
+ if self.acked_protocol_features & VhostUserProtocolFeatures::MQ.bits() == 0
+ && msg.index > 0
+ {
+ return Err(Error::InvalidOperation);
+ }
+ let enable = match msg.num {
+ 1 => true,
+ 0 => false,
+ _ => return Err(Error::InvalidParam),
+ };
+
+ let res = self.backend.set_vring_enable(msg.index, enable);
+ self.send_ack_message(&hdr, res)?;
+ }
+ MasterReq::GET_CONFIG => {
+ if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 {
+ return Err(Error::InvalidOperation);
+ }
+ self.check_request_size(&hdr, size, hdr.get_size() as usize)?;
+ self.get_config(&hdr, &buf)?;
+ }
+ MasterReq::SET_CONFIG => {
+ if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 {
+ return Err(Error::InvalidOperation);
+ }
+ self.check_request_size(&hdr, size, hdr.get_size() as usize)?;
+ self.set_config(&hdr, size, &buf)?;
+ }
+ MasterReq::SET_SLAVE_REQ_FD => {
+ if self.acked_protocol_features & VhostUserProtocolFeatures::SLAVE_REQ.bits() == 0 {
+ return Err(Error::InvalidOperation);
+ }
+ self.check_request_size(&hdr, size, hdr.get_size() as usize)?;
+ self.set_slave_req_fd(&hdr, rfds)?;
+ }
+ MasterReq::GET_MAX_MEM_SLOTS => {
+ if self.acked_protocol_features
+ & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits()
+ == 0
+ {
+ return Err(Error::InvalidOperation);
+ }
+ self.check_request_size(&hdr, size, 0)?;
+ let num = self.backend.get_max_mem_slots()?;
+ let msg = VhostUserU64::new(num);
+ self.send_reply_message(&hdr, &msg)?;
+ }
+ MasterReq::ADD_MEM_REG => {
+ if self.acked_protocol_features
+ & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits()
+ == 0
+ {
+ return Err(Error::InvalidOperation);
+ }
+ let fd = if let Some(fds) = &rfds {
+ if fds.len() != 1 {
+ return Err(Error::InvalidParam);
+ }
+ fds[0]
+ } else {
+ return Err(Error::InvalidParam);
+ };
+
+ let msg =
+ self.extract_request_body::<VhostUserSingleMemoryRegion>(&hdr, size, &buf)?;
+ let res = self.backend.add_mem_region(&msg, fd);
+ self.send_ack_message(&hdr, res)?;
+ }
+ MasterReq::REM_MEM_REG => {
+ if self.acked_protocol_features
+ & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits()
+ == 0
+ {
+ return Err(Error::InvalidOperation);
+ }
+
+ let msg =
+ self.extract_request_body::<VhostUserSingleMemoryRegion>(&hdr, size, &buf)?;
+ let res = self.backend.remove_mem_region(&msg);
+ self.send_ack_message(&hdr, res)?;
+ }
+ _ => {
+ return Err(Error::InvalidMessage);
+ }
+ }
+ Ok(())
+ }
+
+ fn set_mem_table(
+ &mut self,
+ hdr: &VhostUserMsgHeader<MasterReq>,
+ size: usize,
+ buf: &[u8],
+ rfds: Option<Vec<RawFd>>,
+ ) -> Result<()> {
+ self.check_request_size(&hdr, size, hdr.get_size() as usize)?;
+
+ // check message size is consistent
+ let hdrsize = mem::size_of::<VhostUserMemory>();
+ if size < hdrsize {
+ Endpoint::<MasterReq>::close_rfds(rfds);
+ return Err(Error::InvalidMessage);
+ }
+ let msg = unsafe { &*(buf.as_ptr() as *const VhostUserMemory) };
+ if !msg.is_valid() {
+ Endpoint::<MasterReq>::close_rfds(rfds);
+ return Err(Error::InvalidMessage);
+ }
+ if size != hdrsize + msg.num_regions as usize * mem::size_of::<VhostUserMemoryRegion>() {
+ Endpoint::<MasterReq>::close_rfds(rfds);
+ return Err(Error::InvalidMessage);
+ }
+
+ // validate number of fds matching number of memory regions
+ let fds = match rfds {
+ None => return Err(Error::InvalidMessage),
+ Some(fds) => {
+ if fds.len() != msg.num_regions as usize {
+ Endpoint::<MasterReq>::close_rfds(Some(fds));
+ return Err(Error::InvalidMessage);
+ }
+ fds
+ }
+ };
+
+ // Validate memory regions
+ let regions = unsafe {
+ slice::from_raw_parts(
+ buf.as_ptr().add(hdrsize) as *const VhostUserMemoryRegion,
+ msg.num_regions as usize,
+ )
+ };
+ for region in regions.iter() {
+ if !region.is_valid() {
+ Endpoint::<MasterReq>::close_rfds(Some(fds));
+ return Err(Error::InvalidMessage);
+ }
+ }
+
+ self.backend.set_mem_table(&regions, &fds)
+ }
+
+ fn get_config(&mut self, hdr: &VhostUserMsgHeader<MasterReq>, buf: &[u8]) -> Result<()> {
+ let payload_offset = mem::size_of::<VhostUserConfig>();
+ if buf.len() > MAX_MSG_SIZE || buf.len() < payload_offset {
+ return Err(Error::InvalidMessage);
+ }
+ let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const VhostUserConfig) };
+ if !msg.is_valid() {
+ return Err(Error::InvalidMessage);
+ }
+ if buf.len() - payload_offset != msg.size as usize {
+ return Err(Error::InvalidMessage);
+ }
+ let flags = match VhostUserConfigFlags::from_bits(msg.flags) {
+ Some(val) => val,
+ None => return Err(Error::InvalidMessage),
+ };
+ let res = self.backend.get_config(msg.offset, msg.size, flags);
+
+ // vhost-user slave's payload size MUST match master's request
+ // on success, uses zero length of payload to indicate an error
+ // to vhost-user master.
+ match res {
+ Ok(ref buf) if buf.len() == msg.size as usize => {
+ let reply = VhostUserConfig::new(msg.offset, buf.len() as u32, flags);
+ self.send_reply_with_payload(&hdr, &reply, buf.as_slice())?;
+ }
+ Ok(_) => {
+ let reply = VhostUserConfig::new(msg.offset, 0, flags);
+ self.send_reply_message(&hdr, &reply)?;
+ }
+ Err(_) => {
+ let reply = VhostUserConfig::new(msg.offset, 0, flags);
+ self.send_reply_message(&hdr, &reply)?;
+ }
+ }
+ Ok(())
+ }
+
+ fn set_config(
+ &mut self,
+ hdr: &VhostUserMsgHeader<MasterReq>,
+ size: usize,
+ buf: &[u8],
+ ) -> Result<()> {
+ if size > MAX_MSG_SIZE || size < mem::size_of::<VhostUserConfig>() {
+ return Err(Error::InvalidMessage);
+ }
+ let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const VhostUserConfig) };
+ if !msg.is_valid() {
+ return Err(Error::InvalidMessage);
+ }
+ if size - mem::size_of::<VhostUserConfig>() != msg.size as usize {
+ return Err(Error::InvalidMessage);
+ }
+ let flags: VhostUserConfigFlags;
+ match VhostUserConfigFlags::from_bits(msg.flags) {
+ Some(val) => flags = val,
+ None => return Err(Error::InvalidMessage),
+ }
+
+ let res = self.backend.set_config(msg.offset, buf, flags);
+ self.send_ack_message(&hdr, res)?;
+ Ok(())
+ }
+
+ fn set_slave_req_fd(
+ &mut self,
+ hdr: &VhostUserMsgHeader<MasterReq>,
+ rfds: Option<Vec<RawFd>>,
+ ) -> Result<()> {
+ if let Some(fds) = rfds {
+ if fds.len() == 1 {
+ let sock = unsafe { UnixStream::from_raw_fd(fds[0]) };
+ let vu_req = SlaveFsCacheReq::from_stream(sock);
+ self.backend.set_slave_req_fd(vu_req);
+ self.send_ack_message(&hdr, Ok(()))
+ } else {
+ Err(Error::InvalidMessage)
+ }
+ } else {
+ Err(Error::InvalidMessage)
+ }
+ }
+
+ fn handle_vring_fd_request(
+ &mut self,
+ buf: &[u8],
+ rfds: Option<Vec<RawFd>>,
+ ) -> Result<(u8, Option<RawFd>)> {
+ if buf.len() > MAX_MSG_SIZE || buf.len() < mem::size_of::<VhostUserU64>() {
+ return Err(Error::InvalidMessage);
+ }
+ let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const VhostUserU64) };
+ if !msg.is_valid() {
+ return Err(Error::InvalidMessage);
+ }
+
+ // Bits (0-7) of the payload contain the vring index. Bit 8 is the
+ // invalid FD flag. This flag is set when there is no file descriptor
+ // in the ancillary data. This signals that polling will be used
+ // instead of waiting for the call.
+ let nofd = (msg.value & 0x100u64) == 0x100u64;
+
+ let mut rfd = None;
+ match rfds {
+ Some(fds) => {
+ if !nofd && fds.len() == 1 {
+ rfd = Some(fds[0]);
+ } else if (nofd && !fds.is_empty()) || (!nofd && fds.len() != 1) {
+ Endpoint::<MasterReq>::close_rfds(Some(fds));
+ return Err(Error::InvalidMessage);
+ }
+ }
+ None => {
+ if !nofd {
+ return Err(Error::InvalidMessage);
+ }
+ }
+ }
+ Ok((msg.value as u8, rfd))
+ }
+
+ fn check_state(&self) -> Result<()> {
+ match self.error {
+ Some(e) => Err(Error::SocketBroken(std::io::Error::from_raw_os_error(e))),
+ None => Ok(()),
+ }
+ }
+
+ fn check_request_size(
+ &self,
+ hdr: &VhostUserMsgHeader<MasterReq>,
+ size: usize,
+ expected: usize,
+ ) -> Result<()> {
+ if hdr.get_size() as usize != expected
+ || hdr.is_reply()
+ || hdr.get_version() != 0x1
+ || size != expected
+ {
+ return Err(Error::InvalidMessage);
+ }
+ Ok(())
+ }
+
+ fn check_attached_rfds(
+ &self,
+ hdr: &VhostUserMsgHeader<MasterReq>,
+ rfds: Option<Vec<RawFd>>,
+ ) -> Result<Option<Vec<RawFd>>> {
+ match hdr.get_code() {
+ MasterReq::SET_MEM_TABLE => Ok(rfds),
+ MasterReq::SET_VRING_CALL => Ok(rfds),
+ MasterReq::SET_VRING_KICK => Ok(rfds),
+ MasterReq::SET_VRING_ERR => Ok(rfds),
+ MasterReq::SET_LOG_BASE => Ok(rfds),
+ MasterReq::SET_LOG_FD => Ok(rfds),
+ MasterReq::SET_SLAVE_REQ_FD => Ok(rfds),
+ MasterReq::SET_INFLIGHT_FD => Ok(rfds),
+ MasterReq::ADD_MEM_REG => Ok(rfds),
+ _ => {
+ if rfds.is_some() {
+ Endpoint::<MasterReq>::close_rfds(rfds);
+ Err(Error::InvalidMessage)
+ } else {
+ Ok(rfds)
+ }
+ }
+ }
+ }
+
+ fn extract_request_body<T: Sized + VhostUserMsgValidator>(
+ &self,
+ hdr: &VhostUserMsgHeader<MasterReq>,
+ size: usize,
+ buf: &[u8],
+ ) -> Result<T> {
+ self.check_request_size(hdr, size, mem::size_of::<T>())?;
+ let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const T) };
+ if !msg.is_valid() {
+ return Err(Error::InvalidMessage);
+ }
+ Ok(msg)
+ }
+
+ fn update_reply_ack_flag(&mut self) {
+ let vflag = VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits();
+ let pflag = VhostUserProtocolFeatures::REPLY_ACK;
+ if (self.virtio_features & vflag) != 0
+ && (self.acked_virtio_features & vflag) != 0
+ && self.protocol_features.contains(pflag)
+ && (self.acked_protocol_features & pflag.bits()) != 0
+ {
+ self.reply_ack_enabled = true;
+ } else {
+ self.reply_ack_enabled = false;
+ }
+ }
+
+ fn new_reply_header<T: Sized>(
+ &self,
+ req: &VhostUserMsgHeader<MasterReq>,
+ payload_size: usize,
+ ) -> Result<VhostUserMsgHeader<MasterReq>> {
+ if mem::size_of::<T>() > MAX_MSG_SIZE
+ || payload_size > MAX_MSG_SIZE
+ || mem::size_of::<T>() + payload_size > MAX_MSG_SIZE
+ {
+ return Err(Error::InvalidParam);
+ }
+ self.check_state()?;
+ Ok(VhostUserMsgHeader::new(
+ req.get_code(),
+ VhostUserHeaderFlag::REPLY.bits(),
+ (mem::size_of::<T>() + payload_size) as u32,
+ ))
+ }
+
+ fn send_ack_message(
+ &mut self,
+ req: &VhostUserMsgHeader<MasterReq>,
+ res: Result<()>,
+ ) -> Result<()> {
+ if self.reply_ack_enabled && req.is_need_reply() {
+ let hdr = self.new_reply_header::<VhostUserU64>(req, 0)?;
+ let val = match res {
+ Ok(_) => 0,
+ Err(_) => 1,
+ };
+ let msg = VhostUserU64::new(val);
+ self.main_sock.send_message(&hdr, &msg, None)?;
+ }
+ Ok(())
+ }
+
+ fn send_reply_message<T>(
+ &mut self,
+ req: &VhostUserMsgHeader<MasterReq>,
+ msg: &T,
+ ) -> Result<()> {
+ let hdr = self.new_reply_header::<T>(req, 0)?;
+ self.main_sock.send_message(&hdr, msg, None)?;
+ Ok(())
+ }
+
+ fn send_reply_with_payload<T: Sized>(
+ &mut self,
+ req: &VhostUserMsgHeader<MasterReq>,
+ msg: &T,
+ payload: &[u8],
+ ) -> Result<()> {
+ let hdr = self.new_reply_header::<T>(req, payload.len())?;
+ self.main_sock
+ .send_message_with_payload(&hdr, msg, payload, None)?;
+ Ok(())
+ }
+}
+
+impl<S: VhostUserSlaveReqHandler> AsRawFd for SlaveReqHandler<S> {
+ fn as_raw_fd(&self) -> RawFd {
+ self.main_sock.as_raw_fd()
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use std::os::unix::io::AsRawFd;
+
+ use super::*;
+ use crate::vhost_user::dummy_slave::DummySlaveReqHandler;
+
+ #[test]
+ fn test_slave_req_handler_new() {
+ let (p1, _p2) = UnixStream::pair().unwrap();
+ let endpoint = Endpoint::<MasterReq>::from_stream(p1);
+ let backend = Arc::new(Mutex::new(DummySlaveReqHandler::new()));
+ let mut handler = SlaveReqHandler::new(endpoint, backend);
+
+ handler.check_state().unwrap();
+ handler.set_failed(libc::EAGAIN);
+ handler.check_state().unwrap_err();
+ assert!(handler.as_raw_fd() >= 0);
+ }
+}