diff options
Diffstat (limited to 'src/vhost_user')
-rw-r--r-- | src/vhost_user/connection.rs | 56 | ||||
-rw-r--r-- | src/vhost_user/dummy_slave.rs | 58 | ||||
-rw-r--r-- | src/vhost_user/master.rs | 352 | ||||
-rw-r--r-- | src/vhost_user/master_req_handler.rs | 285 | ||||
-rw-r--r-- | src/vhost_user/message.rs | 162 | ||||
-rw-r--r-- | src/vhost_user/mod.rs | 200 | ||||
-rw-r--r-- | src/vhost_user/slave.rs | 46 | ||||
-rw-r--r-- | src/vhost_user/slave_fs_cache.rs | 210 | ||||
-rw-r--r-- | src/vhost_user/slave_req_handler.rs | 303 | ||||
-rw-r--r-- | src/vhost_user/sock_ctrl_msg.rs | 499 |
10 files changed, 1363 insertions, 808 deletions
diff --git a/src/vhost_user/connection.rs b/src/vhost_user/connection.rs index deafdeb..01bf124 100644 --- a/src/vhost_user/connection.rs +++ b/src/vhost_user/connection.rs @@ -5,15 +5,16 @@ #![allow(dead_code)] -use libc::{c_void, iovec}; use std::io::ErrorKind; use std::marker::PhantomData; use std::os::unix::io::{AsRawFd, RawFd}; use std::os::unix::net::{UnixListener, UnixStream}; use std::{mem, slice}; +use libc::{c_void, iovec}; +use vmm_sys_util::sock_ctrl_msg::ScmSocket; + use super::message::*; -use super::sock_ctrl_msg::ScmSocket; use super::{Error, Result}; /// Unix domain socket listener for accepting incoming connections. @@ -215,6 +216,9 @@ impl<R: Req> Endpoint<R> { body: &T, fds: Option<&[RawFd]>, ) -> Result<()> { + if mem::size_of::<T>() > MAX_MSG_SIZE { + return Err(Error::OversizedMsg); + } // Safe because there can't be other mutable referance to hdr and body. let iovs = unsafe { [ @@ -243,14 +247,17 @@ impl<R: Req> Endpoint<R> { /// * - OversizedMsg: message size is too big. /// * - PartialMessage: received a partial message. /// * - IncorrectFds: wrong number of attached fds. - pub fn send_message_with_payload<T: Sized, P: Sized>( + pub fn send_message_with_payload<T: Sized>( &mut self, hdr: &VhostUserMsgHeader<R>, body: &T, - payload: &[P], + payload: &[u8], fds: Option<&[RawFd]>, ) -> Result<()> { - let len = payload.len() * mem::size_of::<P>(); + let len = payload.len(); + if mem::size_of::<T>() > MAX_MSG_SIZE { + return Err(Error::OversizedMsg); + } if len > MAX_MSG_SIZE - mem::size_of::<T>() { return Err(Error::OversizedMsg); } @@ -599,27 +606,32 @@ fn get_sub_iovs_offset(iov_lens: &[usize], skip_size: usize) -> (usize, usize) { #[cfg(test)] mod tests { - use super::*; use std::fs::File; use std::io::{Read, Seek, SeekFrom, Write}; use std::os::unix::io::FromRawFd; + use vmm_sys_util::rand::rand_alphanumerics; use vmm_sys_util::tempfile::TempFile; - const UNIX_SOCKET_LISTENER: &'static str = "/tmp/vhost_user_test_rust_listener"; - const UNIX_SOCKET_CONNECTION: &'static str = "/tmp/vhost_user_test_rust_connection"; - const UNIX_SOCKET_DATA: &'static str = "/tmp/vhost_user_test_rust_data"; - const UNIX_SOCKET_FD: &'static str = "/tmp/vhost_user_test_rust_fd"; - const UNIX_SOCKET_SEND: &'static str = "/tmp/vhost_user_test_rust_send"; + fn temp_path() -> String { + format!( + "/tmp/vhost_test_{}", + rand_alphanumerics(8).to_str().unwrap() + ) + } #[test] fn create_listener() { - let _ = Listener::new(UNIX_SOCKET_LISTENER, true).unwrap(); + let path = temp_path(); + let listener = Listener::new(&path, true).unwrap(); + + assert!(listener.as_raw_fd() > 0); } #[test] fn accept_connection() { - let listener = Listener::new(UNIX_SOCKET_CONNECTION, true).unwrap(); + let path = temp_path(); + let listener = Listener::new(&path, true).unwrap(); listener.set_nonblocking(true).unwrap(); // accept on a fd without incoming connection @@ -628,11 +640,11 @@ mod tests { } #[test] - #[ignore] fn send_data() { - let listener = Listener::new(UNIX_SOCKET_DATA, true).unwrap(); + let path = temp_path(); + let listener = Listener::new(&path, true).unwrap(); listener.set_nonblocking(true).unwrap(); - let mut master = Endpoint::<MasterReq>::connect(UNIX_SOCKET_DATA).unwrap(); + let mut master = Endpoint::<MasterReq>::connect(&path).unwrap(); let sock = listener.accept().unwrap().unwrap(); let mut slave = Endpoint::<MasterReq>::from_stream(sock); @@ -654,11 +666,11 @@ mod tests { } #[test] - #[ignore] fn send_fd() { - let listener = Listener::new(UNIX_SOCKET_FD, true).unwrap(); + let path = temp_path(); + let listener = Listener::new(&path, true).unwrap(); listener.set_nonblocking(true).unwrap(); - let mut master = Endpoint::<MasterReq>::connect(UNIX_SOCKET_FD).unwrap(); + let mut master = Endpoint::<MasterReq>::connect(&path).unwrap(); let sock = listener.accept().unwrap().unwrap(); let mut slave = Endpoint::<MasterReq>::from_stream(sock); @@ -808,11 +820,11 @@ mod tests { } #[test] - #[ignore] fn send_recv() { - let listener = Listener::new(UNIX_SOCKET_SEND, true).unwrap(); + let path = temp_path(); + let listener = Listener::new(&path, true).unwrap(); listener.set_nonblocking(true).unwrap(); - let mut master = Endpoint::<MasterReq>::connect(UNIX_SOCKET_SEND).unwrap(); + let mut master = Endpoint::<MasterReq>::connect(&path).unwrap(); let sock = listener.accept().unwrap().unwrap(); let mut slave = Endpoint::<MasterReq>::from_stream(sock); diff --git a/src/vhost_user/dummy_slave.rs b/src/vhost_user/dummy_slave.rs index 53887e2..9eedcbb 100644 --- a/src/vhost_user/dummy_slave.rs +++ b/src/vhost_user/dummy_slave.rs @@ -1,9 +1,10 @@ // Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved. // SPDX-License-Identifier: Apache-2.0 +use std::os::unix::io::RawFd; + use super::message::*; use super::*; -use std::os::unix::io::RawFd; pub const MAX_QUEUE_NUM: usize = 2; pub const MAX_VRING_NUM: usize = 256; @@ -34,7 +35,7 @@ impl DummySlaveReqHandler { } } -impl VhostUserSlaveReqHandler for DummySlaveReqHandler { +impl VhostUserSlaveReqHandlerMut for DummySlaveReqHandler { fn set_owner(&mut self) -> Result<()> { if self.owned { return Err(Error::InvalidOperation); @@ -56,9 +57,7 @@ impl VhostUserSlaveReqHandler for DummySlaveReqHandler { } fn set_features(&mut self, features: u64) -> Result<()> { - if !self.owned { - return Err(Error::InvalidOperation); - } else if self.features_acked { + if !self.owned || self.features_acked { return Err(Error::InvalidOperation); } else if (features & !VIRTIO_FEATURES) != 0 { return Err(Error::InvalidParam); @@ -83,30 +82,10 @@ impl VhostUserSlaveReqHandler for DummySlaveReqHandler { Ok(()) } - fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures> { - Ok(VhostUserProtocolFeatures::all()) - } - - fn set_protocol_features(&mut self, features: u64) -> Result<()> { - // Note: slave that reported VHOST_USER_F_PROTOCOL_FEATURES must - // support this message even before VHOST_USER_SET_FEATURES was - // called. - // What happens if the master calls set_features() with - // VHOST_USER_F_PROTOCOL_FEATURES cleared after calling this - // interface? - self.acked_protocol_features = features; - Ok(()) - } - fn set_mem_table(&mut self, _ctx: &[VhostUserMemoryRegion], _fds: &[RawFd]) -> Result<()> { - // TODO Ok(()) } - fn get_queue_num(&mut self) -> Result<u64> { - Ok(MAX_QUEUE_NUM as u64) - } - fn set_vring_num(&mut self, index: u32, num: u32) -> Result<()> { if index as usize >= self.queue_num || num == 0 || num as usize > MAX_VRING_NUM { return Err(Error::InvalidParam); @@ -199,6 +178,25 @@ impl VhostUserSlaveReqHandler for DummySlaveReqHandler { Ok(()) } + fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures> { + Ok(VhostUserProtocolFeatures::all()) + } + + fn set_protocol_features(&mut self, features: u64) -> Result<()> { + // Note: slave that reported VHOST_USER_F_PROTOCOL_FEATURES must + // support this message even before VHOST_USER_SET_FEATURES was + // called. + // What happens if the master calls set_features() with + // VHOST_USER_F_PROTOCOL_FEATURES cleared after calling this + // interface? + self.acked_protocol_features = features; + Ok(()) + } + + fn get_queue_num(&mut self) -> Result<u64> { + Ok(MAX_QUEUE_NUM as u64) + } + fn set_vring_enable(&mut self, index: u32, enable: bool) -> Result<()> { // This request should be handled only when VHOST_USER_F_PROTOCOL_FEATURES // has been negotiated. @@ -222,10 +220,9 @@ impl VhostUserSlaveReqHandler for DummySlaveReqHandler { size: u32, _flags: VhostUserConfigFlags, ) -> Result<Vec<u8>> { - if self.acked_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 { + if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 { return Err(Error::InvalidOperation); - } else if offset < VHOST_USER_CONFIG_OFFSET - || offset >= VHOST_USER_CONFIG_SIZE + } else if !(VHOST_USER_CONFIG_OFFSET..VHOST_USER_CONFIG_SIZE).contains(&offset) || size > VHOST_USER_CONFIG_SIZE - VHOST_USER_CONFIG_OFFSET || size + offset > VHOST_USER_CONFIG_SIZE { @@ -236,10 +233,9 @@ impl VhostUserSlaveReqHandler for DummySlaveReqHandler { fn set_config(&mut self, offset: u32, buf: &[u8], _flags: VhostUserConfigFlags) -> Result<()> { let size = buf.len() as u32; - if self.acked_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 { + if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 { return Err(Error::InvalidOperation); - } else if offset < VHOST_USER_CONFIG_OFFSET - || offset >= VHOST_USER_CONFIG_SIZE + } else if !(VHOST_USER_CONFIG_OFFSET..VHOST_USER_CONFIG_SIZE).contains(&offset) || size > VHOST_USER_CONFIG_SIZE - VHOST_USER_CONFIG_OFFSET || size + offset > VHOST_USER_CONFIG_SIZE { diff --git a/src/vhost_user/master.rs b/src/vhost_user/master.rs index ffed909..35ca471 100644 --- a/src/vhost_user/master.rs +++ b/src/vhost_user/master.rs @@ -6,7 +6,7 @@ use std::mem; use std::os::unix::io::{AsRawFd, RawFd}; use std::os::unix::net::UnixStream; -use std::sync::{Arc, Mutex}; +use std::sync::{Arc, Mutex, MutexGuard}; use vmm_sys_util::eventfd::EventFd; @@ -78,6 +78,10 @@ impl Master { } } + fn node(&self) -> MutexGuard<MasterInternal> { + self.node.lock().unwrap() + } + /// Create a new instance from a Unix stream socket. pub fn from_stream(sock: UnixStream, max_queue_num: u64) -> Self { Self::new(Endpoint::<MasterReq>::from_stream(sock), max_queue_num) @@ -115,8 +119,8 @@ impl Master { impl VhostBackend for Master { /// Get from the underlying vhost implementation the feature bitmask. - fn get_features(&mut self) -> Result<u64> { - let mut node = self.node.lock().unwrap(); + fn get_features(&self) -> Result<u64> { + let mut node = self.node(); let hdr = node.send_request_header(MasterReq::GET_FEATURES, None)?; let val = node.recv_reply::<VhostUserU64>(&hdr)?; node.virtio_features = val.value; @@ -124,8 +128,8 @@ impl VhostBackend for Master { } /// Enable features in the underlying vhost implementation using a bitmask. - fn set_features(&mut self, features: u64) -> Result<()> { - let mut node = self.node.lock().unwrap(); + fn set_features(&self, features: u64) -> Result<()> { + let mut node = self.node(); let val = VhostUserU64::new(features); let _ = node.send_request_with_body(MasterReq::SET_FEATURES, &val, None)?; // Don't wait for ACK here because the protocol feature negotiation process hasn't been @@ -135,18 +139,18 @@ impl VhostBackend for Master { } /// Set the current Master as an owner of the session. - fn set_owner(&mut self) -> Result<()> { + fn set_owner(&self) -> Result<()> { // We unwrap() the return value to assert that we are not expecting threads to ever fail // while holding the lock. - let mut node = self.node.lock().unwrap(); + let mut node = self.node(); let _ = node.send_request_header(MasterReq::SET_OWNER, None)?; // Don't wait for ACK here because the protocol feature negotiation process hasn't been // completed yet. Ok(()) } - fn reset_owner(&mut self) -> Result<()> { - let mut node = self.node.lock().unwrap(); + fn reset_owner(&self) -> Result<()> { + let mut node = self.node(); let _ = node.send_request_header(MasterReq::RESET_OWNER, None)?; // Don't wait for ACK here because the protocol feature negotiation process hasn't been // completed yet. @@ -155,7 +159,7 @@ impl VhostBackend for Master { /// Set the memory map regions on the slave so it can translate the vring /// addresses. In the ancillary data there is an array of file descriptors - fn set_mem_table(&mut self, regions: &[VhostUserMemoryRegionInfo]) -> Result<()> { + fn set_mem_table(&self, regions: &[VhostUserMemoryRegionInfo]) -> Result<()> { if regions.is_empty() || regions.len() > MAX_ATTACHED_FD_ENTRIES { return error_code(VhostUserError::InvalidParam); } @@ -174,12 +178,13 @@ impl VhostBackend for Master { ctx.append(®, region.mmap_handle); } - let mut node = self.node.lock().unwrap(); + let mut node = self.node(); let body = VhostUserMemory::new(ctx.regions.len() as u32); + let (_, payload, _) = unsafe { ctx.regions.align_to::<u8>() }; let hdr = node.send_request_with_payload( MasterReq::SET_MEM_TABLE, &body, - ctx.regions.as_slice(), + payload, Some(ctx.fds.as_slice()), )?; node.wait_for_ack(&hdr).map_err(|e| e.into()) @@ -187,8 +192,8 @@ impl VhostBackend for Master { // Clippy doesn't seem to know that if let with && is still experimental #[allow(clippy::unnecessary_unwrap)] - fn set_log_base(&mut self, base: u64, fd: Option<RawFd>) -> Result<()> { - let mut node = self.node.lock().unwrap(); + fn set_log_base(&self, base: u64, fd: Option<RawFd>) -> Result<()> { + let mut node = self.node(); let val = VhostUserU64::new(base); if node.acked_protocol_features & VhostUserProtocolFeatures::LOG_SHMFD.bits() != 0 @@ -202,16 +207,16 @@ impl VhostBackend for Master { Ok(()) } - fn set_log_fd(&mut self, fd: RawFd) -> Result<()> { - let mut node = self.node.lock().unwrap(); + fn set_log_fd(&self, fd: RawFd) -> Result<()> { + let mut node = self.node(); let fds = [fd]; node.send_request_header(MasterReq::SET_LOG_FD, Some(&fds))?; Ok(()) } /// Set the size of the queue. - fn set_vring_num(&mut self, queue_index: usize, num: u16) -> Result<()> { - let mut node = self.node.lock().unwrap(); + fn set_vring_num(&self, queue_index: usize, num: u16) -> Result<()> { + let mut node = self.node(); if queue_index as u64 >= node.max_queue_num { return error_code(VhostUserError::InvalidParam); } @@ -222,8 +227,8 @@ impl VhostBackend for Master { } /// Sets the addresses of the different aspects of the vring. - fn set_vring_addr(&mut self, queue_index: usize, config_data: &VringConfigData) -> Result<()> { - let mut node = self.node.lock().unwrap(); + fn set_vring_addr(&self, queue_index: usize, config_data: &VringConfigData) -> Result<()> { + let mut node = self.node(); if queue_index as u64 >= node.max_queue_num || config_data.flags & !(VhostUserVringAddrFlags::all().bits()) != 0 { @@ -236,8 +241,8 @@ impl VhostBackend for Master { } /// Sets the base offset in the available vring. - fn set_vring_base(&mut self, queue_index: usize, base: u16) -> Result<()> { - let mut node = self.node.lock().unwrap(); + fn set_vring_base(&self, queue_index: usize, base: u16) -> Result<()> { + let mut node = self.node(); if queue_index as u64 >= node.max_queue_num { return error_code(VhostUserError::InvalidParam); } @@ -247,8 +252,8 @@ impl VhostBackend for Master { node.wait_for_ack(&hdr).map_err(|e| e.into()) } - fn get_vring_base(&mut self, queue_index: usize) -> Result<u32> { - let mut node = self.node.lock().unwrap(); + fn get_vring_base(&self, queue_index: usize) -> Result<u32> { + let mut node = self.node(); if queue_index as u64 >= node.max_queue_num { return error_code(VhostUserError::InvalidParam); } @@ -263,8 +268,8 @@ impl VhostBackend for Master { /// Bits (0-7) of the payload contain the vring index. Bit 8 is the invalid FD flag. This flag /// is set when there is no file descriptor in the ancillary data. This signals that polling /// will be used instead of waiting for the call. - fn set_vring_call(&mut self, queue_index: usize, fd: &EventFd) -> Result<()> { - let mut node = self.node.lock().unwrap(); + fn set_vring_call(&self, queue_index: usize, fd: &EventFd) -> Result<()> { + let mut node = self.node(); if queue_index as u64 >= node.max_queue_num { return error_code(VhostUserError::InvalidParam); } @@ -276,8 +281,8 @@ impl VhostBackend for Master { /// Bits (0-7) of the payload contain the vring index. Bit 8 is the invalid FD flag. This flag /// is set when there is no file descriptor in the ancillary data. This signals that polling /// should be used instead of waiting for a kick. - fn set_vring_kick(&mut self, queue_index: usize, fd: &EventFd) -> Result<()> { - let mut node = self.node.lock().unwrap(); + fn set_vring_kick(&self, queue_index: usize, fd: &EventFd) -> Result<()> { + let mut node = self.node(); if queue_index as u64 >= node.max_queue_num { return error_code(VhostUserError::InvalidParam); } @@ -288,8 +293,8 @@ impl VhostBackend for Master { /// Set the event file descriptor to signal when error occurs. /// Bits (0-7) of the payload contain the vring index. Bit 8 is the invalid FD flag. This flag /// is set when there is no file descriptor in the ancillary data. - fn set_vring_err(&mut self, queue_index: usize, fd: &EventFd) -> Result<()> { - let mut node = self.node.lock().unwrap(); + fn set_vring_err(&self, queue_index: usize, fd: &EventFd) -> Result<()> { + let mut node = self.node(); if queue_index as u64 >= node.max_queue_num { return error_code(VhostUserError::InvalidParam); } @@ -300,7 +305,7 @@ impl VhostBackend for Master { impl VhostUserMaster for Master { fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures> { - let mut node = self.node.lock().unwrap(); + let mut node = self.node(); let flag = VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits(); if node.virtio_features & flag == 0 || node.acked_virtio_features & flag == 0 { return error_code(VhostUserError::InvalidOperation); @@ -317,7 +322,7 @@ impl VhostUserMaster for Master { } fn set_protocol_features(&mut self, features: VhostUserProtocolFeatures) -> Result<()> { - let mut node = self.node.lock().unwrap(); + let mut node = self.node(); let flag = VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits(); if node.virtio_features & flag == 0 || node.acked_virtio_features & flag == 0 { return error_code(VhostUserError::InvalidOperation); @@ -332,7 +337,7 @@ impl VhostUserMaster for Master { } fn get_queue_num(&mut self) -> Result<u64> { - let mut node = self.node.lock().unwrap(); + let mut node = self.node(); if !node.is_feature_mq_available() { return error_code(VhostUserError::InvalidOperation); } @@ -347,7 +352,7 @@ impl VhostUserMaster for Master { } fn set_vring_enable(&mut self, queue_index: usize, enable: bool) -> Result<()> { - let mut node = self.node.lock().unwrap(); + let mut node = self.node(); // set_vring_enable() is supported only when PROTOCOL_FEATURES has been enabled. if node.acked_virtio_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() == 0 { return error_code(VhostUserError::InvalidOperation); @@ -373,7 +378,7 @@ impl VhostUserMaster for Master { return error_code(VhostUserError::InvalidParam); } - let mut node = self.node.lock().unwrap(); + let mut node = self.node(); // depends on VhostUserProtocolFeatures::CONFIG if node.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 { return error_code(VhostUserError::InvalidOperation); @@ -390,9 +395,13 @@ impl VhostUserMaster for Master { return error_code(VhostUserError::InvalidMessage); } else if body_reply.size == 0 { return error_code(VhostUserError::SlaveInternalError); - } else if body_reply.size != body.size || body_reply.size as usize != buf.len() { + } else if body_reply.size != body.size + || body_reply.size as usize != buf.len() + || body_reply.offset != body.offset + { return error_code(VhostUserError::InvalidMessage); } + Ok((body_reply, buf_reply)) } @@ -405,7 +414,7 @@ impl VhostUserMaster for Master { return error_code(VhostUserError::InvalidParam); } - let mut node = self.node.lock().unwrap(); + let mut node = self.node(); // depends on VhostUserProtocolFeatures::CONFIG if node.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 { return error_code(VhostUserError::InvalidOperation); @@ -416,7 +425,7 @@ impl VhostUserMaster for Master { } fn set_slave_request_fd(&mut self, fd: RawFd) -> Result<()> { - let mut node = self.node.lock().unwrap(); + let mut node = self.node(); if node.acked_protocol_features & VhostUserProtocolFeatures::SLAVE_REQ.bits() == 0 { return error_code(VhostUserError::InvalidOperation); } @@ -429,7 +438,7 @@ impl VhostUserMaster for Master { impl AsRawFd for Master { fn as_raw_fd(&self) -> RawFd { - let node = self.node.lock().unwrap(); + let node = self.node(); node.main_sock.as_raw_fd() } } @@ -503,14 +512,14 @@ impl MasterInternal { Ok(hdr) } - fn send_request_with_payload<T: Sized, P: Sized>( + fn send_request_with_payload<T: Sized>( &mut self, code: MasterReq, msg: &T, - payload: &[P], + payload: &[u8], fds: Option<&[RawFd]>, ) -> VhostUserResult<VhostUserMsgHeader<MasterReq>> { - let len = mem::size_of::<T>() + payload.len() * mem::size_of::<P>(); + let len = mem::size_of::<T>() + payload.len(); if len > MAX_MSG_SIZE { return Err(VhostUserError::InvalidParam); } @@ -568,7 +577,11 @@ impl MasterInternal { &mut self, hdr: &VhostUserMsgHeader<MasterReq>, ) -> VhostUserResult<(T, Vec<u8>, Option<Vec<RawFd>>)> { - if mem::size_of::<T>() > MAX_MSG_SIZE || hdr.is_reply() { + if mem::size_of::<T>() > MAX_MSG_SIZE + || hdr.get_size() as usize <= mem::size_of::<T>() + || hdr.get_size() as usize > MAX_MSG_SIZE + || hdr.is_reply() + { return Err(VhostUserError::InvalidParam); } self.check_state()?; @@ -582,11 +595,8 @@ impl MasterInternal { { Endpoint::<MasterReq>::close_rfds(rfds); return Err(VhostUserError::InvalidMessage); - } else if bytes > MAX_MSG_SIZE - mem::size_of::<T>() { + } else if bytes != buf.len() { return Err(VhostUserError::InvalidMessage); - } else if bytes < buf.len() { - // It's safe because we have checked the buffer size - unsafe { buf.set_len(bytes) }; } Ok((body, buf, rfds)) } @@ -634,11 +644,14 @@ impl MasterInternal { mod tests { use super::super::connection::Listener; use super::*; + use vmm_sys_util::rand::rand_alphanumerics; - const UNIX_SOCKET_MASTER: &'static str = "/tmp/vhost_user_test_rust_master"; - const UNIX_SOCKET_MASTER2: &'static str = "/tmp/vhost_user_test_rust_master2"; - const UNIX_SOCKET_MASTER3: &'static str = "/tmp/vhost_user_test_rust_master3"; - const UNIX_SOCKET_MASTER4: &'static str = "/tmp/vhost_user_test_rust_master4"; + fn temp_path() -> String { + format!( + "/tmp/vhost_test_{}", + rand_alphanumerics(8).to_str().unwrap() + ) + } fn create_pair(path: &str) -> (Master, Endpoint<MasterReq>) { let listener = Listener::new(path, true).unwrap(); @@ -649,14 +662,15 @@ mod tests { } #[test] - #[ignore] fn create_master() { - let listener = Listener::new(UNIX_SOCKET_MASTER, true).unwrap(); + let path = temp_path(); + let listener = Listener::new(&path, true).unwrap(); listener.set_nonblocking(true).unwrap(); - let mut master = Master::connect(UNIX_SOCKET_MASTER, 1).unwrap(); + let master = Master::connect(&path, 1).unwrap(); let mut slave = Endpoint::<MasterReq>::from_stream(listener.accept().unwrap().unwrap()); + assert!(master.as_raw_fd() > 0); // Send two messages continuously master.set_owner().unwrap(); master.reset_owner().unwrap(); @@ -675,24 +689,24 @@ mod tests { } #[test] - #[ignore] fn test_create_failure() { - let _ = Listener::new(UNIX_SOCKET_MASTER2, true).unwrap(); - let _ = Listener::new(UNIX_SOCKET_MASTER2, false).is_err(); - assert!(Master::connect(UNIX_SOCKET_MASTER2, 1).is_err()); + let path = temp_path(); + let _ = Listener::new(&path, true).unwrap(); + let _ = Listener::new(&path, false).is_err(); + assert!(Master::connect(&path, 1).is_err()); - let listener = Listener::new(UNIX_SOCKET_MASTER2, true).unwrap(); - assert!(Listener::new(UNIX_SOCKET_MASTER2, false).is_err()); + let listener = Listener::new(&path, true).unwrap(); + assert!(Listener::new(&path, false).is_err()); listener.set_nonblocking(true).unwrap(); - let _master = Master::connect(UNIX_SOCKET_MASTER2, 1).unwrap(); + let _master = Master::connect(&path, 1).unwrap(); let _slave = listener.accept().unwrap().unwrap(); } #[test] - #[ignore] fn test_features() { - let (mut master, mut peer) = create_pair(UNIX_SOCKET_MASTER3); + let path = temp_path(); + let (master, mut peer) = create_pair(&path); master.set_owner().unwrap(); let (hdr, rfds) = peer.recv_header().unwrap(); @@ -709,6 +723,9 @@ mod tests { let (_hdr, rfds) = peer.recv_header().unwrap(); assert!(rfds.is_none()); + let hdr = VhostUserMsgHeader::new(MasterReq::SET_FEATURES, 0x4, 8); + let msg = VhostUserU64::new(0x15); + peer.send_message(&hdr, &msg, None).unwrap(); master.set_features(0x15).unwrap(); let (_hdr, msg, rfds) = peer.recv_body::<VhostUserU64>().unwrap(); assert!(rfds.is_none()); @@ -722,9 +739,9 @@ mod tests { } #[test] - #[ignore] fn test_protocol_features() { - let (mut master, mut peer) = create_pair(UNIX_SOCKET_MASTER4); + let path = temp_path(); + let (mut master, mut peer) = create_pair(&path); master.set_owner().unwrap(); let (hdr, rfds) = peer.recv_header().unwrap(); @@ -773,12 +790,209 @@ mod tests { } #[test] - fn test_set_mem_table() { - // TODO + fn test_master_set_config_negative() { + let path = temp_path(); + let (mut master, _peer) = create_pair(&path); + let buf = vec![0x0; MAX_MSG_SIZE + 1]; + + master + .set_config(0x100, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .unwrap_err(); + + { + let mut node = master.node(); + node.virtio_features = 0xffff_ffff; + node.acked_virtio_features = 0xffff_ffff; + node.protocol_features = 0xffff_ffff; + node.acked_protocol_features = 0xffff_ffff; + } + + master + .set_config(0x100, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .unwrap(); + master + .set_config(0x0, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .unwrap_err(); + master + .set_config(0x1000, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .unwrap_err(); + master + .set_config( + 0x100, + unsafe { VhostUserConfigFlags::from_bits_unchecked(0xffff_ffff) }, + &buf[0..4], + ) + .unwrap_err(); + master + .set_config(0x100, VhostUserConfigFlags::WRITABLE, &buf) + .unwrap_err(); + master + .set_config(0x100, VhostUserConfigFlags::WRITABLE, &[]) + .unwrap_err(); + } + + fn create_pair2() -> (Master, Endpoint<MasterReq>) { + let path = temp_path(); + let (master, peer) = create_pair(&path); + + { + let mut node = master.node(); + node.virtio_features = 0xffff_ffff; + node.acked_virtio_features = 0xffff_ffff; + node.protocol_features = 0xffff_ffff; + node.acked_protocol_features = 0xffff_ffff; + } + + (master, peer) + } + + #[test] + fn test_master_get_config_negative0() { + let (mut master, mut peer) = create_pair2(); + let buf = vec![0x0; MAX_MSG_SIZE + 1]; + + let mut hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16); + let msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty()); + peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None) + .unwrap(); + assert!(master + .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .is_ok()); + + hdr.set_code(MasterReq::GET_FEATURES); + peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None) + .unwrap(); + assert!(master + .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .is_err()); + hdr.set_code(MasterReq::GET_CONFIG); + } + + #[test] + fn test_master_get_config_negative1() { + let (mut master, mut peer) = create_pair2(); + let buf = vec![0x0; MAX_MSG_SIZE + 1]; + + let mut hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16); + let msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty()); + peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None) + .unwrap(); + assert!(master + .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .is_ok()); + + hdr.set_reply(false); + peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None) + .unwrap(); + assert!(master + .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .is_err()); } #[test] - fn test_get_ring_num() { - // TODO + fn test_master_get_config_negative2() { + let (mut master, mut peer) = create_pair2(); + let buf = vec![0x0; MAX_MSG_SIZE + 1]; + + let hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16); + let msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty()); + peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None) + .unwrap(); + assert!(master + .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .is_ok()); + } + + #[test] + fn test_master_get_config_negative3() { + let (mut master, mut peer) = create_pair2(); + let buf = vec![0x0; MAX_MSG_SIZE + 1]; + + let hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16); + let mut msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty()); + peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None) + .unwrap(); + assert!(master + .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .is_ok()); + + msg.offset = 0; + peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None) + .unwrap(); + assert!(master + .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .is_err()); + } + + #[test] + fn test_master_get_config_negative4() { + let (mut master, mut peer) = create_pair2(); + let buf = vec![0x0; MAX_MSG_SIZE + 1]; + + let hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16); + let mut msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty()); + peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None) + .unwrap(); + assert!(master + .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .is_ok()); + + msg.offset = 0x101; + peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None) + .unwrap(); + assert!(master + .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .is_err()); + } + + #[test] + fn test_master_get_config_negative5() { + let (mut master, mut peer) = create_pair2(); + let buf = vec![0x0; MAX_MSG_SIZE + 1]; + + let hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16); + let mut msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty()); + peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None) + .unwrap(); + assert!(master + .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .is_ok()); + + msg.offset = (MAX_MSG_SIZE + 1) as u32; + peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None) + .unwrap(); + assert!(master + .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .is_err()); + } + + #[test] + fn test_master_get_config_negative6() { + let (mut master, mut peer) = create_pair2(); + let buf = vec![0x0; MAX_MSG_SIZE + 1]; + + let hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16); + let mut msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty()); + peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None) + .unwrap(); + assert!(master + .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .is_ok()); + + msg.size = 6; + peer.send_message_with_payload(&hdr, &msg, &buf[0..6], None) + .unwrap(); + assert!(master + .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .is_err()); + } + + #[test] + fn test_maset_set_mem_table_failure() { + let (master, _peer) = create_pair2(); + + master.set_mem_table(&[]).unwrap_err(); + let tables = vec![VhostUserMemoryRegionInfo::default(); MAX_ATTACHED_FD_ENTRIES + 1]; + master.set_mem_table(&tables).unwrap_err(); } } diff --git a/src/vhost_user/master_req_handler.rs b/src/vhost_user/master_req_handler.rs index aadfeee..8cba188 100644 --- a/src/vhost_user/master_req_handler.rs +++ b/src/vhost_user/master_req_handler.rs @@ -1,9 +1,6 @@ -// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved. +// Copyright (C) 2019-2021 Alibaba Cloud. All rights reserved. // SPDX-License-Identifier: Apache-2.0 -//! Traits and Structs to handle vhost-user requests from the slave to the master. - -use libc; use std::mem; use std::os::unix::io::{AsRawFd, RawFd}; use std::os::unix::net::UnixStream; @@ -13,83 +10,189 @@ use super::connection::Endpoint; use super::message::*; use super::{Error, HandlerResult, Result}; -/// Trait to handle vhost-user requests from the slave to the master. +/// Define services provided by masters for the slave communication channel. +/// +/// The vhost-user specification defines a slave communication channel, by which slaves could +/// request services from masters. The [VhostUserMasterReqHandler] trait defines services provided +/// by masters, and it's used both on the master side and slave side. +/// - on the slave side, a stub forwarder implementing [VhostUserMasterReqHandler] will proxy +/// service requests to masters. The [SlaveFsCacheReq] is an example stub forwarder. +/// - on the master side, the [MasterReqHandler] will forward service requests to a handler +/// implementing [VhostUserMasterReqHandler]. +/// +/// The [VhostUserMasterReqHandler] trait is design with interior mutability to improve performance +/// for multi-threading. +/// +/// [VhostUserMasterReqHandler]: trait.VhostUserMasterReqHandler.html +/// [MasterReqHandler]: struct.MasterReqHandler.html +/// [SlaveFsCacheReq]: struct.SlaveFsCacheReq.html pub trait VhostUserMasterReqHandler { + /// Handle device configuration change notifications. + fn handle_config_change(&self) -> HandlerResult<u64> { + Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) + } + + /// Handle virtio-fs map file requests. + fn fs_slave_map(&self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> { + // Safe because we have just received the rawfd from kernel. + unsafe { libc::close(fd) }; + Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) + } + + /// Handle virtio-fs unmap file requests. + fn fs_slave_unmap(&self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> { + Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) + } + + /// Handle virtio-fs sync file requests. + fn fs_slave_sync(&self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> { + Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) + } + + /// Handle virtio-fs file IO requests. + fn fs_slave_io(&self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> { + // Safe because we have just received the rawfd from kernel. + unsafe { libc::close(fd) }; + Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) + } + // fn handle_iotlb_msg(&mut self, iotlb: VhostUserIotlb); // fn handle_vring_host_notifier(&mut self, area: VhostUserVringArea, fd: RawFd); +} - /// Handle device configuration change notifications from the slave. +/// A helper trait mirroring [VhostUserMasterReqHandler] but without interior mutability. +/// +/// [VhostUserMasterReqHandler]: trait.VhostUserMasterReqHandler.html +pub trait VhostUserMasterReqHandlerMut { + /// Handle device configuration change notifications. fn handle_config_change(&mut self) -> HandlerResult<u64> { Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) } - /// Handle virtio-fs map file requests from the slave. + /// Handle virtio-fs map file requests. fn fs_slave_map(&mut self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> { // Safe because we have just received the rawfd from kernel. unsafe { libc::close(fd) }; Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) } - /// Handle virtio-fs unmap file requests from the slave. + /// Handle virtio-fs unmap file requests. fn fs_slave_unmap(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> { Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) } - /// Handle virtio-fs sync file requests from the slave. + /// Handle virtio-fs sync file requests. fn fs_slave_sync(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> { Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) } - /// Handle virtio-fs file IO requests from the slave. + /// Handle virtio-fs file IO requests. fn fs_slave_io(&mut self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> { // Safe because we have just received the rawfd from kernel. unsafe { libc::close(fd) }; Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) } + + // fn handle_iotlb_msg(&mut self, iotlb: VhostUserIotlb); + // fn handle_vring_host_notifier(&mut self, area: VhostUserVringArea, fd: RawFd); +} + +impl<S: VhostUserMasterReqHandlerMut> VhostUserMasterReqHandler for Mutex<S> { + fn handle_config_change(&self) -> HandlerResult<u64> { + self.lock().unwrap().handle_config_change() + } + + fn fs_slave_map(&self, fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> { + self.lock().unwrap().fs_slave_map(fs, fd) + } + + fn fs_slave_unmap(&self, fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> { + self.lock().unwrap().fs_slave_unmap(fs) + } + + fn fs_slave_sync(&self, fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> { + self.lock().unwrap().fs_slave_sync(fs) + } + + fn fs_slave_io(&self, fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> { + self.lock().unwrap().fs_slave_io(fs, fd) + } } -/// A vhost-user master request endpoint which relays all received requests from the slave to the -/// provided request handler. +/// Server to handle service requests from slaves from the slave communication channel. +/// +/// The [MasterReqHandler] acts as a server on the master side, to handle service requests from +/// slaves on the slave communication channel. It's actually a proxy invoking the registered +/// handler implementing [VhostUserMasterReqHandler] to do the real work. +/// +/// [MasterReqHandler]: struct.MasterReqHandler.html +/// [VhostUserMasterReqHandler]: trait.VhostUserMasterReqHandler.html pub struct MasterReqHandler<S: VhostUserMasterReqHandler> { // underlying Unix domain socket for communication sub_sock: Endpoint<SlaveReq>, tx_sock: UnixStream, + // Protocol feature VHOST_USER_PROTOCOL_F_REPLY_ACK has been negotiated. + reply_ack_negotiated: bool, // the VirtIO backend device object - backend: Arc<Mutex<S>>, + backend: Arc<S>, // whether the endpoint has encountered any failure error: Option<i32>, } impl<S: VhostUserMasterReqHandler> MasterReqHandler<S> { - /// Create a vhost-user slave request handler. - /// This opens a pair of connected anonymous sockets. - /// Returns Self and the socket that must be sent to the slave via SET_SLAVE_REQ_FD. - pub fn new(backend: Arc<Mutex<S>>) -> Result<Self> { + /// Create a server to handle service requests from slaves on the slave communication channel. + /// + /// This opens a pair of connected anonymous sockets to form the slave communication channel. + /// The socket fd returned by [Self::get_tx_raw_fd()] should be sent to the slave by + /// [VhostUserMaster::set_slave_request_fd()]. + /// + /// [Self::get_tx_raw_fd()]: struct.MasterReqHandler.html#method.get_tx_raw_fd + /// [VhostUserMaster::set_slave_request_fd()]: trait.VhostUserMaster.html#tymethod.set_slave_request_fd + pub fn new(backend: Arc<S>) -> Result<Self> { let (tx, rx) = UnixStream::pair().map_err(Error::SocketError)?; Ok(MasterReqHandler { sub_sock: Endpoint::<SlaveReq>::from_stream(rx), tx_sock: tx, + reply_ack_negotiated: false, backend, error: None, }) } - /// Get the raw fd to send to the slave as slave communication channel. + /// Get the socket fd for the slave to communication with the master. + /// + /// The returned fd should be sent to the slave by [VhostUserMaster::set_slave_request_fd()]. + /// + /// [VhostUserMaster::set_slave_request_fd()]: trait.VhostUserMaster.html#tymethod.set_slave_request_fd pub fn get_tx_raw_fd(&self) -> RawFd { self.tx_sock.as_raw_fd() } - /// Mark endpoint as failed or normal state. + /// Set the negotiation state of the `VHOST_USER_PROTOCOL_F_REPLY_ACK` protocol feature. + /// + /// When the `VHOST_USER_PROTOCOL_F_REPLY_ACK` protocol feature has been negotiated, + /// the "REPLY_ACK" flag will be set in the message header for every slave to master request + /// message. + pub fn set_reply_ack_flag(&mut self, enable: bool) { + self.reply_ack_negotiated = enable; + } + + /// Mark endpoint as failed or in normal state. pub fn set_failed(&mut self, error: i32) { - self.error = Some(error); + if error == 0 { + self.error = None; + } else { + self.error = Some(error); + } } - /// Receive and handle one incoming request message from the slave. + /// Main entrance to server slave request from the slave communication channel. + /// /// The caller needs to: - /// . serialize calls to this function - /// . decide what to do when errer happens - /// . optional recover from failure + /// - serialize calls to this function + /// - decide what to do when errer happens + /// - optional recover from failure pub fn handle_request(&mut self) -> Result<u64> { // Return error if the endpoint is already in failed state. self.check_state()?; @@ -108,6 +211,9 @@ impl<S: VhostUserMasterReqHandler> MasterReqHandler<S> { let (size, buf) = match hdr.get_size() { 0 => (0, vec![0u8; 0]), len => { + if len as usize > MAX_MSG_SIZE { + return Err(Error::InvalidMessage); + } let (size2, rbuf) = self.sub_sock.recv_data(len as usize)?; if size2 != len as usize { return Err(Error::InvalidMessage); @@ -120,41 +226,33 @@ impl<S: VhostUserMasterReqHandler> MasterReqHandler<S> { SlaveReq::CONFIG_CHANGE_MSG => { self.check_msg_size(&hdr, size, 0)?; self.backend - .lock() - .unwrap() .handle_config_change() .map_err(Error::ReqHandlerError) } SlaveReq::FS_MAP => { let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?; + // check_attached_rfds() has validated rfds self.backend - .lock() - .unwrap() - .fs_slave_map(msg, rfds.unwrap()[0]) + .fs_slave_map(&msg, rfds.unwrap()[0]) .map_err(Error::ReqHandlerError) } SlaveReq::FS_UNMAP => { let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?; self.backend - .lock() - .unwrap() - .fs_slave_unmap(msg) + .fs_slave_unmap(&msg) .map_err(Error::ReqHandlerError) } SlaveReq::FS_SYNC => { let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?; self.backend - .lock() - .unwrap() - .fs_slave_sync(msg) + .fs_slave_sync(&msg) .map_err(Error::ReqHandlerError) } SlaveReq::FS_IO => { let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?; + // check_attached_rfds() has validated rfds self.backend - .lock() - .unwrap() - .fs_slave_io(msg, rfds.unwrap()[0]) + .fs_slave_io(&msg, rfds.unwrap()[0]) .map_err(Error::ReqHandlerError) } _ => Err(Error::InvalidMessage), @@ -211,7 +309,7 @@ impl<S: VhostUserMasterReqHandler> MasterReqHandler<S> { _ => { if rfds.is_some() { Endpoint::<SlaveReq>::close_rfds(rfds); - return Err(Error::InvalidMessage); + Err(Error::InvalidMessage) } else { Ok(rfds) } @@ -219,14 +317,14 @@ impl<S: VhostUserMasterReqHandler> MasterReqHandler<S> { } } - fn extract_msg_body<'a, T: Sized + VhostUserMsgValidator>( + fn extract_msg_body<T: Sized + VhostUserMsgValidator>( &self, hdr: &VhostUserMsgHeader<SlaveReq>, size: usize, - buf: &'a [u8], - ) -> Result<&'a T> { + buf: &[u8], + ) -> Result<T> { self.check_msg_size(hdr, size, mem::size_of::<T>())?; - let msg = unsafe { &*(buf.as_ptr() as *const T) }; + let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const T) }; if !msg.is_valid() { return Err(Error::InvalidMessage); } @@ -253,7 +351,7 @@ impl<S: VhostUserMasterReqHandler> MasterReqHandler<S> { req: &VhostUserMsgHeader<SlaveReq>, res: &Result<u64>, ) -> Result<()> { - if req.is_need_reply() { + if self.reply_ack_negotiated && req.is_need_reply() { let hdr = self.new_reply_header::<VhostUserU64>(req)?; let def_err = libc::EINVAL; let val = match res { @@ -278,3 +376,102 @@ impl<S: VhostUserMasterReqHandler> AsRawFd for MasterReqHandler<S> { self.sub_sock.as_raw_fd() } } + +#[cfg(test)] +mod tests { + use super::*; + + #[cfg(feature = "vhost-user-slave")] + use crate::vhost_user::SlaveFsCacheReq; + #[cfg(feature = "vhost-user-slave")] + use std::os::unix::io::FromRawFd; + + struct MockMasterReqHandler {} + + impl VhostUserMasterReqHandlerMut for MockMasterReqHandler { + /// Handle virtio-fs map file requests from the slave. + fn fs_slave_map(&mut self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> { + // Safe because we have just received the rawfd from kernel. + unsafe { libc::close(fd) }; + Ok(0) + } + + /// Handle virtio-fs unmap file requests from the slave. + fn fs_slave_unmap(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> { + Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) + } + } + + #[test] + fn test_new_master_req_handler() { + let backend = Arc::new(Mutex::new(MockMasterReqHandler {})); + let mut handler = MasterReqHandler::new(backend).unwrap(); + + assert!(handler.get_tx_raw_fd() >= 0); + assert!(handler.as_raw_fd() >= 0); + handler.check_state().unwrap(); + + assert_eq!(handler.error, None); + handler.set_failed(libc::EAGAIN); + assert_eq!(handler.error, Some(libc::EAGAIN)); + handler.check_state().unwrap_err(); + } + + #[cfg(feature = "vhost-user-slave")] + #[test] + fn test_master_slave_req_handler() { + let backend = Arc::new(Mutex::new(MockMasterReqHandler {})); + let mut handler = MasterReqHandler::new(backend).unwrap(); + + let fd = unsafe { libc::dup(handler.get_tx_raw_fd()) }; + if fd < 0 { + panic!("failed to duplicated tx fd!"); + } + let stream = unsafe { UnixStream::from_raw_fd(fd) }; + let fs_cache = SlaveFsCacheReq::from_stream(stream); + + std::thread::spawn(move || { + let res = handler.handle_request().unwrap(); + assert_eq!(res, 0); + handler.handle_request().unwrap_err(); + }); + + fs_cache + .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd) + .unwrap(); + // When REPLY_ACK has not been negotiated, the master has no way to detect failure from + // slave side. + fs_cache + .fs_slave_unmap(&VhostUserFSSlaveMsg::default()) + .unwrap(); + } + + #[cfg(feature = "vhost-user-slave")] + #[test] + fn test_master_slave_req_handler_with_ack() { + let backend = Arc::new(Mutex::new(MockMasterReqHandler {})); + let mut handler = MasterReqHandler::new(backend).unwrap(); + handler.set_reply_ack_flag(true); + + let fd = unsafe { libc::dup(handler.get_tx_raw_fd()) }; + if fd < 0 { + panic!("failed to duplicated tx fd!"); + } + let stream = unsafe { UnixStream::from_raw_fd(fd) }; + let fs_cache = SlaveFsCacheReq::from_stream(stream); + + std::thread::spawn(move || { + let res = handler.handle_request().unwrap(); + assert_eq!(res, 0); + handler.handle_request().unwrap_err(); + }); + + fs_cache.set_reply_ack_flag(true); + fs_cache + .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd) + .unwrap(); + fs_cache + .fs_slave_unmap(&VhostUserFSSlaveMsg::default()) + .unwrap_err(); + } +} diff --git a/src/vhost_user/message.rs b/src/vhost_user/message.rs index 4109b61..8600410 100644 --- a/src/vhost_user/message.rs +++ b/src/vhost_user/message.rs @@ -562,9 +562,9 @@ bitflags! { /// Flags for the device configuration message. pub struct VhostUserConfigFlags: u32 { /// Vhost master messages used for writeable fields. - const WRITABLE = 0x0; + const WRITABLE = 0x1; /// Vhost master messages used for live migration. - const LIVE_MIGRATION = 0x1; + const LIVE_MIGRATION = 0x2; } } @@ -596,9 +596,11 @@ impl VhostUserMsgValidator for VhostUserConfig { fn is_valid(&self) -> bool { if (self.flags & !VhostUserConfigFlags::all().bits()) != 0 { return false; + } else if self.offset < 0x100 { + return false; } else if self.size == 0 || self.size > VHOST_USER_CONFIG_SIZE - || self.size + self.offset >= VHOST_USER_CONFIG_SIZE + || self.size + self.offset > VHOST_USER_CONFIG_SIZE { return false; } @@ -656,9 +658,9 @@ pub const VHOST_USER_FS_SLAVE_ENTRIES: usize = 8; #[repr(packed)] #[derive(Default)] pub struct VhostUserFSSlaveMsg { - /// TODO: + /// File offset. pub fd_offset: [u64; VHOST_USER_FS_SLAVE_ENTRIES], - /// TODO: + /// Offset into the DAX window. pub cache_offset: [u64; VHOST_USER_FS_SLAVE_ENTRIES], /// Size of region to map. pub len: [u64; VHOST_USER_FS_SLAVE_ENTRIES], @@ -686,13 +688,31 @@ mod tests { use std::mem; #[test] - fn check_request_code() { + fn check_master_request_code() { let code = MasterReq::NOOP; assert!(!code.is_valid()); let code = MasterReq::MAX_CMD; assert!(!code.is_valid()); + assert!(code > MasterReq::NOOP); let code = MasterReq::GET_FEATURES; assert!(code.is_valid()); + assert_eq!(code, code.clone()); + let code: MasterReq = unsafe { std::mem::transmute::<u32, MasterReq>(10000u32) }; + assert!(!code.is_valid()); + } + + #[test] + fn check_slave_request_code() { + let code = SlaveReq::NOOP; + assert!(!code.is_valid()); + let code = SlaveReq::MAX_CMD; + assert!(!code.is_valid()); + assert!(code > SlaveReq::NOOP); + let code = SlaveReq::CONFIG_CHANGE_MSG; + assert!(code.is_valid()); + assert_eq!(code, code.clone()); + let code: SlaveReq = unsafe { std::mem::transmute::<u32, SlaveReq>(10000u32) }; + assert!(!code.is_valid()); } #[test] @@ -741,6 +761,20 @@ mod tests { assert!(!hdr.is_valid()); hdr.set_version(0x1); assert!(hdr.is_valid()); + + assert_eq!(hdr, hdr.clone()); + } + + #[test] + fn test_vhost_user_message_u64() { + let val = VhostUserU64::default(); + let val1 = VhostUserU64::new(0); + + let a = val.value; + let b = val1.value; + assert_eq!(a, b); + let a = VhostUserU64::new(1).value; + assert_eq!(a, 1); } #[test] @@ -775,6 +809,104 @@ mod tests { msg.guest_phys_addr = 0xFFFFFFFFFFFF0000; msg.memory_size = 0; assert!(!msg.is_valid()); + let a = msg.guest_phys_addr; + let b = msg.guest_phys_addr; + assert_eq!(a, b); + + let msg = VhostUserMemoryRegion::default(); + let a = msg.guest_phys_addr; + assert_eq!(a, 0); + let a = msg.memory_size; + assert_eq!(a, 0); + let a = msg.user_addr; + assert_eq!(a, 0); + let a = msg.mmap_offset; + assert_eq!(a, 0); + } + + #[test] + fn test_vhost_user_state() { + let state = VhostUserVringState::new(5, 8); + + let a = state.index; + assert_eq!(a, 5); + let a = state.num; + assert_eq!(a, 8); + assert_eq!(state.is_valid(), true); + + let state = VhostUserVringState::default(); + let a = state.index; + assert_eq!(a, 0); + let a = state.num; + assert_eq!(a, 0); + assert_eq!(state.is_valid(), true); + } + + #[test] + fn test_vhost_user_addr() { + let mut addr = VhostUserVringAddr::new( + 2, + VhostUserVringAddrFlags::VHOST_VRING_F_LOG, + 0x1000, + 0x2000, + 0x3000, + 0x4000, + ); + + let a = addr.index; + assert_eq!(a, 2); + let a = addr.flags; + assert_eq!(a, VhostUserVringAddrFlags::VHOST_VRING_F_LOG.bits()); + let a = addr.descriptor; + assert_eq!(a, 0x1000); + let a = addr.used; + assert_eq!(a, 0x2000); + let a = addr.available; + assert_eq!(a, 0x3000); + let a = addr.log; + assert_eq!(a, 0x4000); + assert_eq!(addr.is_valid(), true); + + addr.descriptor = 0x1001; + assert_eq!(addr.is_valid(), false); + addr.descriptor = 0x1000; + + addr.available = 0x3001; + assert_eq!(addr.is_valid(), false); + addr.available = 0x3000; + + addr.used = 0x2001; + assert_eq!(addr.is_valid(), false); + addr.used = 0x2000; + assert_eq!(addr.is_valid(), true); + } + + #[test] + fn test_vhost_user_state_from_config() { + let config = VringConfigData { + queue_max_size: 256, + queue_size: 128, + flags: VhostUserVringAddrFlags::VHOST_VRING_F_LOG.bits, + desc_table_addr: 0x1000, + used_ring_addr: 0x2000, + avail_ring_addr: 0x3000, + log_addr: Some(0x4000), + }; + let addr = VhostUserVringAddr::from_config_data(2, &config); + + let a = addr.index; + assert_eq!(a, 2); + let a = addr.flags; + assert_eq!(a, VhostUserVringAddrFlags::VHOST_VRING_F_LOG.bits()); + let a = addr.descriptor; + assert_eq!(a, 0x1000); + let a = addr.used; + assert_eq!(a, 0x2000); + let a = addr.available; + assert_eq!(a, 0x3000); + let a = addr.log; + assert_eq!(a, 0x4000); + assert_eq!(addr.is_valid(), true); } #[test] @@ -801,7 +933,6 @@ mod tests { } #[test] - #[ignore] fn check_user_config_msg() { let mut msg = VhostUserConfig::new( VHOST_USER_CONFIG_OFFSET, @@ -828,4 +959,21 @@ mod tests { msg.flags |= 0x4; assert!(!msg.is_valid()); } + + #[test] + fn test_vhost_user_fs_slave() { + let mut fs_slave = VhostUserFSSlaveMsg::default(); + + assert_eq!(fs_slave.is_valid(), true); + + fs_slave.fd_offset[0] = 0xffff_ffff_ffff_ffff; + fs_slave.len[0] = 0x1; + assert_eq!(fs_slave.is_valid(), false); + + assert_ne!( + VhostUserFSSlaveMsgFlags::MAP_R, + VhostUserFSSlaveMsgFlags::MAP_W + ); + assert_eq!(VhostUserFSSlaveMsgFlags::EMPTY.bits(), 0); + } } diff --git a/src/vhost_user/mod.rs b/src/vhost_user/mod.rs index 48a93ff..6a5b6a1 100644 --- a/src/vhost_user/mod.rs +++ b/src/vhost_user/mod.rs @@ -18,20 +18,23 @@ //! Most messages that can be sent via the Unix domain socket implementing vhost-user have an //! equivalent ioctl to the kernel implementation. -use libc; use std::io::Error as IOError; -mod connection; pub mod message; + +mod connection; pub use self::connection::Listener; + #[cfg(feature = "vhost-user-master")] mod master; #[cfg(feature = "vhost-user-master")] pub use self::master::{Master, VhostUserMaster}; -#[cfg(any(feature = "vhost-user-master", feature = "vhost-user-slave"))] +#[cfg(feature = "vhost-user")] mod master_req_handler; -#[cfg(any(feature = "vhost-user-master", feature = "vhost-user-slave"))] -pub use self::master_req_handler::{MasterReqHandler, VhostUserMasterReqHandler}; +#[cfg(feature = "vhost-user")] +pub use self::master_req_handler::{ + MasterReqHandler, VhostUserMasterReqHandler, VhostUserMasterReqHandlerMut, +}; #[cfg(feature = "vhost-user-slave")] mod slave; @@ -40,14 +43,14 @@ pub use self::slave::SlaveListener; #[cfg(feature = "vhost-user-slave")] mod slave_req_handler; #[cfg(feature = "vhost-user-slave")] -pub use self::slave_req_handler::{SlaveReqHandler, VhostUserSlaveReqHandler}; +pub use self::slave_req_handler::{ + SlaveReqHandler, VhostUserSlaveReqHandler, VhostUserSlaveReqHandlerMut, +}; #[cfg(feature = "vhost-user-slave")] mod slave_fs_cache; #[cfg(feature = "vhost-user-slave")] pub use self::slave_fs_cache::SlaveFsCacheReq; -pub mod sock_ctrl_msg; - /// Errors for vhost-user operations #[derive(Debug)] pub enum Error { @@ -102,6 +105,8 @@ impl std::fmt::Display for Error { } } +impl std::error::Error for Error {} + impl Error { /// Determine whether to rebuild the underline communication channel. pub fn should_reconnect(&self) -> bool { @@ -170,21 +175,32 @@ pub type Result<T> = std::result::Result<T, Error>; /// Result of request handler. pub type HandlerResult<T> = std::result::Result<T, IOError>; -#[cfg(all(test, feature = "vhost-user-master", feature = "vhost-user-slave"))] +#[cfg(all(test, feature = "vhost-user-slave"))] mod dummy_slave; #[cfg(all(test, feature = "vhost-user-master", feature = "vhost-user-slave"))] mod tests { + use std::os::unix::io::AsRawFd; + use std::sync::{Arc, Barrier, Mutex}; + use std::thread; + use vmm_sys_util::rand::rand_alphanumerics; + use super::dummy_slave::{DummySlaveReqHandler, VIRTIO_FEATURES}; use super::message::*; use super::*; use crate::backend::VhostBackend; - use std::sync::{Arc, Barrier, Mutex}; - use std::thread; + use crate::{VhostUserMemoryRegionInfo, VringConfigData}; + + fn temp_path() -> String { + format!( + "/tmp/vhost_test_{}", + rand_alphanumerics(8).to_str().unwrap() + ) + } fn create_slave<S: VhostUserSlaveReqHandler>( path: &str, - backend: Arc<Mutex<S>>, + backend: Arc<S>, ) -> (Master, SlaveReqHandler<S>) { let listener = Listener::new(path, true).unwrap(); let mut slave_listener = SlaveListener::new(listener, backend).unwrap(); @@ -194,7 +210,7 @@ mod tests { #[test] fn create_dummy_slave() { - let mut slave = DummySlaveReqHandler::new(); + let slave = Arc::new(Mutex::new(DummySlaveReqHandler::new())); slave.set_owner().unwrap(); assert!(slave.set_owner().is_err()); @@ -203,8 +219,8 @@ mod tests { #[test] fn test_set_owner() { let slave_be = Arc::new(Mutex::new(DummySlaveReqHandler::new())); - let (mut master, mut slave) = - create_slave("/tmp/vhost_user_lib_unit_test_owner", slave_be.clone()); + let path = temp_path(); + let (master, mut slave) = create_slave(&path, slave_be.clone()); assert_eq!(slave_be.lock().unwrap().owned, false); master.set_owner().unwrap(); @@ -219,14 +235,60 @@ mod tests { fn test_set_features() { let mbar = Arc::new(Barrier::new(2)); let sbar = mbar.clone(); + let path = temp_path(); + let slave_be = Arc::new(Mutex::new(DummySlaveReqHandler::new())); + let (mut master, mut slave) = create_slave(&path, slave_be.clone()); + + thread::spawn(move || { + slave.handle_request().unwrap(); + assert_eq!(slave_be.lock().unwrap().owned, true); + + slave.handle_request().unwrap(); + slave.handle_request().unwrap(); + assert_eq!( + slave_be.lock().unwrap().acked_features, + VIRTIO_FEATURES & !0x1 + ); + + slave.handle_request().unwrap(); + slave.handle_request().unwrap(); + assert_eq!( + slave_be.lock().unwrap().acked_protocol_features, + VhostUserProtocolFeatures::all().bits() + ); + + sbar.wait(); + }); + + master.set_owner().unwrap(); + + // set virtio features + let features = master.get_features().unwrap(); + assert_eq!(features, VIRTIO_FEATURES); + master.set_features(VIRTIO_FEATURES & !0x1).unwrap(); + + // set vhost protocol features + let features = master.get_protocol_features().unwrap(); + assert_eq!(features.bits(), VhostUserProtocolFeatures::all().bits()); + master.set_protocol_features(features).unwrap(); + + mbar.wait(); + } + + #[test] + fn test_master_slave_process() { + let mbar = Arc::new(Barrier::new(2)); + let sbar = mbar.clone(); + let path = temp_path(); let slave_be = Arc::new(Mutex::new(DummySlaveReqHandler::new())); - let (mut master, mut slave) = - create_slave("/tmp/vhost_user_lib_unit_test_feature", slave_be.clone()); + let (mut master, mut slave) = create_slave(&path, slave_be.clone()); thread::spawn(move || { + // set_own() slave.handle_request().unwrap(); assert_eq!(slave_be.lock().unwrap().owned, true); + // get/set_features() slave.handle_request().unwrap(); slave.handle_request().unwrap(); assert_eq!( @@ -241,6 +303,34 @@ mod tests { VhostUserProtocolFeatures::all().bits() ); + // get_queue_num() + slave.handle_request().unwrap(); + + // set_mem_table() + slave.handle_request().unwrap(); + + // get/set_config() + slave.handle_request().unwrap(); + slave.handle_request().unwrap(); + + // set_slave_request_fd + slave.handle_request().unwrap(); + + // set_vring_enable + slave.handle_request().unwrap(); + + // set_log_base,set_log_fd() + slave.handle_request().unwrap_err(); + slave.handle_request().unwrap_err(); + + // set_vring_xxx + slave.handle_request().unwrap(); + slave.handle_request().unwrap(); + slave.handle_request().unwrap(); + slave.handle_request().unwrap(); + slave.handle_request().unwrap(); + slave.handle_request().unwrap(); + sbar.wait(); }); @@ -256,6 +346,82 @@ mod tests { assert_eq!(features.bits(), VhostUserProtocolFeatures::all().bits()); master.set_protocol_features(features).unwrap(); + let num = master.get_queue_num().unwrap(); + assert_eq!(num, 2); + + let eventfd = vmm_sys_util::eventfd::EventFd::new(0).unwrap(); + let mem = [VhostUserMemoryRegionInfo { + guest_phys_addr: 0, + memory_size: 0x10_0000, + userspace_addr: 0, + mmap_offset: 0, + mmap_handle: eventfd.as_raw_fd(), + }]; + master.set_mem_table(&mem).unwrap(); + + master + .set_config(0x100, VhostUserConfigFlags::WRITABLE, &[0xa5u8]) + .unwrap(); + let buf = [0x0u8; 4]; + let (reply_body, reply_payload) = master + .get_config(0x100, 4, VhostUserConfigFlags::empty(), &buf) + .unwrap(); + let offset = reply_body.offset; + assert_eq!(offset, 0x100); + assert_eq!(reply_payload[0], 0xa5); + + master.set_slave_request_fd(eventfd.as_raw_fd()).unwrap(); + master.set_vring_enable(0, true).unwrap(); + + // unimplemented yet + master.set_log_base(0, Some(eventfd.as_raw_fd())).unwrap(); + master.set_log_fd(eventfd.as_raw_fd()).unwrap(); + + master.set_vring_num(0, 256).unwrap(); + master.set_vring_base(0, 0).unwrap(); + let config = VringConfigData { + queue_max_size: 256, + queue_size: 128, + flags: VhostUserVringAddrFlags::VHOST_VRING_F_LOG.bits(), + desc_table_addr: 0x1000, + used_ring_addr: 0x2000, + avail_ring_addr: 0x3000, + log_addr: Some(0x4000), + }; + master.set_vring_addr(0, &config).unwrap(); + master.set_vring_call(0, &eventfd).unwrap(); + master.set_vring_kick(0, &eventfd).unwrap(); + master.set_vring_err(0, &eventfd).unwrap(); + mbar.wait(); } + + #[test] + fn test_error_display() { + assert_eq!(format!("{}", Error::InvalidParam), "invalid parameters"); + assert_eq!(format!("{}", Error::InvalidOperation), "invalid operation"); + } + + #[test] + fn test_should_reconnect() { + assert_eq!(Error::PartialMessage.should_reconnect(), true); + assert_eq!(Error::SlaveInternalError.should_reconnect(), true); + assert_eq!(Error::MasterInternalError.should_reconnect(), true); + assert_eq!(Error::InvalidParam.should_reconnect(), false); + assert_eq!(Error::InvalidOperation.should_reconnect(), false); + assert_eq!(Error::InvalidMessage.should_reconnect(), false); + assert_eq!(Error::IncorrectFds.should_reconnect(), false); + assert_eq!(Error::OversizedMsg.should_reconnect(), false); + assert_eq!(Error::FeatureMismatch.should_reconnect(), false); + } + + #[test] + fn test_error_from_sys_util_error() { + let e: Error = vmm_sys_util::errno::Error::new(libc::EAGAIN).into(); + if let Error::SocketRetry(e1) = e { + assert_eq!(e1.raw_os_error().unwrap(), libc::EAGAIN); + } else { + panic!("invalid error code conversion!"); + } + } } diff --git a/src/vhost_user/slave.rs b/src/vhost_user/slave.rs index 5ac99af..fb65c41 100644 --- a/src/vhost_user/slave.rs +++ b/src/vhost_user/slave.rs @@ -3,7 +3,7 @@ //! Traits and Structs for vhost-user slave. -use std::sync::{Arc, Mutex}; +use std::sync::Arc; use super::connection::{Endpoint, Listener}; use super::message::*; @@ -12,14 +12,14 @@ use super::{Result, SlaveReqHandler, VhostUserSlaveReqHandler}; /// Vhost-user slave side connection listener. pub struct SlaveListener<S: VhostUserSlaveReqHandler> { listener: Listener, - backend: Option<Arc<Mutex<S>>>, + backend: Option<Arc<S>>, } /// Sets up a listener for incoming master connections, and handles construction /// of a Slave on success. impl<S: VhostUserSlaveReqHandler> SlaveListener<S> { /// Create a unix domain socket for incoming master connections. - pub fn new(listener: Listener, backend: Arc<Mutex<S>>) -> Result<Self> { + pub fn new(listener: Listener, backend: Arc<S>) -> Result<Self> { Ok(SlaveListener { listener, backend: Some(backend), @@ -44,3 +44,43 @@ impl<S: VhostUserSlaveReqHandler> SlaveListener<S> { self.listener.set_nonblocking(block) } } + +#[cfg(test)] +mod tests { + use std::sync::Mutex; + + use super::*; + use crate::vhost_user::dummy_slave::DummySlaveReqHandler; + + #[test] + fn test_slave_listener_set_nonblocking() { + let backend = Arc::new(Mutex::new(DummySlaveReqHandler::new())); + let listener = + Listener::new("/tmp/vhost_user_lib_unit_test_slave_nonblocking", true).unwrap(); + let slave_listener = SlaveListener::new(listener, backend).unwrap(); + + slave_listener.set_nonblocking(true).unwrap(); + slave_listener.set_nonblocking(false).unwrap(); + slave_listener.set_nonblocking(false).unwrap(); + slave_listener.set_nonblocking(true).unwrap(); + slave_listener.set_nonblocking(true).unwrap(); + } + + #[cfg(feature = "vhost-user-master")] + #[test] + fn test_slave_listener_accept() { + use super::super::Master; + + let path = "/tmp/vhost_user_lib_unit_test_slave_accept"; + let backend = Arc::new(Mutex::new(DummySlaveReqHandler::new())); + let listener = Listener::new(path, true).unwrap(); + let mut slave_listener = SlaveListener::new(listener, backend).unwrap(); + + slave_listener.set_nonblocking(true).unwrap(); + assert!(slave_listener.accept().unwrap().is_none()); + assert!(slave_listener.accept().unwrap().is_none()); + + let _master = Master::connect(path, 1).unwrap(); + let _slave = slave_listener.accept().unwrap().unwrap(); + } +} diff --git a/src/vhost_user/slave_fs_cache.rs b/src/vhost_user/slave_fs_cache.rs index 1804c7a..a9c4ed2 100644 --- a/src/vhost_user/slave_fs_cache.rs +++ b/src/vhost_user/slave_fs_cache.rs @@ -1,61 +1,59 @@ -// Copyright (C) 2020 Alibaba Cloud Computing. All rights reserved. +// Copyright (C) 2020 Alibaba Cloud. All rights reserved. // SPDX-License-Identifier: Apache-2.0 -use super::connection::Endpoint; -use super::message::*; -use super::{Error, HandlerResult, Result, VhostUserMasterReqHandler}; use std::io; use std::mem; use std::os::unix::io::RawFd; use std::os::unix::net::UnixStream; -use std::sync::{Arc, Mutex}; +use std::sync::{Arc, Mutex, MutexGuard}; + +use super::connection::Endpoint; +use super::message::*; +use super::{Error, HandlerResult, Result, VhostUserMasterReqHandler}; struct SlaveFsCacheReqInternal { sock: Endpoint<SlaveReq>, -} -/// A vhost-user slave endpoint which sends fs cache requests to the master -#[derive(Clone)] -pub struct SlaveFsCacheReq { - // underlying Unix domain socket for communication - node: Arc<Mutex<SlaveFsCacheReqInternal>>, + // Protocol feature VHOST_USER_PROTOCOL_F_REPLY_ACK has been negotiated. + reply_ack_negotiated: bool, // whether the endpoint has encountered any failure error: Option<i32>, } -impl SlaveFsCacheReq { - fn new(ep: Endpoint<SlaveReq>) -> Self { - SlaveFsCacheReq { - node: Arc::new(Mutex::new(SlaveFsCacheReqInternal { sock: ep })), - error: None, +impl SlaveFsCacheReqInternal { + fn check_state(&self) -> Result<u64> { + match self.error { + Some(e) => Err(Error::SocketBroken(std::io::Error::from_raw_os_error(e))), + None => Ok(0), } } - /// Create a new instance. - pub fn from_stream(sock: UnixStream) -> Self { - Self::new(Endpoint::<SlaveReq>::from_stream(sock)) - } - fn send_message( &mut self, - flags: SlaveReq, + request: SlaveReq, fs: &VhostUserFSSlaveMsg, fds: Option<&[RawFd]>, ) -> Result<u64> { self.check_state()?; let len = mem::size_of::<VhostUserFSSlaveMsg>(); - let mut hdr = VhostUserMsgHeader::new(flags, 0, len as u32); - hdr.set_need_reply(true); - self.node.lock().unwrap().sock.send_message(&hdr, fs, fds)?; + let mut hdr = VhostUserMsgHeader::new(request, 0, len as u32); + if self.reply_ack_negotiated { + hdr.set_need_reply(true); + } + self.sock.send_message(&hdr, fs, fds)?; self.wait_for_ack(&hdr) } fn wait_for_ack(&mut self, hdr: &VhostUserMsgHeader<SlaveReq>) -> Result<u64> { self.check_state()?; - let (reply, body, rfds) = self.node.lock().unwrap().sock.recv_body::<VhostUserU64>()?; + if !self.reply_ack_negotiated { + return Ok(0); + } + + let (reply, body, rfds) = self.sock.recv_body::<VhostUserU64>()?; if !reply.is_reply_for(&hdr) || rfds.is_some() || !body.is_valid() { Endpoint::<SlaveReq>::close_rfds(rfds); return Err(Error::InvalidMessage); @@ -63,32 +61,166 @@ impl SlaveFsCacheReq { if body.value != 0 { return Err(Error::MasterInternalError); } - Ok(0) + + Ok(body.value) } +} - fn check_state(&self) -> Result<u64> { - match self.error { - Some(e) => Err(Error::SocketBroken(std::io::Error::from_raw_os_error(e))), - None => Ok(0), +/// Request proxy to send vhost-user-fs slave requests to the master through the slave +/// communication channel. +/// +/// The [SlaveFsCacheReq] acts as a message proxy to forward vhost-user-fs slave requests to the +/// master through the vhost-user slave communication channel. The forwarded messages will be +/// handled by the [MasterReqHandler] server. +/// +/// [SlaveFsCacheReq]: struct.SlaveFsCacheReq.html +/// [MasterReqHandler]: struct.MasterReqHandler.html +#[derive(Clone)] +pub struct SlaveFsCacheReq { + // underlying Unix domain socket for communication + node: Arc<Mutex<SlaveFsCacheReqInternal>>, +} + +impl SlaveFsCacheReq { + fn new(ep: Endpoint<SlaveReq>) -> Self { + SlaveFsCacheReq { + node: Arc::new(Mutex::new(SlaveFsCacheReqInternal { + sock: ep, + reply_ack_negotiated: false, + error: None, + })), } } + fn node(&self) -> MutexGuard<SlaveFsCacheReqInternal> { + self.node.lock().unwrap() + } + + fn send_message( + &self, + request: SlaveReq, + fs: &VhostUserFSSlaveMsg, + fds: Option<&[RawFd]>, + ) -> io::Result<u64> { + self.node() + .send_message(request, fs, fds) + .map_err(|e| io::Error::new(io::ErrorKind::Other, format!("{}", e))) + } + + /// Create a new instance from a `UnixStream` object. + pub fn from_stream(sock: UnixStream) -> Self { + Self::new(Endpoint::<SlaveReq>::from_stream(sock)) + } + + /// Set the negotiation state of the `VHOST_USER_PROTOCOL_F_REPLY_ACK` protocol feature. + /// + /// When the `VHOST_USER_PROTOCOL_F_REPLY_ACK` protocol feature has been negotiated, + /// the "REPLY_ACK" flag will be set in the message header for every slave to master request + /// message. + pub fn set_reply_ack_flag(&self, enable: bool) { + self.node().reply_ack_negotiated = enable; + } + /// Mark endpoint as failed with specified error code. - pub fn set_failed(&mut self, error: i32) { - self.error = Some(error); + pub fn set_failed(&self, error: i32) { + self.node().error = Some(error); } } impl VhostUserMasterReqHandler for SlaveFsCacheReq { - /// Handle virtio-fs map file requests from the slave. - fn fs_slave_map(&mut self, fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> { + /// Forward vhost-user-fs map file requests to the slave. + fn fs_slave_map(&self, fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> { self.send_message(SlaveReq::FS_MAP, fs, Some(&[fd])) - .or_else(|e| Err(io::Error::new(io::ErrorKind::Other, format!("{}", e)))) } - /// Handle virtio-fs unmap file requests from the slave. - fn fs_slave_unmap(&mut self, fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> { + /// Forward vhost-user-fs unmap file requests to the master. + fn fs_slave_unmap(&self, fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> { self.send_message(SlaveReq::FS_UNMAP, fs, None) - .or_else(|e| Err(io::Error::new(io::ErrorKind::Other, format!("{}", e)))) + } +} + +#[cfg(test)] +mod tests { + use std::os::unix::io::AsRawFd; + + use super::*; + + #[test] + fn test_slave_fs_cache_req_set_failed() { + let (p1, _p2) = UnixStream::pair().unwrap(); + let fs_cache = SlaveFsCacheReq::from_stream(p1); + + assert!(fs_cache.node().error.is_none()); + fs_cache.set_failed(libc::EAGAIN); + assert_eq!(fs_cache.node().error, Some(libc::EAGAIN)); + } + + #[test] + fn test_slave_fs_cache_send_failure() { + let (p1, p2) = UnixStream::pair().unwrap(); + let fd = p2.as_raw_fd(); + let fs_cache = SlaveFsCacheReq::from_stream(p1); + + fs_cache.set_failed(libc::ECONNRESET); + fs_cache + .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd) + .unwrap_err(); + fs_cache + .fs_slave_unmap(&VhostUserFSSlaveMsg::default()) + .unwrap_err(); + fs_cache.node().error = None; + + drop(p2); + fs_cache + .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd) + .unwrap_err(); + fs_cache + .fs_slave_unmap(&VhostUserFSSlaveMsg::default()) + .unwrap_err(); + } + + #[test] + fn test_slave_fs_cache_recv_negative() { + let (p1, p2) = UnixStream::pair().unwrap(); + let fd = p2.as_raw_fd(); + let fs_cache = SlaveFsCacheReq::from_stream(p1); + let mut master = Endpoint::<SlaveReq>::from_stream(p2); + + let len = mem::size_of::<VhostUserFSSlaveMsg>(); + let mut hdr = VhostUserMsgHeader::new( + SlaveReq::FS_MAP, + VhostUserHeaderFlag::REPLY.bits(), + len as u32, + ); + let body = VhostUserU64::new(0); + + master.send_message(&hdr, &body, Some(&[fd])).unwrap(); + fs_cache + .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd) + .unwrap(); + + fs_cache.set_reply_ack_flag(true); + fs_cache + .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd) + .unwrap_err(); + + hdr.set_code(SlaveReq::FS_UNMAP); + master.send_message(&hdr, &body, None).unwrap(); + fs_cache + .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd) + .unwrap_err(); + hdr.set_code(SlaveReq::FS_MAP); + + let body = VhostUserU64::new(1); + master.send_message(&hdr, &body, None).unwrap(); + fs_cache + .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd) + .unwrap_err(); + + let body = VhostUserU64::new(0); + master.send_message(&hdr, &body, None).unwrap(); + fs_cache + .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd) + .unwrap(); } } diff --git a/src/vhost_user/slave_req_handler.rs b/src/vhost_user/slave_req_handler.rs index f3b0770..3b44e4c 100644 --- a/src/vhost_user/slave_req_handler.rs +++ b/src/vhost_user/slave_req_handler.rs @@ -1,8 +1,6 @@ // Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved. // SPDX-License-Identifier: Apache-2.0 -//! Traits and Structs to handle vhost-user requests from the master to the slave. - use std::mem; use std::os::unix::io::{AsRawFd, FromRawFd, RawFd}; use std::os::unix::net::UnixStream; @@ -14,9 +12,63 @@ use super::message::*; use super::slave_fs_cache::SlaveFsCacheReq; use super::{Error, Result}; -/// Trait to handle vhost-user requests from the master to the slave. +/// Services provided to the master by the slave with interior mutability. +/// +/// The [VhostUserSlaveReqHandler] trait defines the services provided to the master by the slave. +/// And the [VhostUserSlaveReqHandlerMut] trait is a helper mirroring [VhostUserSlaveReqHandler], +/// but without interior mutability. +/// The vhost-user specification defines a master communication channel, by which masters could +/// request services from slaves. The [VhostUserSlaveReqHandler] trait defines services provided by +/// slaves, and it's used both on the master side and slave side. +/// +/// - on the master side, a stub forwarder implementing [VhostUserSlaveReqHandler] will proxy +/// service requests to slaves. +/// - on the slave side, the [SlaveReqHandler] will forward service requests to a handler +/// implementing [VhostUserSlaveReqHandler]. +/// +/// The [VhostUserSlaveReqHandler] trait is design with interior mutability to improve performance +/// for multi-threading. +/// +/// [VhostUserSlaveReqHandler]: trait.VhostUserSlaveReqHandler.html +/// [VhostUserSlaveReqHandlerMut]: trait.VhostUserSlaveReqHandlerMut.html +/// [SlaveReqHandler]: struct.SlaveReqHandler.html #[allow(missing_docs)] pub trait VhostUserSlaveReqHandler { + fn set_owner(&self) -> Result<()>; + fn reset_owner(&self) -> Result<()>; + fn get_features(&self) -> Result<u64>; + fn set_features(&self, features: u64) -> Result<()>; + fn set_mem_table(&self, ctx: &[VhostUserMemoryRegion], fds: &[RawFd]) -> Result<()>; + fn set_vring_num(&self, index: u32, num: u32) -> Result<()>; + fn set_vring_addr( + &self, + index: u32, + flags: VhostUserVringAddrFlags, + descriptor: u64, + used: u64, + available: u64, + log: u64, + ) -> Result<()>; + fn set_vring_base(&self, index: u32, base: u32) -> Result<()>; + fn get_vring_base(&self, index: u32) -> Result<VhostUserVringState>; + fn set_vring_kick(&self, index: u8, fd: Option<RawFd>) -> Result<()>; + fn set_vring_call(&self, index: u8, fd: Option<RawFd>) -> Result<()>; + fn set_vring_err(&self, index: u8, fd: Option<RawFd>) -> Result<()>; + + fn get_protocol_features(&self) -> Result<VhostUserProtocolFeatures>; + fn set_protocol_features(&self, features: u64) -> Result<()>; + fn get_queue_num(&self) -> Result<u64>; + fn set_vring_enable(&self, index: u32, enable: bool) -> Result<()>; + fn get_config(&self, offset: u32, size: u32, flags: VhostUserConfigFlags) -> Result<Vec<u8>>; + fn set_config(&self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>; + fn set_slave_req_fd(&self, _vu_req: SlaveFsCacheReq) {} +} + +/// Services provided to the master by the slave without interior mutability. +/// +/// This is a helper trait mirroring the [VhostUserSlaveReqHandler] trait. +#[allow(missing_docs)] +pub trait VhostUserSlaveReqHandlerMut { fn set_owner(&mut self) -> Result<()>; fn reset_owner(&mut self) -> Result<()>; fn get_features(&mut self) -> Result<u64>; @@ -52,16 +104,110 @@ pub trait VhostUserSlaveReqHandler { fn set_slave_req_fd(&mut self, _vu_req: SlaveFsCacheReq) {} } -/// A vhost-user slave endpoint which relays all received requests from the -/// master to the virtio backend device object. +impl<T: VhostUserSlaveReqHandlerMut> VhostUserSlaveReqHandler for Mutex<T> { + fn set_owner(&self) -> Result<()> { + self.lock().unwrap().set_owner() + } + + fn reset_owner(&self) -> Result<()> { + self.lock().unwrap().reset_owner() + } + + fn get_features(&self) -> Result<u64> { + self.lock().unwrap().get_features() + } + + fn set_features(&self, features: u64) -> Result<()> { + self.lock().unwrap().set_features(features) + } + + fn set_mem_table(&self, ctx: &[VhostUserMemoryRegion], fds: &[RawFd]) -> Result<()> { + self.lock().unwrap().set_mem_table(ctx, fds) + } + + fn set_vring_num(&self, index: u32, num: u32) -> Result<()> { + self.lock().unwrap().set_vring_num(index, num) + } + + fn set_vring_addr( + &self, + index: u32, + flags: VhostUserVringAddrFlags, + descriptor: u64, + used: u64, + available: u64, + log: u64, + ) -> Result<()> { + self.lock() + .unwrap() + .set_vring_addr(index, flags, descriptor, used, available, log) + } + + fn set_vring_base(&self, index: u32, base: u32) -> Result<()> { + self.lock().unwrap().set_vring_base(index, base) + } + + fn get_vring_base(&self, index: u32) -> Result<VhostUserVringState> { + self.lock().unwrap().get_vring_base(index) + } + + fn set_vring_kick(&self, index: u8, fd: Option<RawFd>) -> Result<()> { + self.lock().unwrap().set_vring_kick(index, fd) + } + + fn set_vring_call(&self, index: u8, fd: Option<RawFd>) -> Result<()> { + self.lock().unwrap().set_vring_call(index, fd) + } + + fn set_vring_err(&self, index: u8, fd: Option<RawFd>) -> Result<()> { + self.lock().unwrap().set_vring_err(index, fd) + } + + fn get_protocol_features(&self) -> Result<VhostUserProtocolFeatures> { + self.lock().unwrap().get_protocol_features() + } + + fn set_protocol_features(&self, features: u64) -> Result<()> { + self.lock().unwrap().set_protocol_features(features) + } + + fn get_queue_num(&self) -> Result<u64> { + self.lock().unwrap().get_queue_num() + } + + fn set_vring_enable(&self, index: u32, enable: bool) -> Result<()> { + self.lock().unwrap().set_vring_enable(index, enable) + } + + fn get_config(&self, offset: u32, size: u32, flags: VhostUserConfigFlags) -> Result<Vec<u8>> { + self.lock().unwrap().get_config(offset, size, flags) + } + + fn set_config(&self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()> { + self.lock().unwrap().set_config(offset, buf, flags) + } + + fn set_slave_req_fd(&self, vu_req: SlaveFsCacheReq) { + self.lock().unwrap().set_slave_req_fd(vu_req) + } +} + +/// Server to handle service requests from masters from the master communication channel. +/// +/// The [SlaveReqHandler] acts as a server on the slave side, to handle service requests from +/// masters on the master communication channel. It's actually a proxy invoking the registered +/// handler implementing [VhostUserSlaveReqHandler] to do the real work. /// /// The lifetime of the SlaveReqHandler object should be the same as the underline Unix Domain /// Socket, so it gets simpler to recover from disconnect. +/// +/// [VhostUserSlaveReqHandler]: trait.VhostUserSlaveReqHandler.html +/// [SlaveReqHandler]: struct.SlaveReqHandler.html pub struct SlaveReqHandler<S: VhostUserSlaveReqHandler> { // underlying Unix domain socket for communication main_sock: Endpoint<MasterReq>, // the vhost-user backend device object - backend: Arc<Mutex<S>>, + backend: Arc<S>, virtio_features: u64, acked_virtio_features: u64, @@ -76,7 +222,7 @@ pub struct SlaveReqHandler<S: VhostUserSlaveReqHandler> { impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { /// Create a vhost-user slave endpoint. - pub(super) fn new(main_sock: Endpoint<MasterReq>, backend: Arc<Mutex<S>>) -> Self { + pub(super) fn new(main_sock: Endpoint<MasterReq>, backend: Arc<S>) -> Self { SlaveReqHandler { main_sock, backend, @@ -94,7 +240,7 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { /// # Arguments /// * - `path` - path of Unix domain socket listener to connect to /// * - `backend` - handler for requests from the master to the slave - pub fn connect(path: &str, backend: Arc<Mutex<S>>) -> Result<Self> { + pub fn connect(path: &str, backend: Arc<S>) -> Result<Self> { Ok(Self::new(Endpoint::<MasterReq>::connect(path)?, backend)) } @@ -103,11 +249,12 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { self.error = Some(error); } - /// Receive and handle one incoming request message from the master. - /// The caller needs to: - /// . serialize calls to this function - /// . decide what to do when error happens - /// . optional recover from failure + /// Main entrance to server slave request from the slave communication channel. + /// + /// Receive and handle one incoming request message from the master. The caller needs to: + /// - serialize calls to this function + /// - decide what to do when error happens + /// - optional recover from failure pub fn handle_request(&mut self) -> Result<()> { // Return error if the endpoint is already in failed state. self.check_state()?; @@ -137,15 +284,15 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { match hdr.get_code() { MasterReq::SET_OWNER => { self.check_request_size(&hdr, size, 0)?; - self.backend.lock().unwrap().set_owner()?; + self.backend.set_owner()?; } MasterReq::RESET_OWNER => { self.check_request_size(&hdr, size, 0)?; - self.backend.lock().unwrap().reset_owner()?; + self.backend.reset_owner()?; } MasterReq::GET_FEATURES => { self.check_request_size(&hdr, size, 0)?; - let features = self.backend.lock().unwrap().get_features()?; + let features = self.backend.get_features()?; let msg = VhostUserU64::new(features); self.send_reply_message(&hdr, &msg)?; self.virtio_features = features; @@ -153,7 +300,7 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { } MasterReq::SET_FEATURES => { let msg = self.extract_request_body::<VhostUserU64>(&hdr, size, &buf)?; - self.backend.lock().unwrap().set_features(msg.value)?; + self.backend.set_features(msg.value)?; self.acked_virtio_features = msg.value; self.update_reply_ack_flag(); } @@ -163,11 +310,7 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { } MasterReq::SET_VRING_NUM => { let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?; - let res = self - .backend - .lock() - .unwrap() - .set_vring_num(msg.index, msg.num); + let res = self.backend.set_vring_num(msg.index, msg.num); self.send_ack_message(&hdr, res)?; } MasterReq::SET_VRING_ADDR => { @@ -176,7 +319,7 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { Some(val) => val, None => return Err(Error::InvalidMessage), }; - let res = self.backend.lock().unwrap().set_vring_addr( + let res = self.backend.set_vring_addr( msg.index, flags, msg.descriptor, @@ -188,39 +331,35 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { } MasterReq::SET_VRING_BASE => { let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?; - let res = self - .backend - .lock() - .unwrap() - .set_vring_base(msg.index, msg.num); + let res = self.backend.set_vring_base(msg.index, msg.num); self.send_ack_message(&hdr, res)?; } MasterReq::GET_VRING_BASE => { let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?; - let reply = self.backend.lock().unwrap().get_vring_base(msg.index)?; + let reply = self.backend.get_vring_base(msg.index)?; self.send_reply_message(&hdr, &reply)?; } MasterReq::SET_VRING_CALL => { self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?; let (index, rfds) = self.handle_vring_fd_request(&buf, rfds)?; - let res = self.backend.lock().unwrap().set_vring_call(index, rfds); + let res = self.backend.set_vring_call(index, rfds); self.send_ack_message(&hdr, res)?; } MasterReq::SET_VRING_KICK => { self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?; let (index, rfds) = self.handle_vring_fd_request(&buf, rfds)?; - let res = self.backend.lock().unwrap().set_vring_kick(index, rfds); + let res = self.backend.set_vring_kick(index, rfds); self.send_ack_message(&hdr, res)?; } MasterReq::SET_VRING_ERR => { self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?; let (index, rfds) = self.handle_vring_fd_request(&buf, rfds)?; - let res = self.backend.lock().unwrap().set_vring_err(index, rfds); + let res = self.backend.set_vring_err(index, rfds); self.send_ack_message(&hdr, res)?; } MasterReq::GET_PROTOCOL_FEATURES => { self.check_request_size(&hdr, size, 0)?; - let features = self.backend.lock().unwrap().get_protocol_features()?; + let features = self.backend.get_protocol_features()?; let msg = VhostUserU64::new(features.bits()); self.send_reply_message(&hdr, &msg)?; self.protocol_features = features; @@ -228,10 +367,7 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { } MasterReq::SET_PROTOCOL_FEATURES => { let msg = self.extract_request_body::<VhostUserU64>(&hdr, size, &buf)?; - self.backend - .lock() - .unwrap() - .set_protocol_features(msg.value)?; + self.backend.set_protocol_features(msg.value)?; self.acked_protocol_features = msg.value; self.update_reply_ack_flag(); } @@ -240,7 +376,7 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { return Err(Error::InvalidOperation); } self.check_request_size(&hdr, size, 0)?; - let num = self.backend.lock().unwrap().get_queue_num()?; + let num = self.backend.get_queue_num()?; let msg = VhostUserU64::new(num); self.send_reply_message(&hdr, &msg)?; } @@ -257,17 +393,14 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { _ => return Err(Error::InvalidParam), }; - let res = self - .backend - .lock() - .unwrap() - .set_vring_enable(msg.index, enable); + let res = self.backend.set_vring_enable(msg.index, enable); self.send_ack_message(&hdr, res)?; } MasterReq::GET_CONFIG => { if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 { return Err(Error::InvalidOperation); } + self.check_request_size(&hdr, size, hdr.get_size() as usize)?; self.get_config(&hdr, &buf)?; } MasterReq::SET_CONFIG => { @@ -281,6 +414,7 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { if self.acked_protocol_features & VhostUserProtocolFeatures::SLAVE_REQ.bits() == 0 { return Err(Error::InvalidOperation); } + self.check_request_size(&hdr, size, hdr.get_size() as usize)?; self.set_slave_req_fd(&hdr, rfds)?; } _ => { @@ -341,15 +475,18 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { } } - self.backend.lock().unwrap().set_mem_table(®ions, &fds) + self.backend.set_mem_table(®ions, &fds) } fn get_config(&mut self, hdr: &VhostUserMsgHeader<MasterReq>, buf: &[u8]) -> Result<()> { - let msg = unsafe { &*(buf.as_ptr() as *const VhostUserConfig) }; + let payload_offset = mem::size_of::<VhostUserConfig>(); + if buf.len() > MAX_MSG_SIZE || buf.len() < payload_offset { + return Err(Error::InvalidMessage); + } + let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const VhostUserConfig) }; if !msg.is_valid() { return Err(Error::InvalidMessage); } - let payload_offset = mem::size_of::<VhostUserConfig>(); if buf.len() - payload_offset != msg.size as usize { return Err(Error::InvalidMessage); } @@ -357,11 +494,7 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { Some(val) => val, None => return Err(Error::InvalidMessage), }; - let res = self - .backend - .lock() - .unwrap() - .get_config(msg.offset, msg.size, flags); + let res = self.backend.get_config(msg.offset, msg.size, flags); // vhost-user slave's payload size MUST match master's request // on success, uses zero length of payload to indicate an error @@ -389,10 +522,10 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { size: usize, buf: &[u8], ) -> Result<()> { - if size < mem::size_of::<VhostUserConfig>() { + if size > MAX_MSG_SIZE || size < mem::size_of::<VhostUserConfig>() { return Err(Error::InvalidMessage); } - let msg = unsafe { &*(buf.as_ptr() as *const VhostUserConfig) }; + let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const VhostUserConfig) }; if !msg.is_valid() { return Err(Error::InvalidMessage); } @@ -405,11 +538,7 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { None => return Err(Error::InvalidMessage), } - let res = self - .backend - .lock() - .unwrap() - .set_config(msg.offset, buf, flags); + let res = self.backend.set_config(msg.offset, buf, flags); self.send_ack_message(&hdr, res)?; Ok(()) } @@ -423,7 +552,7 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { if fds.len() == 1 { let sock = unsafe { UnixStream::from_raw_fd(fds[0]) }; let vu_req = SlaveFsCacheReq::from_stream(sock); - self.backend.lock().unwrap().set_slave_req_fd(vu_req); + self.backend.set_slave_req_fd(vu_req); self.send_ack_message(&hdr, Ok(())) } else { Err(Error::InvalidMessage) @@ -438,7 +567,10 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { buf: &[u8], rfds: Option<Vec<RawFd>>, ) -> Result<(u8, Option<RawFd>)> { - let msg = unsafe { &*(buf.as_ptr() as *const VhostUserU64) }; + if buf.len() > MAX_MSG_SIZE || buf.len() < mem::size_of::<VhostUserU64>() { + return Err(Error::InvalidMessage); + } + let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const VhostUserU64) }; if !msg.is_valid() { return Err(Error::InvalidMessage); } @@ -447,10 +579,7 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { // invalid FD flag. This flag is set when there is no file descriptor // in the ancillary data. This signals that polling will be used // instead of waiting for the call. - let nofd = match msg.value & 0x100u64 { - 0x100u64 => true, - _ => false, - }; + let nofd = (msg.value & 0x100u64) == 0x100u64; let mut rfd = None; match rfds { @@ -519,14 +648,14 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { } } - fn extract_request_body<'a, T: Sized + VhostUserMsgValidator>( + fn extract_request_body<T: Sized + VhostUserMsgValidator>( &self, hdr: &VhostUserMsgHeader<MasterReq>, size: usize, - buf: &'a [u8], - ) -> Result<&'a T> { + buf: &[u8], + ) -> Result<T> { self.check_request_size(hdr, size, mem::size_of::<T>())?; - let msg = unsafe { &*(buf.as_ptr() as *const T) }; + let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const T) }; if !msg.is_valid() { return Err(Error::InvalidMessage); } @@ -552,7 +681,10 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { req: &VhostUserMsgHeader<MasterReq>, payload_size: usize, ) -> Result<VhostUserMsgHeader<MasterReq>> { - if mem::size_of::<T>() > MAX_MSG_SIZE { + if mem::size_of::<T>() > MAX_MSG_SIZE + || payload_size > MAX_MSG_SIZE + || mem::size_of::<T>() + payload_size > MAX_MSG_SIZE + { return Err(Error::InvalidParam); } self.check_state()?; @@ -568,7 +700,7 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { req: &VhostUserMsgHeader<MasterReq>, res: Result<()>, ) -> Result<()> { - if self.reply_ack_enabled { + if self.reply_ack_enabled && req.is_need_reply() { let hdr = self.new_reply_header::<VhostUserU64>(req, 0)?; let val = match res { Ok(_) => 0, @@ -590,16 +722,12 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { Ok(()) } - fn send_reply_with_payload<T, P>( + fn send_reply_with_payload<T: Sized>( &mut self, req: &VhostUserMsgHeader<MasterReq>, msg: &T, - payload: &[P], - ) -> Result<()> - where - T: Sized, - P: Sized, - { + payload: &[u8], + ) -> Result<()> { let hdr = self.new_reply_header::<T>(req, payload.len())?; self.main_sock .send_message_with_payload(&hdr, msg, payload, None)?; @@ -612,3 +740,24 @@ impl<S: VhostUserSlaveReqHandler> AsRawFd for SlaveReqHandler<S> { self.main_sock.as_raw_fd() } } + +#[cfg(test)] +mod tests { + use std::os::unix::io::AsRawFd; + + use super::*; + use crate::vhost_user::dummy_slave::DummySlaveReqHandler; + + #[test] + fn test_slave_req_handler_new() { + let (p1, _p2) = UnixStream::pair().unwrap(); + let endpoint = Endpoint::<MasterReq>::from_stream(p1); + let backend = Arc::new(Mutex::new(DummySlaveReqHandler::new())); + let mut handler = SlaveReqHandler::new(endpoint, backend); + + handler.check_state().unwrap(); + handler.set_failed(libc::EAGAIN); + handler.check_state().unwrap_err(); + assert!(handler.as_raw_fd() >= 0); + } +} diff --git a/src/vhost_user/sock_ctrl_msg.rs b/src/vhost_user/sock_ctrl_msg.rs deleted file mode 100644 index db3ec2e..0000000 --- a/src/vhost_user/sock_ctrl_msg.rs +++ /dev/null @@ -1,499 +0,0 @@ -// Copyright 2017 The Chromium OS Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -//! Used to send and receive messages with file descriptors on sockets that accept control messages -//! (e.g. Unix domain sockets). - -// TODO: move this file into the vmm-sys-util crate - -use std::fs::File; -use std::mem::size_of; -use std::os::unix::io::{AsRawFd, FromRawFd, RawFd}; -use std::os::unix::net::{UnixDatagram, UnixStream}; -use std::ptr::{copy_nonoverlapping, null_mut, write_unaligned}; - -use libc::{ - c_long, c_void, cmsghdr, iovec, msghdr, recvmsg, sendmsg, MSG_NOSIGNAL, SCM_RIGHTS, SOL_SOCKET, -}; -use vmm_sys_util::errno::{Error, Result}; - -// Each of the following macros performs the same function as their C counterparts. They are each -// macros because they are used to size statically allocated arrays. - -macro_rules! CMSG_ALIGN { - ($len:expr) => { - (($len) + size_of::<c_long>() - 1) & !(size_of::<c_long>() - 1) - }; -} - -macro_rules! CMSG_SPACE { - ($len:expr) => { - size_of::<cmsghdr>() + CMSG_ALIGN!($len) - }; -} - -#[cfg(not(target_env = "musl"))] -macro_rules! CMSG_LEN { - ($len:expr) => { - size_of::<cmsghdr>() + ($len) - }; -} - -#[cfg(target_env = "musl")] -macro_rules! CMSG_LEN { - ($len:expr) => {{ - let sz = size_of::<cmsghdr>() + ($len); - assert!(sz <= (std::u32::MAX as usize)); - sz as u32 - }}; -} - -#[cfg(not(target_env = "musl"))] -fn new_msghdr(iovecs: &mut [iovec]) -> msghdr { - msghdr { - msg_name: null_mut(), - msg_namelen: 0, - msg_iov: iovecs.as_mut_ptr(), - msg_iovlen: iovecs.len(), - msg_control: null_mut(), - msg_controllen: 0, - msg_flags: 0, - } -} - -#[cfg(target_env = "musl")] -fn new_msghdr(iovecs: &mut [iovec]) -> msghdr { - assert!(iovecs.len() <= (std::i32::MAX as usize)); - let mut msg: msghdr = unsafe { std::mem::zeroed() }; - msg.msg_name = null_mut(); - msg.msg_iov = iovecs.as_mut_ptr(); - msg.msg_iovlen = iovecs.len() as i32; - msg.msg_control = null_mut(); - msg -} - -#[cfg(not(target_env = "musl"))] -fn set_msg_controllen(msg: &mut msghdr, cmsg_capacity: usize) { - msg.msg_controllen = cmsg_capacity; -} - -#[cfg(target_env = "musl")] -fn set_msg_controllen(msg: &mut msghdr, cmsg_capacity: usize) { - assert!(cmsg_capacity <= (std::u32::MAX as usize)); - msg.msg_controllen = cmsg_capacity as u32; -} - -// This function (macro in the C version) is not used in any compile time constant slots, so is just -// an ordinary function. The returned pointer is hard coded to be RawFd because that's all that this -// module supports. -#[allow(non_snake_case)] -#[inline(always)] -fn CMSG_DATA(cmsg_buffer: *mut cmsghdr) -> *mut RawFd { - // Essentially returns a pointer to just past the header. - cmsg_buffer.wrapping_offset(1) as *mut RawFd -} - -// This function is like CMSG_NEXT, but safer because it reads only from references, although it -// does some pointer arithmetic on cmsg_ptr. -#[cfg_attr(feature = "cargo-clippy", allow(clippy::cast_ptr_alignment))] -fn get_next_cmsg(msghdr: &msghdr, cmsg: &cmsghdr, cmsg_ptr: *mut cmsghdr) -> *mut cmsghdr { - let next_cmsg = - (cmsg_ptr as *mut u8).wrapping_add(CMSG_ALIGN!(cmsg.cmsg_len as usize)) as *mut cmsghdr; - if next_cmsg - .wrapping_offset(1) - .wrapping_sub(msghdr.msg_control as usize) as usize - > msghdr.msg_controllen as usize - { - null_mut() - } else { - next_cmsg - } -} - -const CMSG_BUFFER_INLINE_CAPACITY: usize = CMSG_SPACE!(size_of::<RawFd>() * 32); - -enum CmsgBuffer { - Inline([u64; (CMSG_BUFFER_INLINE_CAPACITY + 7) / 8]), - Heap(Box<[cmsghdr]>), -} - -impl CmsgBuffer { - fn with_capacity(capacity: usize) -> CmsgBuffer { - let cap_in_cmsghdr_units = - (capacity.checked_add(size_of::<cmsghdr>()).unwrap() - 1) / size_of::<cmsghdr>(); - if capacity <= CMSG_BUFFER_INLINE_CAPACITY { - CmsgBuffer::Inline([0u64; (CMSG_BUFFER_INLINE_CAPACITY + 7) / 8]) - } else { - CmsgBuffer::Heap( - vec![ - cmsghdr { - cmsg_len: 0, - cmsg_level: 0, - cmsg_type: 0, - #[cfg(target_env = "musl")] - __pad1: 0, - }; - cap_in_cmsghdr_units - ] - .into_boxed_slice(), - ) - } - } - - fn as_mut_ptr(&mut self) -> *mut cmsghdr { - match self { - CmsgBuffer::Inline(a) => a.as_mut_ptr() as *mut cmsghdr, - CmsgBuffer::Heap(a) => a.as_mut_ptr(), - } - } -} - -fn raw_sendmsg<D: IntoIovec>(fd: RawFd, out_data: &[D], out_fds: &[RawFd]) -> Result<usize> { - let cmsg_capacity = CMSG_SPACE!(size_of::<RawFd>() * out_fds.len()); - let mut cmsg_buffer = CmsgBuffer::with_capacity(cmsg_capacity); - - let mut iovecs = Vec::with_capacity(out_data.len()); - for data in out_data { - iovecs.push(iovec { - iov_base: data.as_ptr() as *mut c_void, - iov_len: data.size(), - }); - } - - let mut msg = new_msghdr(&mut iovecs); - - if !out_fds.is_empty() { - let cmsg = cmsghdr { - cmsg_len: CMSG_LEN!(size_of::<RawFd>() * out_fds.len()), - cmsg_level: SOL_SOCKET, - cmsg_type: SCM_RIGHTS, - #[cfg(target_env = "musl")] - __pad1: 0, - }; - unsafe { - // Safe because cmsg_buffer was allocated to be large enough to contain cmsghdr. - write_unaligned(cmsg_buffer.as_mut_ptr() as *mut cmsghdr, cmsg); - // Safe because the cmsg_buffer was allocated to be large enough to hold out_fds.len() - // file descriptors. - copy_nonoverlapping( - out_fds.as_ptr(), - CMSG_DATA(cmsg_buffer.as_mut_ptr()), - out_fds.len(), - ); - } - - msg.msg_control = cmsg_buffer.as_mut_ptr() as *mut c_void; - set_msg_controllen(&mut msg, cmsg_capacity); - } - - // Safe because the msghdr was properly constructed from valid (or null) pointers of the - // indicated length and we check the return value. - let write_count = unsafe { sendmsg(fd, &msg, MSG_NOSIGNAL) }; - - if write_count == -1 { - Err(Error::last()) - } else { - Ok(write_count as usize) - } -} - -fn raw_recvmsg(fd: RawFd, iovecs: &mut [iovec], in_fds: &mut [RawFd]) -> Result<(usize, usize)> { - let cmsg_capacity = CMSG_SPACE!(size_of::<RawFd>() * in_fds.len()); - let mut cmsg_buffer = CmsgBuffer::with_capacity(cmsg_capacity); - let mut msg = new_msghdr(iovecs); - - if !in_fds.is_empty() { - msg.msg_control = cmsg_buffer.as_mut_ptr() as *mut c_void; - set_msg_controllen(&mut msg, cmsg_capacity); - } - - // Safe because the msghdr was properly constructed from valid (or null) pointers of the - // indicated length and we check the return value. - let total_read = unsafe { recvmsg(fd, &mut msg, libc::MSG_WAITALL) }; - - if total_read == -1 { - return Err(Error::last()); - } - - // When the connection is closed recvmsg() doesn't give an explicit error - if total_read == 0 && (msg.msg_controllen as usize) < size_of::<cmsghdr>() { - return Err(Error::new(libc::ECONNRESET)); - } - - let mut cmsg_ptr = msg.msg_control as *mut cmsghdr; - let mut in_fds_count = 0; - while !cmsg_ptr.is_null() { - // Safe because we checked that cmsg_ptr was non-null, and the loop is constructed such that - // that only happens when there is at least sizeof(cmsghdr) space after the pointer to read. - let cmsg = unsafe { (cmsg_ptr as *mut cmsghdr).read_unaligned() }; - - if cmsg.cmsg_level == SOL_SOCKET && cmsg.cmsg_type == SCM_RIGHTS { - let fd_count = (cmsg.cmsg_len - CMSG_LEN!(0)) as usize / size_of::<RawFd>(); - unsafe { - copy_nonoverlapping( - CMSG_DATA(cmsg_ptr), - in_fds[in_fds_count..(in_fds_count + fd_count)].as_mut_ptr(), - fd_count, - ); - } - in_fds_count += fd_count; - } - - cmsg_ptr = get_next_cmsg(&msg, &cmsg, cmsg_ptr); - } - - Ok((total_read as usize, in_fds_count)) -} - -/// Trait for file descriptors can send and receive socket control messages via `sendmsg` and -/// `recvmsg`. -pub trait ScmSocket { - /// Gets the file descriptor of this socket. - fn socket_fd(&self) -> RawFd; - - /// Sends the given data and file descriptor over the socket. - /// - /// On success, returns the number of bytes sent. - /// - /// # Arguments - /// - /// * `buf` - A buffer of data to send on the `socket`. - /// * `fd` - A file descriptors to be sent. - fn send_with_fd<D: IntoIovec>(&self, buf: D, fd: RawFd) -> Result<usize> { - self.send_with_fds(&[buf], &[fd]) - } - - /// Sends the given data and file descriptors over the socket. - /// - /// On success, returns the number of bytes sent. - /// - /// # Arguments - /// - /// * `bufs` - A list of data buffer to send on the `socket`. - /// * `fds` - A list of file descriptors to be sent. - fn send_with_fds<D: IntoIovec>(&self, bufs: &[D], fds: &[RawFd]) -> Result<usize> { - raw_sendmsg(self.socket_fd(), bufs, fds) - } - - /// Receives data and potentially a file descriptor from the socket. - /// - /// On success, returns the number of bytes and an optional file descriptor. - /// - /// # Arguments - /// - /// * `buf` - A buffer to receive data from the socket. - fn recv_with_fd(&self, buf: &mut [u8]) -> Result<(usize, Option<File>)> { - let mut fd = [0]; - let mut iovecs = [iovec { - iov_base: buf.as_mut_ptr() as *mut c_void, - iov_len: buf.len(), - }]; - - let (read_count, fd_count) = self.recv_with_fds(&mut iovecs[..], &mut fd)?; - let file = if fd_count == 0 { - None - } else { - // Safe because the first fd from recv_with_fds is owned by us and valid because this - // branch was taken. - Some(unsafe { File::from_raw_fd(fd[0]) }) - }; - Ok((read_count, file)) - } - - /// Receives data and file descriptors from the socket. - /// - /// On success, returns the number of bytes and file descriptors received as a tuple - /// `(bytes count, files count)`. - /// - /// # Arguments - /// - /// * `iovecs` - A list of iovec to receive data from the socket. - /// * `fds` - A slice of `RawFd`s to put the received file descriptors into. On success, the - /// number of valid file descriptors is indicated by the second element of the - /// returned tuple. The caller owns these file descriptors, but they will not be - /// closed on drop like a `File`-like type would be. It is recommended that each valid - /// file descriptor gets wrapped in a drop type that closes it after this returns. - fn recv_with_fds(&self, iovecs: &mut [iovec], fds: &mut [RawFd]) -> Result<(usize, usize)> { - raw_recvmsg(self.socket_fd(), iovecs, fds) - } -} - -impl ScmSocket for UnixDatagram { - fn socket_fd(&self) -> RawFd { - self.as_raw_fd() - } -} - -impl ScmSocket for UnixStream { - fn socket_fd(&self) -> RawFd { - self.as_raw_fd() - } -} - -/// Trait for types that can be converted into an `iovec` that can be referenced by a syscall for -/// the lifetime of this object. -/// -/// This trait is unsafe because interfaces that use this trait depend on the base pointer and size -/// being accurate. -pub unsafe trait IntoIovec { - /// Gets the base pointer of this `iovec`. - fn as_ptr(&self) -> *const c_void; - - /// Gets the size in bytes of this `iovec`. - fn size(&self) -> usize; -} - -// Safe because this slice can not have another mutable reference and it's pointer and size are -// guaranteed to be valid. -unsafe impl<'a> IntoIovec for &'a [u8] { - // Clippy false positive: https://github.com/rust-lang/rust-clippy/issues/3480 - #[cfg_attr(feature = "cargo-clippy", allow(clippy::useless_asref))] - fn as_ptr(&self) -> *const c_void { - self.as_ref().as_ptr() as *const c_void - } - - fn size(&self) -> usize { - self.len() - } -} - -#[cfg(test)] -mod tests { - use super::*; - - use std::io::Write; - use std::mem::size_of; - use std::os::raw::c_long; - use std::os::unix::net::UnixDatagram; - use std::slice::from_raw_parts; - - use libc::cmsghdr; - - use vmm_sys_util::eventfd::EventFd; - - #[test] - fn buffer_len() { - assert_eq!(CMSG_SPACE!(0 * size_of::<RawFd>()), size_of::<cmsghdr>()); - assert_eq!( - CMSG_SPACE!(1 * size_of::<RawFd>()), - size_of::<cmsghdr>() + size_of::<c_long>() - ); - if size_of::<RawFd>() == 4 { - assert_eq!( - CMSG_SPACE!(2 * size_of::<RawFd>()), - size_of::<cmsghdr>() + size_of::<c_long>() - ); - assert_eq!( - CMSG_SPACE!(3 * size_of::<RawFd>()), - size_of::<cmsghdr>() + size_of::<c_long>() * 2 - ); - assert_eq!( - CMSG_SPACE!(4 * size_of::<RawFd>()), - size_of::<cmsghdr>() + size_of::<c_long>() * 2 - ); - } else if size_of::<RawFd>() == 8 { - assert_eq!( - CMSG_SPACE!(2 * size_of::<RawFd>()), - size_of::<cmsghdr>() + size_of::<c_long>() * 2 - ); - assert_eq!( - CMSG_SPACE!(3 * size_of::<RawFd>()), - size_of::<cmsghdr>() + size_of::<c_long>() * 3 - ); - assert_eq!( - CMSG_SPACE!(4 * size_of::<RawFd>()), - size_of::<cmsghdr>() + size_of::<c_long>() * 4 - ); - } - } - - #[test] - fn send_recv_no_fd() { - let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair"); - - let write_count = s1 - .send_with_fds(&[[1u8, 1, 2].as_ref(), [21u8, 34, 55].as_ref()], &[]) - .expect("failed to send data"); - - assert_eq!(write_count, 6); - - let mut buf = [0u8; 6]; - let mut files = [0; 1]; - let mut iovecs = [iovec { - iov_base: buf.as_mut_ptr() as *mut c_void, - iov_len: buf.len(), - }]; - let (read_count, file_count) = s2 - .recv_with_fds(&mut iovecs[..], &mut files) - .expect("failed to recv data"); - - assert_eq!(read_count, 6); - assert_eq!(file_count, 0); - assert_eq!(buf, [1, 1, 2, 21, 34, 55]); - } - - #[test] - fn send_recv_only_fd() { - let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair"); - - let evt = EventFd::new(0).expect("failed to create eventfd"); - let write_count = s1 - .send_with_fd([].as_ref(), evt.as_raw_fd()) - .expect("failed to send fd"); - - assert_eq!(write_count, 0); - - let (read_count, file_opt) = s2.recv_with_fd(&mut []).expect("failed to recv fd"); - - let mut file = file_opt.unwrap(); - - assert_eq!(read_count, 0); - assert!(file.as_raw_fd() >= 0); - assert_ne!(file.as_raw_fd(), s1.as_raw_fd()); - assert_ne!(file.as_raw_fd(), s2.as_raw_fd()); - assert_ne!(file.as_raw_fd(), evt.as_raw_fd()); - - file.write(unsafe { from_raw_parts(&1203u64 as *const u64 as *const u8, 8) }) - .expect("failed to write to sent fd"); - - assert_eq!(evt.read().expect("failed to read from eventfd"), 1203); - } - - #[test] - fn send_recv_with_fd() { - let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair"); - - let evt = EventFd::new(0).expect("failed to create eventfd"); - let write_count = s1 - .send_with_fds(&[[237].as_ref()], &[evt.as_raw_fd()]) - .expect("failed to send fd"); - - assert_eq!(write_count, 1); - - let mut files = [0; 2]; - let mut buf = [0u8]; - let mut iovecs = [iovec { - iov_base: buf.as_mut_ptr() as *mut c_void, - iov_len: buf.len(), - }]; - let (read_count, file_count) = s2 - .recv_with_fds(&mut iovecs[..], &mut files) - .expect("failed to recv fd"); - - assert_eq!(read_count, 1); - assert_eq!(buf[0], 237); - assert_eq!(file_count, 1); - assert!(files[0] >= 0); - assert_ne!(files[0], s1.as_raw_fd()); - assert_ne!(files[0], s2.as_raw_fd()); - assert_ne!(files[0], evt.as_raw_fd()); - - let mut file = unsafe { File::from_raw_fd(files[0]) }; - - file.write(unsafe { from_raw_parts(&1203u64 as *const u64 as *const u8, 8) }) - .expect("failed to write to sent fd"); - - assert_eq!(evt.read().expect("failed to read from eventfd"), 1203); - } -} |