diff options
Diffstat (limited to 'src/vhost_user')
-rw-r--r-- | src/vhost_user/connection.rs | 858 | ||||
-rw-r--r-- | src/vhost_user/dummy_slave.rs | 259 | ||||
-rw-r--r-- | src/vhost_user/master.rs | 1071 | ||||
-rw-r--r-- | src/vhost_user/master_req_handler.rs | 477 | ||||
-rw-r--r-- | src/vhost_user/message.rs | 1042 | ||||
-rw-r--r-- | src/vhost_user/mod.rs | 456 | ||||
-rw-r--r-- | src/vhost_user/slave.rs | 86 | ||||
-rw-r--r-- | src/vhost_user/slave_fs_cache.rs | 226 | ||||
-rw-r--r-- | src/vhost_user/slave_req_handler.rs | 828 |
9 files changed, 5303 insertions, 0 deletions
diff --git a/src/vhost_user/connection.rs b/src/vhost_user/connection.rs new file mode 100644 index 0000000..f92db45 --- /dev/null +++ b/src/vhost_user/connection.rs @@ -0,0 +1,858 @@ +// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Structs for Unix Domain Socket listener and endpoint. + +#![allow(dead_code)] + +use std::io::ErrorKind; +use std::marker::PhantomData; +use std::os::unix::io::{AsRawFd, RawFd}; +use std::os::unix::net::{UnixListener, UnixStream}; +use std::path::{Path, PathBuf}; +use std::{mem, slice}; + +use libc::{c_void, iovec}; +use sys_util::ScmSocket; + +use super::message::*; +use super::{Error, Result}; + +/// Unix domain socket listener for accepting incoming connections. +pub struct Listener { + fd: UnixListener, + path: PathBuf, +} + +impl Listener { + /// Create a unix domain socket listener. + /// + /// # Return: + /// * - the new Listener object on success. + /// * - SocketError: failed to create listener socket. + pub fn new<P: AsRef<Path>>(path: P, unlink: bool) -> Result<Self> { + if unlink { + let _ = std::fs::remove_file(&path); + } + let fd = UnixListener::bind(&path).map_err(Error::SocketError)?; + Ok(Listener { + fd, + path: path.as_ref().to_owned(), + }) + } + + /// Accept an incoming connection. + /// + /// # Return: + /// * - Some(UnixStream): new UnixStream object if new incoming connection is available. + /// * - None: no incoming connection available. + /// * - SocketError: errors from accept(). + pub fn accept(&self) -> Result<Option<UnixStream>> { + loop { + match self.fd.accept() { + Ok((socket, _addr)) => return Ok(Some(socket)), + Err(e) => { + match e.kind() { + // No incoming connection available. + ErrorKind::WouldBlock => return Ok(None), + // New connection closed by peer. + ErrorKind::ConnectionAborted => return Ok(None), + // Interrupted by signals, retry + ErrorKind::Interrupted => continue, + _ => return Err(Error::SocketError(e)), + } + } + } + } + } + + /// Change blocking status on the listener. + /// + /// # Return: + /// * - () on success. + /// * - SocketError: failure from set_nonblocking(). + pub fn set_nonblocking(&self, block: bool) -> Result<()> { + self.fd.set_nonblocking(block).map_err(Error::SocketError) + } +} + +impl AsRawFd for Listener { + fn as_raw_fd(&self) -> RawFd { + self.fd.as_raw_fd() + } +} + +impl Drop for Listener { + fn drop(&mut self) { + let _ = std::fs::remove_file(&self.path); + } +} + +/// Unix domain socket endpoint for vhost-user connection. +pub(super) struct Endpoint<R: Req> { + sock: UnixStream, + _r: PhantomData<R>, +} + +impl<R: Req> Endpoint<R> { + /// Create a new stream by connecting to server at `str`. + /// + /// # Return: + /// * - the new Endpoint object on success. + /// * - SocketConnect: failed to connect to peer. + pub fn connect<P: AsRef<Path>>(path: P) -> Result<Self> { + let sock = UnixStream::connect(path).map_err(Error::SocketConnect)?; + Ok(Self::from_stream(sock)) + } + + /// Create an endpoint from a stream object. + pub fn from_stream(sock: UnixStream) -> Self { + Endpoint { + sock, + _r: PhantomData, + } + } + + /// Sends bytes from scatter-gather vectors over the socket with optional attached file + /// descriptors. + /// + /// # Return: + /// * - number of bytes sent on success + /// * - SocketRetry: temporary error caused by signals or short of resources. + /// * - SocketBroken: the underline socket is broken. + /// * - SocketError: other socket related errors. + pub fn send_iovec(&mut self, iovs: &[&[u8]], fds: Option<&[RawFd]>) -> Result<usize> { + let rfds = match fds { + Some(rfds) => rfds, + _ => &[], + }; + self.sock.send_bufs_with_fds(iovs, rfds).map_err(Into::into) + } + + /// Sends all bytes from scatter-gather vectors over the socket with optional attached file + /// descriptors. Will loop until all data has been transfered. + /// + /// # Return: + /// * - number of bytes sent on success + /// * - SocketBroken: the underline socket is broken. + /// * - SocketError: other socket related errors. + pub fn send_iovec_all(&mut self, iovs: &[&[u8]], fds: Option<&[RawFd]>) -> Result<usize> { + let mut data_sent = 0; + let mut data_total = 0; + let iov_lens: Vec<usize> = iovs.iter().map(|iov| iov.len()).collect(); + for len in &iov_lens { + data_total += len; + } + + while (data_total - data_sent) > 0 { + let (nr_skip, offset) = get_sub_iovs_offset(&iov_lens, data_sent); + let iov = &iovs[nr_skip][offset..]; + + let data = &[&[iov], &iovs[(nr_skip + 1)..]].concat(); + let sfds = if data_sent == 0 { fds } else { None }; + + let sent = self.send_iovec(data, sfds); + match sent { + Ok(0) => return Ok(data_sent), + Ok(n) => data_sent += n, + Err(e) => match e { + Error::SocketRetry(_) => {} + _ => return Err(e), + }, + } + } + Ok(data_sent) + } + + /// Sends bytes from a slice over the socket with optional attached file descriptors. + /// + /// # Return: + /// * - number of bytes sent on success + /// * - SocketRetry: temporary error caused by signals or short of resources. + /// * - SocketBroken: the underline socket is broken. + /// * - SocketError: other socket related errors. + pub fn send_slice(&mut self, data: &[u8], fds: Option<&[RawFd]>) -> Result<usize> { + self.send_iovec(&[data], fds) + } + + /// Sends a header-only message with optional attached file descriptors. + /// + /// # Return: + /// * - number of bytes sent on success + /// * - SocketRetry: temporary error caused by signals or short of resources. + /// * - SocketBroken: the underline socket is broken. + /// * - SocketError: other socket related errors. + /// * - PartialMessage: received a partial message. + pub fn send_header( + &mut self, + hdr: &VhostUserMsgHeader<R>, + fds: Option<&[RawFd]>, + ) -> Result<()> { + // Safe because there can't be other mutable referance to hdr. + let iovs = unsafe { + [slice::from_raw_parts( + hdr as *const VhostUserMsgHeader<R> as *const u8, + mem::size_of::<VhostUserMsgHeader<R>>(), + )] + }; + let bytes = self.send_iovec_all(&iovs[..], fds)?; + if bytes != mem::size_of::<VhostUserMsgHeader<R>>() { + return Err(Error::PartialMessage); + } + Ok(()) + } + + /// Send a message with header and body. Optional file descriptors may be attached to + /// the message. + /// + /// # Return: + /// * - number of bytes sent on success + /// * - SocketRetry: temporary error caused by signals or short of resources. + /// * - SocketBroken: the underline socket is broken. + /// * - SocketError: other socket related errors. + /// * - PartialMessage: received a partial message. + pub fn send_message<T: Sized>( + &mut self, + hdr: &VhostUserMsgHeader<R>, + body: &T, + fds: Option<&[RawFd]>, + ) -> Result<()> { + if mem::size_of::<T>() > MAX_MSG_SIZE { + return Err(Error::OversizedMsg); + } + // Safe because there can't be other mutable referance to hdr and body. + let iovs = unsafe { + [ + slice::from_raw_parts( + hdr as *const VhostUserMsgHeader<R> as *const u8, + mem::size_of::<VhostUserMsgHeader<R>>(), + ), + slice::from_raw_parts(body as *const T as *const u8, mem::size_of::<T>()), + ] + }; + let bytes = self.send_iovec_all(&iovs[..], fds)?; + if bytes != mem::size_of::<VhostUserMsgHeader<R>>() + mem::size_of::<T>() { + return Err(Error::PartialMessage); + } + Ok(()) + } + + /// Send a message with header, body and payload. Optional file descriptors + /// may also be attached to the message. + /// + /// # Return: + /// * - number of bytes sent on success + /// * - SocketRetry: temporary error caused by signals or short of resources. + /// * - SocketBroken: the underline socket is broken. + /// * - SocketError: other socket related errors. + /// * - OversizedMsg: message size is too big. + /// * - PartialMessage: received a partial message. + /// * - IncorrectFds: wrong number of attached fds. + pub fn send_message_with_payload<T: Sized>( + &mut self, + hdr: &VhostUserMsgHeader<R>, + body: &T, + payload: &[u8], + fds: Option<&[RawFd]>, + ) -> Result<()> { + let len = payload.len(); + if mem::size_of::<T>() > MAX_MSG_SIZE { + return Err(Error::OversizedMsg); + } + if len > MAX_MSG_SIZE - mem::size_of::<T>() { + return Err(Error::OversizedMsg); + } + if let Some(fd_arr) = fds { + if fd_arr.len() > MAX_ATTACHED_FD_ENTRIES { + return Err(Error::IncorrectFds); + } + } + + // Safe because there can't be other mutable reference to hdr, body and payload. + let iovs = unsafe { + [ + slice::from_raw_parts( + hdr as *const VhostUserMsgHeader<R> as *const u8, + mem::size_of::<VhostUserMsgHeader<R>>(), + ), + slice::from_raw_parts(body as *const T as *const u8, mem::size_of::<T>()), + slice::from_raw_parts(payload.as_ptr() as *const u8, len), + ] + }; + let total = mem::size_of::<VhostUserMsgHeader<R>>() + mem::size_of::<T>() + len; + let len = self.send_iovec_all(&iovs, fds)?; + if len != total { + return Err(Error::PartialMessage); + } + Ok(()) + } + + /// Reads bytes from the socket into the given scatter/gather vectors. + /// + /// # Return: + /// * - (number of bytes received, buf) on success + /// * - SocketRetry: temporary error caused by signals or short of resources. + /// * - SocketBroken: the underline socket is broken. + /// * - SocketError: other socket related errors. + pub fn recv_data(&mut self, len: usize) -> Result<(usize, Vec<u8>)> { + let mut rbuf = vec![0u8; len]; + let (bytes, _) = self.sock.recv_with_fds(&mut rbuf[..], &mut [])?; + Ok((bytes, rbuf)) + } + + /// Reads bytes from the socket into the given scatter/gather vectors with optional attached + /// file descriptors. + /// + /// The underlying communication channel is a Unix domain socket in STREAM mode. It's a little + /// tricky to pass file descriptors through such a communication channel. Let's assume that a + /// sender sending a message with some file descriptors attached. To successfully receive those + /// attached file descriptors, the receiver must obey following rules: + /// 1) file descriptors are attached to a message. + /// 2) message(packet) boundaries must be respected on the receive side. + /// In other words, recvmsg() operations must not cross the packet boundary, otherwise the + /// attached file descriptors will get lost. + /// + /// # Return: + /// * - (number of bytes received, [received fds]) on success + /// * - SocketRetry: temporary error caused by signals or short of resources. + /// * - SocketBroken: the underline socket is broken. + /// * - SocketError: other socket related errors. + pub fn recv_into_iovec(&mut self, iovs: &mut [iovec]) -> Result<(usize, Option<Vec<RawFd>>)> { + let mut fd_array = vec![0; MAX_ATTACHED_FD_ENTRIES]; + let (bytes, fds) = self.sock.recv_iovecs_with_fds(iovs, &mut fd_array)?; + let rfds = match fds { + 0 => None, + n => { + let mut fds = Vec::with_capacity(n); + fds.extend_from_slice(&fd_array[0..n]); + Some(fds) + } + }; + + Ok((bytes, rfds)) + } + + /// Reads all bytes from the socket into the given scatter/gather vectors with optional + /// attached file descriptors. Will loop until all data has been transfered. + /// + /// The underlying communication channel is a Unix domain socket in STREAM mode. It's a little + /// tricky to pass file descriptors through such a communication channel. Let's assume that a + /// sender sending a message with some file descriptors attached. To successfully receive those + /// attached file descriptors, the receiver must obey following rules: + /// 1) file descriptors are attached to a message. + /// 2) message(packet) boundaries must be respected on the receive side. + /// In other words, recvmsg() operations must not cross the packet boundary, otherwise the + /// attached file descriptors will get lost. + /// + /// # Return: + /// * - (number of bytes received, [received fds]) on success + /// * - SocketBroken: the underline socket is broken. + /// * - SocketError: other socket related errors. + pub fn recv_into_iovec_all( + &mut self, + iovs: &mut [iovec], + ) -> Result<(usize, Option<Vec<RawFd>>)> { + let mut data_read = 0; + let mut data_total = 0; + let mut rfds = None; + let iov_lens: Vec<usize> = iovs.iter().map(|iov| iov.iov_len).collect(); + for len in &iov_lens { + data_total += len; + } + + while (data_total - data_read) > 0 { + let (nr_skip, offset) = get_sub_iovs_offset(&iov_lens, data_read); + let iov = &mut iovs[nr_skip]; + + let mut data = [ + &[iovec { + iov_base: (iov.iov_base as usize + offset) as *mut c_void, + iov_len: iov.iov_len - offset, + }], + &iovs[(nr_skip + 1)..], + ] + .concat(); + + let res = self.recv_into_iovec(&mut data); + match res { + Ok((0, _)) => return Ok((data_read, rfds)), + Ok((n, fds)) => { + if data_read == 0 { + rfds = fds; + } + data_read += n; + } + Err(e) => match e { + Error::SocketRetry(_) => {} + _ => return Err(e), + }, + } + } + Ok((data_read, rfds)) + } + + /// Reads bytes from the socket into a new buffer with optional attached + /// file descriptors. Received file descriptors are set close-on-exec. + /// + /// # Return: + /// * - (number of bytes received, buf, [received fds]) on success. + /// * - SocketRetry: temporary error caused by signals or short of resources. + /// * - SocketBroken: the underline socket is broken. + /// * - SocketError: other socket related errors. + pub fn recv_into_buf( + &mut self, + buf_size: usize, + ) -> Result<(usize, Vec<u8>, Option<Vec<RawFd>>)> { + let mut buf = vec![0u8; buf_size]; + let (bytes, rfds) = { + let mut iovs = [iovec { + iov_base: buf.as_mut_ptr() as *mut c_void, + iov_len: buf_size, + }]; + self.recv_into_iovec(&mut iovs)? + }; + Ok((bytes, buf, rfds)) + } + + /// Receive a header-only message with optional attached file descriptors. + /// Note, only the first MAX_ATTACHED_FD_ENTRIES file descriptors will be + /// accepted and all other file descriptor will be discard silently. + /// + /// # Return: + /// * - (message header, [received fds]) on success. + /// * - SocketRetry: temporary error caused by signals or short of resources. + /// * - SocketBroken: the underline socket is broken. + /// * - SocketError: other socket related errors. + /// * - PartialMessage: received a partial message. + /// * - InvalidMessage: received a invalid message. + pub fn recv_header(&mut self) -> Result<(VhostUserMsgHeader<R>, Option<Vec<RawFd>>)> { + let mut hdr = VhostUserMsgHeader::default(); + let mut iovs = [iovec { + iov_base: (&mut hdr as *mut VhostUserMsgHeader<R>) as *mut c_void, + iov_len: mem::size_of::<VhostUserMsgHeader<R>>(), + }]; + let (bytes, rfds) = self.recv_into_iovec_all(&mut iovs[..])?; + + if bytes != mem::size_of::<VhostUserMsgHeader<R>>() { + return Err(Error::PartialMessage); + } else if !hdr.is_valid() { + return Err(Error::InvalidMessage); + } + + Ok((hdr, rfds)) + } + + /// Receive a message with optional attached file descriptors. + /// Note, only the first MAX_ATTACHED_FD_ENTRIES file descriptors will be + /// accepted and all other file descriptor will be discard silently. + /// + /// # Return: + /// * - (message header, message body, [received fds]) on success. + /// * - SocketRetry: temporary error caused by signals or short of resources. + /// * - SocketBroken: the underline socket is broken. + /// * - SocketError: other socket related errors. + /// * - PartialMessage: received a partial message. + /// * - InvalidMessage: received a invalid message. + pub fn recv_body<T: Sized + Default + VhostUserMsgValidator>( + &mut self, + ) -> Result<(VhostUserMsgHeader<R>, T, Option<Vec<RawFd>>)> { + let mut hdr = VhostUserMsgHeader::default(); + let mut body: T = Default::default(); + let mut iovs = [ + iovec { + iov_base: (&mut hdr as *mut VhostUserMsgHeader<R>) as *mut c_void, + iov_len: mem::size_of::<VhostUserMsgHeader<R>>(), + }, + iovec { + iov_base: (&mut body as *mut T) as *mut c_void, + iov_len: mem::size_of::<T>(), + }, + ]; + let (bytes, rfds) = self.recv_into_iovec_all(&mut iovs[..])?; + + let total = mem::size_of::<VhostUserMsgHeader<R>>() + mem::size_of::<T>(); + if bytes != total { + return Err(Error::PartialMessage); + } else if !hdr.is_valid() || !body.is_valid() { + return Err(Error::InvalidMessage); + } + + Ok((hdr, body, rfds)) + } + + /// Receive a message with header and optional content. Callers need to + /// pre-allocate a big enough buffer to receive the message body and + /// optional payload. If there are attached file descriptor associated + /// with the message, the first MAX_ATTACHED_FD_ENTRIES file descriptors + /// will be accepted and all other file descriptor will be discard + /// silently. + /// + /// # Return: + /// * - (message header, message size, [received fds]) on success. + /// * - SocketRetry: temporary error caused by signals or short of resources. + /// * - SocketBroken: the underline socket is broken. + /// * - SocketError: other socket related errors. + /// * - PartialMessage: received a partial message. + /// * - InvalidMessage: received a invalid message. + pub fn recv_body_into_buf( + &mut self, + buf: &mut [u8], + ) -> Result<(VhostUserMsgHeader<R>, usize, Option<Vec<RawFd>>)> { + let mut hdr = VhostUserMsgHeader::default(); + let mut iovs = [ + iovec { + iov_base: (&mut hdr as *mut VhostUserMsgHeader<R>) as *mut c_void, + iov_len: mem::size_of::<VhostUserMsgHeader<R>>(), + }, + iovec { + iov_base: buf.as_mut_ptr() as *mut c_void, + iov_len: buf.len(), + }, + ]; + let (bytes, rfds) = self.recv_into_iovec_all(&mut iovs[..])?; + + if bytes < mem::size_of::<VhostUserMsgHeader<R>>() { + return Err(Error::PartialMessage); + } else if !hdr.is_valid() { + return Err(Error::InvalidMessage); + } + + Ok((hdr, bytes - mem::size_of::<VhostUserMsgHeader<R>>(), rfds)) + } + + /// Receive a message with optional payload and attached file descriptors. + /// Note, only the first MAX_ATTACHED_FD_ENTRIES file descriptors will be + /// accepted and all other file descriptor will be discard silently. + /// + /// # Return: + /// * - (message header, message body, size of payload, [received fds]) on success. + /// * - SocketRetry: temporary error caused by signals or short of resources. + /// * - SocketBroken: the underline socket is broken. + /// * - SocketError: other socket related errors. + /// * - PartialMessage: received a partial message. + /// * - InvalidMessage: received a invalid message. + #[cfg_attr(feature = "cargo-clippy", allow(clippy::type_complexity))] + pub fn recv_payload_into_buf<T: Sized + Default + VhostUserMsgValidator>( + &mut self, + buf: &mut [u8], + ) -> Result<(VhostUserMsgHeader<R>, T, usize, Option<Vec<RawFd>>)> { + let mut hdr = VhostUserMsgHeader::default(); + let mut body: T = Default::default(); + let mut iovs = [ + iovec { + iov_base: (&mut hdr as *mut VhostUserMsgHeader<R>) as *mut c_void, + iov_len: mem::size_of::<VhostUserMsgHeader<R>>(), + }, + iovec { + iov_base: (&mut body as *mut T) as *mut c_void, + iov_len: mem::size_of::<T>(), + }, + iovec { + iov_base: buf.as_mut_ptr() as *mut c_void, + iov_len: buf.len(), + }, + ]; + let (bytes, rfds) = self.recv_into_iovec_all(&mut iovs[..])?; + + let total = mem::size_of::<VhostUserMsgHeader<R>>() + mem::size_of::<T>(); + if bytes < total { + return Err(Error::PartialMessage); + } else if !hdr.is_valid() || !body.is_valid() { + return Err(Error::InvalidMessage); + } + + Ok((hdr, body, bytes - total, rfds)) + } + + /// Close all raw file descriptors. + pub fn close_rfds(rfds: Option<Vec<RawFd>>) { + if let Some(fds) = rfds { + for fd in fds { + // safe because the rawfds are valid and we don't care about the result. + let _ = unsafe { libc::close(fd) }; + } + } + } +} + +impl<T: Req> AsRawFd for Endpoint<T> { + fn as_raw_fd(&self) -> RawFd { + self.sock.as_raw_fd() + } +} + +// Given a slice of sizes and the `skip_size`, return the offset of `skip_size` in the slice. +// For example: +// let iov_lens = vec![4, 4, 5]; +// let size = 6; +// assert_eq!(get_sub_iovs_offset(&iov_len, size), (1, 2)); +fn get_sub_iovs_offset(iov_lens: &[usize], skip_size: usize) -> (usize, usize) { + let mut size = skip_size; + let mut nr_skip = 0; + + for len in iov_lens { + if size >= *len { + size -= *len; + nr_skip += 1; + } else { + break; + } + } + (nr_skip, size) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::fs::File; + use std::io::{Read, Seek, SeekFrom, Write}; + use std::os::unix::io::FromRawFd; + use tempfile::{tempfile, Builder, TempDir}; + + fn temp_dir() -> TempDir { + Builder::new().prefix("/tmp/vhost_test").tempdir().unwrap() + } + + #[test] + fn create_listener() { + let dir = temp_dir(); + let mut path = dir.path().to_owned(); + path.push("sock"); + let listener = Listener::new(&path, true).unwrap(); + + assert!(listener.as_raw_fd() > 0); + } + + #[test] + fn accept_connection() { + let dir = temp_dir(); + let mut path = dir.path().to_owned(); + path.push("sock"); + let listener = Listener::new(&path, true).unwrap(); + listener.set_nonblocking(true).unwrap(); + + // accept on a fd without incoming connection + let conn = listener.accept().unwrap(); + assert!(conn.is_none()); + } + + #[test] + fn send_data() { + let dir = temp_dir(); + let mut path = dir.path().to_owned(); + path.push("sock"); + let listener = Listener::new(&path, true).unwrap(); + listener.set_nonblocking(true).unwrap(); + let mut master = Endpoint::<MasterReq>::connect(&path).unwrap(); + let sock = listener.accept().unwrap().unwrap(); + let mut slave = Endpoint::<MasterReq>::from_stream(sock); + + let buf1 = vec![0x1, 0x2, 0x3, 0x4]; + let mut len = master.send_slice(&buf1[..], None).unwrap(); + assert_eq!(len, 4); + let (bytes, buf2, _) = slave.recv_into_buf(0x1000).unwrap(); + assert_eq!(bytes, 4); + assert_eq!(&buf1[..], &buf2[..bytes]); + + len = master.send_slice(&buf1[..], None).unwrap(); + assert_eq!(len, 4); + let (bytes, buf2, _) = slave.recv_into_buf(0x2).unwrap(); + assert_eq!(bytes, 2); + assert_eq!(&buf1[..2], &buf2[..]); + let (bytes, buf2, _) = slave.recv_into_buf(0x2).unwrap(); + assert_eq!(bytes, 2); + assert_eq!(&buf1[2..], &buf2[..]); + } + + #[test] + fn send_fd() { + let dir = temp_dir(); + let mut path = dir.path().to_owned(); + path.push("sock"); + let listener = Listener::new(&path, true).unwrap(); + listener.set_nonblocking(true).unwrap(); + let mut master = Endpoint::<MasterReq>::connect(&path).unwrap(); + let sock = listener.accept().unwrap().unwrap(); + let mut slave = Endpoint::<MasterReq>::from_stream(sock); + + let mut fd = tempfile().unwrap(); + write!(fd, "test").unwrap(); + + // Normal case for sending/receiving file descriptors + let buf1 = vec![0x1, 0x2, 0x3, 0x4]; + let len = master + .send_slice(&buf1[..], Some(&[fd.as_raw_fd()])) + .unwrap(); + assert_eq!(len, 4); + + let (bytes, buf2, rfds) = slave.recv_into_buf(4).unwrap(); + assert_eq!(bytes, 4); + assert_eq!(&buf1[..], &buf2[..]); + assert!(rfds.is_some()); + let fds = rfds.unwrap(); + { + assert_eq!(fds.len(), 1); + let mut file = unsafe { File::from_raw_fd(fds[0]) }; + let mut content = String::new(); + file.seek(SeekFrom::Start(0)).unwrap(); + file.read_to_string(&mut content).unwrap(); + assert_eq!(content, "test"); + } + + // Following communication pattern should work: + // Sending side: data(header, body) with fds + // Receiving side: data(header) with fds, data(body) + let len = master + .send_slice( + &buf1[..], + Some(&[fd.as_raw_fd(), fd.as_raw_fd(), fd.as_raw_fd()]), + ) + .unwrap(); + assert_eq!(len, 4); + + let (bytes, buf2, rfds) = slave.recv_into_buf(0x2).unwrap(); + assert_eq!(bytes, 2); + assert_eq!(&buf1[..2], &buf2[..]); + assert!(rfds.is_some()); + let fds = rfds.unwrap(); + { + assert_eq!(fds.len(), 3); + let mut file = unsafe { File::from_raw_fd(fds[1]) }; + let mut content = String::new(); + file.seek(SeekFrom::Start(0)).unwrap(); + file.read_to_string(&mut content).unwrap(); + assert_eq!(content, "test"); + } + let (bytes, buf2, rfds) = slave.recv_into_buf(0x2).unwrap(); + assert_eq!(bytes, 2); + assert_eq!(&buf1[2..], &buf2[..]); + assert!(rfds.is_none()); + + // Following communication pattern should not work: + // Sending side: data(header, body) with fds + // Receiving side: data(header), data(body) with fds + let len = master + .send_slice( + &buf1[..], + Some(&[fd.as_raw_fd(), fd.as_raw_fd(), fd.as_raw_fd()]), + ) + .unwrap(); + assert_eq!(len, 4); + + let (bytes, buf4) = slave.recv_data(2).unwrap(); + assert_eq!(bytes, 2); + assert_eq!(&buf1[..2], &buf4[..]); + let (bytes, buf2, rfds) = slave.recv_into_buf(0x2).unwrap(); + assert_eq!(bytes, 2); + assert_eq!(&buf1[2..], &buf2[..]); + assert!(rfds.is_none()); + + // Following communication pattern should work: + // Sending side: data, data with fds + // Receiving side: data, data with fds + let len = master.send_slice(&buf1[..], None).unwrap(); + assert_eq!(len, 4); + let len = master + .send_slice( + &buf1[..], + Some(&[fd.as_raw_fd(), fd.as_raw_fd(), fd.as_raw_fd()]), + ) + .unwrap(); + assert_eq!(len, 4); + + let (bytes, buf2, rfds) = slave.recv_into_buf(0x4).unwrap(); + assert_eq!(bytes, 4); + assert_eq!(&buf1[..], &buf2[..]); + assert!(rfds.is_none()); + + let (bytes, buf2, rfds) = slave.recv_into_buf(0x2).unwrap(); + assert_eq!(bytes, 2); + assert_eq!(&buf1[..2], &buf2[..]); + assert!(rfds.is_some()); + let fds = rfds.unwrap(); + { + assert_eq!(fds.len(), 3); + let mut file = unsafe { File::from_raw_fd(fds[1]) }; + let mut content = String::new(); + file.seek(SeekFrom::Start(0)).unwrap(); + file.read_to_string(&mut content).unwrap(); + assert_eq!(content, "test"); + } + let (bytes, buf2, rfds) = slave.recv_into_buf(0x2).unwrap(); + assert_eq!(bytes, 2); + assert_eq!(&buf1[2..], &buf2[..]); + assert!(rfds.is_none()); + + // Following communication pattern should not work: + // Sending side: data1, data2 with fds + // Receiving side: data + partial of data2, left of data2 with fds + let len = master.send_slice(&buf1[..], None).unwrap(); + assert_eq!(len, 4); + let len = master + .send_slice( + &buf1[..], + Some(&[fd.as_raw_fd(), fd.as_raw_fd(), fd.as_raw_fd()]), + ) + .unwrap(); + assert_eq!(len, 4); + + let (bytes, _) = slave.recv_data(5).unwrap(); + assert_eq!(bytes, 5); + + let (bytes, _, rfds) = slave.recv_into_buf(0x4).unwrap(); + assert_eq!(bytes, 3); + assert!(rfds.is_none()); + + // If the target fd array is too small, extra file descriptors will get lost. + let len = master + .send_slice( + &buf1[..], + Some(&[fd.as_raw_fd(), fd.as_raw_fd(), fd.as_raw_fd()]), + ) + .unwrap(); + assert_eq!(len, 4); + + let (bytes, _, rfds) = slave.recv_into_buf(0x4).unwrap(); + assert_eq!(bytes, 4); + assert!(rfds.is_some()); + + Endpoint::<MasterReq>::close_rfds(rfds); + Endpoint::<MasterReq>::close_rfds(None); + } + + #[test] + fn send_recv() { + let dir = temp_dir(); + let mut path = dir.path().to_owned(); + path.push("sock"); + let listener = Listener::new(&path, true).unwrap(); + listener.set_nonblocking(true).unwrap(); + let mut master = Endpoint::<MasterReq>::connect(&path).unwrap(); + let sock = listener.accept().unwrap().unwrap(); + let mut slave = Endpoint::<MasterReq>::from_stream(sock); + + let mut hdr1 = + VhostUserMsgHeader::new(MasterReq::GET_FEATURES, 0, mem::size_of::<u64>() as u32); + hdr1.set_need_reply(true); + let features1 = 0x1u64; + master.send_message(&hdr1, &features1, None).unwrap(); + + let mut features2 = 0u64; + let slice = unsafe { + slice::from_raw_parts_mut( + (&mut features2 as *mut u64) as *mut u8, + mem::size_of::<u64>(), + ) + }; + let (hdr2, bytes, rfds) = slave.recv_body_into_buf(slice).unwrap(); + assert_eq!(hdr1, hdr2); + assert_eq!(bytes, 8); + assert_eq!(features1, features2); + assert!(rfds.is_none()); + + master.send_header(&hdr1, None).unwrap(); + let (hdr2, rfds) = slave.recv_header().unwrap(); + assert_eq!(hdr1, hdr2); + assert!(rfds.is_none()); + } +} diff --git a/src/vhost_user/dummy_slave.rs b/src/vhost_user/dummy_slave.rs new file mode 100644 index 0000000..b2b83d2 --- /dev/null +++ b/src/vhost_user/dummy_slave.rs @@ -0,0 +1,259 @@ +// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use std::os::unix::io::RawFd; + +use super::message::*; +use super::*; + +pub const MAX_QUEUE_NUM: usize = 2; +pub const MAX_VRING_NUM: usize = 256; +pub const MAX_MEM_SLOTS: usize = 32; +pub const VIRTIO_FEATURES: u64 = 0x40000003; + +#[derive(Default)] +pub struct DummySlaveReqHandler { + pub owned: bool, + pub features_acked: bool, + pub acked_features: u64, + pub acked_protocol_features: u64, + pub queue_num: usize, + pub vring_num: [u32; MAX_QUEUE_NUM], + pub vring_base: [u32; MAX_QUEUE_NUM], + pub call_fd: [Option<RawFd>; MAX_QUEUE_NUM], + pub kick_fd: [Option<RawFd>; MAX_QUEUE_NUM], + pub err_fd: [Option<RawFd>; MAX_QUEUE_NUM], + pub vring_started: [bool; MAX_QUEUE_NUM], + pub vring_enabled: [bool; MAX_QUEUE_NUM], +} + +impl DummySlaveReqHandler { + pub fn new() -> Self { + DummySlaveReqHandler { + queue_num: MAX_QUEUE_NUM, + ..Default::default() + } + } +} + +impl VhostUserSlaveReqHandlerMut for DummySlaveReqHandler { + fn set_owner(&mut self) -> Result<()> { + if self.owned { + return Err(Error::InvalidOperation); + } + self.owned = true; + Ok(()) + } + + fn reset_owner(&mut self) -> Result<()> { + self.owned = false; + self.features_acked = false; + self.acked_features = 0; + self.acked_protocol_features = 0; + Ok(()) + } + + fn get_features(&mut self) -> Result<u64> { + Ok(VIRTIO_FEATURES) + } + + fn set_features(&mut self, features: u64) -> Result<()> { + if !self.owned || self.features_acked { + return Err(Error::InvalidOperation); + } else if (features & !VIRTIO_FEATURES) != 0 { + return Err(Error::InvalidParam); + } + + self.acked_features = features; + self.features_acked = true; + + // If VHOST_USER_F_PROTOCOL_FEATURES has not been negotiated, + // the ring is initialized in an enabled state. + // If VHOST_USER_F_PROTOCOL_FEATURES has been negotiated, + // the ring is initialized in a disabled state. Client must not + // pass data to/from the backend until ring is enabled by + // VHOST_USER_SET_VRING_ENABLE with parameter 1, or after it has + // been disabled by VHOST_USER_SET_VRING_ENABLE with parameter 0. + let vring_enabled = + self.acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() == 0; + for enabled in &mut self.vring_enabled { + *enabled = vring_enabled; + } + + Ok(()) + } + + fn set_mem_table(&mut self, _ctx: &[VhostUserMemoryRegion], _fds: &[RawFd]) -> Result<()> { + Ok(()) + } + + fn set_vring_num(&mut self, index: u32, num: u32) -> Result<()> { + if index as usize >= self.queue_num || num == 0 || num as usize > MAX_VRING_NUM { + return Err(Error::InvalidParam); + } + self.vring_num[index as usize] = num; + Ok(()) + } + + fn set_vring_addr( + &mut self, + index: u32, + _flags: VhostUserVringAddrFlags, + _descriptor: u64, + _used: u64, + _available: u64, + _log: u64, + ) -> Result<()> { + if index as usize >= self.queue_num { + return Err(Error::InvalidParam); + } + Ok(()) + } + + fn set_vring_base(&mut self, index: u32, base: u32) -> Result<()> { + if index as usize >= self.queue_num || base as usize >= MAX_VRING_NUM { + return Err(Error::InvalidParam); + } + self.vring_base[index as usize] = base; + Ok(()) + } + + fn get_vring_base(&mut self, index: u32) -> Result<VhostUserVringState> { + if index as usize >= self.queue_num { + return Err(Error::InvalidParam); + } + // Quotation from vhost-user spec: + // Client must start ring upon receiving a kick (that is, detecting + // that file descriptor is readable) on the descriptor specified by + // VHOST_USER_SET_VRING_KICK, and stop ring upon receiving + // VHOST_USER_GET_VRING_BASE. + self.vring_started[index as usize] = false; + Ok(VhostUserVringState::new( + index, + self.vring_base[index as usize], + )) + } + + fn set_vring_kick(&mut self, index: u8, fd: Option<RawFd>) -> Result<()> { + if index as usize >= self.queue_num || index as usize > self.queue_num { + return Err(Error::InvalidParam); + } + if self.kick_fd[index as usize].is_some() { + // Close file descriptor set by previous operations. + let _ = unsafe { libc::close(self.kick_fd[index as usize].unwrap()) }; + } + self.kick_fd[index as usize] = fd; + + // Quotation from vhost-user spec: + // Client must start ring upon receiving a kick (that is, detecting + // that file descriptor is readable) on the descriptor specified by + // VHOST_USER_SET_VRING_KICK, and stop ring upon receiving + // VHOST_USER_GET_VRING_BASE. + // + // So we should add fd to event monitor(select, poll, epoll) here. + self.vring_started[index as usize] = true; + Ok(()) + } + + fn set_vring_call(&mut self, index: u8, fd: Option<RawFd>) -> Result<()> { + if index as usize >= self.queue_num || index as usize > self.queue_num { + return Err(Error::InvalidParam); + } + if self.call_fd[index as usize].is_some() { + // Close file descriptor set by previous operations. + let _ = unsafe { libc::close(self.call_fd[index as usize].unwrap()) }; + } + self.call_fd[index as usize] = fd; + Ok(()) + } + + fn set_vring_err(&mut self, index: u8, fd: Option<RawFd>) -> Result<()> { + if index as usize >= self.queue_num || index as usize > self.queue_num { + return Err(Error::InvalidParam); + } + if self.err_fd[index as usize].is_some() { + // Close file descriptor set by previous operations. + let _ = unsafe { libc::close(self.err_fd[index as usize].unwrap()) }; + } + self.err_fd[index as usize] = fd; + Ok(()) + } + + fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures> { + Ok(VhostUserProtocolFeatures::all()) + } + + fn set_protocol_features(&mut self, features: u64) -> Result<()> { + // Note: slave that reported VHOST_USER_F_PROTOCOL_FEATURES must + // support this message even before VHOST_USER_SET_FEATURES was + // called. + // What happens if the master calls set_features() with + // VHOST_USER_F_PROTOCOL_FEATURES cleared after calling this + // interface? + self.acked_protocol_features = features; + Ok(()) + } + + fn get_queue_num(&mut self) -> Result<u64> { + Ok(MAX_QUEUE_NUM as u64) + } + + fn set_vring_enable(&mut self, index: u32, enable: bool) -> Result<()> { + // This request should be handled only when VHOST_USER_F_PROTOCOL_FEATURES + // has been negotiated. + if self.acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() == 0 { + return Err(Error::InvalidOperation); + } else if index as usize >= self.queue_num || index as usize > self.queue_num { + return Err(Error::InvalidParam); + } + + // Slave must not pass data to/from the backend until ring is + // enabled by VHOST_USER_SET_VRING_ENABLE with parameter 1, + // or after it has been disabled by VHOST_USER_SET_VRING_ENABLE + // with parameter 0. + self.vring_enabled[index as usize] = enable; + Ok(()) + } + + fn get_config( + &mut self, + offset: u32, + size: u32, + _flags: VhostUserConfigFlags, + ) -> Result<Vec<u8>> { + if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 { + return Err(Error::InvalidOperation); + } else if !(VHOST_USER_CONFIG_OFFSET..VHOST_USER_CONFIG_SIZE).contains(&offset) + || size > VHOST_USER_CONFIG_SIZE - VHOST_USER_CONFIG_OFFSET + || size + offset > VHOST_USER_CONFIG_SIZE + { + return Err(Error::InvalidParam); + } + Ok(vec![0xa5; size as usize]) + } + + fn set_config(&mut self, offset: u32, buf: &[u8], _flags: VhostUserConfigFlags) -> Result<()> { + let size = buf.len() as u32; + if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 { + return Err(Error::InvalidOperation); + } else if !(VHOST_USER_CONFIG_OFFSET..VHOST_USER_CONFIG_SIZE).contains(&offset) + || size > VHOST_USER_CONFIG_SIZE - VHOST_USER_CONFIG_OFFSET + || size + offset > VHOST_USER_CONFIG_SIZE + { + return Err(Error::InvalidParam); + } + Ok(()) + } + + fn get_max_mem_slots(&mut self) -> Result<u64> { + Ok(MAX_MEM_SLOTS as u64) + } + + fn add_mem_region(&mut self, _region: &VhostUserSingleMemoryRegion, _fd: RawFd) -> Result<()> { + Ok(()) + } + + fn remove_mem_region(&mut self, _region: &VhostUserSingleMemoryRegion) -> Result<()> { + Ok(()) + } +} diff --git a/src/vhost_user/master.rs b/src/vhost_user/master.rs new file mode 100644 index 0000000..cc79871 --- /dev/null +++ b/src/vhost_user/master.rs @@ -0,0 +1,1071 @@ +// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Traits and Struct for vhost-user master. + +use std::mem; +use std::os::unix::io::{AsRawFd, RawFd}; +use std::os::unix::net::UnixStream; +use std::path::Path; +use std::sync::{Arc, Mutex, MutexGuard}; + +use sys_util::EventFd; + +use super::connection::Endpoint; +use super::message::*; +use super::{Error as VhostUserError, Result as VhostUserResult}; +use crate::backend::{VhostBackend, VhostUserMemoryRegionInfo, VringConfigData}; +use crate::{Error, Result}; + +/// Trait for vhost-user master to provide extra methods not covered by the VhostBackend yet. +pub trait VhostUserMaster: VhostBackend { + /// Get the protocol feature bitmask from the underlying vhost implementation. + fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures>; + + /// Enable protocol features in the underlying vhost implementation. + fn set_protocol_features(&mut self, features: VhostUserProtocolFeatures) -> Result<()>; + + /// Query how many queues the backend supports. + fn get_queue_num(&mut self) -> Result<u64>; + + /// Signal slave to enable or disable corresponding vring. + /// + /// Slave must not pass data to/from the backend until ring is enabled by + /// VHOST_USER_SET_VRING_ENABLE with parameter 1, or after it has been + /// disabled by VHOST_USER_SET_VRING_ENABLE with parameter 0. + fn set_vring_enable(&mut self, queue_index: usize, enable: bool) -> Result<()>; + + /// Fetch the contents of the virtio device configuration space. + fn get_config( + &mut self, + offset: u32, + size: u32, + flags: VhostUserConfigFlags, + buf: &[u8], + ) -> Result<(VhostUserConfig, VhostUserConfigPayload)>; + + /// Change the virtio device configuration space. It also can be used for live migration on the + /// destination host to set readonly configuration space fields. + fn set_config(&mut self, offset: u32, flags: VhostUserConfigFlags, buf: &[u8]) -> Result<()>; + + /// Setup slave communication channel. + fn set_slave_request_fd(&mut self, fd: RawFd) -> Result<()>; + + /// Query the maximum amount of memory slots supported by the backend. + fn get_max_mem_slots(&mut self) -> Result<u64>; + + /// Add a new guest memory mapping for vhost to use. + fn add_mem_region(&mut self, region: &VhostUserMemoryRegionInfo) -> Result<()>; + + /// Remove a guest memory mapping from vhost. + fn remove_mem_region(&mut self, region: &VhostUserMemoryRegionInfo) -> Result<()>; +} + +fn error_code<T>(err: VhostUserError) -> Result<T> { + Err(Error::VhostUserProtocol(err)) +} + +/// Struct for the vhost-user master endpoint. +#[derive(Clone)] +pub struct Master { + node: Arc<Mutex<MasterInternal>>, +} + +impl Master { + /// Create a new instance. + fn new(ep: Endpoint<MasterReq>, max_queue_num: u64) -> Self { + Master { + node: Arc::new(Mutex::new(MasterInternal { + main_sock: ep, + virtio_features: 0, + acked_virtio_features: 0, + protocol_features: 0, + acked_protocol_features: 0, + protocol_features_ready: false, + max_queue_num, + error: None, + })), + } + } + + fn node(&self) -> MutexGuard<MasterInternal> { + self.node.lock().unwrap() + } + + /// Create a new instance from a Unix stream socket. + pub fn from_stream(sock: UnixStream, max_queue_num: u64) -> Self { + Self::new(Endpoint::<MasterReq>::from_stream(sock), max_queue_num) + } + + /// Create a new vhost-user master endpoint. + /// + /// Will retry as the backend may not be ready to accept the connection. + /// + /// # Arguments + /// * `path` - path of Unix domain socket listener to connect to + pub fn connect<P: AsRef<Path>>(path: P, max_queue_num: u64) -> Result<Self> { + let mut retry_count = 5; + let endpoint = loop { + match Endpoint::<MasterReq>::connect(&path) { + Ok(endpoint) => break Ok(endpoint), + Err(e) => match &e { + VhostUserError::SocketConnect(why) => { + if why.kind() == std::io::ErrorKind::ConnectionRefused && retry_count > 0 { + std::thread::sleep(std::time::Duration::from_millis(100)); + retry_count -= 1; + continue; + } else { + break Err(e); + } + } + _ => break Err(e), + }, + } + }?; + + Ok(Self::new(endpoint, max_queue_num)) + } +} + +impl VhostBackend for Master { + /// Get from the underlying vhost implementation the feature bitmask. + fn get_features(&self) -> Result<u64> { + let mut node = self.node(); + let hdr = node.send_request_header(MasterReq::GET_FEATURES, None)?; + let val = node.recv_reply::<VhostUserU64>(&hdr)?; + node.virtio_features = val.value; + Ok(node.virtio_features) + } + + /// Enable features in the underlying vhost implementation using a bitmask. + fn set_features(&self, features: u64) -> Result<()> { + let mut node = self.node(); + let val = VhostUserU64::new(features); + let _ = node.send_request_with_body(MasterReq::SET_FEATURES, &val, None)?; + // Don't wait for ACK here because the protocol feature negotiation process hasn't been + // completed yet. + node.acked_virtio_features = features & node.virtio_features; + Ok(()) + } + + /// Set the current Master as an owner of the session. + fn set_owner(&self) -> Result<()> { + // We unwrap() the return value to assert that we are not expecting threads to ever fail + // while holding the lock. + let mut node = self.node(); + let _ = node.send_request_header(MasterReq::SET_OWNER, None)?; + // Don't wait for ACK here because the protocol feature negotiation process hasn't been + // completed yet. + Ok(()) + } + + fn reset_owner(&self) -> Result<()> { + let mut node = self.node(); + let _ = node.send_request_header(MasterReq::RESET_OWNER, None)?; + // Don't wait for ACK here because the protocol feature negotiation process hasn't been + // completed yet. + Ok(()) + } + + /// Set the memory map regions on the slave so it can translate the vring + /// addresses. In the ancillary data there is an array of file descriptors + fn set_mem_table(&self, regions: &[VhostUserMemoryRegionInfo]) -> Result<()> { + if regions.is_empty() || regions.len() > MAX_ATTACHED_FD_ENTRIES { + return error_code(VhostUserError::InvalidParam); + } + + let mut ctx = VhostUserMemoryContext::new(); + for region in regions.iter() { + if region.memory_size == 0 || region.mmap_handle < 0 { + return error_code(VhostUserError::InvalidParam); + } + let reg = VhostUserMemoryRegion { + guest_phys_addr: region.guest_phys_addr, + memory_size: region.memory_size, + user_addr: region.userspace_addr, + mmap_offset: region.mmap_offset, + }; + ctx.append(®, region.mmap_handle); + } + + let mut node = self.node(); + let body = VhostUserMemory::new(ctx.regions.len() as u32); + let (_, payload, _) = unsafe { ctx.regions.align_to::<u8>() }; + let hdr = node.send_request_with_payload( + MasterReq::SET_MEM_TABLE, + &body, + payload, + Some(ctx.fds.as_slice()), + )?; + node.wait_for_ack(&hdr).map_err(|e| e.into()) + } + + // Clippy doesn't seem to know that if let with && is still experimental + #[allow(clippy::unnecessary_unwrap)] + fn set_log_base(&self, base: u64, fd: Option<RawFd>) -> Result<()> { + let mut node = self.node(); + let val = VhostUserU64::new(base); + + if node.acked_protocol_features & VhostUserProtocolFeatures::LOG_SHMFD.bits() != 0 + && fd.is_some() + { + let fds = [fd.unwrap()]; + let _ = node.send_request_with_body(MasterReq::SET_LOG_BASE, &val, Some(&fds))?; + } else { + let _ = node.send_request_with_body(MasterReq::SET_LOG_BASE, &val, None)?; + } + Ok(()) + } + + fn set_log_fd(&self, fd: RawFd) -> Result<()> { + let mut node = self.node(); + let fds = [fd]; + node.send_request_header(MasterReq::SET_LOG_FD, Some(&fds))?; + Ok(()) + } + + /// Set the size of the queue. + fn set_vring_num(&self, queue_index: usize, num: u16) -> Result<()> { + let mut node = self.node(); + if queue_index as u64 >= node.max_queue_num { + return error_code(VhostUserError::InvalidParam); + } + + let val = VhostUserVringState::new(queue_index as u32, num.into()); + let hdr = node.send_request_with_body(MasterReq::SET_VRING_NUM, &val, None)?; + node.wait_for_ack(&hdr).map_err(|e| e.into()) + } + + /// Sets the addresses of the different aspects of the vring. + fn set_vring_addr(&self, queue_index: usize, config_data: &VringConfigData) -> Result<()> { + let mut node = self.node(); + if queue_index as u64 >= node.max_queue_num + || config_data.flags & !(VhostUserVringAddrFlags::all().bits()) != 0 + { + return error_code(VhostUserError::InvalidParam); + } + + let val = VhostUserVringAddr::from_config_data(queue_index as u32, config_data); + let hdr = node.send_request_with_body(MasterReq::SET_VRING_ADDR, &val, None)?; + node.wait_for_ack(&hdr).map_err(|e| e.into()) + } + + /// Sets the base offset in the available vring. + fn set_vring_base(&self, queue_index: usize, base: u16) -> Result<()> { + let mut node = self.node(); + if queue_index as u64 >= node.max_queue_num { + return error_code(VhostUserError::InvalidParam); + } + + let val = VhostUserVringState::new(queue_index as u32, base.into()); + let hdr = node.send_request_with_body(MasterReq::SET_VRING_BASE, &val, None)?; + node.wait_for_ack(&hdr).map_err(|e| e.into()) + } + + fn get_vring_base(&self, queue_index: usize) -> Result<u32> { + let mut node = self.node(); + if queue_index as u64 >= node.max_queue_num { + return error_code(VhostUserError::InvalidParam); + } + + let req = VhostUserVringState::new(queue_index as u32, 0); + let hdr = node.send_request_with_body(MasterReq::GET_VRING_BASE, &req, None)?; + let reply = node.recv_reply::<VhostUserVringState>(&hdr)?; + Ok(reply.num) + } + + /// Set the event file descriptor to signal when buffers are used. + /// Bits (0-7) of the payload contain the vring index. Bit 8 is the invalid FD flag. This flag + /// is set when there is no file descriptor in the ancillary data. This signals that polling + /// will be used instead of waiting for the call. + fn set_vring_call(&self, queue_index: usize, fd: &EventFd) -> Result<()> { + let mut node = self.node(); + if queue_index as u64 >= node.max_queue_num { + return error_code(VhostUserError::InvalidParam); + } + node.send_fd_for_vring(MasterReq::SET_VRING_CALL, queue_index, fd.as_raw_fd())?; + Ok(()) + } + + /// Set the event file descriptor for adding buffers to the vring. + /// Bits (0-7) of the payload contain the vring index. Bit 8 is the invalid FD flag. This flag + /// is set when there is no file descriptor in the ancillary data. This signals that polling + /// should be used instead of waiting for a kick. + fn set_vring_kick(&self, queue_index: usize, fd: &EventFd) -> Result<()> { + let mut node = self.node(); + if queue_index as u64 >= node.max_queue_num { + return error_code(VhostUserError::InvalidParam); + } + node.send_fd_for_vring(MasterReq::SET_VRING_KICK, queue_index, fd.as_raw_fd())?; + Ok(()) + } + + /// Set the event file descriptor to signal when error occurs. + /// Bits (0-7) of the payload contain the vring index. Bit 8 is the invalid FD flag. This flag + /// is set when there is no file descriptor in the ancillary data. + fn set_vring_err(&self, queue_index: usize, fd: &EventFd) -> Result<()> { + let mut node = self.node(); + if queue_index as u64 >= node.max_queue_num { + return error_code(VhostUserError::InvalidParam); + } + node.send_fd_for_vring(MasterReq::SET_VRING_ERR, queue_index, fd.as_raw_fd())?; + Ok(()) + } +} + +impl VhostUserMaster for Master { + fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures> { + let mut node = self.node(); + let flag = VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits(); + if node.virtio_features & flag == 0 || node.acked_virtio_features & flag == 0 { + return error_code(VhostUserError::InvalidOperation); + } + let hdr = node.send_request_header(MasterReq::GET_PROTOCOL_FEATURES, None)?; + let val = node.recv_reply::<VhostUserU64>(&hdr)?; + node.protocol_features = val.value; + // Should we support forward compatibility? + // If so just mask out unrecognized flags instead of return errors. + match VhostUserProtocolFeatures::from_bits(node.protocol_features) { + Some(val) => Ok(val), + None => error_code(VhostUserError::InvalidMessage), + } + } + + fn set_protocol_features(&mut self, features: VhostUserProtocolFeatures) -> Result<()> { + let mut node = self.node(); + let flag = VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits(); + if node.virtio_features & flag == 0 || node.acked_virtio_features & flag == 0 { + return error_code(VhostUserError::InvalidOperation); + } + let val = VhostUserU64::new(features.bits()); + let _ = node.send_request_with_body(MasterReq::SET_PROTOCOL_FEATURES, &val, None)?; + // Don't wait for ACK here because the protocol feature negotiation process hasn't been + // completed yet. + node.acked_protocol_features = features.bits(); + node.protocol_features_ready = true; + Ok(()) + } + + fn get_queue_num(&mut self) -> Result<u64> { + let mut node = self.node(); + if !node.is_feature_mq_available() { + return error_code(VhostUserError::InvalidOperation); + } + + let hdr = node.send_request_header(MasterReq::GET_QUEUE_NUM, None)?; + let val = node.recv_reply::<VhostUserU64>(&hdr)?; + if val.value > VHOST_USER_MAX_VRINGS { + return error_code(VhostUserError::InvalidMessage); + } + node.max_queue_num = val.value; + Ok(node.max_queue_num) + } + + fn set_vring_enable(&mut self, queue_index: usize, enable: bool) -> Result<()> { + let mut node = self.node(); + // set_vring_enable() is supported only when PROTOCOL_FEATURES has been enabled. + if node.acked_virtio_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() == 0 { + return error_code(VhostUserError::InvalidOperation); + } else if queue_index as u64 >= node.max_queue_num { + return error_code(VhostUserError::InvalidParam); + } + + let flag = if enable { 1 } else { 0 }; + let val = VhostUserVringState::new(queue_index as u32, flag); + let hdr = node.send_request_with_body(MasterReq::SET_VRING_ENABLE, &val, None)?; + node.wait_for_ack(&hdr).map_err(|e| e.into()) + } + + fn get_config( + &mut self, + offset: u32, + size: u32, + flags: VhostUserConfigFlags, + buf: &[u8], + ) -> Result<(VhostUserConfig, VhostUserConfigPayload)> { + let body = VhostUserConfig::new(offset, size, flags); + if !body.is_valid() { + return error_code(VhostUserError::InvalidParam); + } + + let mut node = self.node(); + // depends on VhostUserProtocolFeatures::CONFIG + if node.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 { + return error_code(VhostUserError::InvalidOperation); + } + + // vhost-user spec states that: + // "Master payload: virtio device config space" + // "Slave payload: virtio device config space" + let hdr = node.send_request_with_payload(MasterReq::GET_CONFIG, &body, buf, None)?; + let (body_reply, buf_reply, rfds) = + node.recv_reply_with_payload::<VhostUserConfig>(&hdr)?; + if rfds.is_some() { + Endpoint::<MasterReq>::close_rfds(rfds); + return error_code(VhostUserError::InvalidMessage); + } else if body_reply.size == 0 { + return error_code(VhostUserError::SlaveInternalError); + } else if body_reply.size != body.size + || body_reply.size as usize != buf.len() + || body_reply.offset != body.offset + { + return error_code(VhostUserError::InvalidMessage); + } + + Ok((body_reply, buf_reply)) + } + + fn set_config(&mut self, offset: u32, flags: VhostUserConfigFlags, buf: &[u8]) -> Result<()> { + if buf.len() > MAX_MSG_SIZE { + return error_code(VhostUserError::InvalidParam); + } + let body = VhostUserConfig::new(offset, buf.len() as u32, flags); + if !body.is_valid() { + return error_code(VhostUserError::InvalidParam); + } + + let mut node = self.node(); + // depends on VhostUserProtocolFeatures::CONFIG + if node.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 { + return error_code(VhostUserError::InvalidOperation); + } + + let hdr = node.send_request_with_payload(MasterReq::SET_CONFIG, &body, buf, None)?; + node.wait_for_ack(&hdr).map_err(|e| e.into()) + } + + fn set_slave_request_fd(&mut self, fd: RawFd) -> Result<()> { + let mut node = self.node(); + if node.acked_protocol_features & VhostUserProtocolFeatures::SLAVE_REQ.bits() == 0 { + return error_code(VhostUserError::InvalidOperation); + } + + let fds = [fd]; + node.send_request_header(MasterReq::SET_SLAVE_REQ_FD, Some(&fds))?; + Ok(()) + } + + fn get_max_mem_slots(&mut self) -> Result<u64> { + let mut node = self.node(); + if node.acked_protocol_features & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits() == 0 + { + return error_code(VhostUserError::InvalidOperation); + } + + let hdr = node.send_request_header(MasterReq::GET_MAX_MEM_SLOTS, None)?; + let val = node.recv_reply::<VhostUserU64>(&hdr)?; + + Ok(val.value) + } + + fn add_mem_region(&mut self, region: &VhostUserMemoryRegionInfo) -> Result<()> { + let mut node = self.node(); + if node.acked_protocol_features & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits() == 0 + { + return error_code(VhostUserError::InvalidOperation); + } + if region.memory_size == 0 || region.mmap_handle < 0 { + return error_code(VhostUserError::InvalidParam); + } + + let body = VhostUserSingleMemoryRegion::new( + region.guest_phys_addr, + region.memory_size, + region.userspace_addr, + region.mmap_offset, + ); + let fds = [region.mmap_handle]; + let hdr = node.send_request_with_body(MasterReq::ADD_MEM_REG, &body, Some(&fds))?; + node.wait_for_ack(&hdr).map_err(|e| e.into()) + } + + fn remove_mem_region(&mut self, region: &VhostUserMemoryRegionInfo) -> Result<()> { + let mut node = self.node(); + if node.acked_protocol_features & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits() == 0 + { + return error_code(VhostUserError::InvalidOperation); + } + if region.memory_size == 0 { + return error_code(VhostUserError::InvalidParam); + } + + let body = VhostUserSingleMemoryRegion::new( + region.guest_phys_addr, + region.memory_size, + region.userspace_addr, + region.mmap_offset, + ); + let hdr = node.send_request_with_body(MasterReq::REM_MEM_REG, &body, None)?; + node.wait_for_ack(&hdr).map_err(|e| e.into()) + } +} + +impl AsRawFd for Master { + fn as_raw_fd(&self) -> RawFd { + let node = self.node(); + node.main_sock.as_raw_fd() + } +} + +/// Context object to pass guest memory configuration to VhostUserMaster::set_mem_table(). +struct VhostUserMemoryContext { + regions: VhostUserMemoryPayload, + fds: Vec<RawFd>, +} + +impl VhostUserMemoryContext { + /// Create a context object. + pub fn new() -> Self { + VhostUserMemoryContext { + regions: VhostUserMemoryPayload::new(), + fds: Vec::new(), + } + } + + /// Append a user memory region and corresponding RawFd into the context object. + pub fn append(&mut self, region: &VhostUserMemoryRegion, fd: RawFd) { + self.regions.push(*region); + self.fds.push(fd); + } +} + +struct MasterInternal { + // Used to send requests to the slave. + main_sock: Endpoint<MasterReq>, + // Cached virtio features from the slave. + virtio_features: u64, + // Cached acked virtio features from the driver. + acked_virtio_features: u64, + // Cached vhost-user protocol features from the slave. + protocol_features: u64, + // Cached vhost-user protocol features. + acked_protocol_features: u64, + // Cached vhost-user protocol features are ready to use. + protocol_features_ready: bool, + // Cached maxinum number of queues supported from the slave. + max_queue_num: u64, + // Internal flag to mark failure state. + error: Option<i32>, +} + +impl MasterInternal { + fn send_request_header( + &mut self, + code: MasterReq, + fds: Option<&[RawFd]>, + ) -> VhostUserResult<VhostUserMsgHeader<MasterReq>> { + self.check_state()?; + let hdr = Self::new_request_header(code, 0); + self.main_sock.send_header(&hdr, fds)?; + Ok(hdr) + } + + fn send_request_with_body<T: Sized>( + &mut self, + code: MasterReq, + msg: &T, + fds: Option<&[RawFd]>, + ) -> VhostUserResult<VhostUserMsgHeader<MasterReq>> { + if mem::size_of::<T>() > MAX_MSG_SIZE { + return Err(VhostUserError::InvalidParam); + } + self.check_state()?; + + let hdr = Self::new_request_header(code, mem::size_of::<T>() as u32); + self.main_sock.send_message(&hdr, msg, fds)?; + Ok(hdr) + } + + fn send_request_with_payload<T: Sized>( + &mut self, + code: MasterReq, + msg: &T, + payload: &[u8], + fds: Option<&[RawFd]>, + ) -> VhostUserResult<VhostUserMsgHeader<MasterReq>> { + let len = mem::size_of::<T>() + payload.len(); + if len > MAX_MSG_SIZE { + return Err(VhostUserError::InvalidParam); + } + if let Some(ref fd_arr) = fds { + if fd_arr.len() > MAX_ATTACHED_FD_ENTRIES { + return Err(VhostUserError::InvalidParam); + } + } + self.check_state()?; + + let hdr = Self::new_request_header(code, len as u32); + self.main_sock + .send_message_with_payload(&hdr, msg, payload, fds)?; + Ok(hdr) + } + + fn send_fd_for_vring( + &mut self, + code: MasterReq, + queue_index: usize, + fd: RawFd, + ) -> VhostUserResult<VhostUserMsgHeader<MasterReq>> { + if queue_index as u64 >= self.max_queue_num { + return Err(VhostUserError::InvalidParam); + } + self.check_state()?; + + // Bits (0-7) of the payload contain the vring index. Bit 8 is the invalid FD flag. + // This flag is set when there is no file descriptor in the ancillary data. This signals + // that polling will be used instead of waiting for the call. + let msg = VhostUserU64::new(queue_index as u64); + let hdr = Self::new_request_header(code, mem::size_of::<VhostUserU64>() as u32); + self.main_sock.send_message(&hdr, &msg, Some(&[fd]))?; + Ok(hdr) + } + + fn recv_reply<T: Sized + Default + VhostUserMsgValidator>( + &mut self, + hdr: &VhostUserMsgHeader<MasterReq>, + ) -> VhostUserResult<T> { + if mem::size_of::<T>() > MAX_MSG_SIZE || hdr.is_reply() { + return Err(VhostUserError::InvalidParam); + } + self.check_state()?; + + let (reply, body, rfds) = self.main_sock.recv_body::<T>()?; + if !reply.is_reply_for(&hdr) || rfds.is_some() || !body.is_valid() { + Endpoint::<MasterReq>::close_rfds(rfds); + return Err(VhostUserError::InvalidMessage); + } + Ok(body) + } + + fn recv_reply_with_payload<T: Sized + Default + VhostUserMsgValidator>( + &mut self, + hdr: &VhostUserMsgHeader<MasterReq>, + ) -> VhostUserResult<(T, Vec<u8>, Option<Vec<RawFd>>)> { + if mem::size_of::<T>() > MAX_MSG_SIZE + || hdr.get_size() as usize <= mem::size_of::<T>() + || hdr.get_size() as usize > MAX_MSG_SIZE + || hdr.is_reply() + { + return Err(VhostUserError::InvalidParam); + } + self.check_state()?; + + let mut buf: Vec<u8> = vec![0; hdr.get_size() as usize - mem::size_of::<T>()]; + let (reply, body, bytes, rfds) = self.main_sock.recv_payload_into_buf::<T>(&mut buf)?; + if !reply.is_reply_for(hdr) + || reply.get_size() as usize != mem::size_of::<T>() + bytes + || rfds.is_some() + || !body.is_valid() + { + Endpoint::<MasterReq>::close_rfds(rfds); + return Err(VhostUserError::InvalidMessage); + } else if bytes != buf.len() { + return Err(VhostUserError::InvalidMessage); + } + Ok((body, buf, rfds)) + } + + fn wait_for_ack(&mut self, hdr: &VhostUserMsgHeader<MasterReq>) -> VhostUserResult<()> { + if self.acked_protocol_features & VhostUserProtocolFeatures::REPLY_ACK.bits() == 0 + || !hdr.is_need_reply() + { + return Ok(()); + } + self.check_state()?; + + let (reply, body, rfds) = self.main_sock.recv_body::<VhostUserU64>()?; + if !reply.is_reply_for(&hdr) || rfds.is_some() || !body.is_valid() { + Endpoint::<MasterReq>::close_rfds(rfds); + return Err(VhostUserError::InvalidMessage); + } + if body.value != 0 { + return Err(VhostUserError::SlaveInternalError); + } + Ok(()) + } + + fn is_feature_mq_available(&self) -> bool { + self.acked_protocol_features & VhostUserProtocolFeatures::MQ.bits() != 0 + } + + fn check_state(&self) -> VhostUserResult<()> { + match self.error { + Some(e) => Err(VhostUserError::SocketBroken( + std::io::Error::from_raw_os_error(e), + )), + None => Ok(()), + } + } + + #[inline] + fn new_request_header(request: MasterReq, size: u32) -> VhostUserMsgHeader<MasterReq> { + // TODO: handle NEED_REPLY flag + VhostUserMsgHeader::new(request, 0x1, size) + } +} + +#[cfg(test)] +mod tests { + use super::super::connection::Listener; + use super::*; + use tempfile::{Builder, TempDir}; + + fn temp_dir() -> TempDir { + Builder::new().prefix("/tmp/vhost_test").tempdir().unwrap() + } + + fn create_pair<P: AsRef<Path>>(path: P) -> (Master, Endpoint<MasterReq>) { + let listener = Listener::new(&path, true).unwrap(); + listener.set_nonblocking(true).unwrap(); + let master = Master::connect(path, 2).unwrap(); + let slave = listener.accept().unwrap().unwrap(); + (master, Endpoint::from_stream(slave)) + } + + #[test] + fn create_master() { + let dir = temp_dir(); + let mut path = dir.path().to_owned(); + path.push("sock"); + let listener = Listener::new(&path, true).unwrap(); + listener.set_nonblocking(true).unwrap(); + + let master = Master::connect(&path, 1).unwrap(); + let mut slave = Endpoint::<MasterReq>::from_stream(listener.accept().unwrap().unwrap()); + + assert!(master.as_raw_fd() > 0); + // Send two messages continuously + master.set_owner().unwrap(); + master.reset_owner().unwrap(); + + let (hdr, rfds) = slave.recv_header().unwrap(); + assert_eq!(hdr.get_code(), MasterReq::SET_OWNER); + assert_eq!(hdr.get_size(), 0); + assert_eq!(hdr.get_version(), 0x1); + assert!(rfds.is_none()); + + let (hdr, rfds) = slave.recv_header().unwrap(); + assert_eq!(hdr.get_code(), MasterReq::RESET_OWNER); + assert_eq!(hdr.get_size(), 0); + assert_eq!(hdr.get_version(), 0x1); + assert!(rfds.is_none()); + } + + #[test] + fn test_create_failure() { + let dir = temp_dir(); + let mut path = dir.path().to_owned(); + path.push("sock"); + let _ = Listener::new(&path, true).unwrap(); + let _ = Listener::new(&path, false).is_err(); + assert!(Master::connect(&path, 1).is_err()); + + let listener = Listener::new(&path, true).unwrap(); + assert!(Listener::new(&path, false).is_err()); + listener.set_nonblocking(true).unwrap(); + + let _master = Master::connect(&path, 1).unwrap(); + let _slave = listener.accept().unwrap().unwrap(); + } + + #[test] + fn test_features() { + let dir = temp_dir(); + let mut path = dir.path().to_owned(); + path.push("sock"); + let (master, mut peer) = create_pair(&path); + + master.set_owner().unwrap(); + let (hdr, rfds) = peer.recv_header().unwrap(); + assert_eq!(hdr.get_code(), MasterReq::SET_OWNER); + assert_eq!(hdr.get_size(), 0); + assert_eq!(hdr.get_version(), 0x1); + assert!(rfds.is_none()); + + let hdr = VhostUserMsgHeader::new(MasterReq::GET_FEATURES, 0x4, 8); + let msg = VhostUserU64::new(0x15); + peer.send_message(&hdr, &msg, None).unwrap(); + let features = master.get_features().unwrap(); + assert_eq!(features, 0x15u64); + let (_hdr, rfds) = peer.recv_header().unwrap(); + assert!(rfds.is_none()); + + let hdr = VhostUserMsgHeader::new(MasterReq::SET_FEATURES, 0x4, 8); + let msg = VhostUserU64::new(0x15); + peer.send_message(&hdr, &msg, None).unwrap(); + master.set_features(0x15).unwrap(); + let (_hdr, msg, rfds) = peer.recv_body::<VhostUserU64>().unwrap(); + assert!(rfds.is_none()); + let val = msg.value; + assert_eq!(val, 0x15); + + let hdr = VhostUserMsgHeader::new(MasterReq::GET_FEATURES, 0x4, 8); + let msg = 0x15u32; + peer.send_message(&hdr, &msg, None).unwrap(); + assert!(master.get_features().is_err()); + } + + #[test] + fn test_protocol_features() { + let dir = temp_dir(); + let mut path = dir.path().to_owned(); + path.push("sock"); + let (mut master, mut peer) = create_pair(&path); + + master.set_owner().unwrap(); + let (hdr, rfds) = peer.recv_header().unwrap(); + assert_eq!(hdr.get_code(), MasterReq::SET_OWNER); + assert!(rfds.is_none()); + + assert!(master.get_protocol_features().is_err()); + assert!(master + .set_protocol_features(VhostUserProtocolFeatures::all()) + .is_err()); + + let vfeatures = 0x15 | VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits(); + let hdr = VhostUserMsgHeader::new(MasterReq::GET_FEATURES, 0x4, 8); + let msg = VhostUserU64::new(vfeatures); + peer.send_message(&hdr, &msg, None).unwrap(); + let features = master.get_features().unwrap(); + assert_eq!(features, vfeatures); + let (_hdr, rfds) = peer.recv_header().unwrap(); + assert!(rfds.is_none()); + + master.set_features(vfeatures).unwrap(); + let (_hdr, msg, rfds) = peer.recv_body::<VhostUserU64>().unwrap(); + assert!(rfds.is_none()); + let val = msg.value; + assert_eq!(val, vfeatures); + + let pfeatures = VhostUserProtocolFeatures::all(); + let hdr = VhostUserMsgHeader::new(MasterReq::GET_PROTOCOL_FEATURES, 0x4, 8); + let msg = VhostUserU64::new(pfeatures.bits()); + peer.send_message(&hdr, &msg, None).unwrap(); + let features = master.get_protocol_features().unwrap(); + assert_eq!(features, pfeatures); + let (_hdr, rfds) = peer.recv_header().unwrap(); + assert!(rfds.is_none()); + + master.set_protocol_features(pfeatures).unwrap(); + let (_hdr, msg, rfds) = peer.recv_body::<VhostUserU64>().unwrap(); + assert!(rfds.is_none()); + let val = msg.value; + assert_eq!(val, pfeatures.bits()); + + let hdr = VhostUserMsgHeader::new(MasterReq::SET_PROTOCOL_FEATURES, 0x4, 8); + let msg = VhostUserU64::new(pfeatures.bits()); + peer.send_message(&hdr, &msg, None).unwrap(); + assert!(master.get_protocol_features().is_err()); + } + + #[test] + fn test_master_set_config_negative() { + let dir = temp_dir(); + let mut path = dir.path().to_owned(); + path.push("sock"); + let (mut master, _peer) = create_pair(&path); + let buf = vec![0x0; MAX_MSG_SIZE + 1]; + + master + .set_config(0x100, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .unwrap_err(); + + { + let mut node = master.node(); + node.virtio_features = 0xffff_ffff; + node.acked_virtio_features = 0xffff_ffff; + node.protocol_features = 0xffff_ffff; + node.acked_protocol_features = 0xffff_ffff; + } + + master + .set_config(0x100, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .unwrap(); + master + .set_config(0x0, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .unwrap_err(); + master + .set_config(0x1000, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .unwrap_err(); + master + .set_config( + 0x100, + unsafe { VhostUserConfigFlags::from_bits_unchecked(0xffff_ffff) }, + &buf[0..4], + ) + .unwrap_err(); + master + .set_config(0x100, VhostUserConfigFlags::WRITABLE, &buf) + .unwrap_err(); + master + .set_config(0x100, VhostUserConfigFlags::WRITABLE, &[]) + .unwrap_err(); + } + + fn create_pair2() -> (Master, Endpoint<MasterReq>) { + let dir = temp_dir(); + let mut path = dir.path().to_owned(); + path.push("sock"); + let (master, peer) = create_pair(&path); + + { + let mut node = master.node(); + node.virtio_features = 0xffff_ffff; + node.acked_virtio_features = 0xffff_ffff; + node.protocol_features = 0xffff_ffff; + node.acked_protocol_features = 0xffff_ffff; + } + + (master, peer) + } + + #[test] + fn test_master_get_config_negative0() { + let (mut master, mut peer) = create_pair2(); + let buf = vec![0x0; MAX_MSG_SIZE + 1]; + + let mut hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16); + let msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty()); + peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None) + .unwrap(); + assert!(master + .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .is_ok()); + + hdr.set_code(MasterReq::GET_FEATURES); + peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None) + .unwrap(); + assert!(master + .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .is_err()); + hdr.set_code(MasterReq::GET_CONFIG); + } + + #[test] + fn test_master_get_config_negative1() { + let (mut master, mut peer) = create_pair2(); + let buf = vec![0x0; MAX_MSG_SIZE + 1]; + + let mut hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16); + let msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty()); + peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None) + .unwrap(); + assert!(master + .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .is_ok()); + + hdr.set_reply(false); + peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None) + .unwrap(); + assert!(master + .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .is_err()); + } + + #[test] + fn test_master_get_config_negative2() { + let (mut master, mut peer) = create_pair2(); + let buf = vec![0x0; MAX_MSG_SIZE + 1]; + + let hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16); + let msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty()); + peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None) + .unwrap(); + assert!(master + .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .is_ok()); + } + + #[test] + fn test_master_get_config_negative3() { + let (mut master, mut peer) = create_pair2(); + let buf = vec![0x0; MAX_MSG_SIZE + 1]; + + let hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16); + let mut msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty()); + peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None) + .unwrap(); + assert!(master + .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .is_ok()); + + msg.offset = 0; + peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None) + .unwrap(); + assert!(master + .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .is_err()); + } + + #[test] + fn test_master_get_config_negative4() { + let (mut master, mut peer) = create_pair2(); + let buf = vec![0x0; MAX_MSG_SIZE + 1]; + + let hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16); + let mut msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty()); + peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None) + .unwrap(); + assert!(master + .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .is_ok()); + + msg.offset = 0x101; + peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None) + .unwrap(); + assert!(master + .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .is_err()); + } + + #[test] + fn test_master_get_config_negative5() { + let (mut master, mut peer) = create_pair2(); + let buf = vec![0x0; MAX_MSG_SIZE + 1]; + + let hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16); + let mut msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty()); + peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None) + .unwrap(); + assert!(master + .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .is_ok()); + + msg.offset = (MAX_MSG_SIZE + 1) as u32; + peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None) + .unwrap(); + assert!(master + .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .is_err()); + } + + #[test] + fn test_master_get_config_negative6() { + let (mut master, mut peer) = create_pair2(); + let buf = vec![0x0; MAX_MSG_SIZE + 1]; + + let hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16); + let mut msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty()); + peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None) + .unwrap(); + assert!(master + .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .is_ok()); + + msg.size = 6; + peer.send_message_with_payload(&hdr, &msg, &buf[0..6], None) + .unwrap(); + assert!(master + .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .is_err()); + } + + #[test] + fn test_maset_set_mem_table_failure() { + let (master, _peer) = create_pair2(); + + master.set_mem_table(&[]).unwrap_err(); + let tables = vec![VhostUserMemoryRegionInfo::default(); MAX_ATTACHED_FD_ENTRIES + 1]; + master.set_mem_table(&tables).unwrap_err(); + } +} diff --git a/src/vhost_user/master_req_handler.rs b/src/vhost_user/master_req_handler.rs new file mode 100644 index 0000000..8cba188 --- /dev/null +++ b/src/vhost_user/master_req_handler.rs @@ -0,0 +1,477 @@ +// Copyright (C) 2019-2021 Alibaba Cloud. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use std::mem; +use std::os::unix::io::{AsRawFd, RawFd}; +use std::os::unix::net::UnixStream; +use std::sync::{Arc, Mutex}; + +use super::connection::Endpoint; +use super::message::*; +use super::{Error, HandlerResult, Result}; + +/// Define services provided by masters for the slave communication channel. +/// +/// The vhost-user specification defines a slave communication channel, by which slaves could +/// request services from masters. The [VhostUserMasterReqHandler] trait defines services provided +/// by masters, and it's used both on the master side and slave side. +/// - on the slave side, a stub forwarder implementing [VhostUserMasterReqHandler] will proxy +/// service requests to masters. The [SlaveFsCacheReq] is an example stub forwarder. +/// - on the master side, the [MasterReqHandler] will forward service requests to a handler +/// implementing [VhostUserMasterReqHandler]. +/// +/// The [VhostUserMasterReqHandler] trait is design with interior mutability to improve performance +/// for multi-threading. +/// +/// [VhostUserMasterReqHandler]: trait.VhostUserMasterReqHandler.html +/// [MasterReqHandler]: struct.MasterReqHandler.html +/// [SlaveFsCacheReq]: struct.SlaveFsCacheReq.html +pub trait VhostUserMasterReqHandler { + /// Handle device configuration change notifications. + fn handle_config_change(&self) -> HandlerResult<u64> { + Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) + } + + /// Handle virtio-fs map file requests. + fn fs_slave_map(&self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> { + // Safe because we have just received the rawfd from kernel. + unsafe { libc::close(fd) }; + Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) + } + + /// Handle virtio-fs unmap file requests. + fn fs_slave_unmap(&self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> { + Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) + } + + /// Handle virtio-fs sync file requests. + fn fs_slave_sync(&self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> { + Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) + } + + /// Handle virtio-fs file IO requests. + fn fs_slave_io(&self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> { + // Safe because we have just received the rawfd from kernel. + unsafe { libc::close(fd) }; + Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) + } + + // fn handle_iotlb_msg(&mut self, iotlb: VhostUserIotlb); + // fn handle_vring_host_notifier(&mut self, area: VhostUserVringArea, fd: RawFd); +} + +/// A helper trait mirroring [VhostUserMasterReqHandler] but without interior mutability. +/// +/// [VhostUserMasterReqHandler]: trait.VhostUserMasterReqHandler.html +pub trait VhostUserMasterReqHandlerMut { + /// Handle device configuration change notifications. + fn handle_config_change(&mut self) -> HandlerResult<u64> { + Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) + } + + /// Handle virtio-fs map file requests. + fn fs_slave_map(&mut self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> { + // Safe because we have just received the rawfd from kernel. + unsafe { libc::close(fd) }; + Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) + } + + /// Handle virtio-fs unmap file requests. + fn fs_slave_unmap(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> { + Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) + } + + /// Handle virtio-fs sync file requests. + fn fs_slave_sync(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> { + Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) + } + + /// Handle virtio-fs file IO requests. + fn fs_slave_io(&mut self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> { + // Safe because we have just received the rawfd from kernel. + unsafe { libc::close(fd) }; + Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) + } + + // fn handle_iotlb_msg(&mut self, iotlb: VhostUserIotlb); + // fn handle_vring_host_notifier(&mut self, area: VhostUserVringArea, fd: RawFd); +} + +impl<S: VhostUserMasterReqHandlerMut> VhostUserMasterReqHandler for Mutex<S> { + fn handle_config_change(&self) -> HandlerResult<u64> { + self.lock().unwrap().handle_config_change() + } + + fn fs_slave_map(&self, fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> { + self.lock().unwrap().fs_slave_map(fs, fd) + } + + fn fs_slave_unmap(&self, fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> { + self.lock().unwrap().fs_slave_unmap(fs) + } + + fn fs_slave_sync(&self, fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> { + self.lock().unwrap().fs_slave_sync(fs) + } + + fn fs_slave_io(&self, fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> { + self.lock().unwrap().fs_slave_io(fs, fd) + } +} + +/// Server to handle service requests from slaves from the slave communication channel. +/// +/// The [MasterReqHandler] acts as a server on the master side, to handle service requests from +/// slaves on the slave communication channel. It's actually a proxy invoking the registered +/// handler implementing [VhostUserMasterReqHandler] to do the real work. +/// +/// [MasterReqHandler]: struct.MasterReqHandler.html +/// [VhostUserMasterReqHandler]: trait.VhostUserMasterReqHandler.html +pub struct MasterReqHandler<S: VhostUserMasterReqHandler> { + // underlying Unix domain socket for communication + sub_sock: Endpoint<SlaveReq>, + tx_sock: UnixStream, + // Protocol feature VHOST_USER_PROTOCOL_F_REPLY_ACK has been negotiated. + reply_ack_negotiated: bool, + // the VirtIO backend device object + backend: Arc<S>, + // whether the endpoint has encountered any failure + error: Option<i32>, +} + +impl<S: VhostUserMasterReqHandler> MasterReqHandler<S> { + /// Create a server to handle service requests from slaves on the slave communication channel. + /// + /// This opens a pair of connected anonymous sockets to form the slave communication channel. + /// The socket fd returned by [Self::get_tx_raw_fd()] should be sent to the slave by + /// [VhostUserMaster::set_slave_request_fd()]. + /// + /// [Self::get_tx_raw_fd()]: struct.MasterReqHandler.html#method.get_tx_raw_fd + /// [VhostUserMaster::set_slave_request_fd()]: trait.VhostUserMaster.html#tymethod.set_slave_request_fd + pub fn new(backend: Arc<S>) -> Result<Self> { + let (tx, rx) = UnixStream::pair().map_err(Error::SocketError)?; + + Ok(MasterReqHandler { + sub_sock: Endpoint::<SlaveReq>::from_stream(rx), + tx_sock: tx, + reply_ack_negotiated: false, + backend, + error: None, + }) + } + + /// Get the socket fd for the slave to communication with the master. + /// + /// The returned fd should be sent to the slave by [VhostUserMaster::set_slave_request_fd()]. + /// + /// [VhostUserMaster::set_slave_request_fd()]: trait.VhostUserMaster.html#tymethod.set_slave_request_fd + pub fn get_tx_raw_fd(&self) -> RawFd { + self.tx_sock.as_raw_fd() + } + + /// Set the negotiation state of the `VHOST_USER_PROTOCOL_F_REPLY_ACK` protocol feature. + /// + /// When the `VHOST_USER_PROTOCOL_F_REPLY_ACK` protocol feature has been negotiated, + /// the "REPLY_ACK" flag will be set in the message header for every slave to master request + /// message. + pub fn set_reply_ack_flag(&mut self, enable: bool) { + self.reply_ack_negotiated = enable; + } + + /// Mark endpoint as failed or in normal state. + pub fn set_failed(&mut self, error: i32) { + if error == 0 { + self.error = None; + } else { + self.error = Some(error); + } + } + + /// Main entrance to server slave request from the slave communication channel. + /// + /// The caller needs to: + /// - serialize calls to this function + /// - decide what to do when errer happens + /// - optional recover from failure + pub fn handle_request(&mut self) -> Result<u64> { + // Return error if the endpoint is already in failed state. + self.check_state()?; + + // The underlying communication channel is a Unix domain socket in + // stream mode, and recvmsg() is a little tricky here. To successfully + // receive attached file descriptors, we need to receive messages and + // corresponding attached file descriptors in this way: + // . recv messsage header and optional attached file + // . validate message header + // . recv optional message body and payload according size field in + // message header + // . validate message body and optional payload + let (hdr, rfds) = self.sub_sock.recv_header()?; + let rfds = self.check_attached_rfds(&hdr, rfds)?; + let (size, buf) = match hdr.get_size() { + 0 => (0, vec![0u8; 0]), + len => { + if len as usize > MAX_MSG_SIZE { + return Err(Error::InvalidMessage); + } + let (size2, rbuf) = self.sub_sock.recv_data(len as usize)?; + if size2 != len as usize { + return Err(Error::InvalidMessage); + } + (size2, rbuf) + } + }; + + let res = match hdr.get_code() { + SlaveReq::CONFIG_CHANGE_MSG => { + self.check_msg_size(&hdr, size, 0)?; + self.backend + .handle_config_change() + .map_err(Error::ReqHandlerError) + } + SlaveReq::FS_MAP => { + let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?; + // check_attached_rfds() has validated rfds + self.backend + .fs_slave_map(&msg, rfds.unwrap()[0]) + .map_err(Error::ReqHandlerError) + } + SlaveReq::FS_UNMAP => { + let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?; + self.backend + .fs_slave_unmap(&msg) + .map_err(Error::ReqHandlerError) + } + SlaveReq::FS_SYNC => { + let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?; + self.backend + .fs_slave_sync(&msg) + .map_err(Error::ReqHandlerError) + } + SlaveReq::FS_IO => { + let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?; + // check_attached_rfds() has validated rfds + self.backend + .fs_slave_io(&msg, rfds.unwrap()[0]) + .map_err(Error::ReqHandlerError) + } + _ => Err(Error::InvalidMessage), + }; + + self.send_ack_message(&hdr, &res)?; + + res + } + + fn check_state(&self) -> Result<()> { + match self.error { + Some(e) => Err(Error::SocketBroken(std::io::Error::from_raw_os_error(e))), + None => Ok(()), + } + } + + fn check_msg_size( + &self, + hdr: &VhostUserMsgHeader<SlaveReq>, + size: usize, + expected: usize, + ) -> Result<()> { + if hdr.get_size() as usize != expected + || hdr.is_reply() + || hdr.get_version() != 0x1 + || size != expected + { + return Err(Error::InvalidMessage); + } + Ok(()) + } + + fn check_attached_rfds( + &self, + hdr: &VhostUserMsgHeader<SlaveReq>, + rfds: Option<Vec<RawFd>>, + ) -> Result<Option<Vec<RawFd>>> { + match hdr.get_code() { + SlaveReq::FS_MAP | SlaveReq::FS_IO => { + // Expect an fd set with a single fd. + match rfds { + None => Err(Error::InvalidMessage), + Some(fds) => { + if fds.len() != 1 { + Endpoint::<SlaveReq>::close_rfds(Some(fds)); + Err(Error::InvalidMessage) + } else { + Ok(Some(fds)) + } + } + } + } + _ => { + if rfds.is_some() { + Endpoint::<SlaveReq>::close_rfds(rfds); + Err(Error::InvalidMessage) + } else { + Ok(rfds) + } + } + } + } + + fn extract_msg_body<T: Sized + VhostUserMsgValidator>( + &self, + hdr: &VhostUserMsgHeader<SlaveReq>, + size: usize, + buf: &[u8], + ) -> Result<T> { + self.check_msg_size(hdr, size, mem::size_of::<T>())?; + let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const T) }; + if !msg.is_valid() { + return Err(Error::InvalidMessage); + } + Ok(msg) + } + + fn new_reply_header<T: Sized>( + &self, + req: &VhostUserMsgHeader<SlaveReq>, + ) -> Result<VhostUserMsgHeader<SlaveReq>> { + if mem::size_of::<T>() > MAX_MSG_SIZE { + return Err(Error::InvalidParam); + } + self.check_state()?; + Ok(VhostUserMsgHeader::new( + req.get_code(), + VhostUserHeaderFlag::REPLY.bits(), + mem::size_of::<T>() as u32, + )) + } + + fn send_ack_message( + &mut self, + req: &VhostUserMsgHeader<SlaveReq>, + res: &Result<u64>, + ) -> Result<()> { + if self.reply_ack_negotiated && req.is_need_reply() { + let hdr = self.new_reply_header::<VhostUserU64>(req)?; + let def_err = libc::EINVAL; + let val = match res { + Ok(n) => *n, + Err(e) => match &*e { + Error::ReqHandlerError(ioerr) => match ioerr.raw_os_error() { + Some(rawerr) => -rawerr as u64, + None => -def_err as u64, + }, + _ => -def_err as u64, + }, + }; + let msg = VhostUserU64::new(val); + self.sub_sock.send_message(&hdr, &msg, None)?; + } + Ok(()) + } +} + +impl<S: VhostUserMasterReqHandler> AsRawFd for MasterReqHandler<S> { + fn as_raw_fd(&self) -> RawFd { + self.sub_sock.as_raw_fd() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[cfg(feature = "vhost-user-slave")] + use crate::vhost_user::SlaveFsCacheReq; + #[cfg(feature = "vhost-user-slave")] + use std::os::unix::io::FromRawFd; + + struct MockMasterReqHandler {} + + impl VhostUserMasterReqHandlerMut for MockMasterReqHandler { + /// Handle virtio-fs map file requests from the slave. + fn fs_slave_map(&mut self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> { + // Safe because we have just received the rawfd from kernel. + unsafe { libc::close(fd) }; + Ok(0) + } + + /// Handle virtio-fs unmap file requests from the slave. + fn fs_slave_unmap(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> { + Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) + } + } + + #[test] + fn test_new_master_req_handler() { + let backend = Arc::new(Mutex::new(MockMasterReqHandler {})); + let mut handler = MasterReqHandler::new(backend).unwrap(); + + assert!(handler.get_tx_raw_fd() >= 0); + assert!(handler.as_raw_fd() >= 0); + handler.check_state().unwrap(); + + assert_eq!(handler.error, None); + handler.set_failed(libc::EAGAIN); + assert_eq!(handler.error, Some(libc::EAGAIN)); + handler.check_state().unwrap_err(); + } + + #[cfg(feature = "vhost-user-slave")] + #[test] + fn test_master_slave_req_handler() { + let backend = Arc::new(Mutex::new(MockMasterReqHandler {})); + let mut handler = MasterReqHandler::new(backend).unwrap(); + + let fd = unsafe { libc::dup(handler.get_tx_raw_fd()) }; + if fd < 0 { + panic!("failed to duplicated tx fd!"); + } + let stream = unsafe { UnixStream::from_raw_fd(fd) }; + let fs_cache = SlaveFsCacheReq::from_stream(stream); + + std::thread::spawn(move || { + let res = handler.handle_request().unwrap(); + assert_eq!(res, 0); + handler.handle_request().unwrap_err(); + }); + + fs_cache + .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd) + .unwrap(); + // When REPLY_ACK has not been negotiated, the master has no way to detect failure from + // slave side. + fs_cache + .fs_slave_unmap(&VhostUserFSSlaveMsg::default()) + .unwrap(); + } + + #[cfg(feature = "vhost-user-slave")] + #[test] + fn test_master_slave_req_handler_with_ack() { + let backend = Arc::new(Mutex::new(MockMasterReqHandler {})); + let mut handler = MasterReqHandler::new(backend).unwrap(); + handler.set_reply_ack_flag(true); + + let fd = unsafe { libc::dup(handler.get_tx_raw_fd()) }; + if fd < 0 { + panic!("failed to duplicated tx fd!"); + } + let stream = unsafe { UnixStream::from_raw_fd(fd) }; + let fs_cache = SlaveFsCacheReq::from_stream(stream); + + std::thread::spawn(move || { + let res = handler.handle_request().unwrap(); + assert_eq!(res, 0); + handler.handle_request().unwrap_err(); + }); + + fs_cache.set_reply_ack_flag(true); + fs_cache + .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd) + .unwrap(); + fs_cache + .fs_slave_unmap(&VhostUserFSSlaveMsg::default()) + .unwrap_err(); + } +} diff --git a/src/vhost_user/message.rs b/src/vhost_user/message.rs new file mode 100644 index 0000000..ea2df4e --- /dev/null +++ b/src/vhost_user/message.rs @@ -0,0 +1,1042 @@ +// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Define communication messages for the vhost-user protocol. +//! +//! For message definition, please refer to the [vhost-user spec](https://github.com/qemu/qemu/blob/f7526eece29cd2e36a63b6703508b24453095eb8/docs/interop/vhost-user.txt). + +#![allow(dead_code)] +#![allow(non_camel_case_types)] + +use std::fmt::Debug; +use std::marker::PhantomData; + +use crate::VringConfigData; + +/// The vhost-user specification uses a field of u32 to store message length. +/// On the other hand, preallocated buffers are needed to receive messages from the Unix domain +/// socket. To preallocating a 4GB buffer for each vhost-user message is really just an overhead. +/// Among all defined vhost-user messages, only the VhostUserConfig and VhostUserMemory has variable +/// message size. For the VhostUserConfig, a maximum size of 4K is enough because the user +/// configuration space for virtio devices is (4K - 0x100) bytes at most. For the VhostUserMemory, +/// 4K should be enough too because it can support 255 memory regions at most. +pub const MAX_MSG_SIZE: usize = 0x1000; + +/// The VhostUserMemory message has variable message size and variable number of attached file +/// descriptors. Each user memory region entry in the message payload occupies 32 bytes, +/// so setting maximum number of attached file descriptors based on the maximum message size. +/// But rust only implements Default and AsMut traits for arrays with 0 - 32 entries, so further +/// reduce the maximum number... +// pub const MAX_ATTACHED_FD_ENTRIES: usize = (MAX_MSG_SIZE - 8) / 32; +pub const MAX_ATTACHED_FD_ENTRIES: usize = 32; + +/// Starting position (inclusion) of the device configuration space in virtio devices. +pub const VHOST_USER_CONFIG_OFFSET: u32 = 0x100; + +/// Ending position (exclusion) of the device configuration space in virtio devices. +pub const VHOST_USER_CONFIG_SIZE: u32 = 0x1000; + +/// Maximum number of vrings supported. +pub const VHOST_USER_MAX_VRINGS: u64 = 0x8000u64; + +pub(super) trait Req: + Clone + Copy + Debug + PartialEq + Eq + PartialOrd + Ord + Into<u32> +{ + fn is_valid(&self) -> bool; +} + +/// Type of requests sending from masters to slaves. +#[repr(u32)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] +pub enum MasterReq { + /// Null operation. + NOOP = 0, + /// Get from the underlying vhost implementation the features bit mask. + GET_FEATURES = 1, + /// Enable features in the underlying vhost implementation using a bit mask. + SET_FEATURES = 2, + /// Set the current Master as an owner of the session. + SET_OWNER = 3, + /// No longer used. + RESET_OWNER = 4, + /// Set the memory map regions on the slave so it can translate the vring addresses. + SET_MEM_TABLE = 5, + /// Set logging shared memory space. + SET_LOG_BASE = 6, + /// Set the logging file descriptor, which is passed as ancillary data. + SET_LOG_FD = 7, + /// Set the size of the queue. + SET_VRING_NUM = 8, + /// Set the addresses of the different aspects of the vring. + SET_VRING_ADDR = 9, + /// Set the base offset in the available vring. + SET_VRING_BASE = 10, + /// Get the available vring base offset. + GET_VRING_BASE = 11, + /// Set the event file descriptor for adding buffers to the vring. + SET_VRING_KICK = 12, + /// Set the event file descriptor to signal when buffers are used. + SET_VRING_CALL = 13, + /// Set the event file descriptor to signal when error occurs. + SET_VRING_ERR = 14, + /// Get the protocol feature bit mask from the underlying vhost implementation. + GET_PROTOCOL_FEATURES = 15, + /// Enable protocol features in the underlying vhost implementation. + SET_PROTOCOL_FEATURES = 16, + /// Query how many queues the backend supports. + GET_QUEUE_NUM = 17, + /// Signal slave to enable or disable corresponding vring. + SET_VRING_ENABLE = 18, + /// Ask vhost user backend to broadcast a fake RARP to notify the migration is terminated + /// for guest that does not support GUEST_ANNOUNCE. + SEND_RARP = 19, + /// Set host MTU value exposed to the guest. + NET_SET_MTU = 20, + /// Set the socket file descriptor for slave initiated requests. + SET_SLAVE_REQ_FD = 21, + /// Send IOTLB messages with struct vhost_iotlb_msg as payload. + IOTLB_MSG = 22, + /// Set the endianness of a VQ for legacy devices. + SET_VRING_ENDIAN = 23, + /// Fetch the contents of the virtio device configuration space. + GET_CONFIG = 24, + /// Change the contents of the virtio device configuration space. + SET_CONFIG = 25, + /// Create a session for crypto operation. + CREATE_CRYPTO_SESSION = 26, + /// Close a session for crypto operation. + CLOSE_CRYPTO_SESSION = 27, + /// Advise slave that a migration with postcopy enabled is underway. + POSTCOPY_ADVISE = 28, + /// Advise slave that a transition to postcopy mode has happened. + POSTCOPY_LISTEN = 29, + /// Advise that postcopy migration has now completed. + POSTCOPY_END = 30, + /// Get a shared buffer from slave. + GET_INFLIGHT_FD = 31, + /// Send the shared inflight buffer back to slave. + SET_INFLIGHT_FD = 32, + /// Sets the GPU protocol socket file descriptor. + GPU_SET_SOCKET = 33, + /// Ask the vhost user backend to disable all rings and reset all internal + /// device state to the initial state. + RESET_DEVICE = 34, + /// Indicate that a buffer was added to the vring instead of signalling it + /// using the vring’s kick file descriptor. + VRING_KICK = 35, + /// Return a u64 payload containing the maximum number of memory slots. + GET_MAX_MEM_SLOTS = 36, + /// Update the memory tables by adding the region described. + ADD_MEM_REG = 37, + /// Update the memory tables by removing the region described. + REM_MEM_REG = 38, + /// Notify the backend with updated device status as defined in the VIRTIO + /// specification. + SET_STATUS = 39, + /// Query the backend for its device status as defined in the VIRTIO + /// specification. + GET_STATUS = 40, + /// Upper bound of valid commands. + MAX_CMD = 41, +} + +impl Into<u32> for MasterReq { + fn into(self) -> u32 { + self as u32 + } +} + +impl Req for MasterReq { + fn is_valid(&self) -> bool { + (*self > MasterReq::NOOP) && (*self < MasterReq::MAX_CMD) + } +} + +/// Type of requests sending from slaves to masters. +#[repr(u32)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] +pub enum SlaveReq { + /// Null operation. + NOOP = 0, + /// Send IOTLB messages with struct vhost_iotlb_msg as payload. + IOTLB_MSG = 1, + /// Notify that the virtio device's configuration space has changed. + CONFIG_CHANGE_MSG = 2, + /// Set host notifier for a specified queue. + VRING_HOST_NOTIFIER_MSG = 3, + /// Indicate that a buffer was used from the vring. + VRING_CALL = 4, + /// Indicate that an error occurred on the specific vring. + VRING_ERR = 5, + /// Virtio-fs draft: map file content into the window. + FS_MAP = 6, + /// Virtio-fs draft: unmap file content from the window. + FS_UNMAP = 7, + /// Virtio-fs draft: sync file content. + FS_SYNC = 8, + /// Virtio-fs draft: perform a read/write from an fd directly to GPA. + FS_IO = 9, + /// Upper bound of valid commands. + MAX_CMD = 10, +} + +impl Into<u32> for SlaveReq { + fn into(self) -> u32 { + self as u32 + } +} + +impl Req for SlaveReq { + fn is_valid(&self) -> bool { + (*self > SlaveReq::NOOP) && (*self < SlaveReq::MAX_CMD) + } +} + +/// Vhost message Validator. +pub trait VhostUserMsgValidator { + /// Validate message syntax only. + /// It doesn't validate message semantics such as protocol version number and dependency + /// on feature flags etc. + fn is_valid(&self) -> bool { + true + } +} + +// Bit mask for common message flags. +bitflags! { + /// Common message flags for vhost-user requests and replies. + pub struct VhostUserHeaderFlag: u32 { + /// Bits[0..2] is message version number. + const VERSION = 0x3; + /// Mark message as reply. + const REPLY = 0x4; + /// Sender anticipates a reply message from the peer. + const NEED_REPLY = 0x8; + /// All valid bits. + const ALL_FLAGS = 0xc; + /// All reserved bits. + const RESERVED_BITS = !0xf; + } +} + +/// Common message header for vhost-user requests and replies. +/// A vhost-user message consists of 3 header fields and an optional payload. All numbers are in the +/// machine native byte order. +#[allow(safe_packed_borrows)] +#[repr(packed)] +#[derive(Debug, Clone, Copy, PartialEq)] +pub(super) struct VhostUserMsgHeader<R: Req> { + request: u32, + flags: u32, + size: u32, + _r: PhantomData<R>, +} + +impl<R: Req> VhostUserMsgHeader<R> { + /// Create a new instance of `VhostUserMsgHeader`. + pub fn new(request: R, flags: u32, size: u32) -> Self { + // Default to protocol version 1 + let fl = (flags & VhostUserHeaderFlag::ALL_FLAGS.bits()) | 0x1; + VhostUserMsgHeader { + request: request.into(), + flags: fl, + size, + _r: PhantomData, + } + } + + /// Get message type. + pub fn get_code(&self) -> R { + // It's safe because R is marked as repr(u32). + unsafe { std::mem::transmute_copy::<u32, R>(&self.request) } + } + + /// Set message type. + pub fn set_code(&mut self, request: R) { + self.request = request.into(); + } + + /// Get message version number. + pub fn get_version(&self) -> u32 { + self.flags & 0x3 + } + + /// Set message version number. + pub fn set_version(&mut self, ver: u32) { + self.flags &= !0x3; + self.flags |= ver & 0x3; + } + + /// Check whether it's a reply message. + pub fn is_reply(&self) -> bool { + (self.flags & VhostUserHeaderFlag::REPLY.bits()) != 0 + } + + /// Mark message as reply. + pub fn set_reply(&mut self, is_reply: bool) { + if is_reply { + self.flags |= VhostUserHeaderFlag::REPLY.bits(); + } else { + self.flags &= !VhostUserHeaderFlag::REPLY.bits(); + } + } + + /// Check whether reply for this message is requested. + pub fn is_need_reply(&self) -> bool { + (self.flags & VhostUserHeaderFlag::NEED_REPLY.bits()) != 0 + } + + /// Mark that reply for this message is needed. + pub fn set_need_reply(&mut self, need_reply: bool) { + if need_reply { + self.flags |= VhostUserHeaderFlag::NEED_REPLY.bits(); + } else { + self.flags &= !VhostUserHeaderFlag::NEED_REPLY.bits(); + } + } + + /// Check whether it's the reply message for the request `req`. + pub fn is_reply_for(&self, req: &VhostUserMsgHeader<R>) -> bool { + self.is_reply() && !req.is_reply() && self.get_code() == req.get_code() + } + + /// Get message size. + pub fn get_size(&self) -> u32 { + self.size + } + + /// Set message size. + pub fn set_size(&mut self, size: u32) { + self.size = size; + } +} + +impl<R: Req> Default for VhostUserMsgHeader<R> { + fn default() -> Self { + VhostUserMsgHeader { + request: 0, + flags: 0x1, + size: 0, + _r: PhantomData, + } + } +} + +impl<T: Req> VhostUserMsgValidator for VhostUserMsgHeader<T> { + #[allow(clippy::if_same_then_else)] + fn is_valid(&self) -> bool { + if !self.get_code().is_valid() { + return false; + } else if self.size as usize > MAX_MSG_SIZE { + return false; + } else if self.get_version() != 0x1 { + return false; + } else if (self.flags & VhostUserHeaderFlag::RESERVED_BITS.bits()) != 0 { + return false; + } + true + } +} + +// Bit mask for transport specific flags in VirtIO feature set defined by vhost-user. +bitflags! { + /// Transport specific flags in VirtIO feature set defined by vhost-user. + pub struct VhostUserVirtioFeatures: u64 { + /// Feature flag for the protocol feature. + const PROTOCOL_FEATURES = 0x4000_0000; + } +} + +// Bit mask for vhost-user protocol feature flags. +bitflags! { + /// Vhost-user protocol feature flags. + pub struct VhostUserProtocolFeatures: u64 { + /// Support multiple queues. + const MQ = 0x0000_0001; + /// Support logging through shared memory fd. + const LOG_SHMFD = 0x0000_0002; + /// Support broadcasting fake RARP packet. + const RARP = 0x0000_0004; + /// Support sending reply messages for requests with NEED_REPLY flag set. + const REPLY_ACK = 0x0000_0008; + /// Support setting MTU for virtio-net devices. + const MTU = 0x0000_0010; + /// Allow the slave to send requests to the master by an optional communication channel. + const SLAVE_REQ = 0x0000_0020; + /// Support setting slave endian by SET_VRING_ENDIAN. + const CROSS_ENDIAN = 0x0000_0040; + /// Support crypto operations. + const CRYPTO_SESSION = 0x0000_0080; + /// Support sending userfault_fd from slaves to masters. + const PAGEFAULT = 0x0000_0100; + /// Support Virtio device configuration. + const CONFIG = 0x0000_0200; + /// Allow the slave to send fds (at most 8 descriptors in each message) to the master. + const SLAVE_SEND_FD = 0x0000_0400; + /// Allow the slave to register a host notifier. + const HOST_NOTIFIER = 0x0000_0800; + /// Support inflight shmfd. + const INFLIGHT_SHMFD = 0x0000_1000; + /// Support resetting the device. + const RESET_DEVICE = 0x0000_2000; + /// Support inband notifications. + const INBAND_NOTIFICATIONS = 0x0000_4000; + /// Support configuring memory slots. + const CONFIGURE_MEM_SLOTS = 0x0000_8000; + /// Support reporting status. + const STATUS = 0x0001_0000; + } +} + +/// A generic message to encapsulate a 64-bit value. +#[repr(packed)] +#[derive(Default)] +pub struct VhostUserU64 { + /// The encapsulated 64-bit common value. + pub value: u64, +} + +impl VhostUserU64 { + /// Create a new instance. + pub fn new(value: u64) -> Self { + VhostUserU64 { value } + } +} + +impl VhostUserMsgValidator for VhostUserU64 {} + +/// Memory region descriptor for the SET_MEM_TABLE request. +#[repr(packed)] +#[derive(Default)] +pub struct VhostUserMemory { + /// Number of memory regions in the payload. + pub num_regions: u32, + /// Padding for alignment. + pub padding1: u32, +} + +impl VhostUserMemory { + /// Create a new instance. + pub fn new(cnt: u32) -> Self { + VhostUserMemory { + num_regions: cnt, + padding1: 0, + } + } +} + +impl VhostUserMsgValidator for VhostUserMemory { + #[allow(clippy::if_same_then_else)] + fn is_valid(&self) -> bool { + if self.padding1 != 0 { + return false; + } else if self.num_regions == 0 || self.num_regions > MAX_ATTACHED_FD_ENTRIES as u32 { + return false; + } + true + } +} + +/// Memory region descriptors as payload for the SET_MEM_TABLE request. +#[repr(packed)] +#[derive(Default, Clone, Copy)] +pub struct VhostUserMemoryRegion { + /// Guest physical address of the memory region. + pub guest_phys_addr: u64, + /// Size of the memory region. + pub memory_size: u64, + /// Virtual address in the current process. + pub user_addr: u64, + /// Offset where region starts in the mapped memory. + pub mmap_offset: u64, +} + +impl VhostUserMemoryRegion { + /// Create a new instance. + pub fn new(guest_phys_addr: u64, memory_size: u64, user_addr: u64, mmap_offset: u64) -> Self { + VhostUserMemoryRegion { + guest_phys_addr, + memory_size, + user_addr, + mmap_offset, + } + } +} + +impl VhostUserMsgValidator for VhostUserMemoryRegion { + fn is_valid(&self) -> bool { + if self.memory_size == 0 + || self.guest_phys_addr.checked_add(self.memory_size).is_none() + || self.user_addr.checked_add(self.memory_size).is_none() + || self.mmap_offset.checked_add(self.memory_size).is_none() + { + return false; + } + true + } +} + +/// Payload of the VhostUserMemory message. +pub type VhostUserMemoryPayload = Vec<VhostUserMemoryRegion>; + +/// Single memory region descriptor as payload for ADD_MEM_REG and REM_MEM_REG +/// requests. +#[repr(C)] +#[derive(Default, Clone, Copy)] +pub struct VhostUserSingleMemoryRegion { + /// Padding for correct alignment + padding: u64, + /// Guest physical address of the memory region. + pub guest_phys_addr: u64, + /// Size of the memory region. + pub memory_size: u64, + /// Virtual address in the current process. + pub user_addr: u64, + /// Offset where region starts in the mapped memory. + pub mmap_offset: u64, +} + +impl VhostUserSingleMemoryRegion { + /// Create a new instance. + pub fn new(guest_phys_addr: u64, memory_size: u64, user_addr: u64, mmap_offset: u64) -> Self { + VhostUserSingleMemoryRegion { + padding: 0, + guest_phys_addr, + memory_size, + user_addr, + mmap_offset, + } + } +} + +impl VhostUserMsgValidator for VhostUserSingleMemoryRegion { + fn is_valid(&self) -> bool { + if self.memory_size == 0 + || self.guest_phys_addr.checked_add(self.memory_size).is_none() + || self.user_addr.checked_add(self.memory_size).is_none() + || self.mmap_offset.checked_add(self.memory_size).is_none() + { + return false; + } + true + } +} + +/// Vring state descriptor. +#[repr(packed)] +#[derive(Default)] +pub struct VhostUserVringState { + /// Vring index. + pub index: u32, + /// A common 32bit value to encapsulate vring state etc. + pub num: u32, +} + +impl VhostUserVringState { + /// Create a new instance. + pub fn new(index: u32, num: u32) -> Self { + VhostUserVringState { index, num } + } +} + +impl VhostUserMsgValidator for VhostUserVringState {} + +// Bit mask for vring address flags. +bitflags! { + /// Flags for vring address. + pub struct VhostUserVringAddrFlags: u32 { + /// Support log of vring operations. + /// Modifications to "used" vring should be logged. + const VHOST_VRING_F_LOG = 0x1; + } +} + +/// Vring address descriptor. +#[repr(packed)] +#[derive(Default)] +pub struct VhostUserVringAddr { + /// Vring index. + pub index: u32, + /// Vring flags defined by VhostUserVringAddrFlags. + pub flags: u32, + /// Ring address of the vring descriptor table. + pub descriptor: u64, + /// Ring address of the vring used ring. + pub used: u64, + /// Ring address of the vring available ring. + pub available: u64, + /// Guest address for logging. + pub log: u64, +} + +impl VhostUserVringAddr { + /// Create a new instance. + pub fn new( + index: u32, + flags: VhostUserVringAddrFlags, + descriptor: u64, + used: u64, + available: u64, + log: u64, + ) -> Self { + VhostUserVringAddr { + index, + flags: flags.bits(), + descriptor, + used, + available, + log, + } + } + + /// Create a new instance from `VringConfigData`. + #[cfg_attr(feature = "cargo-clippy", allow(clippy::identity_conversion))] + pub fn from_config_data(index: u32, config_data: &VringConfigData) -> Self { + let log_addr = config_data.log_addr.unwrap_or(0); + VhostUserVringAddr { + index, + flags: config_data.flags, + descriptor: config_data.desc_table_addr, + used: config_data.used_ring_addr, + available: config_data.avail_ring_addr, + log: log_addr, + } + } +} + +impl VhostUserMsgValidator for VhostUserVringAddr { + #[allow(clippy::if_same_then_else)] + fn is_valid(&self) -> bool { + if (self.flags & !VhostUserVringAddrFlags::all().bits()) != 0 { + return false; + } else if self.descriptor & 0xf != 0 { + return false; + } else if self.available & 0x1 != 0 { + return false; + } else if self.used & 0x3 != 0 { + return false; + } + true + } +} + +// Bit mask for the vhost-user device configuration message. +bitflags! { + /// Flags for the device configuration message. + pub struct VhostUserConfigFlags: u32 { + /// Vhost master messages used for writeable fields. + const WRITABLE = 0x1; + /// Vhost master messages used for live migration. + const LIVE_MIGRATION = 0x2; + } +} + +/// Message to read/write device configuration space. +#[repr(packed)] +#[derive(Default)] +pub struct VhostUserConfig { + /// Offset of virtio device's configuration space. + pub offset: u32, + /// Configuration space access size in bytes. + pub size: u32, + /// Flags for the device configuration operation. + pub flags: u32, +} + +impl VhostUserConfig { + /// Create a new instance. + pub fn new(offset: u32, size: u32, flags: VhostUserConfigFlags) -> Self { + VhostUserConfig { + offset, + size, + flags: flags.bits(), + } + } +} + +impl VhostUserMsgValidator for VhostUserConfig { + #[allow(clippy::if_same_then_else)] + fn is_valid(&self) -> bool { + if (self.flags & !VhostUserConfigFlags::all().bits()) != 0 { + return false; + } else if self.offset < 0x100 { + return false; + } else if self.size == 0 + || self.size > VHOST_USER_CONFIG_SIZE + || self.size + self.offset > VHOST_USER_CONFIG_SIZE + { + return false; + } + true + } +} + +/// Payload for the VhostUserConfig message. +pub type VhostUserConfigPayload = Vec<u8>; + +/* + * TODO: support dirty log, live migration and IOTLB operations. +#[repr(packed)] +pub struct VhostUserVringArea { + pub index: u32, + pub flags: u32, + pub size: u64, + pub offset: u64, +} + +#[repr(packed)] +pub struct VhostUserLog { + pub size: u64, + pub offset: u64, +} + +#[repr(packed)] +pub struct VhostUserIotlb { + pub iova: u64, + pub size: u64, + pub user_addr: u64, + pub permission: u8, + pub optype: u8, +} +*/ + +// Bit mask for flags in virtio-fs slave messages +bitflags! { + #[derive(Default)] + /// Flags for virtio-fs slave messages. + pub struct VhostUserFSSlaveMsgFlags: u64 { + /// Empty permission. + const EMPTY = 0x0; + /// Read permission. + const MAP_R = 0x1; + /// Write permission. + const MAP_W = 0x2; + } +} + +/// Max entries in one virtio-fs slave request. +pub const VHOST_USER_FS_SLAVE_ENTRIES: usize = 8; + +/// Slave request message to update the MMIO window. +#[repr(packed)] +#[derive(Default)] +pub struct VhostUserFSSlaveMsg { + /// File offset. + pub fd_offset: [u64; VHOST_USER_FS_SLAVE_ENTRIES], + /// Offset into the DAX window. + pub cache_offset: [u64; VHOST_USER_FS_SLAVE_ENTRIES], + /// Size of region to map. + pub len: [u64; VHOST_USER_FS_SLAVE_ENTRIES], + /// Flags for the mmap operation + pub flags: [VhostUserFSSlaveMsgFlags; VHOST_USER_FS_SLAVE_ENTRIES], +} + +impl VhostUserMsgValidator for VhostUserFSSlaveMsg { + fn is_valid(&self) -> bool { + for i in 0..VHOST_USER_FS_SLAVE_ENTRIES { + if ({ self.flags[i] }.bits() & !VhostUserFSSlaveMsgFlags::all().bits()) != 0 + || self.fd_offset[i].checked_add(self.len[i]).is_none() + || self.cache_offset[i].checked_add(self.len[i]).is_none() + { + return false; + } + } + true + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::mem; + + #[test] + fn check_master_request_code() { + let code = MasterReq::NOOP; + assert!(!code.is_valid()); + let code = MasterReq::MAX_CMD; + assert!(!code.is_valid()); + assert!(code > MasterReq::NOOP); + let code = MasterReq::GET_FEATURES; + assert!(code.is_valid()); + assert_eq!(code, code.clone()); + let code: MasterReq = unsafe { std::mem::transmute::<u32, MasterReq>(10000u32) }; + assert!(!code.is_valid()); + } + + #[test] + fn check_slave_request_code() { + let code = SlaveReq::NOOP; + assert!(!code.is_valid()); + let code = SlaveReq::MAX_CMD; + assert!(!code.is_valid()); + assert!(code > SlaveReq::NOOP); + let code = SlaveReq::CONFIG_CHANGE_MSG; + assert!(code.is_valid()); + assert_eq!(code, code.clone()); + let code: SlaveReq = unsafe { std::mem::transmute::<u32, SlaveReq>(10000u32) }; + assert!(!code.is_valid()); + } + + #[test] + fn msg_header_ops() { + let mut hdr = VhostUserMsgHeader::new(MasterReq::GET_FEATURES, 0, 0x100); + assert_eq!(hdr.get_code(), MasterReq::GET_FEATURES); + hdr.set_code(MasterReq::SET_FEATURES); + assert_eq!(hdr.get_code(), MasterReq::SET_FEATURES); + + assert_eq!(hdr.get_version(), 0x1); + + assert_eq!(hdr.is_reply(), false); + hdr.set_reply(true); + assert_eq!(hdr.is_reply(), true); + hdr.set_reply(false); + + assert_eq!(hdr.is_need_reply(), false); + hdr.set_need_reply(true); + assert_eq!(hdr.is_need_reply(), true); + hdr.set_need_reply(false); + + assert_eq!(hdr.get_size(), 0x100); + hdr.set_size(0x200); + assert_eq!(hdr.get_size(), 0x200); + + assert_eq!(hdr.is_need_reply(), false); + assert_eq!(hdr.is_reply(), false); + assert_eq!(hdr.get_version(), 0x1); + + // Check message length + assert!(hdr.is_valid()); + hdr.set_size(0x2000); + assert!(!hdr.is_valid()); + hdr.set_size(0x100); + assert_eq!(hdr.get_size(), 0x100); + assert!(hdr.is_valid()); + hdr.set_size((MAX_MSG_SIZE - mem::size_of::<VhostUserMsgHeader<MasterReq>>()) as u32); + assert!(hdr.is_valid()); + hdr.set_size(0x0); + assert!(hdr.is_valid()); + + // Check version + hdr.set_version(0x0); + assert!(!hdr.is_valid()); + hdr.set_version(0x2); + assert!(!hdr.is_valid()); + hdr.set_version(0x1); + assert!(hdr.is_valid()); + + assert_eq!(hdr, hdr.clone()); + } + + #[test] + fn test_vhost_user_message_u64() { + let val = VhostUserU64::default(); + let val1 = VhostUserU64::new(0); + + let a = val.value; + let b = val1.value; + assert_eq!(a, b); + let a = VhostUserU64::new(1).value; + assert_eq!(a, 1); + } + + #[test] + fn check_user_memory() { + let mut msg = VhostUserMemory::new(1); + assert!(msg.is_valid()); + msg.num_regions = MAX_ATTACHED_FD_ENTRIES as u32; + assert!(msg.is_valid()); + + msg.num_regions += 1; + assert!(!msg.is_valid()); + msg.num_regions = 0xFFFFFFFF; + assert!(!msg.is_valid()); + msg.num_regions = MAX_ATTACHED_FD_ENTRIES as u32; + msg.padding1 = 1; + assert!(!msg.is_valid()); + } + + #[test] + fn check_user_memory_region() { + let mut msg = VhostUserMemoryRegion { + guest_phys_addr: 0, + memory_size: 0x1000, + user_addr: 0, + mmap_offset: 0, + }; + assert!(msg.is_valid()); + msg.guest_phys_addr = 0xFFFFFFFFFFFFEFFF; + assert!(msg.is_valid()); + msg.guest_phys_addr = 0xFFFFFFFFFFFFF000; + assert!(!msg.is_valid()); + msg.guest_phys_addr = 0xFFFFFFFFFFFF0000; + msg.memory_size = 0; + assert!(!msg.is_valid()); + let a = msg.guest_phys_addr; + let b = msg.guest_phys_addr; + assert_eq!(a, b); + + let msg = VhostUserMemoryRegion::default(); + let a = msg.guest_phys_addr; + assert_eq!(a, 0); + let a = msg.memory_size; + assert_eq!(a, 0); + let a = msg.user_addr; + assert_eq!(a, 0); + let a = msg.mmap_offset; + assert_eq!(a, 0); + } + + #[test] + fn test_vhost_user_state() { + let state = VhostUserVringState::new(5, 8); + + let a = state.index; + assert_eq!(a, 5); + let a = state.num; + assert_eq!(a, 8); + assert_eq!(state.is_valid(), true); + + let state = VhostUserVringState::default(); + let a = state.index; + assert_eq!(a, 0); + let a = state.num; + assert_eq!(a, 0); + assert_eq!(state.is_valid(), true); + } + + #[test] + fn test_vhost_user_addr() { + let mut addr = VhostUserVringAddr::new( + 2, + VhostUserVringAddrFlags::VHOST_VRING_F_LOG, + 0x1000, + 0x2000, + 0x3000, + 0x4000, + ); + + let a = addr.index; + assert_eq!(a, 2); + let a = addr.flags; + assert_eq!(a, VhostUserVringAddrFlags::VHOST_VRING_F_LOG.bits()); + let a = addr.descriptor; + assert_eq!(a, 0x1000); + let a = addr.used; + assert_eq!(a, 0x2000); + let a = addr.available; + assert_eq!(a, 0x3000); + let a = addr.log; + assert_eq!(a, 0x4000); + assert_eq!(addr.is_valid(), true); + + addr.descriptor = 0x1001; + assert_eq!(addr.is_valid(), false); + addr.descriptor = 0x1000; + + addr.available = 0x3001; + assert_eq!(addr.is_valid(), false); + addr.available = 0x3000; + + addr.used = 0x2001; + assert_eq!(addr.is_valid(), false); + addr.used = 0x2000; + assert_eq!(addr.is_valid(), true); + } + + #[test] + fn test_vhost_user_state_from_config() { + let config = VringConfigData { + queue_max_size: 256, + queue_size: 128, + flags: VhostUserVringAddrFlags::VHOST_VRING_F_LOG.bits, + desc_table_addr: 0x1000, + used_ring_addr: 0x2000, + avail_ring_addr: 0x3000, + log_addr: Some(0x4000), + }; + let addr = VhostUserVringAddr::from_config_data(2, &config); + + let a = addr.index; + assert_eq!(a, 2); + let a = addr.flags; + assert_eq!(a, VhostUserVringAddrFlags::VHOST_VRING_F_LOG.bits()); + let a = addr.descriptor; + assert_eq!(a, 0x1000); + let a = addr.used; + assert_eq!(a, 0x2000); + let a = addr.available; + assert_eq!(a, 0x3000); + let a = addr.log; + assert_eq!(a, 0x4000); + assert_eq!(addr.is_valid(), true); + } + + #[test] + fn check_user_vring_addr() { + let mut msg = + VhostUserVringAddr::new(0, VhostUserVringAddrFlags::all(), 0x0, 0x0, 0x0, 0x0); + assert!(msg.is_valid()); + + msg.descriptor = 1; + assert!(!msg.is_valid()); + msg.descriptor = 0; + + msg.available = 1; + assert!(!msg.is_valid()); + msg.available = 0; + + msg.used = 1; + assert!(!msg.is_valid()); + msg.used = 0; + + msg.flags |= 0x80000000; + assert!(!msg.is_valid()); + msg.flags &= !0x80000000; + } + + #[test] + fn check_user_config_msg() { + let mut msg = VhostUserConfig::new( + VHOST_USER_CONFIG_OFFSET, + VHOST_USER_CONFIG_SIZE - VHOST_USER_CONFIG_OFFSET, + VhostUserConfigFlags::WRITABLE, + ); + + assert!(msg.is_valid()); + msg.size = 0; + assert!(!msg.is_valid()); + msg.size = 1; + assert!(msg.is_valid()); + msg.offset = 0; + assert!(!msg.is_valid()); + msg.offset = VHOST_USER_CONFIG_SIZE; + assert!(!msg.is_valid()); + msg.offset = VHOST_USER_CONFIG_SIZE - 1; + assert!(msg.is_valid()); + msg.size = 2; + assert!(!msg.is_valid()); + msg.size = 1; + msg.flags |= VhostUserConfigFlags::LIVE_MIGRATION.bits(); + assert!(msg.is_valid()); + msg.flags |= 0x4; + assert!(!msg.is_valid()); + } + + #[test] + fn test_vhost_user_fs_slave() { + let mut fs_slave = VhostUserFSSlaveMsg::default(); + + assert_eq!(fs_slave.is_valid(), true); + + fs_slave.fd_offset[0] = 0xffff_ffff_ffff_ffff; + fs_slave.len[0] = 0x1; + assert_eq!(fs_slave.is_valid(), false); + + assert_ne!( + VhostUserFSSlaveMsgFlags::MAP_R, + VhostUserFSSlaveMsgFlags::MAP_W + ); + assert_eq!(VhostUserFSSlaveMsgFlags::EMPTY.bits(), 0); + } +} diff --git a/src/vhost_user/mod.rs b/src/vhost_user/mod.rs new file mode 100644 index 0000000..9ef6453 --- /dev/null +++ b/src/vhost_user/mod.rs @@ -0,0 +1,456 @@ +// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! The protocol for vhost-user is based on the existing implementation of vhost for the Linux +//! Kernel. The protocol defines two sides of the communication, master and slave. Master is +//! the application that shares its virtqueues. Slave is the consumer of the virtqueues. +//! +//! The communication channel between the master and the slave includes two sub channels. One is +//! used to send requests from the master to the slave and optional replies from the slave to the +//! master. This sub channel is created on master startup by connecting to the slave service +//! endpoint. The other is used to send requests from the slave to the master and optional replies +//! from the master to the slave. This sub channel is created by the master issuing a +//! VHOST_USER_SET_SLAVE_REQ_FD request to the slave with an auxiliary file descriptor. +//! +//! Unix domain socket is used as the underlying communication channel because the master needs to +//! send file descriptors to the slave. +//! +//! Most messages that can be sent via the Unix domain socket implementing vhost-user have an +//! equivalent ioctl to the kernel implementation. + +use std::io::Error as IOError; + +pub mod message; + +mod connection; +pub use self::connection::Listener; + +#[cfg(feature = "vhost-user-master")] +mod master; +#[cfg(feature = "vhost-user-master")] +pub use self::master::{Master, VhostUserMaster}; +#[cfg(feature = "vhost-user")] +mod master_req_handler; +#[cfg(feature = "vhost-user")] +pub use self::master_req_handler::{ + MasterReqHandler, VhostUserMasterReqHandler, VhostUserMasterReqHandlerMut, +}; + +#[cfg(feature = "vhost-user-slave")] +mod slave; +#[cfg(feature = "vhost-user-slave")] +pub use self::slave::SlaveListener; +#[cfg(feature = "vhost-user-slave")] +mod slave_req_handler; +#[cfg(feature = "vhost-user-slave")] +pub use self::slave_req_handler::{ + SlaveReqHandler, VhostUserSlaveReqHandler, VhostUserSlaveReqHandlerMut, +}; +#[cfg(feature = "vhost-user-slave")] +mod slave_fs_cache; +#[cfg(feature = "vhost-user-slave")] +pub use self::slave_fs_cache::SlaveFsCacheReq; + +/// Errors for vhost-user operations +#[derive(Debug)] +pub enum Error { + /// Invalid parameters. + InvalidParam, + /// Unsupported operations due to that the protocol feature hasn't been negotiated. + InvalidOperation, + /// Invalid message format, flag or content. + InvalidMessage, + /// Only part of a message have been sent or received successfully + PartialMessage, + /// Message is too large + OversizedMsg, + /// Fd array in question is too big or too small + IncorrectFds, + /// Can't connect to peer. + SocketConnect(std::io::Error), + /// Generic socket errors. + SocketError(std::io::Error), + /// The socket is broken or has been closed. + SocketBroken(std::io::Error), + /// Should retry the socket operation again. + SocketRetry(std::io::Error), + /// Failure from the slave side. + SlaveInternalError, + /// Failure from the master side. + MasterInternalError, + /// Virtio/protocol features mismatch. + FeatureMismatch, + /// Error from request handler + ReqHandlerError(IOError), +} + +impl std::fmt::Display for Error { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + Error::InvalidParam => write!(f, "invalid parameters"), + Error::InvalidOperation => write!(f, "invalid operation"), + Error::InvalidMessage => write!(f, "invalid message"), + Error::PartialMessage => write!(f, "partial message"), + Error::OversizedMsg => write!(f, "oversized message"), + Error::IncorrectFds => write!(f, "wrong number of attached fds"), + Error::SocketError(e) => write!(f, "socket error: {}", e), + Error::SocketConnect(e) => write!(f, "can't connect to peer: {}", e), + Error::SocketBroken(e) => write!(f, "socket is broken: {}", e), + Error::SocketRetry(e) => write!(f, "temporary socket error: {}", e), + Error::SlaveInternalError => write!(f, "slave internal error"), + Error::MasterInternalError => write!(f, "Master internal error"), + Error::FeatureMismatch => write!(f, "virtio/protocol features mismatch"), + Error::ReqHandlerError(e) => write!(f, "handler failed to handle request: {}", e), + } + } +} + +impl std::error::Error for Error {} + +impl Error { + /// Determine whether to rebuild the underline communication channel. + pub fn should_reconnect(&self) -> bool { + match *self { + // Should reconnect because it may be caused by temporary network errors. + Error::PartialMessage => true, + // Should reconnect because the underline socket is broken. + Error::SocketBroken(_) => true, + // Slave internal error, hope it recovers on reconnect. + Error::SlaveInternalError => true, + // Master internal error, hope it recovers on reconnect. + Error::MasterInternalError => true, + // Should just retry the IO operation instead of rebuilding the underline connection. + Error::SocketRetry(_) => false, + Error::InvalidParam | Error::InvalidOperation => false, + Error::InvalidMessage | Error::IncorrectFds | Error::OversizedMsg => false, + Error::SocketError(_) | Error::SocketConnect(_) => false, + Error::FeatureMismatch => false, + Error::ReqHandlerError(_) => false, + } + } +} + +impl std::convert::From<sys_util::Error> for Error { + /// Convert raw socket errors into meaningful vhost-user errors. + /// + /// The sys_util::Error is a simple wrapper over the raw errno, which doesn't means + /// much to the vhost-user connection manager. So convert it into meaningful errors to simplify + /// the connection manager logic. + /// + /// # Return: + /// * - Error::SocketRetry: temporary error caused by signals or short of resources. + /// * - Error::SocketBroken: the underline socket is broken. + /// * - Error::SocketError: other socket related errors. + #[allow(unreachable_patterns)] // EWOULDBLOCK equals to EGAIN on linux + fn from(err: sys_util::Error) -> Self { + match err.errno() { + // The socket is marked nonblocking and the requested operation would block. + libc::EAGAIN => Error::SocketRetry(IOError::from_raw_os_error(libc::EAGAIN)), + // The socket is marked nonblocking and the requested operation would block. + libc::EWOULDBLOCK => Error::SocketRetry(IOError::from_raw_os_error(libc::EWOULDBLOCK)), + // A signal occurred before any data was transmitted + libc::EINTR => Error::SocketRetry(IOError::from_raw_os_error(libc::EINTR)), + // The output queue for a network interface was full. This generally indicates + // that the interface has stopped sending, but may be caused by transient congestion. + libc::ENOBUFS => Error::SocketRetry(IOError::from_raw_os_error(libc::ENOBUFS)), + // No memory available. + libc::ENOMEM => Error::SocketRetry(IOError::from_raw_os_error(libc::ENOMEM)), + // Connection reset by peer. + libc::ECONNRESET => Error::SocketBroken(IOError::from_raw_os_error(libc::ECONNRESET)), + // The local end has been shut down on a connection oriented socket. In this case the + // process will also receive a SIGPIPE unless MSG_NOSIGNAL is set. + libc::EPIPE => Error::SocketBroken(IOError::from_raw_os_error(libc::EPIPE)), + // Write permission is denied on the destination socket file, or search permission is + // denied for one of the directories the path prefix. + libc::EACCES => Error::SocketConnect(IOError::from_raw_os_error(libc::EACCES)), + // Catch all other errors + e => Error::SocketError(IOError::from_raw_os_error(e)), + } + } +} + +/// Result of vhost-user operations +pub type Result<T> = std::result::Result<T, Error>; + +/// Result of request handler. +pub type HandlerResult<T> = std::result::Result<T, IOError>; + +#[cfg(all(test, feature = "vhost-user-slave"))] +mod dummy_slave; + +#[cfg(all(test, feature = "vhost-user-master", feature = "vhost-user-slave"))] +mod tests { + use std::os::unix::io::AsRawFd; + use std::path::Path; + use std::sync::{Arc, Barrier, Mutex}; + use std::thread; + + use super::dummy_slave::{DummySlaveReqHandler, VIRTIO_FEATURES}; + use super::message::*; + use super::*; + use crate::backend::VhostBackend; + use crate::{VhostUserMemoryRegionInfo, VringConfigData}; + use tempfile::{tempfile, Builder, TempDir}; + + fn temp_dir() -> TempDir { + Builder::new().prefix("/tmp/vhost_test").tempdir().unwrap() + } + + fn create_slave<P, S>(path: P, backend: Arc<S>) -> (Master, SlaveReqHandler<S>) + where + P: AsRef<Path>, + S: VhostUserSlaveReqHandler, + { + let listener = Listener::new(&path, true).unwrap(); + let mut slave_listener = SlaveListener::new(listener, backend).unwrap(); + let master = Master::connect(&path, 1).unwrap(); + (master, slave_listener.accept().unwrap().unwrap()) + } + + #[test] + fn create_dummy_slave() { + let slave = Arc::new(Mutex::new(DummySlaveReqHandler::new())); + + slave.set_owner().unwrap(); + assert!(slave.set_owner().is_err()); + } + + #[test] + fn test_set_owner() { + let slave_be = Arc::new(Mutex::new(DummySlaveReqHandler::new())); + let dir = temp_dir(); + let mut path = dir.path().to_owned(); + path.push("sock"); + let (master, mut slave) = create_slave(&path, slave_be.clone()); + + assert_eq!(slave_be.lock().unwrap().owned, false); + master.set_owner().unwrap(); + slave.handle_request().unwrap(); + assert_eq!(slave_be.lock().unwrap().owned, true); + master.set_owner().unwrap(); + assert!(slave.handle_request().is_err()); + assert_eq!(slave_be.lock().unwrap().owned, true); + } + + #[test] + fn test_set_features() { + let mbar = Arc::new(Barrier::new(2)); + let sbar = mbar.clone(); + let dir = temp_dir(); + let mut path = dir.path().to_owned(); + path.push("sock"); + let slave_be = Arc::new(Mutex::new(DummySlaveReqHandler::new())); + let (mut master, mut slave) = create_slave(&path, slave_be.clone()); + + thread::spawn(move || { + slave.handle_request().unwrap(); + assert_eq!(slave_be.lock().unwrap().owned, true); + + slave.handle_request().unwrap(); + slave.handle_request().unwrap(); + assert_eq!( + slave_be.lock().unwrap().acked_features, + VIRTIO_FEATURES & !0x1 + ); + + slave.handle_request().unwrap(); + slave.handle_request().unwrap(); + assert_eq!( + slave_be.lock().unwrap().acked_protocol_features, + VhostUserProtocolFeatures::all().bits() + ); + + sbar.wait(); + }); + + master.set_owner().unwrap(); + + // set virtio features + let features = master.get_features().unwrap(); + assert_eq!(features, VIRTIO_FEATURES); + master.set_features(VIRTIO_FEATURES & !0x1).unwrap(); + + // set vhost protocol features + let features = master.get_protocol_features().unwrap(); + assert_eq!(features.bits(), VhostUserProtocolFeatures::all().bits()); + master.set_protocol_features(features).unwrap(); + + mbar.wait(); + } + + #[test] + fn test_master_slave_process() { + let mbar = Arc::new(Barrier::new(2)); + let sbar = mbar.clone(); + let dir = temp_dir(); + let mut path = dir.path().to_owned(); + path.push("sock"); + let slave_be = Arc::new(Mutex::new(DummySlaveReqHandler::new())); + let (mut master, mut slave) = create_slave(&path, slave_be.clone()); + + thread::spawn(move || { + // set_own() + slave.handle_request().unwrap(); + assert_eq!(slave_be.lock().unwrap().owned, true); + + // get/set_features() + slave.handle_request().unwrap(); + slave.handle_request().unwrap(); + assert_eq!( + slave_be.lock().unwrap().acked_features, + VIRTIO_FEATURES & !0x1 + ); + + slave.handle_request().unwrap(); + slave.handle_request().unwrap(); + assert_eq!( + slave_be.lock().unwrap().acked_protocol_features, + VhostUserProtocolFeatures::all().bits() + ); + + // get_queue_num() + slave.handle_request().unwrap(); + + // set_mem_table() + slave.handle_request().unwrap(); + + // get/set_config() + slave.handle_request().unwrap(); + slave.handle_request().unwrap(); + + // set_slave_request_fd + slave.handle_request().unwrap(); + + // set_vring_enable + slave.handle_request().unwrap(); + + // set_log_base,set_log_fd() + slave.handle_request().unwrap_err(); + slave.handle_request().unwrap_err(); + + // set_vring_xxx + slave.handle_request().unwrap(); + slave.handle_request().unwrap(); + slave.handle_request().unwrap(); + slave.handle_request().unwrap(); + slave.handle_request().unwrap(); + slave.handle_request().unwrap(); + + // get_max_mem_slots() + slave.handle_request().unwrap(); + + // add_mem_region() + slave.handle_request().unwrap(); + + // remove_mem_region() + slave.handle_request().unwrap(); + + sbar.wait(); + }); + + master.set_owner().unwrap(); + + // set virtio features + let features = master.get_features().unwrap(); + assert_eq!(features, VIRTIO_FEATURES); + master.set_features(VIRTIO_FEATURES & !0x1).unwrap(); + + // set vhost protocol features + let features = master.get_protocol_features().unwrap(); + assert_eq!(features.bits(), VhostUserProtocolFeatures::all().bits()); + master.set_protocol_features(features).unwrap(); + + let num = master.get_queue_num().unwrap(); + assert_eq!(num, 2); + + let eventfd = sys_util::EventFd::new().unwrap(); + let mem = [VhostUserMemoryRegionInfo { + guest_phys_addr: 0, + memory_size: 0x10_0000, + userspace_addr: 0, + mmap_offset: 0, + mmap_handle: eventfd.as_raw_fd(), + }]; + master.set_mem_table(&mem).unwrap(); + + master + .set_config(0x100, VhostUserConfigFlags::WRITABLE, &[0xa5u8]) + .unwrap(); + let buf = [0x0u8; 4]; + let (reply_body, reply_payload) = master + .get_config(0x100, 4, VhostUserConfigFlags::empty(), &buf) + .unwrap(); + let offset = reply_body.offset; + assert_eq!(offset, 0x100); + assert_eq!(reply_payload[0], 0xa5); + + master.set_slave_request_fd(eventfd.as_raw_fd()).unwrap(); + master.set_vring_enable(0, true).unwrap(); + + // unimplemented yet + master.set_log_base(0, Some(eventfd.as_raw_fd())).unwrap(); + master.set_log_fd(eventfd.as_raw_fd()).unwrap(); + + master.set_vring_num(0, 256).unwrap(); + master.set_vring_base(0, 0).unwrap(); + let config = VringConfigData { + queue_max_size: 256, + queue_size: 128, + flags: VhostUserVringAddrFlags::VHOST_VRING_F_LOG.bits(), + desc_table_addr: 0x1000, + used_ring_addr: 0x2000, + avail_ring_addr: 0x3000, + log_addr: Some(0x4000), + }; + master.set_vring_addr(0, &config).unwrap(); + master.set_vring_call(0, &eventfd).unwrap(); + master.set_vring_kick(0, &eventfd).unwrap(); + master.set_vring_err(0, &eventfd).unwrap(); + + let max_mem_slots = master.get_max_mem_slots().unwrap(); + assert_eq!(max_mem_slots, 32); + + let region_file = tempfile().unwrap(); + let region = VhostUserMemoryRegionInfo { + guest_phys_addr: 0x10_0000, + memory_size: 0x10_0000, + userspace_addr: 0, + mmap_offset: 0, + mmap_handle: region_file.as_raw_fd(), + }; + master.add_mem_region(®ion).unwrap(); + + master.remove_mem_region(®ion).unwrap(); + + mbar.wait(); + } + + #[test] + fn test_error_display() { + assert_eq!(format!("{}", Error::InvalidParam), "invalid parameters"); + assert_eq!(format!("{}", Error::InvalidOperation), "invalid operation"); + } + + #[test] + fn test_should_reconnect() { + assert_eq!(Error::PartialMessage.should_reconnect(), true); + assert_eq!(Error::SlaveInternalError.should_reconnect(), true); + assert_eq!(Error::MasterInternalError.should_reconnect(), true); + assert_eq!(Error::InvalidParam.should_reconnect(), false); + assert_eq!(Error::InvalidOperation.should_reconnect(), false); + assert_eq!(Error::InvalidMessage.should_reconnect(), false); + assert_eq!(Error::IncorrectFds.should_reconnect(), false); + assert_eq!(Error::OversizedMsg.should_reconnect(), false); + assert_eq!(Error::FeatureMismatch.should_reconnect(), false); + } + + #[test] + fn test_error_from_sys_util_error() { + let e: Error = sys_util::Error::new(libc::EAGAIN).into(); + if let Error::SocketRetry(e1) = e { + assert_eq!(e1.raw_os_error().unwrap(), libc::EAGAIN); + } else { + panic!("invalid error code conversion!"); + } + } +} diff --git a/src/vhost_user/slave.rs b/src/vhost_user/slave.rs new file mode 100644 index 0000000..fb65c41 --- /dev/null +++ b/src/vhost_user/slave.rs @@ -0,0 +1,86 @@ +// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Traits and Structs for vhost-user slave. + +use std::sync::Arc; + +use super::connection::{Endpoint, Listener}; +use super::message::*; +use super::{Result, SlaveReqHandler, VhostUserSlaveReqHandler}; + +/// Vhost-user slave side connection listener. +pub struct SlaveListener<S: VhostUserSlaveReqHandler> { + listener: Listener, + backend: Option<Arc<S>>, +} + +/// Sets up a listener for incoming master connections, and handles construction +/// of a Slave on success. +impl<S: VhostUserSlaveReqHandler> SlaveListener<S> { + /// Create a unix domain socket for incoming master connections. + pub fn new(listener: Listener, backend: Arc<S>) -> Result<Self> { + Ok(SlaveListener { + listener, + backend: Some(backend), + }) + } + + /// Accept an incoming connection from the master, returning Some(Slave) on + /// success, or None if the socket is nonblocking and no incoming connection + /// was detected + pub fn accept(&mut self) -> Result<Option<SlaveReqHandler<S>>> { + if let Some(fd) = self.listener.accept()? { + return Ok(Some(SlaveReqHandler::new( + Endpoint::<MasterReq>::from_stream(fd), + self.backend.take().unwrap(), + ))); + } + Ok(None) + } + + /// Change blocking status on the listener. + pub fn set_nonblocking(&self, block: bool) -> Result<()> { + self.listener.set_nonblocking(block) + } +} + +#[cfg(test)] +mod tests { + use std::sync::Mutex; + + use super::*; + use crate::vhost_user::dummy_slave::DummySlaveReqHandler; + + #[test] + fn test_slave_listener_set_nonblocking() { + let backend = Arc::new(Mutex::new(DummySlaveReqHandler::new())); + let listener = + Listener::new("/tmp/vhost_user_lib_unit_test_slave_nonblocking", true).unwrap(); + let slave_listener = SlaveListener::new(listener, backend).unwrap(); + + slave_listener.set_nonblocking(true).unwrap(); + slave_listener.set_nonblocking(false).unwrap(); + slave_listener.set_nonblocking(false).unwrap(); + slave_listener.set_nonblocking(true).unwrap(); + slave_listener.set_nonblocking(true).unwrap(); + } + + #[cfg(feature = "vhost-user-master")] + #[test] + fn test_slave_listener_accept() { + use super::super::Master; + + let path = "/tmp/vhost_user_lib_unit_test_slave_accept"; + let backend = Arc::new(Mutex::new(DummySlaveReqHandler::new())); + let listener = Listener::new(path, true).unwrap(); + let mut slave_listener = SlaveListener::new(listener, backend).unwrap(); + + slave_listener.set_nonblocking(true).unwrap(); + assert!(slave_listener.accept().unwrap().is_none()); + assert!(slave_listener.accept().unwrap().is_none()); + + let _master = Master::connect(path, 1).unwrap(); + let _slave = slave_listener.accept().unwrap().unwrap(); + } +} diff --git a/src/vhost_user/slave_fs_cache.rs b/src/vhost_user/slave_fs_cache.rs new file mode 100644 index 0000000..a9c4ed2 --- /dev/null +++ b/src/vhost_user/slave_fs_cache.rs @@ -0,0 +1,226 @@ +// Copyright (C) 2020 Alibaba Cloud. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use std::io; +use std::mem; +use std::os::unix::io::RawFd; +use std::os::unix::net::UnixStream; +use std::sync::{Arc, Mutex, MutexGuard}; + +use super::connection::Endpoint; +use super::message::*; +use super::{Error, HandlerResult, Result, VhostUserMasterReqHandler}; + +struct SlaveFsCacheReqInternal { + sock: Endpoint<SlaveReq>, + + // Protocol feature VHOST_USER_PROTOCOL_F_REPLY_ACK has been negotiated. + reply_ack_negotiated: bool, + + // whether the endpoint has encountered any failure + error: Option<i32>, +} + +impl SlaveFsCacheReqInternal { + fn check_state(&self) -> Result<u64> { + match self.error { + Some(e) => Err(Error::SocketBroken(std::io::Error::from_raw_os_error(e))), + None => Ok(0), + } + } + + fn send_message( + &mut self, + request: SlaveReq, + fs: &VhostUserFSSlaveMsg, + fds: Option<&[RawFd]>, + ) -> Result<u64> { + self.check_state()?; + + let len = mem::size_of::<VhostUserFSSlaveMsg>(); + let mut hdr = VhostUserMsgHeader::new(request, 0, len as u32); + if self.reply_ack_negotiated { + hdr.set_need_reply(true); + } + self.sock.send_message(&hdr, fs, fds)?; + + self.wait_for_ack(&hdr) + } + + fn wait_for_ack(&mut self, hdr: &VhostUserMsgHeader<SlaveReq>) -> Result<u64> { + self.check_state()?; + if !self.reply_ack_negotiated { + return Ok(0); + } + + let (reply, body, rfds) = self.sock.recv_body::<VhostUserU64>()?; + if !reply.is_reply_for(&hdr) || rfds.is_some() || !body.is_valid() { + Endpoint::<SlaveReq>::close_rfds(rfds); + return Err(Error::InvalidMessage); + } + if body.value != 0 { + return Err(Error::MasterInternalError); + } + + Ok(body.value) + } +} + +/// Request proxy to send vhost-user-fs slave requests to the master through the slave +/// communication channel. +/// +/// The [SlaveFsCacheReq] acts as a message proxy to forward vhost-user-fs slave requests to the +/// master through the vhost-user slave communication channel. The forwarded messages will be +/// handled by the [MasterReqHandler] server. +/// +/// [SlaveFsCacheReq]: struct.SlaveFsCacheReq.html +/// [MasterReqHandler]: struct.MasterReqHandler.html +#[derive(Clone)] +pub struct SlaveFsCacheReq { + // underlying Unix domain socket for communication + node: Arc<Mutex<SlaveFsCacheReqInternal>>, +} + +impl SlaveFsCacheReq { + fn new(ep: Endpoint<SlaveReq>) -> Self { + SlaveFsCacheReq { + node: Arc::new(Mutex::new(SlaveFsCacheReqInternal { + sock: ep, + reply_ack_negotiated: false, + error: None, + })), + } + } + + fn node(&self) -> MutexGuard<SlaveFsCacheReqInternal> { + self.node.lock().unwrap() + } + + fn send_message( + &self, + request: SlaveReq, + fs: &VhostUserFSSlaveMsg, + fds: Option<&[RawFd]>, + ) -> io::Result<u64> { + self.node() + .send_message(request, fs, fds) + .map_err(|e| io::Error::new(io::ErrorKind::Other, format!("{}", e))) + } + + /// Create a new instance from a `UnixStream` object. + pub fn from_stream(sock: UnixStream) -> Self { + Self::new(Endpoint::<SlaveReq>::from_stream(sock)) + } + + /// Set the negotiation state of the `VHOST_USER_PROTOCOL_F_REPLY_ACK` protocol feature. + /// + /// When the `VHOST_USER_PROTOCOL_F_REPLY_ACK` protocol feature has been negotiated, + /// the "REPLY_ACK" flag will be set in the message header for every slave to master request + /// message. + pub fn set_reply_ack_flag(&self, enable: bool) { + self.node().reply_ack_negotiated = enable; + } + + /// Mark endpoint as failed with specified error code. + pub fn set_failed(&self, error: i32) { + self.node().error = Some(error); + } +} + +impl VhostUserMasterReqHandler for SlaveFsCacheReq { + /// Forward vhost-user-fs map file requests to the slave. + fn fs_slave_map(&self, fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> { + self.send_message(SlaveReq::FS_MAP, fs, Some(&[fd])) + } + + /// Forward vhost-user-fs unmap file requests to the master. + fn fs_slave_unmap(&self, fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> { + self.send_message(SlaveReq::FS_UNMAP, fs, None) + } +} + +#[cfg(test)] +mod tests { + use std::os::unix::io::AsRawFd; + + use super::*; + + #[test] + fn test_slave_fs_cache_req_set_failed() { + let (p1, _p2) = UnixStream::pair().unwrap(); + let fs_cache = SlaveFsCacheReq::from_stream(p1); + + assert!(fs_cache.node().error.is_none()); + fs_cache.set_failed(libc::EAGAIN); + assert_eq!(fs_cache.node().error, Some(libc::EAGAIN)); + } + + #[test] + fn test_slave_fs_cache_send_failure() { + let (p1, p2) = UnixStream::pair().unwrap(); + let fd = p2.as_raw_fd(); + let fs_cache = SlaveFsCacheReq::from_stream(p1); + + fs_cache.set_failed(libc::ECONNRESET); + fs_cache + .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd) + .unwrap_err(); + fs_cache + .fs_slave_unmap(&VhostUserFSSlaveMsg::default()) + .unwrap_err(); + fs_cache.node().error = None; + + drop(p2); + fs_cache + .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd) + .unwrap_err(); + fs_cache + .fs_slave_unmap(&VhostUserFSSlaveMsg::default()) + .unwrap_err(); + } + + #[test] + fn test_slave_fs_cache_recv_negative() { + let (p1, p2) = UnixStream::pair().unwrap(); + let fd = p2.as_raw_fd(); + let fs_cache = SlaveFsCacheReq::from_stream(p1); + let mut master = Endpoint::<SlaveReq>::from_stream(p2); + + let len = mem::size_of::<VhostUserFSSlaveMsg>(); + let mut hdr = VhostUserMsgHeader::new( + SlaveReq::FS_MAP, + VhostUserHeaderFlag::REPLY.bits(), + len as u32, + ); + let body = VhostUserU64::new(0); + + master.send_message(&hdr, &body, Some(&[fd])).unwrap(); + fs_cache + .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd) + .unwrap(); + + fs_cache.set_reply_ack_flag(true); + fs_cache + .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd) + .unwrap_err(); + + hdr.set_code(SlaveReq::FS_UNMAP); + master.send_message(&hdr, &body, None).unwrap(); + fs_cache + .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd) + .unwrap_err(); + hdr.set_code(SlaveReq::FS_MAP); + + let body = VhostUserU64::new(1); + master.send_message(&hdr, &body, None).unwrap(); + fs_cache + .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd) + .unwrap_err(); + + let body = VhostUserU64::new(0); + master.send_message(&hdr, &body, None).unwrap(); + fs_cache + .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd) + .unwrap(); + } +} diff --git a/src/vhost_user/slave_req_handler.rs b/src/vhost_user/slave_req_handler.rs new file mode 100644 index 0000000..18459a2 --- /dev/null +++ b/src/vhost_user/slave_req_handler.rs @@ -0,0 +1,828 @@ +// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use std::mem; +use std::os::unix::io::{AsRawFd, FromRawFd, RawFd}; +use std::os::unix::net::UnixStream; +use std::slice; +use std::sync::{Arc, Mutex}; + +use super::connection::Endpoint; +use super::message::*; +use super::slave_fs_cache::SlaveFsCacheReq; +use super::{Error, Result}; + +/// Services provided to the master by the slave with interior mutability. +/// +/// The [VhostUserSlaveReqHandler] trait defines the services provided to the master by the slave. +/// And the [VhostUserSlaveReqHandlerMut] trait is a helper mirroring [VhostUserSlaveReqHandler], +/// but without interior mutability. +/// The vhost-user specification defines a master communication channel, by which masters could +/// request services from slaves. The [VhostUserSlaveReqHandler] trait defines services provided by +/// slaves, and it's used both on the master side and slave side. +/// +/// - on the master side, a stub forwarder implementing [VhostUserSlaveReqHandler] will proxy +/// service requests to slaves. +/// - on the slave side, the [SlaveReqHandler] will forward service requests to a handler +/// implementing [VhostUserSlaveReqHandler]. +/// +/// The [VhostUserSlaveReqHandler] trait is design with interior mutability to improve performance +/// for multi-threading. +/// +/// [VhostUserSlaveReqHandler]: trait.VhostUserSlaveReqHandler.html +/// [VhostUserSlaveReqHandlerMut]: trait.VhostUserSlaveReqHandlerMut.html +/// [SlaveReqHandler]: struct.SlaveReqHandler.html +#[allow(missing_docs)] +pub trait VhostUserSlaveReqHandler { + fn set_owner(&self) -> Result<()>; + fn reset_owner(&self) -> Result<()>; + fn get_features(&self) -> Result<u64>; + fn set_features(&self, features: u64) -> Result<()>; + fn set_mem_table(&self, ctx: &[VhostUserMemoryRegion], fds: &[RawFd]) -> Result<()>; + fn set_vring_num(&self, index: u32, num: u32) -> Result<()>; + fn set_vring_addr( + &self, + index: u32, + flags: VhostUserVringAddrFlags, + descriptor: u64, + used: u64, + available: u64, + log: u64, + ) -> Result<()>; + fn set_vring_base(&self, index: u32, base: u32) -> Result<()>; + fn get_vring_base(&self, index: u32) -> Result<VhostUserVringState>; + fn set_vring_kick(&self, index: u8, fd: Option<RawFd>) -> Result<()>; + fn set_vring_call(&self, index: u8, fd: Option<RawFd>) -> Result<()>; + fn set_vring_err(&self, index: u8, fd: Option<RawFd>) -> Result<()>; + + fn get_protocol_features(&self) -> Result<VhostUserProtocolFeatures>; + fn set_protocol_features(&self, features: u64) -> Result<()>; + fn get_queue_num(&self) -> Result<u64>; + fn set_vring_enable(&self, index: u32, enable: bool) -> Result<()>; + fn get_config(&self, offset: u32, size: u32, flags: VhostUserConfigFlags) -> Result<Vec<u8>>; + fn set_config(&self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>; + fn set_slave_req_fd(&self, _vu_req: SlaveFsCacheReq) {} + fn get_max_mem_slots(&self) -> Result<u64>; + fn add_mem_region(&self, region: &VhostUserSingleMemoryRegion, fd: RawFd) -> Result<()>; + fn remove_mem_region(&self, region: &VhostUserSingleMemoryRegion) -> Result<()>; +} + +/// Services provided to the master by the slave without interior mutability. +/// +/// This is a helper trait mirroring the [VhostUserSlaveReqHandler] trait. +#[allow(missing_docs)] +pub trait VhostUserSlaveReqHandlerMut { + fn set_owner(&mut self) -> Result<()>; + fn reset_owner(&mut self) -> Result<()>; + fn get_features(&mut self) -> Result<u64>; + fn set_features(&mut self, features: u64) -> Result<()>; + fn set_mem_table(&mut self, ctx: &[VhostUserMemoryRegion], fds: &[RawFd]) -> Result<()>; + fn set_vring_num(&mut self, index: u32, num: u32) -> Result<()>; + fn set_vring_addr( + &mut self, + index: u32, + flags: VhostUserVringAddrFlags, + descriptor: u64, + used: u64, + available: u64, + log: u64, + ) -> Result<()>; + fn set_vring_base(&mut self, index: u32, base: u32) -> Result<()>; + fn get_vring_base(&mut self, index: u32) -> Result<VhostUserVringState>; + fn set_vring_kick(&mut self, index: u8, fd: Option<RawFd>) -> Result<()>; + fn set_vring_call(&mut self, index: u8, fd: Option<RawFd>) -> Result<()>; + fn set_vring_err(&mut self, index: u8, fd: Option<RawFd>) -> Result<()>; + + fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures>; + fn set_protocol_features(&mut self, features: u64) -> Result<()>; + fn get_queue_num(&mut self) -> Result<u64>; + fn set_vring_enable(&mut self, index: u32, enable: bool) -> Result<()>; + fn get_config( + &mut self, + offset: u32, + size: u32, + flags: VhostUserConfigFlags, + ) -> Result<Vec<u8>>; + fn set_config(&mut self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>; + fn set_slave_req_fd(&mut self, _vu_req: SlaveFsCacheReq) {} + fn get_max_mem_slots(&mut self) -> Result<u64>; + fn add_mem_region(&mut self, region: &VhostUserSingleMemoryRegion, fd: RawFd) -> Result<()>; + fn remove_mem_region(&mut self, region: &VhostUserSingleMemoryRegion) -> Result<()>; +} + +impl<T: VhostUserSlaveReqHandlerMut> VhostUserSlaveReqHandler for Mutex<T> { + fn set_owner(&self) -> Result<()> { + self.lock().unwrap().set_owner() + } + + fn reset_owner(&self) -> Result<()> { + self.lock().unwrap().reset_owner() + } + + fn get_features(&self) -> Result<u64> { + self.lock().unwrap().get_features() + } + + fn set_features(&self, features: u64) -> Result<()> { + self.lock().unwrap().set_features(features) + } + + fn set_mem_table(&self, ctx: &[VhostUserMemoryRegion], fds: &[RawFd]) -> Result<()> { + self.lock().unwrap().set_mem_table(ctx, fds) + } + + fn set_vring_num(&self, index: u32, num: u32) -> Result<()> { + self.lock().unwrap().set_vring_num(index, num) + } + + fn set_vring_addr( + &self, + index: u32, + flags: VhostUserVringAddrFlags, + descriptor: u64, + used: u64, + available: u64, + log: u64, + ) -> Result<()> { + self.lock() + .unwrap() + .set_vring_addr(index, flags, descriptor, used, available, log) + } + + fn set_vring_base(&self, index: u32, base: u32) -> Result<()> { + self.lock().unwrap().set_vring_base(index, base) + } + + fn get_vring_base(&self, index: u32) -> Result<VhostUserVringState> { + self.lock().unwrap().get_vring_base(index) + } + + fn set_vring_kick(&self, index: u8, fd: Option<RawFd>) -> Result<()> { + self.lock().unwrap().set_vring_kick(index, fd) + } + + fn set_vring_call(&self, index: u8, fd: Option<RawFd>) -> Result<()> { + self.lock().unwrap().set_vring_call(index, fd) + } + + fn set_vring_err(&self, index: u8, fd: Option<RawFd>) -> Result<()> { + self.lock().unwrap().set_vring_err(index, fd) + } + + fn get_protocol_features(&self) -> Result<VhostUserProtocolFeatures> { + self.lock().unwrap().get_protocol_features() + } + + fn set_protocol_features(&self, features: u64) -> Result<()> { + self.lock().unwrap().set_protocol_features(features) + } + + fn get_queue_num(&self) -> Result<u64> { + self.lock().unwrap().get_queue_num() + } + + fn set_vring_enable(&self, index: u32, enable: bool) -> Result<()> { + self.lock().unwrap().set_vring_enable(index, enable) + } + + fn get_config(&self, offset: u32, size: u32, flags: VhostUserConfigFlags) -> Result<Vec<u8>> { + self.lock().unwrap().get_config(offset, size, flags) + } + + fn set_config(&self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()> { + self.lock().unwrap().set_config(offset, buf, flags) + } + + fn set_slave_req_fd(&self, vu_req: SlaveFsCacheReq) { + self.lock().unwrap().set_slave_req_fd(vu_req) + } + + fn get_max_mem_slots(&self) -> Result<u64> { + self.lock().unwrap().get_max_mem_slots() + } + + fn add_mem_region(&self, region: &VhostUserSingleMemoryRegion, fd: RawFd) -> Result<()> { + self.lock().unwrap().add_mem_region(region, fd) + } + + fn remove_mem_region(&self, region: &VhostUserSingleMemoryRegion) -> Result<()> { + self.lock().unwrap().remove_mem_region(region) + } +} + +/// Server to handle service requests from masters from the master communication channel. +/// +/// The [SlaveReqHandler] acts as a server on the slave side, to handle service requests from +/// masters on the master communication channel. It's actually a proxy invoking the registered +/// handler implementing [VhostUserSlaveReqHandler] to do the real work. +/// +/// The lifetime of the SlaveReqHandler object should be the same as the underline Unix Domain +/// Socket, so it gets simpler to recover from disconnect. +/// +/// [VhostUserSlaveReqHandler]: trait.VhostUserSlaveReqHandler.html +/// [SlaveReqHandler]: struct.SlaveReqHandler.html +pub struct SlaveReqHandler<S: VhostUserSlaveReqHandler> { + // underlying Unix domain socket for communication + main_sock: Endpoint<MasterReq>, + // the vhost-user backend device object + backend: Arc<S>, + + virtio_features: u64, + acked_virtio_features: u64, + protocol_features: VhostUserProtocolFeatures, + acked_protocol_features: u64, + + // sending ack for messages without payload + reply_ack_enabled: bool, + // whether the endpoint has encountered any failure + error: Option<i32>, +} + +impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { + /// Create a vhost-user slave endpoint. + pub(super) fn new(main_sock: Endpoint<MasterReq>, backend: Arc<S>) -> Self { + SlaveReqHandler { + main_sock, + backend, + virtio_features: 0, + acked_virtio_features: 0, + protocol_features: VhostUserProtocolFeatures::empty(), + acked_protocol_features: 0, + reply_ack_enabled: false, + error: None, + } + } + + /// Create a new vhost-user slave endpoint. + /// + /// # Arguments + /// * - `path` - path of Unix domain socket listener to connect to + /// * - `backend` - handler for requests from the master to the slave + pub fn connect(path: &str, backend: Arc<S>) -> Result<Self> { + Ok(Self::new(Endpoint::<MasterReq>::connect(path)?, backend)) + } + + /// Mark endpoint as failed with specified error code. + pub fn set_failed(&mut self, error: i32) { + self.error = Some(error); + } + + /// Main entrance to server slave request from the slave communication channel. + /// + /// Receive and handle one incoming request message from the master. The caller needs to: + /// - serialize calls to this function + /// - decide what to do when error happens + /// - optional recover from failure + pub fn handle_request(&mut self) -> Result<()> { + // Return error if the endpoint is already in failed state. + self.check_state()?; + + // The underlying communication channel is a Unix domain socket in + // stream mode, and recvmsg() is a little tricky here. To successfully + // receive attached file descriptors, we need to receive messages and + // corresponding attached file descriptors in this way: + // . recv messsage header and optional attached file + // . validate message header + // . recv optional message body and payload according size field in + // message header + // . validate message body and optional payload + let (hdr, rfds) = self.main_sock.recv_header()?; + let rfds = self.check_attached_rfds(&hdr, rfds)?; + let (size, buf) = match hdr.get_size() { + 0 => (0, vec![0u8; 0]), + len => { + let (size2, rbuf) = self.main_sock.recv_data(len as usize)?; + if size2 != len as usize { + return Err(Error::InvalidMessage); + } + (size2, rbuf) + } + }; + + match hdr.get_code() { + MasterReq::SET_OWNER => { + self.check_request_size(&hdr, size, 0)?; + self.backend.set_owner()?; + } + MasterReq::RESET_OWNER => { + self.check_request_size(&hdr, size, 0)?; + self.backend.reset_owner()?; + } + MasterReq::GET_FEATURES => { + self.check_request_size(&hdr, size, 0)?; + let features = self.backend.get_features()?; + let msg = VhostUserU64::new(features); + self.send_reply_message(&hdr, &msg)?; + self.virtio_features = features; + self.update_reply_ack_flag(); + } + MasterReq::SET_FEATURES => { + let msg = self.extract_request_body::<VhostUserU64>(&hdr, size, &buf)?; + self.backend.set_features(msg.value)?; + self.acked_virtio_features = msg.value; + self.update_reply_ack_flag(); + } + MasterReq::SET_MEM_TABLE => { + let res = self.set_mem_table(&hdr, size, &buf, rfds); + self.send_ack_message(&hdr, res)?; + } + MasterReq::SET_VRING_NUM => { + let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?; + let res = self.backend.set_vring_num(msg.index, msg.num); + self.send_ack_message(&hdr, res)?; + } + MasterReq::SET_VRING_ADDR => { + let msg = self.extract_request_body::<VhostUserVringAddr>(&hdr, size, &buf)?; + let flags = match VhostUserVringAddrFlags::from_bits(msg.flags) { + Some(val) => val, + None => return Err(Error::InvalidMessage), + }; + let res = self.backend.set_vring_addr( + msg.index, + flags, + msg.descriptor, + msg.used, + msg.available, + msg.log, + ); + self.send_ack_message(&hdr, res)?; + } + MasterReq::SET_VRING_BASE => { + let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?; + let res = self.backend.set_vring_base(msg.index, msg.num); + self.send_ack_message(&hdr, res)?; + } + MasterReq::GET_VRING_BASE => { + let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?; + let reply = self.backend.get_vring_base(msg.index)?; + self.send_reply_message(&hdr, &reply)?; + } + MasterReq::SET_VRING_CALL => { + self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?; + let (index, rfds) = self.handle_vring_fd_request(&buf, rfds)?; + let res = self.backend.set_vring_call(index, rfds); + self.send_ack_message(&hdr, res)?; + } + MasterReq::SET_VRING_KICK => { + self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?; + let (index, rfds) = self.handle_vring_fd_request(&buf, rfds)?; + let res = self.backend.set_vring_kick(index, rfds); + self.send_ack_message(&hdr, res)?; + } + MasterReq::SET_VRING_ERR => { + self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?; + let (index, rfds) = self.handle_vring_fd_request(&buf, rfds)?; + let res = self.backend.set_vring_err(index, rfds); + self.send_ack_message(&hdr, res)?; + } + MasterReq::GET_PROTOCOL_FEATURES => { + self.check_request_size(&hdr, size, 0)?; + let features = self.backend.get_protocol_features()?; + let msg = VhostUserU64::new(features.bits()); + self.send_reply_message(&hdr, &msg)?; + self.protocol_features = features; + self.update_reply_ack_flag(); + } + MasterReq::SET_PROTOCOL_FEATURES => { + let msg = self.extract_request_body::<VhostUserU64>(&hdr, size, &buf)?; + self.backend.set_protocol_features(msg.value)?; + self.acked_protocol_features = msg.value; + self.update_reply_ack_flag(); + } + MasterReq::GET_QUEUE_NUM => { + if self.acked_protocol_features & VhostUserProtocolFeatures::MQ.bits() == 0 { + return Err(Error::InvalidOperation); + } + self.check_request_size(&hdr, size, 0)?; + let num = self.backend.get_queue_num()?; + let msg = VhostUserU64::new(num); + self.send_reply_message(&hdr, &msg)?; + } + MasterReq::SET_VRING_ENABLE => { + let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?; + if self.acked_protocol_features & VhostUserProtocolFeatures::MQ.bits() == 0 + && msg.index > 0 + { + return Err(Error::InvalidOperation); + } + let enable = match msg.num { + 1 => true, + 0 => false, + _ => return Err(Error::InvalidParam), + }; + + let res = self.backend.set_vring_enable(msg.index, enable); + self.send_ack_message(&hdr, res)?; + } + MasterReq::GET_CONFIG => { + if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 { + return Err(Error::InvalidOperation); + } + self.check_request_size(&hdr, size, hdr.get_size() as usize)?; + self.get_config(&hdr, &buf)?; + } + MasterReq::SET_CONFIG => { + if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 { + return Err(Error::InvalidOperation); + } + self.check_request_size(&hdr, size, hdr.get_size() as usize)?; + self.set_config(&hdr, size, &buf)?; + } + MasterReq::SET_SLAVE_REQ_FD => { + if self.acked_protocol_features & VhostUserProtocolFeatures::SLAVE_REQ.bits() == 0 { + return Err(Error::InvalidOperation); + } + self.check_request_size(&hdr, size, hdr.get_size() as usize)?; + self.set_slave_req_fd(&hdr, rfds)?; + } + MasterReq::GET_MAX_MEM_SLOTS => { + if self.acked_protocol_features + & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits() + == 0 + { + return Err(Error::InvalidOperation); + } + self.check_request_size(&hdr, size, 0)?; + let num = self.backend.get_max_mem_slots()?; + let msg = VhostUserU64::new(num); + self.send_reply_message(&hdr, &msg)?; + } + MasterReq::ADD_MEM_REG => { + if self.acked_protocol_features + & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits() + == 0 + { + return Err(Error::InvalidOperation); + } + let fd = if let Some(fds) = &rfds { + if fds.len() != 1 { + return Err(Error::InvalidParam); + } + fds[0] + } else { + return Err(Error::InvalidParam); + }; + + let msg = + self.extract_request_body::<VhostUserSingleMemoryRegion>(&hdr, size, &buf)?; + let res = self.backend.add_mem_region(&msg, fd); + self.send_ack_message(&hdr, res)?; + } + MasterReq::REM_MEM_REG => { + if self.acked_protocol_features + & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits() + == 0 + { + return Err(Error::InvalidOperation); + } + + let msg = + self.extract_request_body::<VhostUserSingleMemoryRegion>(&hdr, size, &buf)?; + let res = self.backend.remove_mem_region(&msg); + self.send_ack_message(&hdr, res)?; + } + _ => { + return Err(Error::InvalidMessage); + } + } + Ok(()) + } + + fn set_mem_table( + &mut self, + hdr: &VhostUserMsgHeader<MasterReq>, + size: usize, + buf: &[u8], + rfds: Option<Vec<RawFd>>, + ) -> Result<()> { + self.check_request_size(&hdr, size, hdr.get_size() as usize)?; + + // check message size is consistent + let hdrsize = mem::size_of::<VhostUserMemory>(); + if size < hdrsize { + Endpoint::<MasterReq>::close_rfds(rfds); + return Err(Error::InvalidMessage); + } + let msg = unsafe { &*(buf.as_ptr() as *const VhostUserMemory) }; + if !msg.is_valid() { + Endpoint::<MasterReq>::close_rfds(rfds); + return Err(Error::InvalidMessage); + } + if size != hdrsize + msg.num_regions as usize * mem::size_of::<VhostUserMemoryRegion>() { + Endpoint::<MasterReq>::close_rfds(rfds); + return Err(Error::InvalidMessage); + } + + // validate number of fds matching number of memory regions + let fds = match rfds { + None => return Err(Error::InvalidMessage), + Some(fds) => { + if fds.len() != msg.num_regions as usize { + Endpoint::<MasterReq>::close_rfds(Some(fds)); + return Err(Error::InvalidMessage); + } + fds + } + }; + + // Validate memory regions + let regions = unsafe { + slice::from_raw_parts( + buf.as_ptr().add(hdrsize) as *const VhostUserMemoryRegion, + msg.num_regions as usize, + ) + }; + for region in regions.iter() { + if !region.is_valid() { + Endpoint::<MasterReq>::close_rfds(Some(fds)); + return Err(Error::InvalidMessage); + } + } + + self.backend.set_mem_table(®ions, &fds) + } + + fn get_config(&mut self, hdr: &VhostUserMsgHeader<MasterReq>, buf: &[u8]) -> Result<()> { + let payload_offset = mem::size_of::<VhostUserConfig>(); + if buf.len() > MAX_MSG_SIZE || buf.len() < payload_offset { + return Err(Error::InvalidMessage); + } + let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const VhostUserConfig) }; + if !msg.is_valid() { + return Err(Error::InvalidMessage); + } + if buf.len() - payload_offset != msg.size as usize { + return Err(Error::InvalidMessage); + } + let flags = match VhostUserConfigFlags::from_bits(msg.flags) { + Some(val) => val, + None => return Err(Error::InvalidMessage), + }; + let res = self.backend.get_config(msg.offset, msg.size, flags); + + // vhost-user slave's payload size MUST match master's request + // on success, uses zero length of payload to indicate an error + // to vhost-user master. + match res { + Ok(ref buf) if buf.len() == msg.size as usize => { + let reply = VhostUserConfig::new(msg.offset, buf.len() as u32, flags); + self.send_reply_with_payload(&hdr, &reply, buf.as_slice())?; + } + Ok(_) => { + let reply = VhostUserConfig::new(msg.offset, 0, flags); + self.send_reply_message(&hdr, &reply)?; + } + Err(_) => { + let reply = VhostUserConfig::new(msg.offset, 0, flags); + self.send_reply_message(&hdr, &reply)?; + } + } + Ok(()) + } + + fn set_config( + &mut self, + hdr: &VhostUserMsgHeader<MasterReq>, + size: usize, + buf: &[u8], + ) -> Result<()> { + if size > MAX_MSG_SIZE || size < mem::size_of::<VhostUserConfig>() { + return Err(Error::InvalidMessage); + } + let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const VhostUserConfig) }; + if !msg.is_valid() { + return Err(Error::InvalidMessage); + } + if size - mem::size_of::<VhostUserConfig>() != msg.size as usize { + return Err(Error::InvalidMessage); + } + let flags: VhostUserConfigFlags; + match VhostUserConfigFlags::from_bits(msg.flags) { + Some(val) => flags = val, + None => return Err(Error::InvalidMessage), + } + + let res = self.backend.set_config(msg.offset, buf, flags); + self.send_ack_message(&hdr, res)?; + Ok(()) + } + + fn set_slave_req_fd( + &mut self, + hdr: &VhostUserMsgHeader<MasterReq>, + rfds: Option<Vec<RawFd>>, + ) -> Result<()> { + if let Some(fds) = rfds { + if fds.len() == 1 { + let sock = unsafe { UnixStream::from_raw_fd(fds[0]) }; + let vu_req = SlaveFsCacheReq::from_stream(sock); + self.backend.set_slave_req_fd(vu_req); + self.send_ack_message(&hdr, Ok(())) + } else { + Err(Error::InvalidMessage) + } + } else { + Err(Error::InvalidMessage) + } + } + + fn handle_vring_fd_request( + &mut self, + buf: &[u8], + rfds: Option<Vec<RawFd>>, + ) -> Result<(u8, Option<RawFd>)> { + if buf.len() > MAX_MSG_SIZE || buf.len() < mem::size_of::<VhostUserU64>() { + return Err(Error::InvalidMessage); + } + let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const VhostUserU64) }; + if !msg.is_valid() { + return Err(Error::InvalidMessage); + } + + // Bits (0-7) of the payload contain the vring index. Bit 8 is the + // invalid FD flag. This flag is set when there is no file descriptor + // in the ancillary data. This signals that polling will be used + // instead of waiting for the call. + let nofd = (msg.value & 0x100u64) == 0x100u64; + + let mut rfd = None; + match rfds { + Some(fds) => { + if !nofd && fds.len() == 1 { + rfd = Some(fds[0]); + } else if (nofd && !fds.is_empty()) || (!nofd && fds.len() != 1) { + Endpoint::<MasterReq>::close_rfds(Some(fds)); + return Err(Error::InvalidMessage); + } + } + None => { + if !nofd { + return Err(Error::InvalidMessage); + } + } + } + Ok((msg.value as u8, rfd)) + } + + fn check_state(&self) -> Result<()> { + match self.error { + Some(e) => Err(Error::SocketBroken(std::io::Error::from_raw_os_error(e))), + None => Ok(()), + } + } + + fn check_request_size( + &self, + hdr: &VhostUserMsgHeader<MasterReq>, + size: usize, + expected: usize, + ) -> Result<()> { + if hdr.get_size() as usize != expected + || hdr.is_reply() + || hdr.get_version() != 0x1 + || size != expected + { + return Err(Error::InvalidMessage); + } + Ok(()) + } + + fn check_attached_rfds( + &self, + hdr: &VhostUserMsgHeader<MasterReq>, + rfds: Option<Vec<RawFd>>, + ) -> Result<Option<Vec<RawFd>>> { + match hdr.get_code() { + MasterReq::SET_MEM_TABLE => Ok(rfds), + MasterReq::SET_VRING_CALL => Ok(rfds), + MasterReq::SET_VRING_KICK => Ok(rfds), + MasterReq::SET_VRING_ERR => Ok(rfds), + MasterReq::SET_LOG_BASE => Ok(rfds), + MasterReq::SET_LOG_FD => Ok(rfds), + MasterReq::SET_SLAVE_REQ_FD => Ok(rfds), + MasterReq::SET_INFLIGHT_FD => Ok(rfds), + MasterReq::ADD_MEM_REG => Ok(rfds), + _ => { + if rfds.is_some() { + Endpoint::<MasterReq>::close_rfds(rfds); + Err(Error::InvalidMessage) + } else { + Ok(rfds) + } + } + } + } + + fn extract_request_body<T: Sized + VhostUserMsgValidator>( + &self, + hdr: &VhostUserMsgHeader<MasterReq>, + size: usize, + buf: &[u8], + ) -> Result<T> { + self.check_request_size(hdr, size, mem::size_of::<T>())?; + let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const T) }; + if !msg.is_valid() { + return Err(Error::InvalidMessage); + } + Ok(msg) + } + + fn update_reply_ack_flag(&mut self) { + let vflag = VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits(); + let pflag = VhostUserProtocolFeatures::REPLY_ACK; + if (self.virtio_features & vflag) != 0 + && (self.acked_virtio_features & vflag) != 0 + && self.protocol_features.contains(pflag) + && (self.acked_protocol_features & pflag.bits()) != 0 + { + self.reply_ack_enabled = true; + } else { + self.reply_ack_enabled = false; + } + } + + fn new_reply_header<T: Sized>( + &self, + req: &VhostUserMsgHeader<MasterReq>, + payload_size: usize, + ) -> Result<VhostUserMsgHeader<MasterReq>> { + if mem::size_of::<T>() > MAX_MSG_SIZE + || payload_size > MAX_MSG_SIZE + || mem::size_of::<T>() + payload_size > MAX_MSG_SIZE + { + return Err(Error::InvalidParam); + } + self.check_state()?; + Ok(VhostUserMsgHeader::new( + req.get_code(), + VhostUserHeaderFlag::REPLY.bits(), + (mem::size_of::<T>() + payload_size) as u32, + )) + } + + fn send_ack_message( + &mut self, + req: &VhostUserMsgHeader<MasterReq>, + res: Result<()>, + ) -> Result<()> { + if self.reply_ack_enabled && req.is_need_reply() { + let hdr = self.new_reply_header::<VhostUserU64>(req, 0)?; + let val = match res { + Ok(_) => 0, + Err(_) => 1, + }; + let msg = VhostUserU64::new(val); + self.main_sock.send_message(&hdr, &msg, None)?; + } + Ok(()) + } + + fn send_reply_message<T>( + &mut self, + req: &VhostUserMsgHeader<MasterReq>, + msg: &T, + ) -> Result<()> { + let hdr = self.new_reply_header::<T>(req, 0)?; + self.main_sock.send_message(&hdr, msg, None)?; + Ok(()) + } + + fn send_reply_with_payload<T: Sized>( + &mut self, + req: &VhostUserMsgHeader<MasterReq>, + msg: &T, + payload: &[u8], + ) -> Result<()> { + let hdr = self.new_reply_header::<T>(req, payload.len())?; + self.main_sock + .send_message_with_payload(&hdr, msg, payload, None)?; + Ok(()) + } +} + +impl<S: VhostUserSlaveReqHandler> AsRawFd for SlaveReqHandler<S> { + fn as_raw_fd(&self) -> RawFd { + self.main_sock.as_raw_fd() + } +} + +#[cfg(test)] +mod tests { + use std::os::unix::io::AsRawFd; + + use super::*; + use crate::vhost_user::dummy_slave::DummySlaveReqHandler; + + #[test] + fn test_slave_req_handler_new() { + let (p1, _p2) = UnixStream::pair().unwrap(); + let endpoint = Endpoint::<MasterReq>::from_stream(p1); + let backend = Arc::new(Mutex::new(DummySlaveReqHandler::new())); + let mut handler = SlaveReqHandler::new(endpoint, backend); + + handler.check_state().unwrap(); + handler.set_failed(libc::EAGAIN); + handler.check_state().unwrap_err(); + assert!(handler.as_raw_fd() >= 0); + } +} |