From 8c6919bf60bd641398ddd53864fbc74d75548837 Mon Sep 17 00:00:00 2001 From: "dependabot-preview[bot]" <27856297+dependabot-preview[bot]@users.noreply.github.com> Date: Wed, 24 Feb 2021 07:08:05 +0000 Subject: build(deps): bump rust-vmm-ci from `e58ea74` to `ebc7016` Bumps [rust-vmm-ci](https://github.com/rust-vmm/rust-vmm-ci) from `e58ea74` to `ebc7016`. - [Release notes](https://github.com/rust-vmm/rust-vmm-ci/releases) - [Commits](https://github.com/rust-vmm/rust-vmm-ci/compare/e58ea7445ace0cb984f8002ba2436c34cf592efe...ebc701641fa57f78d03f3f5ecac617b7bf7470b4) Signed-off-by: dependabot-preview[bot] --- rust-vmm-ci | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rust-vmm-ci b/rust-vmm-ci index e58ea74..ebc7016 160000 --- a/rust-vmm-ci +++ b/rust-vmm-ci @@ -1 +1 @@ -Subproject commit e58ea7445ace0cb984f8002ba2436c34cf592efe +Subproject commit ebc701641fa57f78d03f3f5ecac617b7bf7470b4 -- cgit v1.2.3 From d748d5bdcf70b7d565a0600a57e8cf08976babb4 Mon Sep 17 00:00:00 2001 From: Liu Jiang Date: Fri, 19 Feb 2021 18:59:54 +0800 Subject: Introduce VhostBackendMut trait Originally the VhostBackend trait is designed to take a mutable self, which causes a common usage pattern Arc>. This pattern may enforce serialization among multiple threads. So rename the original VhostBackend as VhostBackendMut, and introduce a new VhostBackend trait with interior mutability to improve performance by removing the serialization. Signed-off-by: Liu Jiang --- src/backend.rs | 328 ++++++++++++++++++++++++++++++++++++++++++++++- src/vhost_kern/mod.rs | 28 ++-- src/vhost_user/master.rs | 32 ++--- src/vhost_user/mod.rs | 2 +- 4 files changed, 354 insertions(+), 36 deletions(-) diff --git a/src/backend.rs b/src/backend.rs index 2d1a4a2..9dafef7 100644 --- a/src/backend.rs +++ b/src/backend.rs @@ -1,4 +1,4 @@ -// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved. +// Copyright (C) 2019-2021 Alibaba Cloud. All rights reserved. // SPDX-License-Identifier: Apache-2.0 or BSD-3-Clause // // Portions Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. @@ -9,14 +9,18 @@ //! Common traits and structs for vhost-kern and vhost-user backend drivers. -use super::Result; +use std::cell::RefCell; use std::os::unix::io::RawFd; +use std::sync::RwLock; + use vmm_sys_util::eventfd::EventFd; +use super::Result; + /// Maximum number of memory regions supported. pub const VHOST_MAX_MEMORY_REGIONS: usize = 255; -/// Vring/virtque configuration data. +/// Vring configuration data. pub struct VringConfigData { /// Maximum queue size supported by the driver. pub queue_max_size: u16, @@ -65,21 +69,108 @@ pub struct VhostUserMemoryRegionInfo { pub userspace_addr: u64, /// Optional offset where region starts in the mapped memory. pub mmap_offset: u64, - /// Optional file diescriptor for mmap + /// Optional file descriptor for mmap. pub mmap_handle: RawFd, } -/// An interface for setting up vhost-based backend drivers. +/// An interface for setting up vhost-based backend drivers with interior mutability. /// /// Vhost devices are subset of virtio devices, which improve virtio device's performance by /// delegating data plane operations to dedicated IO service processes. Vhost devices use the /// same virtqueue layout as virtio devices to allow vhost devices to be mapped directly to /// virtio devices. +/// /// The purpose of vhost is to implement a subset of a virtio device's functionality outside the /// VMM process. Typically fast paths for IO operations are delegated to the dedicated IO service /// processes, and slow path for device configuration are still handled by the VMM process. It may /// also be used to control access permissions of virtio backend devices. pub trait VhostBackend: std::marker::Sized { + /// Get a bitmask of supported virtio/vhost features. + fn get_features(&self) -> Result; + + /// Inform the vhost subsystem which features to enable. + /// This should be a subset of supported features from get_features(). + /// + /// # Arguments + /// * `features` - Bitmask of features to set. + fn set_features(&self, features: u64) -> Result<()>; + + /// Set the current process as the owner of the vhost backend. + /// This must be run before any other vhost commands. + fn set_owner(&self) -> Result<()>; + + /// Used to be sent to request disabling all rings + /// This is no longer used. + fn reset_owner(&self) -> Result<()>; + + /// Set the guest memory mappings for vhost to use. + fn set_mem_table(&self, regions: &[VhostUserMemoryRegionInfo]) -> Result<()>; + + /// Set base address for page modification logging. + fn set_log_base(&self, base: u64, fd: Option) -> Result<()>; + + /// Specify an eventfd file descriptor to signal on log write. + fn set_log_fd(&self, fd: RawFd) -> Result<()>; + + /// Set the number of descriptors in the vring. + /// + /// # Arguments + /// * `queue_index` - Index of the queue to set descriptor count for. + /// * `num` - Number of descriptors in the queue. + fn set_vring_num(&self, queue_index: usize, num: u16) -> Result<()>; + + /// Set the addresses for a given vring. + /// + /// # Arguments + /// * `queue_index` - Index of the queue to set addresses for. + /// * `config_data` - Configuration data for a vring. + fn set_vring_addr(&self, queue_index: usize, config_data: &VringConfigData) -> Result<()>; + + /// Set the first index to look for available descriptors. + /// + /// # Arguments + /// * `queue_index` - Index of the queue to modify. + /// * `num` - Index where available descriptors start. + fn set_vring_base(&self, queue_index: usize, base: u16) -> Result<()>; + + /// Get the available vring base offset. + fn get_vring_base(&self, queue_index: usize) -> Result; + + /// Set the eventfd to trigger when buffers have been used by the host. + /// + /// # Arguments + /// * `queue_index` - Index of the queue to modify. + /// * `fd` - EventFd to trigger. + fn set_vring_call(&self, queue_index: usize, fd: &EventFd) -> Result<()>; + + /// Set the eventfd that will be signaled by the guest when buffers are + /// available for the host to process. + /// + /// # Arguments + /// * `queue_index` - Index of the queue to modify. + /// * `fd` - EventFd that will be signaled from guest. + fn set_vring_kick(&self, queue_index: usize, fd: &EventFd) -> Result<()>; + + /// Set the eventfd that will be signaled by the guest when error happens. + /// + /// # Arguments + /// * `queue_index` - Index of the queue to modify. + /// * `fd` - EventFd that will be signaled from guest. + fn set_vring_err(&self, queue_index: usize, fd: &EventFd) -> Result<()>; +} + +/// An interface for setting up vhost-based backend drivers. +/// +/// Vhost devices are subset of virtio devices, which improve virtio device's performance by +/// delegating data plane operations to dedicated IO service processes. Vhost devices use the +/// same virtqueue layout as virtio devices to allow vhost devices to be mapped directly to +/// virtio devices. +/// +/// The purpose of vhost is to implement a subset of a virtio device's functionality outside the +/// VMM process. Typically fast paths for IO operations are delegated to the dedicated IO service +/// processes, and slow path for device configuration are still handled by the VMM process. It may +/// also be used to control access permissions of virtio backend devices. +pub trait VhostBackendMut: std::marker::Sized { /// Get a bitmask of supported virtio/vhost features. fn get_features(&mut self) -> Result; @@ -154,10 +245,237 @@ pub trait VhostBackend: std::marker::Sized { fn set_vring_err(&mut self, queue_index: usize, fd: &EventFd) -> Result<()>; } +impl VhostBackend for RwLock { + fn get_features(&self) -> Result { + self.write().unwrap().get_features() + } + + fn set_features(&self, features: u64) -> Result<()> { + self.write().unwrap().set_features(features) + } + + fn set_owner(&self) -> Result<()> { + self.write().unwrap().set_owner() + } + + fn reset_owner(&self) -> Result<()> { + self.write().unwrap().reset_owner() + } + + fn set_mem_table(&self, regions: &[VhostUserMemoryRegionInfo]) -> Result<()> { + self.write().unwrap().set_mem_table(regions) + } + + fn set_log_base(&self, base: u64, fd: Option) -> Result<()> { + self.write().unwrap().set_log_base(base, fd) + } + + fn set_log_fd(&self, fd: RawFd) -> Result<()> { + self.write().unwrap().set_log_fd(fd) + } + + fn set_vring_num(&self, queue_index: usize, num: u16) -> Result<()> { + self.write().unwrap().set_vring_num(queue_index, num) + } + + fn set_vring_addr(&self, queue_index: usize, config_data: &VringConfigData) -> Result<()> { + self.write() + .unwrap() + .set_vring_addr(queue_index, config_data) + } + + fn set_vring_base(&self, queue_index: usize, base: u16) -> Result<()> { + self.write().unwrap().set_vring_base(queue_index, base) + } + + fn get_vring_base(&self, queue_index: usize) -> Result { + self.write().unwrap().get_vring_base(queue_index) + } + + fn set_vring_call(&self, queue_index: usize, fd: &EventFd) -> Result<()> { + self.write().unwrap().set_vring_call(queue_index, fd) + } + + fn set_vring_kick(&self, queue_index: usize, fd: &EventFd) -> Result<()> { + self.write().unwrap().set_vring_kick(queue_index, fd) + } + + fn set_vring_err(&self, queue_index: usize, fd: &EventFd) -> Result<()> { + self.write().unwrap().set_vring_err(queue_index, fd) + } +} + +impl VhostBackend for RefCell { + fn get_features(&self) -> Result { + self.borrow_mut().get_features() + } + + fn set_features(&self, features: u64) -> Result<()> { + self.borrow_mut().set_features(features) + } + + fn set_owner(&self) -> Result<()> { + self.borrow_mut().set_owner() + } + + fn reset_owner(&self) -> Result<()> { + self.borrow_mut().reset_owner() + } + + fn set_mem_table(&self, regions: &[VhostUserMemoryRegionInfo]) -> Result<()> { + self.borrow_mut().set_mem_table(regions) + } + + fn set_log_base(&self, base: u64, fd: Option) -> Result<()> { + self.borrow_mut().set_log_base(base, fd) + } + + fn set_log_fd(&self, fd: RawFd) -> Result<()> { + self.borrow_mut().set_log_fd(fd) + } + + fn set_vring_num(&self, queue_index: usize, num: u16) -> Result<()> { + self.borrow_mut().set_vring_num(queue_index, num) + } + + fn set_vring_addr(&self, queue_index: usize, config_data: &VringConfigData) -> Result<()> { + self.borrow_mut().set_vring_addr(queue_index, config_data) + } + + fn set_vring_base(&self, queue_index: usize, base: u16) -> Result<()> { + self.borrow_mut().set_vring_base(queue_index, base) + } + + fn get_vring_base(&self, queue_index: usize) -> Result { + self.borrow_mut().get_vring_base(queue_index) + } + + fn set_vring_call(&self, queue_index: usize, fd: &EventFd) -> Result<()> { + self.borrow_mut().set_vring_call(queue_index, fd) + } + + fn set_vring_kick(&self, queue_index: usize, fd: &EventFd) -> Result<()> { + self.borrow_mut().set_vring_kick(queue_index, fd) + } + + fn set_vring_err(&self, queue_index: usize, fd: &EventFd) -> Result<()> { + self.borrow_mut().set_vring_err(queue_index, fd) + } +} #[cfg(test)] mod tests { use VringConfigData; + struct MockBackend {} + + impl VhostBackendMut for MockBackend { + fn get_features(&mut self) -> Result { + Ok(0x1) + } + + fn set_features(&mut self, features: u64) -> Result<()> { + assert_eq!(features, 0x1); + Ok(()) + } + + fn set_owner(&mut self) -> Result<()> { + Ok(()) + } + + fn reset_owner(&mut self) -> Result<()> { + Ok(()) + } + + fn set_mem_table(&mut self, _regions: &[VhostUserMemoryRegionInfo]) -> Result<()> { + Ok(()) + } + + fn set_log_base(&mut self, base: u64, fd: Option) -> Result<()> { + assert_eq!(base, 0x100); + assert_eq!(fd, Some(100)); + Ok(()) + } + + fn set_log_fd(&mut self, fd: RawFd) -> Result<()> { + assert_eq!(fd, 100); + Ok(()) + } + + fn set_vring_num(&mut self, queue_index: usize, num: u16) -> Result<()> { + assert_eq!(queue_index, 1); + assert_eq!(num, 256); + Ok(()) + } + + fn set_vring_addr( + &mut self, + queue_index: usize, + _config_data: &VringConfigData, + ) -> Result<()> { + assert_eq!(queue_index, 1); + Ok(()) + } + + fn set_vring_base(&mut self, queue_index: usize, base: u16) -> Result<()> { + assert_eq!(queue_index, 1); + assert_eq!(base, 2); + Ok(()) + } + + fn get_vring_base(&mut self, queue_index: usize) -> Result { + assert_eq!(queue_index, 1); + Ok(2) + } + + fn set_vring_call(&mut self, queue_index: usize, _fd: &EventFd) -> Result<()> { + assert_eq!(queue_index, 1); + Ok(()) + } + + fn set_vring_kick(&mut self, queue_index: usize, _fd: &EventFd) -> Result<()> { + assert_eq!(queue_index, 1); + Ok(()) + } + + fn set_vring_err(&mut self, queue_index: usize, _fd: &EventFd) -> Result<()> { + assert_eq!(queue_index, 1); + Ok(()) + } + } + + #[test] + fn test_vring_backend_mut() { + let b = RwLock::new(MockBackend {}); + + assert_eq!(b.get_features().unwrap(), 0x1); + b.set_features(0x1).unwrap(); + b.set_owner().unwrap(); + b.reset_owner().unwrap(); + b.set_mem_table(&[]).unwrap(); + b.set_log_base(0x100, Some(100)).unwrap(); + b.set_log_fd(100).unwrap(); + b.set_vring_num(1, 256).unwrap(); + + let config = VringConfigData { + queue_max_size: 0x1000, + queue_size: 0x2000, + flags: 0x0, + desc_table_addr: 0x4000, + used_ring_addr: 0x5000, + avail_ring_addr: 0x6000, + log_addr: None, + }; + b.set_vring_addr(1, &config).unwrap(); + + b.set_vring_base(1, 2).unwrap(); + assert_eq!(b.get_vring_base(1).unwrap(), 2); + + let eventfd = EventFd::new(0).unwrap(); + b.set_vring_call(1, &eventfd).unwrap(); + b.set_vring_kick(1, &eventfd).unwrap(); + b.set_vring_err(1, &eventfd).unwrap(); + } + #[test] fn test_vring_config_data() { let mut config = VringConfigData { diff --git a/src/vhost_kern/mod.rs b/src/vhost_kern/mod.rs index 350e134..248cbae 100644 --- a/src/vhost_kern/mod.rs +++ b/src/vhost_kern/mod.rs @@ -87,20 +87,20 @@ pub trait VhostKernBackend: AsRawFd { impl VhostBackend for T { /// Set the current process as the owner of this file descriptor. /// This must be run before any other vhost ioctls. - fn set_owner(&mut self) -> Result<()> { + fn set_owner(&self) -> Result<()> { // This ioctl is called on a valid vhost fd and has its return value checked. let ret = unsafe { ioctl(self, VHOST_SET_OWNER()) }; ioctl_result(ret, ()) } - fn reset_owner(&mut self) -> Result<()> { + fn reset_owner(&self) -> Result<()> { // This ioctl is called on a valid vhost fd and has its return value checked. let ret = unsafe { ioctl(self, VHOST_RESET_OWNER()) }; ioctl_result(ret, ()) } /// Get a bitmask of supported virtio/vhost features. - fn get_features(&mut self) -> Result { + fn get_features(&self) -> Result { let mut avail_features: u64 = 0; // This ioctl is called on a valid vhost fd and has its return value checked. let ret = unsafe { ioctl_with_mut_ref(self, VHOST_GET_FEATURES(), &mut avail_features) }; @@ -112,14 +112,14 @@ impl VhostBackend for T { /// /// # Arguments /// * `features` - Bitmask of features to set. - fn set_features(&mut self, features: u64) -> Result<()> { + fn set_features(&self, features: u64) -> Result<()> { // This ioctl is called on a valid vhost fd and has its return value checked. let ret = unsafe { ioctl_with_ref(self, VHOST_SET_FEATURES(), &features) }; ioctl_result(ret, ()) } /// Set the guest memory mappings for vhost to use. - fn set_mem_table(&mut self, regions: &[VhostUserMemoryRegionInfo]) -> Result<()> { + fn set_mem_table(&self, regions: &[VhostUserMemoryRegionInfo]) -> Result<()> { if regions.is_empty() || regions.len() > VHOST_MAX_MEMORY_REGIONS { return Err(Error::InvalidGuestMemory); } @@ -148,7 +148,7 @@ impl VhostBackend for T { /// /// # Arguments /// * `base` - Base address for page modification logging. - fn set_log_base(&mut self, base: u64, fd: Option) -> Result<()> { + fn set_log_base(&self, base: u64, fd: Option) -> Result<()> { if fd.is_some() { return Err(Error::LogAddress); } @@ -159,7 +159,7 @@ impl VhostBackend for T { } /// Specify an eventfd file descriptor to signal on log write. - fn set_log_fd(&mut self, fd: RawFd) -> Result<()> { + fn set_log_fd(&self, fd: RawFd) -> Result<()> { // This ioctl is called on a valid vhost fd and has its return value checked. let val: i32 = fd; let ret = unsafe { ioctl_with_ref(self, VHOST_SET_LOG_FD(), &val) }; @@ -171,7 +171,7 @@ impl VhostBackend for T { /// # Arguments /// * `queue_index` - Index of the queue to set descriptor count for. /// * `num` - Number of descriptors in the queue. - fn set_vring_num(&mut self, queue_index: usize, num: u16) -> Result<()> { + fn set_vring_num(&self, queue_index: usize, num: u16) -> Result<()> { let vring_state = vhost_vring_state { index: queue_index as u32, num: u32::from(num), @@ -187,7 +187,7 @@ impl VhostBackend for T { /// # Arguments /// * `queue_index` - Index of the queue to set addresses for. /// * `config_data` - Vring config data. - fn set_vring_addr(&mut self, queue_index: usize, config_data: &VringConfigData) -> Result<()> { + fn set_vring_addr(&self, queue_index: usize, config_data: &VringConfigData) -> Result<()> { if !self.is_valid(config_data) { return Err(Error::InvalidQueue); } @@ -212,7 +212,7 @@ impl VhostBackend for T { /// # Arguments /// * `queue_index` - Index of the queue to modify. /// * `num` - Index where available descriptors start. - fn set_vring_base(&mut self, queue_index: usize, base: u16) -> Result<()> { + fn set_vring_base(&self, queue_index: usize, base: u16) -> Result<()> { let vring_state = vhost_vring_state { index: queue_index as u32, num: u32::from(base), @@ -224,7 +224,7 @@ impl VhostBackend for T { } /// Get a bitmask of supported virtio/vhost features. - fn get_vring_base(&mut self, queue_index: usize) -> Result { + fn get_vring_base(&self, queue_index: usize) -> Result { let vring_state = vhost_vring_state { index: queue_index as u32, num: 0, @@ -239,7 +239,7 @@ impl VhostBackend for T { /// # Arguments /// * `queue_index` - Index of the queue to modify. /// * `fd` - EventFd to trigger. - fn set_vring_call(&mut self, queue_index: usize, fd: &EventFd) -> Result<()> { + fn set_vring_call(&self, queue_index: usize, fd: &EventFd) -> Result<()> { let vring_file = vhost_vring_file { index: queue_index as u32, fd: fd.as_raw_fd(), @@ -256,7 +256,7 @@ impl VhostBackend for T { /// # Arguments /// * `queue_index` - Index of the queue to modify. /// * `fd` - EventFd that will be signaled from guest. - fn set_vring_kick(&mut self, queue_index: usize, fd: &EventFd) -> Result<()> { + fn set_vring_kick(&self, queue_index: usize, fd: &EventFd) -> Result<()> { let vring_file = vhost_vring_file { index: queue_index as u32, fd: fd.as_raw_fd(), @@ -272,7 +272,7 @@ impl VhostBackend for T { /// # Arguments /// * `queue_index` - Index of the queue to modify. /// * `fd` - EventFd that will be signaled from the backend. - fn set_vring_err(&mut self, queue_index: usize, fd: &EventFd) -> Result<()> { + fn set_vring_err(&self, queue_index: usize, fd: &EventFd) -> Result<()> { let vring_file = vhost_vring_file { index: queue_index as u32, fd: fd.as_raw_fd(), diff --git a/src/vhost_user/master.rs b/src/vhost_user/master.rs index ffed909..2651b84 100644 --- a/src/vhost_user/master.rs +++ b/src/vhost_user/master.rs @@ -115,7 +115,7 @@ impl Master { impl VhostBackend for Master { /// Get from the underlying vhost implementation the feature bitmask. - fn get_features(&mut self) -> Result { + fn get_features(&self) -> Result { let mut node = self.node.lock().unwrap(); let hdr = node.send_request_header(MasterReq::GET_FEATURES, None)?; let val = node.recv_reply::(&hdr)?; @@ -124,7 +124,7 @@ impl VhostBackend for Master { } /// Enable features in the underlying vhost implementation using a bitmask. - fn set_features(&mut self, features: u64) -> Result<()> { + fn set_features(&self, features: u64) -> Result<()> { let mut node = self.node.lock().unwrap(); let val = VhostUserU64::new(features); let _ = node.send_request_with_body(MasterReq::SET_FEATURES, &val, None)?; @@ -135,7 +135,7 @@ impl VhostBackend for Master { } /// Set the current Master as an owner of the session. - fn set_owner(&mut self) -> Result<()> { + fn set_owner(&self) -> Result<()> { // We unwrap() the return value to assert that we are not expecting threads to ever fail // while holding the lock. let mut node = self.node.lock().unwrap(); @@ -145,7 +145,7 @@ impl VhostBackend for Master { Ok(()) } - fn reset_owner(&mut self) -> Result<()> { + fn reset_owner(&self) -> Result<()> { let mut node = self.node.lock().unwrap(); let _ = node.send_request_header(MasterReq::RESET_OWNER, None)?; // Don't wait for ACK here because the protocol feature negotiation process hasn't been @@ -155,7 +155,7 @@ impl VhostBackend for Master { /// Set the memory map regions on the slave so it can translate the vring /// addresses. In the ancillary data there is an array of file descriptors - fn set_mem_table(&mut self, regions: &[VhostUserMemoryRegionInfo]) -> Result<()> { + fn set_mem_table(&self, regions: &[VhostUserMemoryRegionInfo]) -> Result<()> { if regions.is_empty() || regions.len() > MAX_ATTACHED_FD_ENTRIES { return error_code(VhostUserError::InvalidParam); } @@ -187,7 +187,7 @@ impl VhostBackend for Master { // Clippy doesn't seem to know that if let with && is still experimental #[allow(clippy::unnecessary_unwrap)] - fn set_log_base(&mut self, base: u64, fd: Option) -> Result<()> { + fn set_log_base(&self, base: u64, fd: Option) -> Result<()> { let mut node = self.node.lock().unwrap(); let val = VhostUserU64::new(base); @@ -202,7 +202,7 @@ impl VhostBackend for Master { Ok(()) } - fn set_log_fd(&mut self, fd: RawFd) -> Result<()> { + fn set_log_fd(&self, fd: RawFd) -> Result<()> { let mut node = self.node.lock().unwrap(); let fds = [fd]; node.send_request_header(MasterReq::SET_LOG_FD, Some(&fds))?; @@ -210,7 +210,7 @@ impl VhostBackend for Master { } /// Set the size of the queue. - fn set_vring_num(&mut self, queue_index: usize, num: u16) -> Result<()> { + fn set_vring_num(&self, queue_index: usize, num: u16) -> Result<()> { let mut node = self.node.lock().unwrap(); if queue_index as u64 >= node.max_queue_num { return error_code(VhostUserError::InvalidParam); @@ -222,7 +222,7 @@ impl VhostBackend for Master { } /// Sets the addresses of the different aspects of the vring. - fn set_vring_addr(&mut self, queue_index: usize, config_data: &VringConfigData) -> Result<()> { + fn set_vring_addr(&self, queue_index: usize, config_data: &VringConfigData) -> Result<()> { let mut node = self.node.lock().unwrap(); if queue_index as u64 >= node.max_queue_num || config_data.flags & !(VhostUserVringAddrFlags::all().bits()) != 0 @@ -236,7 +236,7 @@ impl VhostBackend for Master { } /// Sets the base offset in the available vring. - fn set_vring_base(&mut self, queue_index: usize, base: u16) -> Result<()> { + fn set_vring_base(&self, queue_index: usize, base: u16) -> Result<()> { let mut node = self.node.lock().unwrap(); if queue_index as u64 >= node.max_queue_num { return error_code(VhostUserError::InvalidParam); @@ -247,7 +247,7 @@ impl VhostBackend for Master { node.wait_for_ack(&hdr).map_err(|e| e.into()) } - fn get_vring_base(&mut self, queue_index: usize) -> Result { + fn get_vring_base(&self, queue_index: usize) -> Result { let mut node = self.node.lock().unwrap(); if queue_index as u64 >= node.max_queue_num { return error_code(VhostUserError::InvalidParam); @@ -263,7 +263,7 @@ impl VhostBackend for Master { /// Bits (0-7) of the payload contain the vring index. Bit 8 is the invalid FD flag. This flag /// is set when there is no file descriptor in the ancillary data. This signals that polling /// will be used instead of waiting for the call. - fn set_vring_call(&mut self, queue_index: usize, fd: &EventFd) -> Result<()> { + fn set_vring_call(&self, queue_index: usize, fd: &EventFd) -> Result<()> { let mut node = self.node.lock().unwrap(); if queue_index as u64 >= node.max_queue_num { return error_code(VhostUserError::InvalidParam); @@ -276,7 +276,7 @@ impl VhostBackend for Master { /// Bits (0-7) of the payload contain the vring index. Bit 8 is the invalid FD flag. This flag /// is set when there is no file descriptor in the ancillary data. This signals that polling /// should be used instead of waiting for a kick. - fn set_vring_kick(&mut self, queue_index: usize, fd: &EventFd) -> Result<()> { + fn set_vring_kick(&self, queue_index: usize, fd: &EventFd) -> Result<()> { let mut node = self.node.lock().unwrap(); if queue_index as u64 >= node.max_queue_num { return error_code(VhostUserError::InvalidParam); @@ -288,7 +288,7 @@ impl VhostBackend for Master { /// Set the event file descriptor to signal when error occurs. /// Bits (0-7) of the payload contain the vring index. Bit 8 is the invalid FD flag. This flag /// is set when there is no file descriptor in the ancillary data. - fn set_vring_err(&mut self, queue_index: usize, fd: &EventFd) -> Result<()> { + fn set_vring_err(&self, queue_index: usize, fd: &EventFd) -> Result<()> { let mut node = self.node.lock().unwrap(); if queue_index as u64 >= node.max_queue_num { return error_code(VhostUserError::InvalidParam); @@ -654,7 +654,7 @@ mod tests { let listener = Listener::new(UNIX_SOCKET_MASTER, true).unwrap(); listener.set_nonblocking(true).unwrap(); - let mut master = Master::connect(UNIX_SOCKET_MASTER, 1).unwrap(); + let master = Master::connect(UNIX_SOCKET_MASTER, 1).unwrap(); let mut slave = Endpoint::::from_stream(listener.accept().unwrap().unwrap()); // Send two messages continuously @@ -692,7 +692,7 @@ mod tests { #[test] #[ignore] fn test_features() { - let (mut master, mut peer) = create_pair(UNIX_SOCKET_MASTER3); + let (master, mut peer) = create_pair(UNIX_SOCKET_MASTER3); master.set_owner().unwrap(); let (hdr, rfds) = peer.recv_header().unwrap(); diff --git a/src/vhost_user/mod.rs b/src/vhost_user/mod.rs index 48a93ff..4259d0f 100644 --- a/src/vhost_user/mod.rs +++ b/src/vhost_user/mod.rs @@ -203,7 +203,7 @@ mod tests { #[test] fn test_set_owner() { let slave_be = Arc::new(Mutex::new(DummySlaveReqHandler::new())); - let (mut master, mut slave) = + let (master, mut slave) = create_slave("/tmp/vhost_user_lib_unit_test_owner", slave_be.clone()); assert_eq!(slave_be.lock().unwrap().owned, false); -- cgit v1.2.3 From a7847dcf7def2b750468297e060d0a55be2a7b62 Mon Sep 17 00:00:00 2001 From: Liu Jiang Date: Fri, 19 Feb 2021 19:11:17 +0800 Subject: Update .gitignore file Ignore build, build_kcov and .idea directories under the root. Signed-off-by: Liu Jiang --- .gitignore | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.gitignore b/.gitignore index 6936990..f738aa8 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ +/build +/kcov_build /target +.idea **/*.rs.bk Cargo.lock -- cgit v1.2.3 From 97421f754d8d38508687926b5045c4ed1bd6c107 Mon Sep 17 00:00:00 2001 From: Liu Jiang Date: Fri, 19 Feb 2021 19:21:17 +0800 Subject: Upgrade to rust 2018 edition Upgrade Cargo.toml to rust edition 2018. Also introduce a helper feature flag "vhost-user" to simplify code. Signed-off-by: Liu Jiang --- Cargo.toml | 9 +++++-- coverage_config_x86_64.json | 2 +- src/backend.rs | 2 +- src/lib.rs | 62 +++++++++++++++++++++++++++++++++++---------- 4 files changed, 58 insertions(+), 17 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 8c676b3..b8609f5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,13 +4,15 @@ version = "0.1.0" authors = ["Liu Jiang "] repository = "https://github.com/rust-vmm/vhost" license = "Apache-2.0 or BSD-3-Clause" +edition = "2018" [features] default = [] vhost-vsock = [] vhost-kern = ["vm-memory"] -vhost-user-master = [] -vhost-user-slave = [] +vhost-user = [] +vhost-user-master = ["vhost-user"] +vhost-user-slave = ["vhost-user"] [dependencies] bitflags = ">=1.0.1" @@ -18,3 +20,6 @@ libc = ">=0.2.39" vmm-sys-util = ">=0.3.1" vm-memory = { version = "0.2.0", optional = true } + +[dev-dependencies] +vm-memory = { version = "0.2.0", features=["backend-mmap"] } diff --git a/coverage_config_x86_64.json b/coverage_config_x86_64.json index ec91006..e9ac51c 100644 --- a/coverage_config_x86_64.json +++ b/coverage_config_x86_64.json @@ -1 +1 @@ -{"coverage_score": 40.2, "exclude_path": "", "crate_features": "vhost-vsock,vhost-kern,vhost-user-master,vhost-user-slave"} +{"coverage_score": 73.3, "exclude_path": "src/vhost_kern/", "crate_features": ""} diff --git a/src/backend.rs b/src/backend.rs index 9dafef7..89fde50 100644 --- a/src/backend.rs +++ b/src/backend.rs @@ -364,7 +364,7 @@ impl VhostBackend for RefCell { } #[cfg(test)] mod tests { - use VringConfigData; + use super::*; struct MockBackend {} diff --git a/src/lib.rs b/src/lib.rs index e0cb2b8..a3852a6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,4 @@ -// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved. +// Copyright (C) 2019 Alibaba Cloud. All rights reserved. // SPDX-License-Identifier: Apache-2.0 or BSD-3-Clause //! Virtio Vhost Backend Drivers @@ -32,14 +32,8 @@ #![deny(missing_docs)] -#[cfg_attr( - any(feature = "vhost-user-master", feature = "vhost-user-slave"), - macro_use -)] +#[cfg_attr(feature = "vhost-user", macro_use)] extern crate bitflags; -extern crate libc; -#[cfg(feature = "vhost-kern")] -extern crate vm_memory; #[cfg_attr(feature = "vhost-kern", macro_use)] extern crate vmm_sys_util; @@ -48,7 +42,7 @@ pub use backend::*; #[cfg(feature = "vhost-kern")] pub mod vhost_kern; -#[cfg(any(feature = "vhost-user-master", feature = "vhost-user-slave"))] +#[cfg(feature = "vhost-user")] pub mod vhost_user; #[cfg(feature = "vhost-vsock")] pub mod vsock; @@ -80,7 +74,7 @@ pub enum Error { IoctlError(std::io::Error), /// Error from IO subsystem. IOError(std::io::Error), - #[cfg(any(feature = "vhost-user-master", feature = "vhost-user-slave"))] + #[cfg(feature = "vhost-user-master")] /// Error from the vhost-user subsystem. VhostUserProtocol(vhost_user::Error), } @@ -94,20 +88,22 @@ impl std::fmt::Display for Error { Error::InvalidQueue => write!(f, "invalid virtque"), Error::DescriptorTableAddress => write!(f, "invalid virtque descriptor talbe address"), Error::UsedAddress => write!(f, "invalid virtque used talbe address"), - Error::AvailAddress => write!(f, "invalid virtque available talbe address"), + Error::AvailAddress => write!(f, "invalid virtque available table address"), Error::LogAddress => write!(f, "invalid virtque log address"), Error::IOError(e) => write!(f, "IO error: {}", e), #[cfg(feature = "vhost-kern")] Error::VhostOpen(e) => write!(f, "failure in opening vhost file: {}", e), #[cfg(feature = "vhost-kern")] Error::IoctlError(e) => write!(f, "failure in vhost ioctl: {}", e), - #[cfg(any(feature = "vhost-user-master", feature = "vhost-user-slave"))] + #[cfg(feature = "vhost-user-master")] Error::VhostUserProtocol(e) => write!(f, "vhost-user: {}", e), } } } -#[cfg(any(feature = "vhost-user-master", feature = "vhost-user-slave"))] +impl std::error::Error for Error {} + +#[cfg(feature = "vhost-user")] impl std::convert::From for Error { fn from(err: vhost_user::Error) -> Self { Error::VhostUserProtocol(err) @@ -116,3 +112,43 @@ impl std::convert::From for Error { /// Result of vhost operations pub type Result = std::result::Result; + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_error() { + assert_eq!( + format!("{}", Error::AvailAddress), + "invalid virtque available table address" + ); + assert_eq!( + format!("{}", Error::InvalidOperation), + "invalid vhost operations" + ); + assert_eq!( + format!("{}", Error::InvalidGuestMemory), + "invalid guest memory object" + ); + assert_eq!( + format!("{}", Error::InvalidGuestMemoryRegion), + "invalid guest memory region" + ); + assert_eq!(format!("{}", Error::InvalidQueue), "invalid virtque"); + assert_eq!( + format!("{}", Error::DescriptorTableAddress), + "invalid virtque descriptor talbe address" + ); + assert_eq!( + format!("{}", Error::UsedAddress), + "invalid virtque used talbe address" + ); + assert_eq!( + format!("{}", Error::LogAddress), + "invalid virtque log address" + ); + + assert_eq!(format!("{:?}", Error::AvailAddress), "AvailAddress"); + } +} -- cgit v1.2.3 From 07eee11be3aa7b82fbbb23a6bf4d933e2fa4ae97 Mon Sep 17 00:00:00 2001 From: Liu Jiang Date: Fri, 19 Feb 2021 22:07:48 +0800 Subject: vhost_kern: add more unit test cases Add more unit test cases for vhost-kern submodule. Signed-off-by: Liu Jiang --- src/vhost_kern/mod.rs | 48 ++++++++++----------- src/vhost_kern/vsock.rs | 112 ++++++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 130 insertions(+), 30 deletions(-) diff --git a/src/vhost_kern/mod.rs b/src/vhost_kern/mod.rs index 248cbae..f263a39 100644 --- a/src/vhost_kern/mod.rs +++ b/src/vhost_kern/mod.rs @@ -13,7 +13,7 @@ use std::os::unix::io::{AsRawFd, RawFd}; -use vm_memory::GuestAddressSpace; +use vm_memory::{Address, GuestAddress, GuestAddressSpace, GuestMemory, GuestUsize}; use vmm_sys_util::eventfd::EventFd; use vmm_sys_util::ioctl::{ioctl, ioctl_with_mut_ref, ioctl_with_ptr, ioctl_with_ref}; @@ -39,7 +39,7 @@ fn ioctl_result(rc: i32, res: T) -> Result { /// Represent an in-kernel vhost device backend. pub trait VhostKernBackend: AsRawFd { - /// Assoicated type to access guest memory. + /// Associated type to access guest memory. type AS: GuestAddressSpace; /// Get the object to access the guest's memory. @@ -55,50 +55,32 @@ pub trait VhostKernBackend: AsRawFd { return false; } - // TODO: the GuestMemory trait lacks of method to look up GPA by HVA, - // so there's no way to validate HVAs. Please extend vm-memory crate - // first. - /* + let m = self.mem().memory(); let desc_table_size = 16 * u64::from(queue_size) as GuestUsize; let avail_ring_size = 6 + 2 * u64::from(queue_size) as GuestUsize; let used_ring_size = 6 + 8 * u64::from(queue_size) as GuestUsize; if GuestAddress(config_data.desc_table_addr) .checked_add(desc_table_size) - .map_or(true, |v| !self.mem().address_in_range(v)) + .map_or(true, |v| !m.address_in_range(v)) { false } else if GuestAddress(config_data.avail_ring_addr) .checked_add(avail_ring_size) - .map_or(true, |v| !self.mem().address_in_range(v)) + .map_or(true, |v| !m.address_in_range(v)) { false } else if GuestAddress(config_data.used_ring_addr) .checked_add(used_ring_size) - .map_or(true, |v| !self.mem().address_in_range(v)) + .map_or(true, |v| !m.address_in_range(v)) { false + } else { + config_data.is_log_addr_valid() } - */ - - config_data.is_log_addr_valid() } } impl VhostBackend for T { - /// Set the current process as the owner of this file descriptor. - /// This must be run before any other vhost ioctls. - fn set_owner(&self) -> Result<()> { - // This ioctl is called on a valid vhost fd and has its return value checked. - let ret = unsafe { ioctl(self, VHOST_SET_OWNER()) }; - ioctl_result(ret, ()) - } - - fn reset_owner(&self) -> Result<()> { - // This ioctl is called on a valid vhost fd and has its return value checked. - let ret = unsafe { ioctl(self, VHOST_RESET_OWNER()) }; - ioctl_result(ret, ()) - } - /// Get a bitmask of supported virtio/vhost features. fn get_features(&self) -> Result { let mut avail_features: u64 = 0; @@ -118,6 +100,20 @@ impl VhostBackend for T { ioctl_result(ret, ()) } + /// Set the current process as the owner of this file descriptor. + /// This must be run before any other vhost ioctls. + fn set_owner(&self) -> Result<()> { + // This ioctl is called on a valid vhost fd and has its return value checked. + let ret = unsafe { ioctl(self, VHOST_SET_OWNER()) }; + ioctl_result(ret, ()) + } + + fn reset_owner(&self) -> Result<()> { + // This ioctl is called on a valid vhost fd and has its return value checked. + let ret = unsafe { ioctl(self, VHOST_RESET_OWNER()) }; + ioctl_result(ret, ()) + } + /// Set the guest memory mappings for vhost to use. fn set_mem_table(&self, regions: &[VhostUserMemoryRegionInfo]) -> Result<()> { if regions.is_empty() || regions.len() > VHOST_MAX_MEMORY_REGIONS { diff --git a/src/vhost_kern/vsock.rs b/src/vhost_kern/vsock.rs index c4149bd..7ccf670 100644 --- a/src/vhost_kern/vsock.rs +++ b/src/vhost_kern/vsock.rs @@ -1,22 +1,23 @@ -// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved. +// Copyright (C) 2019 Alibaba Cloud. All rights reserved. // SPDX-License-Identifier: Apache-2.0 or BSD-3-Clause // // Copyright 2017 The Chromium OS Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE-BSD-Google file. -//! Kernel-based vsock vhost backend. +//! Kernel-based vhost-vsock backend. use std::fs::{File, OpenOptions}; use std::os::unix::fs::OpenOptionsExt; use std::os::unix::io::{AsRawFd, RawFd}; -use super::vhost_binding::{VHOST_VSOCK_SET_GUEST_CID, VHOST_VSOCK_SET_RUNNING}; -use super::{ioctl_result, Error, Result, VhostKernBackend}; use libc; use vm_memory::GuestAddressSpace; use vmm_sys_util::ioctl::ioctl_with_ref; +use super::vhost_binding::{VHOST_VSOCK_SET_GUEST_CID, VHOST_VSOCK_SET_RUNNING}; +use super::{ioctl_result, Error, Result, VhostKernBackend}; + const VHOST_PATH: &str = "/dev/vhost-vsock"; /// Handle for running VHOST_VSOCK ioctls. @@ -79,3 +80,106 @@ impl AsRawFd for Vsock { self.fd.as_raw_fd() } } + +#[cfg(test)] +mod tests { + use vm_memory::{GuestAddress, GuestMemory, GuestMemoryMmap}; + use vmm_sys_util::eventfd::EventFd; + + use super::*; + use crate::{VhostBackend, VhostUserMemoryRegionInfo, VringConfigData}; + + #[test] + fn test_vsock_new_device() { + let m = GuestMemoryMmap::from_ranges(&[(GuestAddress(0), 0x10_0000)]).unwrap(); + let vsock = Vsock::new(&m).unwrap(); + + assert!(vsock.as_raw_fd() >= 0); + assert!(vsock.mem().find_region(GuestAddress(0x100)).is_some()); + assert!(vsock.mem().find_region(GuestAddress(0x10_0000)).is_none()); + } + + #[test] + fn test_vsock_is_valid() { + let m = GuestMemoryMmap::from_ranges(&[(GuestAddress(0), 0x10_0000)]).unwrap(); + let vsock = Vsock::new(&m).unwrap(); + + let mut config = VringConfigData { + queue_max_size: 32, + queue_size: 32, + flags: 0, + desc_table_addr: 0x1000, + used_ring_addr: 0x2000, + avail_ring_addr: 0x3000, + log_addr: None, + }; + assert_eq!(vsock.is_valid(&config), true); + + config.queue_size = 0; + assert_eq!(vsock.is_valid(&config), false); + config.queue_size = 31; + assert_eq!(vsock.is_valid(&config), false); + config.queue_size = 33; + assert_eq!(vsock.is_valid(&config), false); + } + + #[test] + fn test_vsock_ioctls() { + let m = GuestMemoryMmap::from_ranges(&[(GuestAddress(0), 0x10_0000)]).unwrap(); + let vsock = Vsock::new(&m).unwrap(); + + let features = vsock.get_features().unwrap(); + vsock.set_features(features).unwrap(); + + vsock.set_owner().unwrap(); + + vsock.set_mem_table(&[]).unwrap_err(); + + /* + let region = VhostUserMemoryRegionInfo { + guest_phys_addr: 0x0, + memory_size: 0x10_0000, + userspace_addr: 0, + mmap_offset: 0, + mmap_handle: -1, + }; + vsock.set_mem_table(&[region]).unwrap_err(); + */ + + let region = VhostUserMemoryRegionInfo { + guest_phys_addr: 0x0, + memory_size: 0x10_0000, + userspace_addr: m.get_host_address(GuestAddress(0x0)).unwrap() as u64, + mmap_offset: 0, + mmap_handle: -1, + }; + vsock.set_mem_table(&[region]).unwrap(); + + vsock.set_log_base(0x4000, Some(1)).unwrap_err(); + vsock.set_log_base(0x4000, None).unwrap(); + + let eventfd = EventFd::new(0).unwrap(); + vsock.set_log_fd(eventfd.as_raw_fd()).unwrap(); + + vsock.set_vring_num(0, 32).unwrap(); + + let config = VringConfigData { + queue_max_size: 32, + queue_size: 32, + flags: 0, + desc_table_addr: 0x1000, + used_ring_addr: 0x2000, + avail_ring_addr: 0x3000, + log_addr: None, + }; + vsock.set_vring_addr(0, &config).unwrap(); + vsock.set_vring_base(0, 1).unwrap(); + vsock.set_vring_call(0, &eventfd).unwrap(); + vsock.set_vring_kick(0, &eventfd).unwrap(); + vsock.set_vring_err(0, &eventfd).unwrap(); + assert_eq!(vsock.get_vring_base(0).unwrap(), 1); + vsock.set_guest_cid(0xdead).unwrap(); + //vsock.start().unwrap(); + //vsock.stop().unwrap(); + } +} -- cgit v1.2.3 From 0a8b2449345d87aca80d3dfcbf2fe183e292199e Mon Sep 17 00:00:00 2001 From: Liu Jiang Date: Fri, 19 Feb 2021 22:28:52 +0800 Subject: vhost_kern/vsock: implemnt VhostVsock Refine VhostVsock trait definition, and implement it for kernel based vsock driver. Signed-off-by: Liu Jiang --- src/vhost_kern/vsock.rs | 28 ++++++++++++---------------- src/vsock.rs | 6 +++--- 2 files changed, 15 insertions(+), 19 deletions(-) diff --git a/src/vhost_kern/vsock.rs b/src/vhost_kern/vsock.rs index 7ccf670..7cc1cf5 100644 --- a/src/vhost_kern/vsock.rs +++ b/src/vhost_kern/vsock.rs @@ -17,6 +17,7 @@ use vmm_sys_util::ioctl::ioctl_with_ref; use super::vhost_binding::{VHOST_VSOCK_SET_GUEST_CID, VHOST_VSOCK_SET_RUNNING}; use super::{ioctl_result, Error, Result, VhostKernBackend}; +use crate::vsock::VhostVsock; const VHOST_PATH: &str = "/dev/vhost-vsock"; @@ -40,31 +41,26 @@ impl Vsock { }) } - /// Set the CID for the guest. This number is used for routing all data destined for - /// running in the guest. Each guest on a hypervisor must have an unique CID - /// - /// # Arguments - /// * `cid` - CID to assign to the guest - pub fn set_guest_cid(&self, cid: u64) -> Result<()> { + fn set_running(&self, running: bool) -> Result<()> { + let on: ::std::os::raw::c_int = if running { 1 } else { 0 }; + let ret = unsafe { ioctl_with_ref(&self.fd, VHOST_VSOCK_SET_RUNNING(), &on) }; + ioctl_result(ret, ()) + } +} + +impl VhostVsock for Vsock { + fn set_guest_cid(&self, cid: u64) -> Result<()> { let ret = unsafe { ioctl_with_ref(&self.fd, VHOST_VSOCK_SET_GUEST_CID(), &cid) }; ioctl_result(ret, ()) } - /// Tell the VHOST driver to start performing data transfer. - pub fn start(&self) -> Result<()> { + fn start(&self) -> Result<()> { self.set_running(true) } - /// Tell the VHOST driver to stop performing data transfer. - pub fn stop(&self) -> Result<()> { + fn stop(&self) -> Result<()> { self.set_running(false) } - - fn set_running(&self, running: bool) -> Result<()> { - let on: ::std::os::raw::c_int = if running { 1 } else { 0 }; - let ret = unsafe { ioctl_with_ref(&self.fd, VHOST_VSOCK_SET_RUNNING(), &on) }; - ioctl_result(ret, ()) - } } impl VhostKernBackend for Vsock { diff --git a/src/vsock.rs b/src/vsock.rs index 4fb75f5..1e1b0b9 100644 --- a/src/vsock.rs +++ b/src/vsock.rs @@ -20,11 +20,11 @@ pub trait VhostVsock: VhostBackend { /// /// # Arguments /// * `cid` - CID to assign to the guest - fn set_guest_cid(&mut self, cid: u64) -> Result<()>; + fn set_guest_cid(&self, cid: u64) -> Result<()>; /// Tell the VHOST driver to start performing data transfer. - fn start(&mut self) -> Result<()>; + fn start(&self) -> Result<()>; /// Tell the VHOST driver to stop performing data transfer. - fn stop(&mut self) -> Result<()>; + fn stop(&self) -> Result<()>; } -- cgit v1.2.3 From fa089424f13db12ff85fbe5c1e8650d010da7f8f Mon Sep 17 00:00:00 2001 From: Liu Jiang Date: Sat, 20 Feb 2021 01:42:17 +0800 Subject: vhost_user: use sock_ctrl_msg from vmm-sys-util Now the sock_ctrl_msg from the vmm-sys-util crate is ready for use, so replace the locally copied version. Signed-off-by: Liu Jiang --- src/vhost_user/connection.rs | 5 +- src/vhost_user/mod.rs | 2 - src/vhost_user/sock_ctrl_msg.rs | 499 ---------------------------------------- 3 files changed, 3 insertions(+), 503 deletions(-) delete mode 100644 src/vhost_user/sock_ctrl_msg.rs diff --git a/src/vhost_user/connection.rs b/src/vhost_user/connection.rs index deafdeb..5aa580b 100644 --- a/src/vhost_user/connection.rs +++ b/src/vhost_user/connection.rs @@ -5,15 +5,16 @@ #![allow(dead_code)] -use libc::{c_void, iovec}; use std::io::ErrorKind; use std::marker::PhantomData; use std::os::unix::io::{AsRawFd, RawFd}; use std::os::unix::net::{UnixListener, UnixStream}; use std::{mem, slice}; +use libc::{c_void, iovec}; +use vmm_sys_util::sock_ctrl_msg::ScmSocket; + use super::message::*; -use super::sock_ctrl_msg::ScmSocket; use super::{Error, Result}; /// Unix domain socket listener for accepting incoming connections. diff --git a/src/vhost_user/mod.rs b/src/vhost_user/mod.rs index 4259d0f..148a00e 100644 --- a/src/vhost_user/mod.rs +++ b/src/vhost_user/mod.rs @@ -46,8 +46,6 @@ mod slave_fs_cache; #[cfg(feature = "vhost-user-slave")] pub use self::slave_fs_cache::SlaveFsCacheReq; -pub mod sock_ctrl_msg; - /// Errors for vhost-user operations #[derive(Debug)] pub enum Error { diff --git a/src/vhost_user/sock_ctrl_msg.rs b/src/vhost_user/sock_ctrl_msg.rs deleted file mode 100644 index db3ec2e..0000000 --- a/src/vhost_user/sock_ctrl_msg.rs +++ /dev/null @@ -1,499 +0,0 @@ -// Copyright 2017 The Chromium OS Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -//! Used to send and receive messages with file descriptors on sockets that accept control messages -//! (e.g. Unix domain sockets). - -// TODO: move this file into the vmm-sys-util crate - -use std::fs::File; -use std::mem::size_of; -use std::os::unix::io::{AsRawFd, FromRawFd, RawFd}; -use std::os::unix::net::{UnixDatagram, UnixStream}; -use std::ptr::{copy_nonoverlapping, null_mut, write_unaligned}; - -use libc::{ - c_long, c_void, cmsghdr, iovec, msghdr, recvmsg, sendmsg, MSG_NOSIGNAL, SCM_RIGHTS, SOL_SOCKET, -}; -use vmm_sys_util::errno::{Error, Result}; - -// Each of the following macros performs the same function as their C counterparts. They are each -// macros because they are used to size statically allocated arrays. - -macro_rules! CMSG_ALIGN { - ($len:expr) => { - (($len) + size_of::() - 1) & !(size_of::() - 1) - }; -} - -macro_rules! CMSG_SPACE { - ($len:expr) => { - size_of::() + CMSG_ALIGN!($len) - }; -} - -#[cfg(not(target_env = "musl"))] -macro_rules! CMSG_LEN { - ($len:expr) => { - size_of::() + ($len) - }; -} - -#[cfg(target_env = "musl")] -macro_rules! CMSG_LEN { - ($len:expr) => {{ - let sz = size_of::() + ($len); - assert!(sz <= (std::u32::MAX as usize)); - sz as u32 - }}; -} - -#[cfg(not(target_env = "musl"))] -fn new_msghdr(iovecs: &mut [iovec]) -> msghdr { - msghdr { - msg_name: null_mut(), - msg_namelen: 0, - msg_iov: iovecs.as_mut_ptr(), - msg_iovlen: iovecs.len(), - msg_control: null_mut(), - msg_controllen: 0, - msg_flags: 0, - } -} - -#[cfg(target_env = "musl")] -fn new_msghdr(iovecs: &mut [iovec]) -> msghdr { - assert!(iovecs.len() <= (std::i32::MAX as usize)); - let mut msg: msghdr = unsafe { std::mem::zeroed() }; - msg.msg_name = null_mut(); - msg.msg_iov = iovecs.as_mut_ptr(); - msg.msg_iovlen = iovecs.len() as i32; - msg.msg_control = null_mut(); - msg -} - -#[cfg(not(target_env = "musl"))] -fn set_msg_controllen(msg: &mut msghdr, cmsg_capacity: usize) { - msg.msg_controllen = cmsg_capacity; -} - -#[cfg(target_env = "musl")] -fn set_msg_controllen(msg: &mut msghdr, cmsg_capacity: usize) { - assert!(cmsg_capacity <= (std::u32::MAX as usize)); - msg.msg_controllen = cmsg_capacity as u32; -} - -// This function (macro in the C version) is not used in any compile time constant slots, so is just -// an ordinary function. The returned pointer is hard coded to be RawFd because that's all that this -// module supports. -#[allow(non_snake_case)] -#[inline(always)] -fn CMSG_DATA(cmsg_buffer: *mut cmsghdr) -> *mut RawFd { - // Essentially returns a pointer to just past the header. - cmsg_buffer.wrapping_offset(1) as *mut RawFd -} - -// This function is like CMSG_NEXT, but safer because it reads only from references, although it -// does some pointer arithmetic on cmsg_ptr. -#[cfg_attr(feature = "cargo-clippy", allow(clippy::cast_ptr_alignment))] -fn get_next_cmsg(msghdr: &msghdr, cmsg: &cmsghdr, cmsg_ptr: *mut cmsghdr) -> *mut cmsghdr { - let next_cmsg = - (cmsg_ptr as *mut u8).wrapping_add(CMSG_ALIGN!(cmsg.cmsg_len as usize)) as *mut cmsghdr; - if next_cmsg - .wrapping_offset(1) - .wrapping_sub(msghdr.msg_control as usize) as usize - > msghdr.msg_controllen as usize - { - null_mut() - } else { - next_cmsg - } -} - -const CMSG_BUFFER_INLINE_CAPACITY: usize = CMSG_SPACE!(size_of::() * 32); - -enum CmsgBuffer { - Inline([u64; (CMSG_BUFFER_INLINE_CAPACITY + 7) / 8]), - Heap(Box<[cmsghdr]>), -} - -impl CmsgBuffer { - fn with_capacity(capacity: usize) -> CmsgBuffer { - let cap_in_cmsghdr_units = - (capacity.checked_add(size_of::()).unwrap() - 1) / size_of::(); - if capacity <= CMSG_BUFFER_INLINE_CAPACITY { - CmsgBuffer::Inline([0u64; (CMSG_BUFFER_INLINE_CAPACITY + 7) / 8]) - } else { - CmsgBuffer::Heap( - vec![ - cmsghdr { - cmsg_len: 0, - cmsg_level: 0, - cmsg_type: 0, - #[cfg(target_env = "musl")] - __pad1: 0, - }; - cap_in_cmsghdr_units - ] - .into_boxed_slice(), - ) - } - } - - fn as_mut_ptr(&mut self) -> *mut cmsghdr { - match self { - CmsgBuffer::Inline(a) => a.as_mut_ptr() as *mut cmsghdr, - CmsgBuffer::Heap(a) => a.as_mut_ptr(), - } - } -} - -fn raw_sendmsg(fd: RawFd, out_data: &[D], out_fds: &[RawFd]) -> Result { - let cmsg_capacity = CMSG_SPACE!(size_of::() * out_fds.len()); - let mut cmsg_buffer = CmsgBuffer::with_capacity(cmsg_capacity); - - let mut iovecs = Vec::with_capacity(out_data.len()); - for data in out_data { - iovecs.push(iovec { - iov_base: data.as_ptr() as *mut c_void, - iov_len: data.size(), - }); - } - - let mut msg = new_msghdr(&mut iovecs); - - if !out_fds.is_empty() { - let cmsg = cmsghdr { - cmsg_len: CMSG_LEN!(size_of::() * out_fds.len()), - cmsg_level: SOL_SOCKET, - cmsg_type: SCM_RIGHTS, - #[cfg(target_env = "musl")] - __pad1: 0, - }; - unsafe { - // Safe because cmsg_buffer was allocated to be large enough to contain cmsghdr. - write_unaligned(cmsg_buffer.as_mut_ptr() as *mut cmsghdr, cmsg); - // Safe because the cmsg_buffer was allocated to be large enough to hold out_fds.len() - // file descriptors. - copy_nonoverlapping( - out_fds.as_ptr(), - CMSG_DATA(cmsg_buffer.as_mut_ptr()), - out_fds.len(), - ); - } - - msg.msg_control = cmsg_buffer.as_mut_ptr() as *mut c_void; - set_msg_controllen(&mut msg, cmsg_capacity); - } - - // Safe because the msghdr was properly constructed from valid (or null) pointers of the - // indicated length and we check the return value. - let write_count = unsafe { sendmsg(fd, &msg, MSG_NOSIGNAL) }; - - if write_count == -1 { - Err(Error::last()) - } else { - Ok(write_count as usize) - } -} - -fn raw_recvmsg(fd: RawFd, iovecs: &mut [iovec], in_fds: &mut [RawFd]) -> Result<(usize, usize)> { - let cmsg_capacity = CMSG_SPACE!(size_of::() * in_fds.len()); - let mut cmsg_buffer = CmsgBuffer::with_capacity(cmsg_capacity); - let mut msg = new_msghdr(iovecs); - - if !in_fds.is_empty() { - msg.msg_control = cmsg_buffer.as_mut_ptr() as *mut c_void; - set_msg_controllen(&mut msg, cmsg_capacity); - } - - // Safe because the msghdr was properly constructed from valid (or null) pointers of the - // indicated length and we check the return value. - let total_read = unsafe { recvmsg(fd, &mut msg, libc::MSG_WAITALL) }; - - if total_read == -1 { - return Err(Error::last()); - } - - // When the connection is closed recvmsg() doesn't give an explicit error - if total_read == 0 && (msg.msg_controllen as usize) < size_of::() { - return Err(Error::new(libc::ECONNRESET)); - } - - let mut cmsg_ptr = msg.msg_control as *mut cmsghdr; - let mut in_fds_count = 0; - while !cmsg_ptr.is_null() { - // Safe because we checked that cmsg_ptr was non-null, and the loop is constructed such that - // that only happens when there is at least sizeof(cmsghdr) space after the pointer to read. - let cmsg = unsafe { (cmsg_ptr as *mut cmsghdr).read_unaligned() }; - - if cmsg.cmsg_level == SOL_SOCKET && cmsg.cmsg_type == SCM_RIGHTS { - let fd_count = (cmsg.cmsg_len - CMSG_LEN!(0)) as usize / size_of::(); - unsafe { - copy_nonoverlapping( - CMSG_DATA(cmsg_ptr), - in_fds[in_fds_count..(in_fds_count + fd_count)].as_mut_ptr(), - fd_count, - ); - } - in_fds_count += fd_count; - } - - cmsg_ptr = get_next_cmsg(&msg, &cmsg, cmsg_ptr); - } - - Ok((total_read as usize, in_fds_count)) -} - -/// Trait for file descriptors can send and receive socket control messages via `sendmsg` and -/// `recvmsg`. -pub trait ScmSocket { - /// Gets the file descriptor of this socket. - fn socket_fd(&self) -> RawFd; - - /// Sends the given data and file descriptor over the socket. - /// - /// On success, returns the number of bytes sent. - /// - /// # Arguments - /// - /// * `buf` - A buffer of data to send on the `socket`. - /// * `fd` - A file descriptors to be sent. - fn send_with_fd(&self, buf: D, fd: RawFd) -> Result { - self.send_with_fds(&[buf], &[fd]) - } - - /// Sends the given data and file descriptors over the socket. - /// - /// On success, returns the number of bytes sent. - /// - /// # Arguments - /// - /// * `bufs` - A list of data buffer to send on the `socket`. - /// * `fds` - A list of file descriptors to be sent. - fn send_with_fds(&self, bufs: &[D], fds: &[RawFd]) -> Result { - raw_sendmsg(self.socket_fd(), bufs, fds) - } - - /// Receives data and potentially a file descriptor from the socket. - /// - /// On success, returns the number of bytes and an optional file descriptor. - /// - /// # Arguments - /// - /// * `buf` - A buffer to receive data from the socket. - fn recv_with_fd(&self, buf: &mut [u8]) -> Result<(usize, Option)> { - let mut fd = [0]; - let mut iovecs = [iovec { - iov_base: buf.as_mut_ptr() as *mut c_void, - iov_len: buf.len(), - }]; - - let (read_count, fd_count) = self.recv_with_fds(&mut iovecs[..], &mut fd)?; - let file = if fd_count == 0 { - None - } else { - // Safe because the first fd from recv_with_fds is owned by us and valid because this - // branch was taken. - Some(unsafe { File::from_raw_fd(fd[0]) }) - }; - Ok((read_count, file)) - } - - /// Receives data and file descriptors from the socket. - /// - /// On success, returns the number of bytes and file descriptors received as a tuple - /// `(bytes count, files count)`. - /// - /// # Arguments - /// - /// * `iovecs` - A list of iovec to receive data from the socket. - /// * `fds` - A slice of `RawFd`s to put the received file descriptors into. On success, the - /// number of valid file descriptors is indicated by the second element of the - /// returned tuple. The caller owns these file descriptors, but they will not be - /// closed on drop like a `File`-like type would be. It is recommended that each valid - /// file descriptor gets wrapped in a drop type that closes it after this returns. - fn recv_with_fds(&self, iovecs: &mut [iovec], fds: &mut [RawFd]) -> Result<(usize, usize)> { - raw_recvmsg(self.socket_fd(), iovecs, fds) - } -} - -impl ScmSocket for UnixDatagram { - fn socket_fd(&self) -> RawFd { - self.as_raw_fd() - } -} - -impl ScmSocket for UnixStream { - fn socket_fd(&self) -> RawFd { - self.as_raw_fd() - } -} - -/// Trait for types that can be converted into an `iovec` that can be referenced by a syscall for -/// the lifetime of this object. -/// -/// This trait is unsafe because interfaces that use this trait depend on the base pointer and size -/// being accurate. -pub unsafe trait IntoIovec { - /// Gets the base pointer of this `iovec`. - fn as_ptr(&self) -> *const c_void; - - /// Gets the size in bytes of this `iovec`. - fn size(&self) -> usize; -} - -// Safe because this slice can not have another mutable reference and it's pointer and size are -// guaranteed to be valid. -unsafe impl<'a> IntoIovec for &'a [u8] { - // Clippy false positive: https://github.com/rust-lang/rust-clippy/issues/3480 - #[cfg_attr(feature = "cargo-clippy", allow(clippy::useless_asref))] - fn as_ptr(&self) -> *const c_void { - self.as_ref().as_ptr() as *const c_void - } - - fn size(&self) -> usize { - self.len() - } -} - -#[cfg(test)] -mod tests { - use super::*; - - use std::io::Write; - use std::mem::size_of; - use std::os::raw::c_long; - use std::os::unix::net::UnixDatagram; - use std::slice::from_raw_parts; - - use libc::cmsghdr; - - use vmm_sys_util::eventfd::EventFd; - - #[test] - fn buffer_len() { - assert_eq!(CMSG_SPACE!(0 * size_of::()), size_of::()); - assert_eq!( - CMSG_SPACE!(1 * size_of::()), - size_of::() + size_of::() - ); - if size_of::() == 4 { - assert_eq!( - CMSG_SPACE!(2 * size_of::()), - size_of::() + size_of::() - ); - assert_eq!( - CMSG_SPACE!(3 * size_of::()), - size_of::() + size_of::() * 2 - ); - assert_eq!( - CMSG_SPACE!(4 * size_of::()), - size_of::() + size_of::() * 2 - ); - } else if size_of::() == 8 { - assert_eq!( - CMSG_SPACE!(2 * size_of::()), - size_of::() + size_of::() * 2 - ); - assert_eq!( - CMSG_SPACE!(3 * size_of::()), - size_of::() + size_of::() * 3 - ); - assert_eq!( - CMSG_SPACE!(4 * size_of::()), - size_of::() + size_of::() * 4 - ); - } - } - - #[test] - fn send_recv_no_fd() { - let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair"); - - let write_count = s1 - .send_with_fds(&[[1u8, 1, 2].as_ref(), [21u8, 34, 55].as_ref()], &[]) - .expect("failed to send data"); - - assert_eq!(write_count, 6); - - let mut buf = [0u8; 6]; - let mut files = [0; 1]; - let mut iovecs = [iovec { - iov_base: buf.as_mut_ptr() as *mut c_void, - iov_len: buf.len(), - }]; - let (read_count, file_count) = s2 - .recv_with_fds(&mut iovecs[..], &mut files) - .expect("failed to recv data"); - - assert_eq!(read_count, 6); - assert_eq!(file_count, 0); - assert_eq!(buf, [1, 1, 2, 21, 34, 55]); - } - - #[test] - fn send_recv_only_fd() { - let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair"); - - let evt = EventFd::new(0).expect("failed to create eventfd"); - let write_count = s1 - .send_with_fd([].as_ref(), evt.as_raw_fd()) - .expect("failed to send fd"); - - assert_eq!(write_count, 0); - - let (read_count, file_opt) = s2.recv_with_fd(&mut []).expect("failed to recv fd"); - - let mut file = file_opt.unwrap(); - - assert_eq!(read_count, 0); - assert!(file.as_raw_fd() >= 0); - assert_ne!(file.as_raw_fd(), s1.as_raw_fd()); - assert_ne!(file.as_raw_fd(), s2.as_raw_fd()); - assert_ne!(file.as_raw_fd(), evt.as_raw_fd()); - - file.write(unsafe { from_raw_parts(&1203u64 as *const u64 as *const u8, 8) }) - .expect("failed to write to sent fd"); - - assert_eq!(evt.read().expect("failed to read from eventfd"), 1203); - } - - #[test] - fn send_recv_with_fd() { - let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair"); - - let evt = EventFd::new(0).expect("failed to create eventfd"); - let write_count = s1 - .send_with_fds(&[[237].as_ref()], &[evt.as_raw_fd()]) - .expect("failed to send fd"); - - assert_eq!(write_count, 1); - - let mut files = [0; 2]; - let mut buf = [0u8]; - let mut iovecs = [iovec { - iov_base: buf.as_mut_ptr() as *mut c_void, - iov_len: buf.len(), - }]; - let (read_count, file_count) = s2 - .recv_with_fds(&mut iovecs[..], &mut files) - .expect("failed to recv fd"); - - assert_eq!(read_count, 1); - assert_eq!(buf[0], 237); - assert_eq!(file_count, 1); - assert!(files[0] >= 0); - assert_ne!(files[0], s1.as_raw_fd()); - assert_ne!(files[0], s2.as_raw_fd()); - assert_ne!(files[0], evt.as_raw_fd()); - - let mut file = unsafe { File::from_raw_fd(files[0]) }; - - file.write(unsafe { from_raw_parts(&1203u64 as *const u64 as *const u8, 8) }) - .expect("failed to write to sent fd"); - - assert_eq!(evt.read().expect("failed to read from eventfd"), 1203); - } -} -- cgit v1.2.3 From 3b497e1fbea8740b4c502684b6f219104ac78d7f Mon Sep 17 00:00:00 2001 From: Liu Jiang Date: Sun, 21 Feb 2021 21:35:37 +0800 Subject: vhost-user: refine the MasterReqHandler trait Refine the MasterReqHandler trait with: 1) better documentation, 2) abstracting as MasterReqHandler and MasterReqHandlerMut, 3) honoring the negotiation state of VHOST_USER_PROTOCOL_F_REPLY_ACK, 4) enhancing set_failed() to clear the error state, 5) validating field `size` in the received message header, 6) reading struct by std::ptr::read_unaligned instead of directly access the underlying buffer, Signed-off-by: Liu Jiang --- src/vhost_user/master_req_handler.rs | 183 +++++++++++++++++++++++++++-------- 1 file changed, 141 insertions(+), 42 deletions(-) diff --git a/src/vhost_user/master_req_handler.rs b/src/vhost_user/master_req_handler.rs index aadfeee..02c2bb7 100644 --- a/src/vhost_user/master_req_handler.rs +++ b/src/vhost_user/master_req_handler.rs @@ -1,8 +1,6 @@ -// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved. +// Copyright (C) 2019-2021 Alibaba Cloud. All rights reserved. // SPDX-License-Identifier: Apache-2.0 -//! Traits and Structs to handle vhost-user requests from the slave to the master. - use libc; use std::mem; use std::os::unix::io::{AsRawFd, RawFd}; @@ -13,83 +11,189 @@ use super::connection::Endpoint; use super::message::*; use super::{Error, HandlerResult, Result}; -/// Trait to handle vhost-user requests from the slave to the master. +/// Define services provided by masters for the slave communication channel. +/// +/// The vhost-user specification defines a slave communication channel, by which slaves could +/// request services from masters. The [VhostUserMasterReqHandler] trait defines services provided +/// by masters, and it's used both on the master side and slave side. +/// - on the slave side, a stub forwarder implementing [VhostUserMasterReqHandler] will proxy +/// service requests to masters. The [SlaveFsCacheReq] is an example stub forwarder. +/// - on the master side, the [MasterReqHandler] will forward service requests to a handler +/// implementing [VhostUserMasterReqHandler]. +/// +/// The [VhostUserMasterReqHandler] trait is design with interior mutability to improve performance +/// for multi-threading. +/// +/// [VhostUserMasterReqHandler]: trait.VhostUserMasterReqHandler.html +/// [MasterReqHandler]: struct.MasterReqHandler.html +/// [SlaveFsCacheReq]: struct.SlaveFsCacheReq.html pub trait VhostUserMasterReqHandler { + /// Handle device configuration change notifications. + fn handle_config_change(&self) -> HandlerResult { + Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) + } + + /// Handle virtio-fs map file requests. + fn fs_slave_map(&self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult { + // Safe because we have just received the rawfd from kernel. + unsafe { libc::close(fd) }; + Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) + } + + /// Handle virtio-fs unmap file requests. + fn fs_slave_unmap(&self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult { + Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) + } + + /// Handle virtio-fs sync file requests. + fn fs_slave_sync(&self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult { + Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) + } + + /// Handle virtio-fs file IO requests. + fn fs_slave_io(&self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult { + // Safe because we have just received the rawfd from kernel. + unsafe { libc::close(fd) }; + Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) + } + // fn handle_iotlb_msg(&mut self, iotlb: VhostUserIotlb); // fn handle_vring_host_notifier(&mut self, area: VhostUserVringArea, fd: RawFd); +} - /// Handle device configuration change notifications from the slave. +/// A helper trait mirroring [VhostUserMasterReqHandler] but without interior mutability. +/// +/// [VhostUserMasterReqHandler]: trait.VhostUserMasterReqHandler.html +pub trait VhostUserMasterReqHandlerMut { + /// Handle device configuration change notifications. fn handle_config_change(&mut self) -> HandlerResult { Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) } - /// Handle virtio-fs map file requests from the slave. + /// Handle virtio-fs map file requests. fn fs_slave_map(&mut self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult { // Safe because we have just received the rawfd from kernel. unsafe { libc::close(fd) }; Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) } - /// Handle virtio-fs unmap file requests from the slave. + /// Handle virtio-fs unmap file requests. fn fs_slave_unmap(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult { Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) } - /// Handle virtio-fs sync file requests from the slave. + /// Handle virtio-fs sync file requests. fn fs_slave_sync(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult { Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) } - /// Handle virtio-fs file IO requests from the slave. + /// Handle virtio-fs file IO requests. fn fs_slave_io(&mut self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult { // Safe because we have just received the rawfd from kernel. unsafe { libc::close(fd) }; Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) } + + // fn handle_iotlb_msg(&mut self, iotlb: VhostUserIotlb); + // fn handle_vring_host_notifier(&mut self, area: VhostUserVringArea, fd: RawFd); +} + +impl VhostUserMasterReqHandler for Mutex { + fn handle_config_change(&self) -> HandlerResult { + self.lock().unwrap().handle_config_change() + } + + fn fs_slave_map(&self, fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult { + self.lock().unwrap().fs_slave_map(fs, fd) + } + + fn fs_slave_unmap(&self, fs: &VhostUserFSSlaveMsg) -> HandlerResult { + self.lock().unwrap().fs_slave_unmap(fs) + } + + fn fs_slave_sync(&self, fs: &VhostUserFSSlaveMsg) -> HandlerResult { + self.lock().unwrap().fs_slave_sync(fs) + } + + fn fs_slave_io(&self, fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult { + self.lock().unwrap().fs_slave_io(fs, fd) + } } -/// A vhost-user master request endpoint which relays all received requests from the slave to the -/// provided request handler. +/// Server to handle service requests from slaves from the slave communication channel. +/// +/// The [MasterReqHandler] acts as a server on the master side, to handle service requests from +/// slaves on the slave communication channel. It's actually a proxy invoking the registered +/// handler implementing [VhostUserMasterReqHandler] to do the real work. +/// +/// [MasterReqHandler]: struct.MasterReqHandler.html +/// [VhostUserMasterReqHandler]: trait.VhostUserMasterReqHandler.html pub struct MasterReqHandler { // underlying Unix domain socket for communication sub_sock: Endpoint, tx_sock: UnixStream, + // Protocol feature VHOST_USER_PROTOCOL_F_REPLY_ACK has been negotiated. + reply_ack_negotiated: bool, // the VirtIO backend device object - backend: Arc>, + backend: Arc, // whether the endpoint has encountered any failure error: Option, } impl MasterReqHandler { - /// Create a vhost-user slave request handler. - /// This opens a pair of connected anonymous sockets. - /// Returns Self and the socket that must be sent to the slave via SET_SLAVE_REQ_FD. - pub fn new(backend: Arc>) -> Result { + /// Create a server to handle service requests from slaves on the slave communication channel. + /// + /// This opens a pair of connected anonymous sockets to form the slave communication channel. + /// The socket fd returned by [Self::get_tx_raw_fd()] should be sent to the slave by + /// [VhostUserMaster::set_slave_request_fd()]. + /// + /// [Self::get_tx_raw_fd()]: struct.MasterReqHandler.html#method.get_tx_raw_fd + /// [VhostUserMaster::set_slave_request_fd()]: trait.VhostUserMaster.html#tymethod.set_slave_request_fd + pub fn new(backend: Arc) -> Result { let (tx, rx) = UnixStream::pair().map_err(Error::SocketError)?; Ok(MasterReqHandler { sub_sock: Endpoint::::from_stream(rx), tx_sock: tx, + reply_ack_negotiated: false, backend, error: None, }) } - /// Get the raw fd to send to the slave as slave communication channel. + /// Get the socket fd for the slave to communication with the master. + /// + /// The returned fd should be sent to the slave by [VhostUserMaster::set_slave_request_fd()]. + /// + /// [VhostUserMaster::set_slave_request_fd()]: trait.VhostUserMaster.html#tymethod.set_slave_request_fd pub fn get_tx_raw_fd(&self) -> RawFd { self.tx_sock.as_raw_fd() } - /// Mark endpoint as failed or normal state. + /// Set the negotiation state of the `VHOST_USER_PROTOCOL_F_REPLY_ACK` protocol feature. + /// + /// When the `VHOST_USER_PROTOCOL_F_REPLY_ACK` protocol feature has been negotiated, + /// the "REPLY_ACK" flag will be set in the message header for every slave to master request + /// message. + pub fn set_reply_ack_flag(&mut self, enable: bool) { + self.reply_ack_negotiated = enable; + } + + /// Mark endpoint as failed or in normal state. pub fn set_failed(&mut self, error: i32) { - self.error = Some(error); + if error == 0 { + self.error = None; + } else { + self.error = Some(error); + } } - /// Receive and handle one incoming request message from the slave. + /// Main entrance to server slave request from the slave communication channel. + /// /// The caller needs to: - /// . serialize calls to this function - /// . decide what to do when errer happens - /// . optional recover from failure + /// - serialize calls to this function + /// - decide what to do when errer happens + /// - optional recover from failure pub fn handle_request(&mut self) -> Result { // Return error if the endpoint is already in failed state. self.check_state()?; @@ -108,6 +212,9 @@ impl MasterReqHandler { let (size, buf) = match hdr.get_size() { 0 => (0, vec![0u8; 0]), len => { + if len as usize > MAX_MSG_SIZE { + return Err(Error::InvalidMessage); + } let (size2, rbuf) = self.sub_sock.recv_data(len as usize)?; if size2 != len as usize { return Err(Error::InvalidMessage); @@ -120,41 +227,33 @@ impl MasterReqHandler { SlaveReq::CONFIG_CHANGE_MSG => { self.check_msg_size(&hdr, size, 0)?; self.backend - .lock() - .unwrap() .handle_config_change() .map_err(Error::ReqHandlerError) } SlaveReq::FS_MAP => { let msg = self.extract_msg_body::(&hdr, size, &buf)?; + // check_attached_rfds() has validated rfds self.backend - .lock() - .unwrap() - .fs_slave_map(msg, rfds.unwrap()[0]) + .fs_slave_map(&msg, rfds.unwrap()[0]) .map_err(Error::ReqHandlerError) } SlaveReq::FS_UNMAP => { let msg = self.extract_msg_body::(&hdr, size, &buf)?; self.backend - .lock() - .unwrap() - .fs_slave_unmap(msg) + .fs_slave_unmap(&msg) .map_err(Error::ReqHandlerError) } SlaveReq::FS_SYNC => { let msg = self.extract_msg_body::(&hdr, size, &buf)?; self.backend - .lock() - .unwrap() - .fs_slave_sync(msg) + .fs_slave_sync(&msg) .map_err(Error::ReqHandlerError) } SlaveReq::FS_IO => { let msg = self.extract_msg_body::(&hdr, size, &buf)?; + // check_attached_rfds() has validated rfds self.backend - .lock() - .unwrap() - .fs_slave_io(msg, rfds.unwrap()[0]) + .fs_slave_io(&msg, rfds.unwrap()[0]) .map_err(Error::ReqHandlerError) } _ => Err(Error::InvalidMessage), @@ -219,14 +318,14 @@ impl MasterReqHandler { } } - fn extract_msg_body<'a, T: Sized + VhostUserMsgValidator>( + fn extract_msg_body( &self, hdr: &VhostUserMsgHeader, size: usize, - buf: &'a [u8], - ) -> Result<&'a T> { + buf: &[u8], + ) -> Result { self.check_msg_size(hdr, size, mem::size_of::())?; - let msg = unsafe { &*(buf.as_ptr() as *const T) }; + let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const T) }; if !msg.is_valid() { return Err(Error::InvalidMessage); } @@ -253,7 +352,7 @@ impl MasterReqHandler { req: &VhostUserMsgHeader, res: &Result, ) -> Result<()> { - if req.is_need_reply() { + if self.reply_ack_negotiated && req.is_need_reply() { let hdr = self.new_reply_header::(req)?; let def_err = libc::EINVAL; let val = match res { -- cgit v1.2.3 From dcb79ab156dbe8e547cb4c61a873a9c4319fc028 Mon Sep 17 00:00:00 2001 From: Liu Jiang Date: Sun, 21 Feb 2021 21:45:10 +0800 Subject: vhost_user: refine the SlaveFsCacheReq struct Refine the SlaveFsCacheReq struct by: 1) honoring the negotiation result of VHOST_USER_PROTOCOL_F_REPLY_ACK, 2) better documentation, 3) adding unit test cases. Signed-off-by: Liu Jiang --- src/vhost_user/slave_fs_cache.rs | 161 ++++++++++++++++++++++++++++++--------- 1 file changed, 123 insertions(+), 38 deletions(-) diff --git a/src/vhost_user/slave_fs_cache.rs b/src/vhost_user/slave_fs_cache.rs index 1804c7a..32b2b8e 100644 --- a/src/vhost_user/slave_fs_cache.rs +++ b/src/vhost_user/slave_fs_cache.rs @@ -1,61 +1,59 @@ -// Copyright (C) 2020 Alibaba Cloud Computing. All rights reserved. +// Copyright (C) 2020 Alibaba Cloud. All rights reserved. // SPDX-License-Identifier: Apache-2.0 -use super::connection::Endpoint; -use super::message::*; -use super::{Error, HandlerResult, Result, VhostUserMasterReqHandler}; use std::io; use std::mem; use std::os::unix::io::RawFd; use std::os::unix::net::UnixStream; use std::sync::{Arc, Mutex}; +use super::connection::Endpoint; +use super::message::*; +use super::{Error, HandlerResult, Result, VhostUserMasterReqHandler}; + struct SlaveFsCacheReqInternal { sock: Endpoint, -} -/// A vhost-user slave endpoint which sends fs cache requests to the master -#[derive(Clone)] -pub struct SlaveFsCacheReq { - // underlying Unix domain socket for communication - node: Arc>, + // Protocol feature VHOST_USER_PROTOCOL_F_REPLY_ACK has been negotiated. + reply_ack_negotiated: bool, // whether the endpoint has encountered any failure error: Option, } -impl SlaveFsCacheReq { - fn new(ep: Endpoint) -> Self { - SlaveFsCacheReq { - node: Arc::new(Mutex::new(SlaveFsCacheReqInternal { sock: ep })), - error: None, +impl SlaveFsCacheReqInternal { + fn check_state(&self) -> Result { + match self.error { + Some(e) => Err(Error::SocketBroken(std::io::Error::from_raw_os_error(e))), + None => Ok(0), } } - /// Create a new instance. - pub fn from_stream(sock: UnixStream) -> Self { - Self::new(Endpoint::::from_stream(sock)) - } - fn send_message( &mut self, - flags: SlaveReq, + request: SlaveReq, fs: &VhostUserFSSlaveMsg, fds: Option<&[RawFd]>, ) -> Result { self.check_state()?; let len = mem::size_of::(); - let mut hdr = VhostUserMsgHeader::new(flags, 0, len as u32); - hdr.set_need_reply(true); - self.node.lock().unwrap().sock.send_message(&hdr, fs, fds)?; + let mut hdr = VhostUserMsgHeader::new(request, 0, len as u32); + if self.reply_ack_negotiated { + hdr.set_need_reply(true); + } + self.sock.send_message(&hdr, fs, fds)?; self.wait_for_ack(&hdr) } fn wait_for_ack(&mut self, hdr: &VhostUserMsgHeader) -> Result { self.check_state()?; - let (reply, body, rfds) = self.node.lock().unwrap().sock.recv_body::()?; + if !self.reply_ack_negotiated { + return Ok(0); + } + + let (reply, body, rfds) = self.sock.recv_body::()?; if !reply.is_reply_for(&hdr) || rfds.is_some() || !body.is_valid() { Endpoint::::close_rfds(rfds); return Err(Error::InvalidMessage); @@ -63,32 +61,119 @@ impl SlaveFsCacheReq { if body.value != 0 { return Err(Error::MasterInternalError); } - Ok(0) + + Ok(body.value) } +} - fn check_state(&self) -> Result { - match self.error { - Some(e) => Err(Error::SocketBroken(std::io::Error::from_raw_os_error(e))), - None => Ok(0), +/// Request proxy to send vhost-user-fs slave requests to the master through the slave +/// communication channel. +/// +/// The [SlaveFsCacheReq] acts as a message proxy to forward vhost-user-fs slave requests to the +/// master through the vhost-user slave communication channel. The forwarded messages will be +/// handled by the [MasterReqHandler] server. +/// +/// [SlaveFsCacheReq]: struct.SlaveFsCacheReq.html +/// [MasterReqHandler]: struct.MasterReqHandler.html +#[derive(Clone)] +pub struct SlaveFsCacheReq { + // underlying Unix domain socket for communication + node: Arc>, +} + +impl SlaveFsCacheReq { + fn new(ep: Endpoint) -> Self { + SlaveFsCacheReq { + node: Arc::new(Mutex::new(SlaveFsCacheReqInternal { + sock: ep, + reply_ack_negotiated: false, + error: None, + })), } } + fn send_message( + &self, + request: SlaveReq, + fs: &VhostUserFSSlaveMsg, + fds: Option<&[RawFd]>, + ) -> io::Result { + self.node + .lock() + .unwrap() + .send_message(request, fs, fds) + .or_else(|e| Err(io::Error::new(io::ErrorKind::Other, format!("{}", e)))) + } + + /// Create a new instance from a `UnixStream` object. + pub fn from_stream(sock: UnixStream) -> Self { + Self::new(Endpoint::::from_stream(sock)) + } + + /// Set the negotiation state of the `VHOST_USER_PROTOCOL_F_REPLY_ACK` protocol feature. + /// + /// When the `VHOST_USER_PROTOCOL_F_REPLY_ACK` protocol feature has been negotiated, + /// the "REPLY_ACK" flag will be set in the message header for every slave to master request + /// message. + pub fn set_reply_ack_flag(&self, enable: bool) { + self.node.lock().unwrap().reply_ack_negotiated = enable; + } + /// Mark endpoint as failed with specified error code. - pub fn set_failed(&mut self, error: i32) { - self.error = Some(error); + pub fn set_failed(&self, error: i32) { + self.node.lock().unwrap().error = Some(error); } } impl VhostUserMasterReqHandler for SlaveFsCacheReq { - /// Handle virtio-fs map file requests from the slave. - fn fs_slave_map(&mut self, fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult { + /// Forward vhost-user-fs map file requests to the slave. + fn fs_slave_map(&self, fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult { self.send_message(SlaveReq::FS_MAP, fs, Some(&[fd])) - .or_else(|e| Err(io::Error::new(io::ErrorKind::Other, format!("{}", e)))) } - /// Handle virtio-fs unmap file requests from the slave. - fn fs_slave_unmap(&mut self, fs: &VhostUserFSSlaveMsg) -> HandlerResult { + /// Forward vhost-user-fs unmap file requests to the master. + fn fs_slave_unmap(&self, fs: &VhostUserFSSlaveMsg) -> HandlerResult { self.send_message(SlaveReq::FS_UNMAP, fs, None) - .or_else(|e| Err(io::Error::new(io::ErrorKind::Other, format!("{}", e)))) + } +} + +#[cfg(test)] +mod tests { + use std::os::unix::io::AsRawFd; + + use super::*; + + #[test] + fn test_slave_fs_cache_req_set_failed() { + let (p1, _p2) = UnixStream::pair().unwrap(); + let fs_cache = SlaveFsCacheReq::from_stream(p1); + + assert!(fs_cache.node.lock().unwrap().error.is_none()); + fs_cache.set_failed(libc::EAGAIN); + assert_eq!(fs_cache.node.lock().unwrap().error, Some(libc::EAGAIN)); + } + + #[test] + fn test_slave_fs_cache_send_failure() { + let (p1, p2) = UnixStream::pair().unwrap(); + let fd = p2.as_raw_fd(); + let fs_cache = SlaveFsCacheReq::from_stream(p1); + + fs_cache.set_failed(libc::ECONNRESET); + fs_cache + .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd) + .unwrap_err(); + fs_cache + .fs_slave_unmap(&VhostUserFSSlaveMsg::default()) + .unwrap_err(); + fs_cache.node.lock().unwrap().error = None; + + drop(p2); + fs_cache + .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd) + .unwrap_err(); + fs_cache + .fs_slave_unmap(&VhostUserFSSlaveMsg::default()) + .unwrap_err(); } } -- cgit v1.2.3 From dc452e5aeb5d106c1f76210f8c6b788bc41cc571 Mon Sep 17 00:00:00 2001 From: Liu Jiang Date: Sun, 21 Feb 2021 21:57:31 +0800 Subject: vhost_user: fix bugs and refine message defintion Fix two bugs: 1) in VhostUserConfigFlags definition 2) in validating VhostUserConfig Refine vhost-user message definition by: 1) validate offset of VhostUserConfig 2) better documentation 3) more unit tests Signed-off-by: Liu Jiang --- src/vhost_user/message.rs | 162 ++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 155 insertions(+), 7 deletions(-) diff --git a/src/vhost_user/message.rs b/src/vhost_user/message.rs index 4109b61..0f6e1a2 100644 --- a/src/vhost_user/message.rs +++ b/src/vhost_user/message.rs @@ -562,9 +562,9 @@ bitflags! { /// Flags for the device configuration message. pub struct VhostUserConfigFlags: u32 { /// Vhost master messages used for writeable fields. - const WRITABLE = 0x0; + const WRITABLE = 0x1; /// Vhost master messages used for live migration. - const LIVE_MIGRATION = 0x1; + const LIVE_MIGRATION = 0x2; } } @@ -596,9 +596,11 @@ impl VhostUserMsgValidator for VhostUserConfig { fn is_valid(&self) -> bool { if (self.flags & !VhostUserConfigFlags::all().bits()) != 0 { return false; + } else if self.offset < 0x100 { + return false; } else if self.size == 0 || self.size > VHOST_USER_CONFIG_SIZE - || self.size + self.offset >= VHOST_USER_CONFIG_SIZE + || self.size + self.offset > VHOST_USER_CONFIG_SIZE { return false; } @@ -656,9 +658,9 @@ pub const VHOST_USER_FS_SLAVE_ENTRIES: usize = 8; #[repr(packed)] #[derive(Default)] pub struct VhostUserFSSlaveMsg { - /// TODO: + /// File offset. pub fd_offset: [u64; VHOST_USER_FS_SLAVE_ENTRIES], - /// TODO: + /// Offset into the DAX window. pub cache_offset: [u64; VHOST_USER_FS_SLAVE_ENTRIES], /// Size of region to map. pub len: [u64; VHOST_USER_FS_SLAVE_ENTRIES], @@ -686,13 +688,31 @@ mod tests { use std::mem; #[test] - fn check_request_code() { + fn check_master_request_code() { let code = MasterReq::NOOP; assert!(!code.is_valid()); let code = MasterReq::MAX_CMD; assert!(!code.is_valid()); + assert!(code > MasterReq::NOOP); let code = MasterReq::GET_FEATURES; assert!(code.is_valid()); + assert_eq!(code, code.clone()); + let code: MasterReq = unsafe { std::mem::transmute::(10000u32) }; + assert!(!code.is_valid()); + } + + #[test] + fn check_slave_request_code() { + let code = SlaveReq::NOOP; + assert!(!code.is_valid()); + let code = SlaveReq::MAX_CMD; + assert!(!code.is_valid()); + assert!(code > SlaveReq::NOOP); + let code = SlaveReq::CONFIG_CHANGE_MSG; + assert!(code.is_valid()); + assert_eq!(code, code.clone()); + let code: SlaveReq = unsafe { std::mem::transmute::(10000u32) }; + assert!(!code.is_valid()); } #[test] @@ -741,6 +761,20 @@ mod tests { assert!(!hdr.is_valid()); hdr.set_version(0x1); assert!(hdr.is_valid()); + + assert_eq!(hdr, hdr.clone()); + } + + #[test] + fn test_vhost_user_message_u64() { + let val = VhostUserU64::default(); + let val1 = VhostUserU64::new(0); + + let a = val.value; + let b = val1.value; + assert_eq!(a, b); + let a = VhostUserU64::new(1).value; + assert_eq!(a, 1); } #[test] @@ -775,6 +809,104 @@ mod tests { msg.guest_phys_addr = 0xFFFFFFFFFFFF0000; msg.memory_size = 0; assert!(!msg.is_valid()); + let a = msg.clone().guest_phys_addr; + let b = msg.guest_phys_addr; + assert_eq!(a, b); + + let msg = VhostUserMemoryRegion::default(); + let a = msg.guest_phys_addr; + assert_eq!(a, 0); + let a = msg.memory_size; + assert_eq!(a, 0); + let a = msg.user_addr; + assert_eq!(a, 0); + let a = msg.mmap_offset; + assert_eq!(a, 0); + } + + #[test] + fn test_vhost_user_state() { + let state = VhostUserVringState::new(5, 8); + + let a = state.index; + assert_eq!(a, 5); + let a = state.num; + assert_eq!(a, 8); + assert_eq!(state.is_valid(), true); + + let state = VhostUserVringState::default(); + let a = state.index; + assert_eq!(a, 0); + let a = state.num; + assert_eq!(a, 0); + assert_eq!(state.is_valid(), true); + } + + #[test] + fn test_vhost_user_addr() { + let mut addr = VhostUserVringAddr::new( + 2, + VhostUserVringAddrFlags::VHOST_VRING_F_LOG, + 0x1000, + 0x2000, + 0x3000, + 0x4000, + ); + + let a = addr.index; + assert_eq!(a, 2); + let a = addr.flags; + assert_eq!(a, VhostUserVringAddrFlags::VHOST_VRING_F_LOG.bits()); + let a = addr.descriptor; + assert_eq!(a, 0x1000); + let a = addr.used; + assert_eq!(a, 0x2000); + let a = addr.available; + assert_eq!(a, 0x3000); + let a = addr.log; + assert_eq!(a, 0x4000); + assert_eq!(addr.is_valid(), true); + + addr.descriptor = 0x1001; + assert_eq!(addr.is_valid(), false); + addr.descriptor = 0x1000; + + addr.available = 0x3001; + assert_eq!(addr.is_valid(), false); + addr.available = 0x3000; + + addr.used = 0x2001; + assert_eq!(addr.is_valid(), false); + addr.used = 0x2000; + assert_eq!(addr.is_valid(), true); + } + + #[test] + fn test_vhost_user_state_from_config() { + let config = VringConfigData { + queue_max_size: 256, + queue_size: 128, + flags: VhostUserVringAddrFlags::VHOST_VRING_F_LOG.bits, + desc_table_addr: 0x1000, + used_ring_addr: 0x2000, + avail_ring_addr: 0x3000, + log_addr: Some(0x4000), + }; + let addr = VhostUserVringAddr::from_config_data(2, &config); + + let a = addr.index; + assert_eq!(a, 2); + let a = addr.flags; + assert_eq!(a, VhostUserVringAddrFlags::VHOST_VRING_F_LOG.bits()); + let a = addr.descriptor; + assert_eq!(a, 0x1000); + let a = addr.used; + assert_eq!(a, 0x2000); + let a = addr.available; + assert_eq!(a, 0x3000); + let a = addr.log; + assert_eq!(a, 0x4000); + assert_eq!(addr.is_valid(), true); } #[test] @@ -801,7 +933,6 @@ mod tests { } #[test] - #[ignore] fn check_user_config_msg() { let mut msg = VhostUserConfig::new( VHOST_USER_CONFIG_OFFSET, @@ -828,4 +959,21 @@ mod tests { msg.flags |= 0x4; assert!(!msg.is_valid()); } + + #[test] + fn test_vhost_user_fs_slave() { + let mut fs_slave = VhostUserFSSlaveMsg::default(); + + assert_eq!(fs_slave.is_valid(), true); + + fs_slave.fd_offset[0] = 0xffff_ffff_ffff_ffff; + fs_slave.len[0] = 0x1; + assert_eq!(fs_slave.is_valid(), false); + + assert_ne!( + VhostUserFSSlaveMsgFlags::MAP_R, + VhostUserFSSlaveMsgFlags::MAP_W + ); + assert_eq!(VhostUserFSSlaveMsgFlags::EMPTY.bits(), 0); + } } -- cgit v1.2.3 From 9e22e2fe2f0f22161eb292a538e5b712d1f7d9be Mon Sep 17 00:00:00 2001 From: Liu Jiang Date: Sun, 21 Feb 2021 22:36:10 +0800 Subject: vhost_user: refine connection implementation Refine connection implementation by: 1) using "payload: &[u8]" for send_message_with_payload() 2) enabling all unit test cases Signed-off-by: Liu Jiang --- src/vhost_user/connection.rs | 19 ++++++++++++------- src/vhost_user/master.rs | 9 +++++---- src/vhost_user/slave_req_handler.rs | 10 +++------- 3 files changed, 20 insertions(+), 18 deletions(-) diff --git a/src/vhost_user/connection.rs b/src/vhost_user/connection.rs index 5aa580b..d89f9c7 100644 --- a/src/vhost_user/connection.rs +++ b/src/vhost_user/connection.rs @@ -216,6 +216,9 @@ impl Endpoint { body: &T, fds: Option<&[RawFd]>, ) -> Result<()> { + if mem::size_of::() > MAX_MSG_SIZE { + return Err(Error::OversizedMsg); + } // Safe because there can't be other mutable referance to hdr and body. let iovs = unsafe { [ @@ -244,14 +247,17 @@ impl Endpoint { /// * - OversizedMsg: message size is too big. /// * - PartialMessage: received a partial message. /// * - IncorrectFds: wrong number of attached fds. - pub fn send_message_with_payload( + pub fn send_message_with_payload( &mut self, hdr: &VhostUserMsgHeader, body: &T, - payload: &[P], + payload: &[u8], fds: Option<&[RawFd]>, ) -> Result<()> { - let len = payload.len() * mem::size_of::

(); + let len = payload.len(); + if mem::size_of::() > MAX_MSG_SIZE { + return Err(Error::OversizedMsg); + } if len > MAX_MSG_SIZE - mem::size_of::() { return Err(Error::OversizedMsg); } @@ -615,7 +621,9 @@ mod tests { #[test] fn create_listener() { - let _ = Listener::new(UNIX_SOCKET_LISTENER, true).unwrap(); + let listener = Listener::new(UNIX_SOCKET_LISTENER, true).unwrap(); + + assert!(listener.as_raw_fd() > 0); } #[test] @@ -629,7 +637,6 @@ mod tests { } #[test] - #[ignore] fn send_data() { let listener = Listener::new(UNIX_SOCKET_DATA, true).unwrap(); listener.set_nonblocking(true).unwrap(); @@ -655,7 +662,6 @@ mod tests { } #[test] - #[ignore] fn send_fd() { let listener = Listener::new(UNIX_SOCKET_FD, true).unwrap(); listener.set_nonblocking(true).unwrap(); @@ -809,7 +815,6 @@ mod tests { } #[test] - #[ignore] fn send_recv() { let listener = Listener::new(UNIX_SOCKET_SEND, true).unwrap(); listener.set_nonblocking(true).unwrap(); diff --git a/src/vhost_user/master.rs b/src/vhost_user/master.rs index 2651b84..906932c 100644 --- a/src/vhost_user/master.rs +++ b/src/vhost_user/master.rs @@ -176,10 +176,11 @@ impl VhostBackend for Master { let mut node = self.node.lock().unwrap(); let body = VhostUserMemory::new(ctx.regions.len() as u32); + let (_, payload, _) = unsafe { ctx.regions.align_to::() }; let hdr = node.send_request_with_payload( MasterReq::SET_MEM_TABLE, &body, - ctx.regions.as_slice(), + payload, Some(ctx.fds.as_slice()), )?; node.wait_for_ack(&hdr).map_err(|e| e.into()) @@ -503,14 +504,14 @@ impl MasterInternal { Ok(hdr) } - fn send_request_with_payload( + fn send_request_with_payload( &mut self, code: MasterReq, msg: &T, - payload: &[P], + payload: &[u8], fds: Option<&[RawFd]>, ) -> VhostUserResult> { - let len = mem::size_of::() + payload.len() * mem::size_of::

(); + let len = mem::size_of::() + payload.len(); if len > MAX_MSG_SIZE { return Err(VhostUserError::InvalidParam); } diff --git a/src/vhost_user/slave_req_handler.rs b/src/vhost_user/slave_req_handler.rs index f3b0770..95e23d1 100644 --- a/src/vhost_user/slave_req_handler.rs +++ b/src/vhost_user/slave_req_handler.rs @@ -590,16 +590,12 @@ impl SlaveReqHandler { Ok(()) } - fn send_reply_with_payload( + fn send_reply_with_payload( &mut self, req: &VhostUserMsgHeader, msg: &T, - payload: &[P], - ) -> Result<()> - where - T: Sized, - P: Sized, - { + payload: &[u8], + ) -> Result<()> { let hdr = self.new_reply_header::(req, payload.len())?; self.main_sock .send_message_with_payload(&hdr, msg, payload, None)?; -- cgit v1.2.3 From 21b89b2ff5c418144760d08c2000776d0f0792f0 Mon Sep 17 00:00:00 2001 From: Liu Jiang Date: Sun, 21 Feb 2021 23:10:50 +0800 Subject: vhost_user: fix a bug in SlaveReqHandler An acknowlege reply message should be sent iif: 1) the VHOST_USER_PROTOCOL_F_REPLY_ACK feature is nogotiated, 2) the NEED_REPLY in header.flags is set. Also enforce stricter validation for message size. Signed-off-by: Liu Jiang --- src/vhost_user/master.rs | 5 ++++- src/vhost_user/slave_req_handler.rs | 7 +++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/vhost_user/master.rs b/src/vhost_user/master.rs index 906932c..65c7960 100644 --- a/src/vhost_user/master.rs +++ b/src/vhost_user/master.rs @@ -569,7 +569,10 @@ impl MasterInternal { &mut self, hdr: &VhostUserMsgHeader, ) -> VhostUserResult<(T, Vec, Option>)> { - if mem::size_of::() > MAX_MSG_SIZE || hdr.is_reply() { + if mem::size_of::() > MAX_MSG_SIZE + || hdr.get_size() as usize <= mem::size_of::() + || hdr.is_reply() + { return Err(VhostUserError::InvalidParam); } self.check_state()?; diff --git a/src/vhost_user/slave_req_handler.rs b/src/vhost_user/slave_req_handler.rs index 95e23d1..985693f 100644 --- a/src/vhost_user/slave_req_handler.rs +++ b/src/vhost_user/slave_req_handler.rs @@ -552,7 +552,10 @@ impl SlaveReqHandler { req: &VhostUserMsgHeader, payload_size: usize, ) -> Result> { - if mem::size_of::() > MAX_MSG_SIZE { + if mem::size_of::() > MAX_MSG_SIZE + || payload_size > MAX_MSG_SIZE + || mem::size_of::() + payload_size > MAX_MSG_SIZE + { return Err(Error::InvalidParam); } self.check_state()?; @@ -568,7 +571,7 @@ impl SlaveReqHandler { req: &VhostUserMsgHeader, res: Result<()>, ) -> Result<()> { - if self.reply_ack_enabled { + if self.reply_ack_enabled && req.is_need_reply() { let hdr = self.new_reply_header::(req, 0)?; let val = match res { Ok(_) => 0, -- cgit v1.2.3 From 74c353bb5fc003d95f8099a84bfdb8aa1895e2b4 Mon Sep 17 00:00:00 2001 From: Liu Jiang Date: Fri, 19 Feb 2021 22:51:08 +0800 Subject: vhost_user: add VhostUserSlaveReqHandlerMut trait Rename the original VhostUserSlaveReqHandler trait as VhostUserSlaveReqHandlerMut, and add another VhostUserSlaveReqHandler trait with interior mutability. This also help to simplify caller implementations. Signed-off-by: Liu Jiang --- src/vhost_user/mod.rs | 19 ++- src/vhost_user/slave.rs | 6 +- src/vhost_user/slave_req_handler.rs | 234 +++++++++++++++++++++++++++--------- 3 files changed, 195 insertions(+), 64 deletions(-) diff --git a/src/vhost_user/mod.rs b/src/vhost_user/mod.rs index 148a00e..91e4203 100644 --- a/src/vhost_user/mod.rs +++ b/src/vhost_user/mod.rs @@ -18,20 +18,23 @@ //! Most messages that can be sent via the Unix domain socket implementing vhost-user have an //! equivalent ioctl to the kernel implementation. -use libc; use std::io::Error as IOError; -mod connection; pub mod message; + +mod connection; pub use self::connection::Listener; + #[cfg(feature = "vhost-user-master")] mod master; #[cfg(feature = "vhost-user-master")] pub use self::master::{Master, VhostUserMaster}; -#[cfg(any(feature = "vhost-user-master", feature = "vhost-user-slave"))] +#[cfg(feature = "vhost-user")] mod master_req_handler; -#[cfg(any(feature = "vhost-user-master", feature = "vhost-user-slave"))] -pub use self::master_req_handler::{MasterReqHandler, VhostUserMasterReqHandler}; +#[cfg(feature = "vhost-user")] +pub use self::master_req_handler::{ + MasterReqHandler, VhostUserMasterReqHandler, VhostUserMasterReqHandlerMut, +}; #[cfg(feature = "vhost-user-slave")] mod slave; @@ -40,7 +43,9 @@ pub use self::slave::SlaveListener; #[cfg(feature = "vhost-user-slave")] mod slave_req_handler; #[cfg(feature = "vhost-user-slave")] -pub use self::slave_req_handler::{SlaveReqHandler, VhostUserSlaveReqHandler}; +pub use self::slave_req_handler::{ + SlaveReqHandler, VhostUserSlaveReqHandler, VhostUserSlaveReqHandlerMut, +}; #[cfg(feature = "vhost-user-slave")] mod slave_fs_cache; #[cfg(feature = "vhost-user-slave")] @@ -100,6 +105,8 @@ impl std::fmt::Display for Error { } } +impl std::error::Error for Error {} + impl Error { /// Determine whether to rebuild the underline communication channel. pub fn should_reconnect(&self) -> bool { diff --git a/src/vhost_user/slave.rs b/src/vhost_user/slave.rs index 5ac99af..c167dce 100644 --- a/src/vhost_user/slave.rs +++ b/src/vhost_user/slave.rs @@ -3,7 +3,7 @@ //! Traits and Structs for vhost-user slave. -use std::sync::{Arc, Mutex}; +use std::sync::Arc; use super::connection::{Endpoint, Listener}; use super::message::*; @@ -12,14 +12,14 @@ use super::{Result, SlaveReqHandler, VhostUserSlaveReqHandler}; /// Vhost-user slave side connection listener. pub struct SlaveListener { listener: Listener, - backend: Option>>, + backend: Option>, } /// Sets up a listener for incoming master connections, and handles construction /// of a Slave on success. impl SlaveListener { /// Create a unix domain socket for incoming master connections. - pub fn new(listener: Listener, backend: Arc>) -> Result { + pub fn new(listener: Listener, backend: Arc) -> Result { Ok(SlaveListener { listener, backend: Some(backend), diff --git a/src/vhost_user/slave_req_handler.rs b/src/vhost_user/slave_req_handler.rs index 985693f..190501b 100644 --- a/src/vhost_user/slave_req_handler.rs +++ b/src/vhost_user/slave_req_handler.rs @@ -1,8 +1,6 @@ // Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved. // SPDX-License-Identifier: Apache-2.0 -//! Traits and Structs to handle vhost-user requests from the master to the slave. - use std::mem; use std::os::unix::io::{AsRawFd, FromRawFd, RawFd}; use std::os::unix::net::UnixStream; @@ -14,9 +12,63 @@ use super::message::*; use super::slave_fs_cache::SlaveFsCacheReq; use super::{Error, Result}; -/// Trait to handle vhost-user requests from the master to the slave. +/// Services provided to the master by the slave with interior mutability. +/// +/// The [VhostUserSlaveReqHandler] trait defines the services provided to the master by the slave. +/// And the [VhostUserSlaveReqHandlerMut] trait is a helper mirroring [VhostUserSlaveReqHandler], +/// but without interior mutability. +/// The vhost-user specification defines a master communication channel, by which masters could +/// request services from slaves. The [VhostUserSlaveReqHandler] trait defines services provided by +/// slaves, and it's used both on the master side and slave side. +/// +/// - on the master side, a stub forwarder implementing [VhostUserSlaveReqHandler] will proxy +/// service requests to slaves. +/// - on the slave side, the [SlaveReqHandler] will forward service requests to a handler +/// implementing [VhostUserSlaveReqHandler]. +/// +/// The [VhostUserSlaveReqHandler] trait is design with interior mutability to improve performance +/// for multi-threading. +/// +/// [VhostUserSlaveReqHandler]: trait.VhostUserSlaveReqHandler.html +/// [VhostUserSlaveReqHandlerMut]: trait.VhostUserSlaveReqHandlerMut.html +/// [SlaveReqHandler]: struct.SlaveReqHandler.html #[allow(missing_docs)] pub trait VhostUserSlaveReqHandler { + fn set_owner(&self) -> Result<()>; + fn reset_owner(&self) -> Result<()>; + fn get_features(&self) -> Result; + fn set_features(&self, features: u64) -> Result<()>; + fn set_mem_table(&self, ctx: &[VhostUserMemoryRegion], fds: &[RawFd]) -> Result<()>; + fn set_vring_num(&self, index: u32, num: u32) -> Result<()>; + fn set_vring_addr( + &self, + index: u32, + flags: VhostUserVringAddrFlags, + descriptor: u64, + used: u64, + available: u64, + log: u64, + ) -> Result<()>; + fn set_vring_base(&self, index: u32, base: u32) -> Result<()>; + fn get_vring_base(&self, index: u32) -> Result; + fn set_vring_kick(&self, index: u8, fd: Option) -> Result<()>; + fn set_vring_call(&self, index: u8, fd: Option) -> Result<()>; + fn set_vring_err(&self, index: u8, fd: Option) -> Result<()>; + + fn get_protocol_features(&self) -> Result; + fn set_protocol_features(&self, features: u64) -> Result<()>; + fn get_queue_num(&self) -> Result; + fn set_vring_enable(&self, index: u32, enable: bool) -> Result<()>; + fn get_config(&self, offset: u32, size: u32, flags: VhostUserConfigFlags) -> Result>; + fn set_config(&self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>; + fn set_slave_req_fd(&self, _vu_req: SlaveFsCacheReq) {} +} + +/// Services provided to the master by the slave without interior mutability. +/// +/// This is a helper trait mirroring the [VhostUserSlaveReqHandler] trait. +#[allow(missing_docs)] +pub trait VhostUserSlaveReqHandlerMut { fn set_owner(&mut self) -> Result<()>; fn reset_owner(&mut self) -> Result<()>; fn get_features(&mut self) -> Result; @@ -52,16 +104,110 @@ pub trait VhostUserSlaveReqHandler { fn set_slave_req_fd(&mut self, _vu_req: SlaveFsCacheReq) {} } -/// A vhost-user slave endpoint which relays all received requests from the -/// master to the virtio backend device object. +impl VhostUserSlaveReqHandler for Mutex { + fn set_owner(&self) -> Result<()> { + self.lock().unwrap().set_owner() + } + + fn reset_owner(&self) -> Result<()> { + self.lock().unwrap().reset_owner() + } + + fn get_features(&self) -> Result { + self.lock().unwrap().get_features() + } + + fn set_features(&self, features: u64) -> Result<()> { + self.lock().unwrap().set_features(features) + } + + fn set_mem_table(&self, ctx: &[VhostUserMemoryRegion], fds: &[RawFd]) -> Result<()> { + self.lock().unwrap().set_mem_table(ctx, fds) + } + + fn set_vring_num(&self, index: u32, num: u32) -> Result<()> { + self.lock().unwrap().set_vring_num(index, num) + } + + fn set_vring_addr( + &self, + index: u32, + flags: VhostUserVringAddrFlags, + descriptor: u64, + used: u64, + available: u64, + log: u64, + ) -> Result<()> { + self.lock() + .unwrap() + .set_vring_addr(index, flags, descriptor, used, available, log) + } + + fn set_vring_base(&self, index: u32, base: u32) -> Result<()> { + self.lock().unwrap().set_vring_base(index, base) + } + + fn get_vring_base(&self, index: u32) -> Result { + self.lock().unwrap().get_vring_base(index) + } + + fn set_vring_kick(&self, index: u8, fd: Option) -> Result<()> { + self.lock().unwrap().set_vring_kick(index, fd) + } + + fn set_vring_call(&self, index: u8, fd: Option) -> Result<()> { + self.lock().unwrap().set_vring_call(index, fd) + } + + fn set_vring_err(&self, index: u8, fd: Option) -> Result<()> { + self.lock().unwrap().set_vring_err(index, fd) + } + + fn get_protocol_features(&self) -> Result { + self.lock().unwrap().get_protocol_features() + } + + fn set_protocol_features(&self, features: u64) -> Result<()> { + self.lock().unwrap().set_protocol_features(features) + } + + fn get_queue_num(&self) -> Result { + self.lock().unwrap().get_queue_num() + } + + fn set_vring_enable(&self, index: u32, enable: bool) -> Result<()> { + self.lock().unwrap().set_vring_enable(index, enable) + } + + fn get_config(&self, offset: u32, size: u32, flags: VhostUserConfigFlags) -> Result> { + self.lock().unwrap().get_config(offset, size, flags) + } + + fn set_config(&self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()> { + self.lock().unwrap().set_config(offset, buf, flags) + } + + fn set_slave_req_fd(&self, vu_req: SlaveFsCacheReq) { + self.lock().unwrap().set_slave_req_fd(vu_req) + } +} + +/// Server to handle service requests from masters from the master communication channel. +/// +/// The [SlaveReqHandler] acts as a server on the slave side, to handle service requests from +/// masters on the master communication channel. It's actually a proxy invoking the registered +/// handler implementing [VhostUserSlaveReqHandler] to do the real work. /// /// The lifetime of the SlaveReqHandler object should be the same as the underline Unix Domain /// Socket, so it gets simpler to recover from disconnect. +/// +/// [VhostUserSlaveReqHandler]: trait.VhostUserSlaveReqHandler.html +/// [SlaveReqHandler]: struct.SlaveReqHandler.html pub struct SlaveReqHandler { // underlying Unix domain socket for communication main_sock: Endpoint, // the vhost-user backend device object - backend: Arc>, + backend: Arc, virtio_features: u64, acked_virtio_features: u64, @@ -76,7 +222,7 @@ pub struct SlaveReqHandler { impl SlaveReqHandler { /// Create a vhost-user slave endpoint. - pub(super) fn new(main_sock: Endpoint, backend: Arc>) -> Self { + pub(super) fn new(main_sock: Endpoint, backend: Arc) -> Self { SlaveReqHandler { main_sock, backend, @@ -94,7 +240,7 @@ impl SlaveReqHandler { /// # Arguments /// * - `path` - path of Unix domain socket listener to connect to /// * - `backend` - handler for requests from the master to the slave - pub fn connect(path: &str, backend: Arc>) -> Result { + pub fn connect(path: &str, backend: Arc) -> Result { Ok(Self::new(Endpoint::::connect(path)?, backend)) } @@ -103,11 +249,12 @@ impl SlaveReqHandler { self.error = Some(error); } - /// Receive and handle one incoming request message from the master. - /// The caller needs to: - /// . serialize calls to this function - /// . decide what to do when error happens - /// . optional recover from failure + /// Main entrance to server slave request from the slave communication channel. + /// + /// Receive and handle one incoming request message from the master. The caller needs to: + /// - serialize calls to this function + /// - decide what to do when error happens + /// - optional recover from failure pub fn handle_request(&mut self) -> Result<()> { // Return error if the endpoint is already in failed state. self.check_state()?; @@ -137,15 +284,15 @@ impl SlaveReqHandler { match hdr.get_code() { MasterReq::SET_OWNER => { self.check_request_size(&hdr, size, 0)?; - self.backend.lock().unwrap().set_owner()?; + self.backend.set_owner()?; } MasterReq::RESET_OWNER => { self.check_request_size(&hdr, size, 0)?; - self.backend.lock().unwrap().reset_owner()?; + self.backend.reset_owner()?; } MasterReq::GET_FEATURES => { self.check_request_size(&hdr, size, 0)?; - let features = self.backend.lock().unwrap().get_features()?; + let features = self.backend.get_features()?; let msg = VhostUserU64::new(features); self.send_reply_message(&hdr, &msg)?; self.virtio_features = features; @@ -153,7 +300,7 @@ impl SlaveReqHandler { } MasterReq::SET_FEATURES => { let msg = self.extract_request_body::(&hdr, size, &buf)?; - self.backend.lock().unwrap().set_features(msg.value)?; + self.backend.set_features(msg.value)?; self.acked_virtio_features = msg.value; self.update_reply_ack_flag(); } @@ -163,11 +310,7 @@ impl SlaveReqHandler { } MasterReq::SET_VRING_NUM => { let msg = self.extract_request_body::(&hdr, size, &buf)?; - let res = self - .backend - .lock() - .unwrap() - .set_vring_num(msg.index, msg.num); + let res = self.backend.set_vring_num(msg.index, msg.num); self.send_ack_message(&hdr, res)?; } MasterReq::SET_VRING_ADDR => { @@ -176,7 +319,7 @@ impl SlaveReqHandler { Some(val) => val, None => return Err(Error::InvalidMessage), }; - let res = self.backend.lock().unwrap().set_vring_addr( + let res = self.backend.set_vring_addr( msg.index, flags, msg.descriptor, @@ -188,39 +331,35 @@ impl SlaveReqHandler { } MasterReq::SET_VRING_BASE => { let msg = self.extract_request_body::(&hdr, size, &buf)?; - let res = self - .backend - .lock() - .unwrap() - .set_vring_base(msg.index, msg.num); + let res = self.backend.set_vring_base(msg.index, msg.num); self.send_ack_message(&hdr, res)?; } MasterReq::GET_VRING_BASE => { let msg = self.extract_request_body::(&hdr, size, &buf)?; - let reply = self.backend.lock().unwrap().get_vring_base(msg.index)?; + let reply = self.backend.get_vring_base(msg.index)?; self.send_reply_message(&hdr, &reply)?; } MasterReq::SET_VRING_CALL => { self.check_request_size(&hdr, size, mem::size_of::())?; let (index, rfds) = self.handle_vring_fd_request(&buf, rfds)?; - let res = self.backend.lock().unwrap().set_vring_call(index, rfds); + let res = self.backend.set_vring_call(index, rfds); self.send_ack_message(&hdr, res)?; } MasterReq::SET_VRING_KICK => { self.check_request_size(&hdr, size, mem::size_of::())?; let (index, rfds) = self.handle_vring_fd_request(&buf, rfds)?; - let res = self.backend.lock().unwrap().set_vring_kick(index, rfds); + let res = self.backend.set_vring_kick(index, rfds); self.send_ack_message(&hdr, res)?; } MasterReq::SET_VRING_ERR => { self.check_request_size(&hdr, size, mem::size_of::())?; let (index, rfds) = self.handle_vring_fd_request(&buf, rfds)?; - let res = self.backend.lock().unwrap().set_vring_err(index, rfds); + let res = self.backend.set_vring_err(index, rfds); self.send_ack_message(&hdr, res)?; } MasterReq::GET_PROTOCOL_FEATURES => { self.check_request_size(&hdr, size, 0)?; - let features = self.backend.lock().unwrap().get_protocol_features()?; + let features = self.backend.get_protocol_features()?; let msg = VhostUserU64::new(features.bits()); self.send_reply_message(&hdr, &msg)?; self.protocol_features = features; @@ -228,10 +367,7 @@ impl SlaveReqHandler { } MasterReq::SET_PROTOCOL_FEATURES => { let msg = self.extract_request_body::(&hdr, size, &buf)?; - self.backend - .lock() - .unwrap() - .set_protocol_features(msg.value)?; + self.backend.set_protocol_features(msg.value)?; self.acked_protocol_features = msg.value; self.update_reply_ack_flag(); } @@ -240,7 +376,7 @@ impl SlaveReqHandler { return Err(Error::InvalidOperation); } self.check_request_size(&hdr, size, 0)?; - let num = self.backend.lock().unwrap().get_queue_num()?; + let num = self.backend.get_queue_num()?; let msg = VhostUserU64::new(num); self.send_reply_message(&hdr, &msg)?; } @@ -257,11 +393,7 @@ impl SlaveReqHandler { _ => return Err(Error::InvalidParam), }; - let res = self - .backend - .lock() - .unwrap() - .set_vring_enable(msg.index, enable); + let res = self.backend.set_vring_enable(msg.index, enable); self.send_ack_message(&hdr, res)?; } MasterReq::GET_CONFIG => { @@ -341,7 +473,7 @@ impl SlaveReqHandler { } } - self.backend.lock().unwrap().set_mem_table(®ions, &fds) + self.backend.set_mem_table(®ions, &fds) } fn get_config(&mut self, hdr: &VhostUserMsgHeader, buf: &[u8]) -> Result<()> { @@ -357,11 +489,7 @@ impl SlaveReqHandler { Some(val) => val, None => return Err(Error::InvalidMessage), }; - let res = self - .backend - .lock() - .unwrap() - .get_config(msg.offset, msg.size, flags); + let res = self.backend.get_config(msg.offset, msg.size, flags); // vhost-user slave's payload size MUST match master's request // on success, uses zero length of payload to indicate an error @@ -405,11 +533,7 @@ impl SlaveReqHandler { None => return Err(Error::InvalidMessage), } - let res = self - .backend - .lock() - .unwrap() - .set_config(msg.offset, buf, flags); + let res = self.backend.set_config(msg.offset, buf, flags); self.send_ack_message(&hdr, res)?; Ok(()) } @@ -423,7 +547,7 @@ impl SlaveReqHandler { if fds.len() == 1 { let sock = unsafe { UnixStream::from_raw_fd(fds[0]) }; let vu_req = SlaveFsCacheReq::from_stream(sock); - self.backend.lock().unwrap().set_slave_req_fd(vu_req); + self.backend.set_slave_req_fd(vu_req); self.send_ack_message(&hdr, Ok(())) } else { Err(Error::InvalidMessage) -- cgit v1.2.3 From 56b823482b1f47392aeec9051ed3bcf526926864 Mon Sep 17 00:00:00 2001 From: Liu Jiang Date: Mon, 22 Feb 2021 00:04:52 +0800 Subject: vhost_user: use read_aligned() to access data Use std::ptr::read_aligned() to safely access data buffer instead of directly accessing data struct in data buffer. Also enforce stricter message size validation. Signed-off-by: Liu Jiang --- src/vhost_user/slave_req_handler.rs | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/src/vhost_user/slave_req_handler.rs b/src/vhost_user/slave_req_handler.rs index 190501b..344afff 100644 --- a/src/vhost_user/slave_req_handler.rs +++ b/src/vhost_user/slave_req_handler.rs @@ -400,6 +400,7 @@ impl SlaveReqHandler { if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 { return Err(Error::InvalidOperation); } + self.check_request_size(&hdr, size, hdr.get_size() as usize)?; self.get_config(&hdr, &buf)?; } MasterReq::SET_CONFIG => { @@ -413,6 +414,7 @@ impl SlaveReqHandler { if self.acked_protocol_features & VhostUserProtocolFeatures::SLAVE_REQ.bits() == 0 { return Err(Error::InvalidOperation); } + self.check_request_size(&hdr, size, hdr.get_size() as usize)?; self.set_slave_req_fd(&hdr, rfds)?; } _ => { @@ -477,11 +479,14 @@ impl SlaveReqHandler { } fn get_config(&mut self, hdr: &VhostUserMsgHeader, buf: &[u8]) -> Result<()> { - let msg = unsafe { &*(buf.as_ptr() as *const VhostUserConfig) }; + let payload_offset = mem::size_of::(); + if buf.len() > MAX_MSG_SIZE || buf.len() < payload_offset { + return Err(Error::InvalidMessage); + } + let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const VhostUserConfig) }; if !msg.is_valid() { return Err(Error::InvalidMessage); } - let payload_offset = mem::size_of::(); if buf.len() - payload_offset != msg.size as usize { return Err(Error::InvalidMessage); } @@ -517,10 +522,10 @@ impl SlaveReqHandler { size: usize, buf: &[u8], ) -> Result<()> { - if size < mem::size_of::() { + if size > MAX_MSG_SIZE || size < mem::size_of::() { return Err(Error::InvalidMessage); } - let msg = unsafe { &*(buf.as_ptr() as *const VhostUserConfig) }; + let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const VhostUserConfig) }; if !msg.is_valid() { return Err(Error::InvalidMessage); } @@ -562,7 +567,10 @@ impl SlaveReqHandler { buf: &[u8], rfds: Option>, ) -> Result<(u8, Option)> { - let msg = unsafe { &*(buf.as_ptr() as *const VhostUserU64) }; + if buf.len() > MAX_MSG_SIZE || buf.len() < mem::size_of::() { + return Err(Error::InvalidMessage); + } + let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const VhostUserU64) }; if !msg.is_valid() { return Err(Error::InvalidMessage); } @@ -643,14 +651,14 @@ impl SlaveReqHandler { } } - fn extract_request_body<'a, T: Sized + VhostUserMsgValidator>( + fn extract_request_body( &self, hdr: &VhostUserMsgHeader, size: usize, - buf: &'a [u8], - ) -> Result<&'a T> { + buf: &[u8], + ) -> Result { self.check_request_size(hdr, size, mem::size_of::())?; - let msg = unsafe { &*(buf.as_ptr() as *const T) }; + let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const T) }; if !msg.is_valid() { return Err(Error::InvalidMessage); } -- cgit v1.2.3 From ec6eae722ef7c6fcecbb8acddf1d9905083dc3de Mon Sep 17 00:00:00 2001 From: Liu Jiang Date: Sun, 21 Feb 2021 22:56:35 +0800 Subject: vhost_user: add more unit test cases Add more unit test cases for vhost-user protocol. Signed-off-by: Liu Jiang --- coverage_config_x86_64.json | 2 +- src/lib.rs | 12 ++- src/vhost_user/connection.rs | 34 ++++--- src/vhost_user/dummy_slave.rs | 48 +++++----- src/vhost_user/master.rs | 62 ++++++------ src/vhost_user/master_req_handler.rs | 99 +++++++++++++++++++ src/vhost_user/mod.rs | 180 +++++++++++++++++++++++++++++++++-- src/vhost_user/slave.rs | 40 ++++++++ src/vhost_user/slave_req_handler.rs | 21 ++++ 9 files changed, 417 insertions(+), 81 deletions(-) diff --git a/coverage_config_x86_64.json b/coverage_config_x86_64.json index e9ac51c..075255b 100644 --- a/coverage_config_x86_64.json +++ b/coverage_config_x86_64.json @@ -1 +1 @@ -{"coverage_score": 73.3, "exclude_path": "src/vhost_kern/", "crate_features": ""} +{"coverage_score": 78.9, "exclude_path": "src/vhost_kern/", "crate_features": "vhost-user-master,vhost-user-slave"} diff --git a/src/lib.rs b/src/lib.rs index a3852a6..b7ed15c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -74,7 +74,7 @@ pub enum Error { IoctlError(std::io::Error), /// Error from IO subsystem. IOError(std::io::Error), - #[cfg(feature = "vhost-user-master")] + #[cfg(feature = "vhost-user")] /// Error from the vhost-user subsystem. VhostUserProtocol(vhost_user::Error), } @@ -95,7 +95,7 @@ impl std::fmt::Display for Error { Error::VhostOpen(e) => write!(f, "failure in opening vhost file: {}", e), #[cfg(feature = "vhost-kern")] Error::IoctlError(e) => write!(f, "failure in vhost ioctl: {}", e), - #[cfg(feature = "vhost-user-master")] + #[cfg(feature = "vhost-user")] Error::VhostUserProtocol(e) => write!(f, "vhost-user: {}", e), } } @@ -151,4 +151,12 @@ mod tests { assert_eq!(format!("{:?}", Error::AvailAddress), "AvailAddress"); } + + #[cfg(feature = "vhost-user")] + #[test] + fn test_convert_from_vhost_user_error() { + let e: Error = vhost_user::Error::OversizedMsg.into(); + + assert_eq!(format!("{}", e), "vhost-user: oversized message"); + } } diff --git a/src/vhost_user/connection.rs b/src/vhost_user/connection.rs index d89f9c7..01bf124 100644 --- a/src/vhost_user/connection.rs +++ b/src/vhost_user/connection.rs @@ -606,29 +606,32 @@ fn get_sub_iovs_offset(iov_lens: &[usize], skip_size: usize) -> (usize, usize) { #[cfg(test)] mod tests { - use super::*; use std::fs::File; use std::io::{Read, Seek, SeekFrom, Write}; use std::os::unix::io::FromRawFd; + use vmm_sys_util::rand::rand_alphanumerics; use vmm_sys_util::tempfile::TempFile; - const UNIX_SOCKET_LISTENER: &'static str = "/tmp/vhost_user_test_rust_listener"; - const UNIX_SOCKET_CONNECTION: &'static str = "/tmp/vhost_user_test_rust_connection"; - const UNIX_SOCKET_DATA: &'static str = "/tmp/vhost_user_test_rust_data"; - const UNIX_SOCKET_FD: &'static str = "/tmp/vhost_user_test_rust_fd"; - const UNIX_SOCKET_SEND: &'static str = "/tmp/vhost_user_test_rust_send"; + fn temp_path() -> String { + format!( + "/tmp/vhost_test_{}", + rand_alphanumerics(8).to_str().unwrap() + ) + } #[test] fn create_listener() { - let listener = Listener::new(UNIX_SOCKET_LISTENER, true).unwrap(); + let path = temp_path(); + let listener = Listener::new(&path, true).unwrap(); assert!(listener.as_raw_fd() > 0); } #[test] fn accept_connection() { - let listener = Listener::new(UNIX_SOCKET_CONNECTION, true).unwrap(); + let path = temp_path(); + let listener = Listener::new(&path, true).unwrap(); listener.set_nonblocking(true).unwrap(); // accept on a fd without incoming connection @@ -638,9 +641,10 @@ mod tests { #[test] fn send_data() { - let listener = Listener::new(UNIX_SOCKET_DATA, true).unwrap(); + let path = temp_path(); + let listener = Listener::new(&path, true).unwrap(); listener.set_nonblocking(true).unwrap(); - let mut master = Endpoint::::connect(UNIX_SOCKET_DATA).unwrap(); + let mut master = Endpoint::::connect(&path).unwrap(); let sock = listener.accept().unwrap().unwrap(); let mut slave = Endpoint::::from_stream(sock); @@ -663,9 +667,10 @@ mod tests { #[test] fn send_fd() { - let listener = Listener::new(UNIX_SOCKET_FD, true).unwrap(); + let path = temp_path(); + let listener = Listener::new(&path, true).unwrap(); listener.set_nonblocking(true).unwrap(); - let mut master = Endpoint::::connect(UNIX_SOCKET_FD).unwrap(); + let mut master = Endpoint::::connect(&path).unwrap(); let sock = listener.accept().unwrap().unwrap(); let mut slave = Endpoint::::from_stream(sock); @@ -816,9 +821,10 @@ mod tests { #[test] fn send_recv() { - let listener = Listener::new(UNIX_SOCKET_SEND, true).unwrap(); + let path = temp_path(); + let listener = Listener::new(&path, true).unwrap(); listener.set_nonblocking(true).unwrap(); - let mut master = Endpoint::::connect(UNIX_SOCKET_SEND).unwrap(); + let mut master = Endpoint::::connect(&path).unwrap(); let sock = listener.accept().unwrap().unwrap(); let mut slave = Endpoint::::from_stream(sock); diff --git a/src/vhost_user/dummy_slave.rs b/src/vhost_user/dummy_slave.rs index 53887e2..99f08e7 100644 --- a/src/vhost_user/dummy_slave.rs +++ b/src/vhost_user/dummy_slave.rs @@ -1,9 +1,10 @@ // Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved. // SPDX-License-Identifier: Apache-2.0 +use std::os::unix::io::RawFd; + use super::message::*; use super::*; -use std::os::unix::io::RawFd; pub const MAX_QUEUE_NUM: usize = 2; pub const MAX_VRING_NUM: usize = 256; @@ -34,7 +35,7 @@ impl DummySlaveReqHandler { } } -impl VhostUserSlaveReqHandler for DummySlaveReqHandler { +impl VhostUserSlaveReqHandlerMut for DummySlaveReqHandler { fn set_owner(&mut self) -> Result<()> { if self.owned { return Err(Error::InvalidOperation); @@ -83,30 +84,10 @@ impl VhostUserSlaveReqHandler for DummySlaveReqHandler { Ok(()) } - fn get_protocol_features(&mut self) -> Result { - Ok(VhostUserProtocolFeatures::all()) - } - - fn set_protocol_features(&mut self, features: u64) -> Result<()> { - // Note: slave that reported VHOST_USER_F_PROTOCOL_FEATURES must - // support this message even before VHOST_USER_SET_FEATURES was - // called. - // What happens if the master calls set_features() with - // VHOST_USER_F_PROTOCOL_FEATURES cleared after calling this - // interface? - self.acked_protocol_features = features; - Ok(()) - } - fn set_mem_table(&mut self, _ctx: &[VhostUserMemoryRegion], _fds: &[RawFd]) -> Result<()> { - // TODO Ok(()) } - fn get_queue_num(&mut self) -> Result { - Ok(MAX_QUEUE_NUM as u64) - } - fn set_vring_num(&mut self, index: u32, num: u32) -> Result<()> { if index as usize >= self.queue_num || num == 0 || num as usize > MAX_VRING_NUM { return Err(Error::InvalidParam); @@ -199,6 +180,25 @@ impl VhostUserSlaveReqHandler for DummySlaveReqHandler { Ok(()) } + fn get_protocol_features(&mut self) -> Result { + Ok(VhostUserProtocolFeatures::all()) + } + + fn set_protocol_features(&mut self, features: u64) -> Result<()> { + // Note: slave that reported VHOST_USER_F_PROTOCOL_FEATURES must + // support this message even before VHOST_USER_SET_FEATURES was + // called. + // What happens if the master calls set_features() with + // VHOST_USER_F_PROTOCOL_FEATURES cleared after calling this + // interface? + self.acked_protocol_features = features; + Ok(()) + } + + fn get_queue_num(&mut self) -> Result { + Ok(MAX_QUEUE_NUM as u64) + } + fn set_vring_enable(&mut self, index: u32, enable: bool) -> Result<()> { // This request should be handled only when VHOST_USER_F_PROTOCOL_FEATURES // has been negotiated. @@ -222,7 +222,7 @@ impl VhostUserSlaveReqHandler for DummySlaveReqHandler { size: u32, _flags: VhostUserConfigFlags, ) -> Result> { - if self.acked_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 { + if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 { return Err(Error::InvalidOperation); } else if offset < VHOST_USER_CONFIG_OFFSET || offset >= VHOST_USER_CONFIG_SIZE @@ -236,7 +236,7 @@ impl VhostUserSlaveReqHandler for DummySlaveReqHandler { fn set_config(&mut self, offset: u32, buf: &[u8], _flags: VhostUserConfigFlags) -> Result<()> { let size = buf.len() as u32; - if self.acked_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 { + if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 { return Err(Error::InvalidOperation); } else if offset < VHOST_USER_CONFIG_OFFSET || offset >= VHOST_USER_CONFIG_SIZE diff --git a/src/vhost_user/master.rs b/src/vhost_user/master.rs index 65c7960..cc754b5 100644 --- a/src/vhost_user/master.rs +++ b/src/vhost_user/master.rs @@ -393,7 +393,10 @@ impl VhostUserMaster for Master { return error_code(VhostUserError::SlaveInternalError); } else if body_reply.size != body.size || body_reply.size as usize != buf.len() { return error_code(VhostUserError::InvalidMessage); + } else if body_reply.offset != body.offset { + return error_code(VhostUserError::InvalidMessage); } + Ok((body_reply, buf_reply)) } @@ -571,6 +574,7 @@ impl MasterInternal { ) -> VhostUserResult<(T, Vec, Option>)> { if mem::size_of::() > MAX_MSG_SIZE || hdr.get_size() as usize <= mem::size_of::() + || hdr.get_size() as usize > MAX_MSG_SIZE || hdr.is_reply() { return Err(VhostUserError::InvalidParam); @@ -586,11 +590,8 @@ impl MasterInternal { { Endpoint::::close_rfds(rfds); return Err(VhostUserError::InvalidMessage); - } else if bytes > MAX_MSG_SIZE - mem::size_of::() { + } else if bytes != buf.len() { return Err(VhostUserError::InvalidMessage); - } else if bytes < buf.len() { - // It's safe because we have checked the buffer size - unsafe { buf.set_len(bytes) }; } Ok((body, buf, rfds)) } @@ -638,11 +639,14 @@ impl MasterInternal { mod tests { use super::super::connection::Listener; use super::*; + use vmm_sys_util::rand::rand_alphanumerics; - const UNIX_SOCKET_MASTER: &'static str = "/tmp/vhost_user_test_rust_master"; - const UNIX_SOCKET_MASTER2: &'static str = "/tmp/vhost_user_test_rust_master2"; - const UNIX_SOCKET_MASTER3: &'static str = "/tmp/vhost_user_test_rust_master3"; - const UNIX_SOCKET_MASTER4: &'static str = "/tmp/vhost_user_test_rust_master4"; + fn temp_path() -> String { + format!( + "/tmp/vhost_test_{}", + rand_alphanumerics(8).to_str().unwrap() + ) + } fn create_pair(path: &str) -> (Master, Endpoint) { let listener = Listener::new(path, true).unwrap(); @@ -653,14 +657,15 @@ mod tests { } #[test] - #[ignore] fn create_master() { - let listener = Listener::new(UNIX_SOCKET_MASTER, true).unwrap(); + let path = temp_path(); + let listener = Listener::new(&path, true).unwrap(); listener.set_nonblocking(true).unwrap(); - let master = Master::connect(UNIX_SOCKET_MASTER, 1).unwrap(); + let master = Master::connect(&path, 1).unwrap(); let mut slave = Endpoint::::from_stream(listener.accept().unwrap().unwrap()); + assert!(master.as_raw_fd() > 0); // Send two messages continuously master.set_owner().unwrap(); master.reset_owner().unwrap(); @@ -679,24 +684,24 @@ mod tests { } #[test] - #[ignore] fn test_create_failure() { - let _ = Listener::new(UNIX_SOCKET_MASTER2, true).unwrap(); - let _ = Listener::new(UNIX_SOCKET_MASTER2, false).is_err(); - assert!(Master::connect(UNIX_SOCKET_MASTER2, 1).is_err()); + let path = temp_path(); + let _ = Listener::new(&path, true).unwrap(); + let _ = Listener::new(&path, false).is_err(); + assert!(Master::connect(&path, 1).is_err()); - let listener = Listener::new(UNIX_SOCKET_MASTER2, true).unwrap(); - assert!(Listener::new(UNIX_SOCKET_MASTER2, false).is_err()); + let listener = Listener::new(&path, true).unwrap(); + assert!(Listener::new(&path, false).is_err()); listener.set_nonblocking(true).unwrap(); - let _master = Master::connect(UNIX_SOCKET_MASTER2, 1).unwrap(); + let _master = Master::connect(&path, 1).unwrap(); let _slave = listener.accept().unwrap().unwrap(); } #[test] - #[ignore] fn test_features() { - let (master, mut peer) = create_pair(UNIX_SOCKET_MASTER3); + let path = temp_path(); + let (master, mut peer) = create_pair(&path); master.set_owner().unwrap(); let (hdr, rfds) = peer.recv_header().unwrap(); @@ -713,6 +718,9 @@ mod tests { let (_hdr, rfds) = peer.recv_header().unwrap(); assert!(rfds.is_none()); + let hdr = VhostUserMsgHeader::new(MasterReq::SET_FEATURES, 0x4, 8); + let msg = VhostUserU64::new(0x15); + peer.send_message(&hdr, &msg, None).unwrap(); master.set_features(0x15).unwrap(); let (_hdr, msg, rfds) = peer.recv_body::().unwrap(); assert!(rfds.is_none()); @@ -726,9 +734,9 @@ mod tests { } #[test] - #[ignore] fn test_protocol_features() { - let (mut master, mut peer) = create_pair(UNIX_SOCKET_MASTER4); + let path = temp_path(); + let (mut master, mut peer) = create_pair(&path); master.set_owner().unwrap(); let (hdr, rfds) = peer.recv_header().unwrap(); @@ -775,14 +783,4 @@ mod tests { peer.send_message(&hdr, &msg, None).unwrap(); assert!(master.get_protocol_features().is_err()); } - - #[test] - fn test_set_mem_table() { - // TODO - } - - #[test] - fn test_get_ring_num() { - // TODO - } } diff --git a/src/vhost_user/master_req_handler.rs b/src/vhost_user/master_req_handler.rs index 02c2bb7..fb33f15 100644 --- a/src/vhost_user/master_req_handler.rs +++ b/src/vhost_user/master_req_handler.rs @@ -377,3 +377,102 @@ impl AsRawFd for MasterReqHandler { self.sub_sock.as_raw_fd() } } + +#[cfg(test)] +mod tests { + use super::*; + + #[cfg(feature = "vhost-user-slave")] + use crate::vhost_user::SlaveFsCacheReq; + #[cfg(feature = "vhost-user-slave")] + use std::os::unix::io::FromRawFd; + + struct MockMasterReqHandler {} + + impl VhostUserMasterReqHandlerMut for MockMasterReqHandler { + /// Handle virtio-fs map file requests from the slave. + fn fs_slave_map(&mut self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult { + // Safe because we have just received the rawfd from kernel. + unsafe { libc::close(fd) }; + Ok(0) + } + + /// Handle virtio-fs unmap file requests from the slave. + fn fs_slave_unmap(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult { + Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) + } + } + + #[test] + fn test_new_master_req_handler() { + let backend = Arc::new(Mutex::new(MockMasterReqHandler {})); + let mut handler = MasterReqHandler::new(backend).unwrap(); + + assert!(handler.get_tx_raw_fd() >= 0); + assert!(handler.as_raw_fd() >= 0); + handler.check_state().unwrap(); + + assert_eq!(handler.error, None); + handler.set_failed(libc::EAGAIN); + assert_eq!(handler.error, Some(libc::EAGAIN)); + handler.check_state().unwrap_err(); + } + + #[cfg(feature = "vhost-user-slave")] + #[test] + fn test_master_slave_req_handler() { + let backend = Arc::new(Mutex::new(MockMasterReqHandler {})); + let mut handler = MasterReqHandler::new(backend).unwrap(); + + let fd = unsafe { libc::dup(handler.get_tx_raw_fd()) }; + if fd < 0 { + panic!("failed to duplicated tx fd!"); + } + let stream = unsafe { UnixStream::from_raw_fd(fd) }; + let fs_cache = SlaveFsCacheReq::from_stream(stream); + + std::thread::spawn(move || { + let res = handler.handle_request().unwrap(); + assert_eq!(res, 0); + handler.handle_request().unwrap_err(); + }); + + fs_cache + .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd) + .unwrap(); + // When REPLY_ACK has not been negotiated, the master has no way to detect failure from + // slave side. + fs_cache + .fs_slave_unmap(&VhostUserFSSlaveMsg::default()) + .unwrap(); + } + + #[cfg(feature = "vhost-user-slave")] + #[test] + fn test_master_slave_req_handler_with_ack() { + let backend = Arc::new(Mutex::new(MockMasterReqHandler {})); + let mut handler = MasterReqHandler::new(backend).unwrap(); + handler.set_reply_ack_flag(true); + + let fd = unsafe { libc::dup(handler.get_tx_raw_fd()) }; + if fd < 0 { + panic!("failed to duplicated tx fd!"); + } + let stream = unsafe { UnixStream::from_raw_fd(fd) }; + let fs_cache = SlaveFsCacheReq::from_stream(stream); + + std::thread::spawn(move || { + let res = handler.handle_request().unwrap(); + assert_eq!(res, 0); + handler.handle_request().unwrap_err(); + }); + + fs_cache.set_reply_ack_flag(true); + fs_cache + .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd) + .unwrap(); + fs_cache + .fs_slave_unmap(&VhostUserFSSlaveMsg::default()) + .unwrap_err(); + } +} diff --git a/src/vhost_user/mod.rs b/src/vhost_user/mod.rs index 91e4203..bf0a261 100644 --- a/src/vhost_user/mod.rs +++ b/src/vhost_user/mod.rs @@ -175,21 +175,32 @@ pub type Result = std::result::Result; /// Result of request handler. pub type HandlerResult = std::result::Result; -#[cfg(all(test, feature = "vhost-user-master", feature = "vhost-user-slave"))] +#[cfg(all(test, feature = "vhost-user-slave"))] mod dummy_slave; #[cfg(all(test, feature = "vhost-user-master", feature = "vhost-user-slave"))] mod tests { + use std::os::unix::io::AsRawFd; + use std::sync::{Arc, Barrier, Mutex}; + use std::thread; + use vmm_sys_util::rand::rand_alphanumerics; + use super::dummy_slave::{DummySlaveReqHandler, VIRTIO_FEATURES}; use super::message::*; use super::*; use crate::backend::VhostBackend; - use std::sync::{Arc, Barrier, Mutex}; - use std::thread; + use crate::{VhostUserMemoryRegionInfo, VringConfigData}; + + fn temp_path() -> String { + format!( + "/tmp/vhost_test_{}", + rand_alphanumerics(8).to_str().unwrap() + ) + } fn create_slave( path: &str, - backend: Arc>, + backend: Arc, ) -> (Master, SlaveReqHandler) { let listener = Listener::new(path, true).unwrap(); let mut slave_listener = SlaveListener::new(listener, backend).unwrap(); @@ -208,8 +219,8 @@ mod tests { #[test] fn test_set_owner() { let slave_be = Arc::new(Mutex::new(DummySlaveReqHandler::new())); - let (master, mut slave) = - create_slave("/tmp/vhost_user_lib_unit_test_owner", slave_be.clone()); + let path = temp_path(); + let (master, mut slave) = create_slave(&path, slave_be.clone()); assert_eq!(slave_be.lock().unwrap().owned, false); master.set_owner().unwrap(); @@ -224,14 +235,60 @@ mod tests { fn test_set_features() { let mbar = Arc::new(Barrier::new(2)); let sbar = mbar.clone(); + let path = temp_path(); + let slave_be = Arc::new(Mutex::new(DummySlaveReqHandler::new())); + let (mut master, mut slave) = create_slave(&path, slave_be.clone()); + + thread::spawn(move || { + slave.handle_request().unwrap(); + assert_eq!(slave_be.lock().unwrap().owned, true); + + slave.handle_request().unwrap(); + slave.handle_request().unwrap(); + assert_eq!( + slave_be.lock().unwrap().acked_features, + VIRTIO_FEATURES & !0x1 + ); + + slave.handle_request().unwrap(); + slave.handle_request().unwrap(); + assert_eq!( + slave_be.lock().unwrap().acked_protocol_features, + VhostUserProtocolFeatures::all().bits() + ); + + sbar.wait(); + }); + + master.set_owner().unwrap(); + + // set virtio features + let features = master.get_features().unwrap(); + assert_eq!(features, VIRTIO_FEATURES); + master.set_features(VIRTIO_FEATURES & !0x1).unwrap(); + + // set vhost protocol features + let features = master.get_protocol_features().unwrap(); + assert_eq!(features.bits(), VhostUserProtocolFeatures::all().bits()); + master.set_protocol_features(features).unwrap(); + + mbar.wait(); + } + + #[test] + fn test_master_slave_process() { + let mbar = Arc::new(Barrier::new(2)); + let sbar = mbar.clone(); + let path = temp_path(); let slave_be = Arc::new(Mutex::new(DummySlaveReqHandler::new())); - let (mut master, mut slave) = - create_slave("/tmp/vhost_user_lib_unit_test_feature", slave_be.clone()); + let (mut master, mut slave) = create_slave(&path, slave_be.clone()); thread::spawn(move || { + // set_own() slave.handle_request().unwrap(); assert_eq!(slave_be.lock().unwrap().owned, true); + // get/set_features() slave.handle_request().unwrap(); slave.handle_request().unwrap(); assert_eq!( @@ -246,6 +303,36 @@ mod tests { VhostUserProtocolFeatures::all().bits() ); + // get_queue_num() + slave.handle_request().unwrap(); + + // set_mem_table() + slave.handle_request().unwrap(); + + // get/set_config() + slave.handle_request().unwrap(); + slave.handle_request().unwrap(); + + // set_slave_request_fd + slave.handle_request().unwrap(); + + // set_vring_enable + slave.handle_request().unwrap(); + + /* + // set_log_base,set_log_fd() + slave.handle_request().unwrap(); + slave.handle_request().unwrap(); + */ + + // set_vring_xxx + slave.handle_request().unwrap(); + slave.handle_request().unwrap(); + slave.handle_request().unwrap(); + slave.handle_request().unwrap(); + slave.handle_request().unwrap(); + slave.handle_request().unwrap(); + sbar.wait(); }); @@ -261,6 +348,83 @@ mod tests { assert_eq!(features.bits(), VhostUserProtocolFeatures::all().bits()); master.set_protocol_features(features).unwrap(); + let num = master.get_queue_num().unwrap(); + assert_eq!(num, 2); + + let eventfd = vmm_sys_util::eventfd::EventFd::new(0).unwrap(); + let mem = [VhostUserMemoryRegionInfo { + guest_phys_addr: 0, + memory_size: 0x10_0000, + userspace_addr: 0, + mmap_offset: 0, + mmap_handle: eventfd.as_raw_fd(), + }]; + master.set_mem_table(&mem).unwrap(); + + master + .set_config(0x100, VhostUserConfigFlags::WRITABLE, &[0xa5u8]) + .unwrap(); + let buf = [0x0u8; 4]; + let (reply_body, reply_payload) = master + .get_config(0x100, 4, VhostUserConfigFlags::empty(), &buf) + .unwrap(); + let offset = reply_body.offset; + assert_eq!(offset, 0x100); + assert_eq!(reply_payload[0], 0xa5); + + master.set_slave_request_fd(eventfd.as_raw_fd()).unwrap(); + master.set_vring_enable(0, true).unwrap(); + + /* + master.set_log_base(0, Some(eventfd.as_raw_fd())).unwrap(); + master.set_log_fd(eventfd.as_raw_fd()).unwrap(); + */ + + master.set_vring_num(0, 256).unwrap(); + master.set_vring_base(0, 0).unwrap(); + let config = VringConfigData { + queue_max_size: 256, + queue_size: 128, + flags: VhostUserVringAddrFlags::VHOST_VRING_F_LOG.bits(), + desc_table_addr: 0x1000, + used_ring_addr: 0x2000, + avail_ring_addr: 0x3000, + log_addr: Some(0x4000), + }; + master.set_vring_addr(0, &config).unwrap(); + master.set_vring_call(0, &eventfd).unwrap(); + master.set_vring_kick(0, &eventfd).unwrap(); + master.set_vring_err(0, &eventfd).unwrap(); + mbar.wait(); } + + #[test] + fn test_error_display() { + assert_eq!(format!("{}", Error::InvalidParam), "invalid parameters"); + assert_eq!(format!("{}", Error::InvalidOperation), "invalid operation"); + } + + #[test] + fn test_should_reconnect() { + assert_eq!(Error::PartialMessage.should_reconnect(), true); + assert_eq!(Error::SlaveInternalError.should_reconnect(), true); + assert_eq!(Error::MasterInternalError.should_reconnect(), true); + assert_eq!(Error::InvalidParam.should_reconnect(), false); + assert_eq!(Error::InvalidOperation.should_reconnect(), false); + assert_eq!(Error::InvalidMessage.should_reconnect(), false); + assert_eq!(Error::IncorrectFds.should_reconnect(), false); + assert_eq!(Error::OversizedMsg.should_reconnect(), false); + assert_eq!(Error::FeatureMismatch.should_reconnect(), false); + } + + #[test] + fn test_error_from_sys_util_error() { + let e: Error = vmm_sys_util::errno::Error::new(libc::EAGAIN.into()).into(); + if let Error::SocketRetry(e1) = e { + assert_eq!(e1.raw_os_error().unwrap(), libc::EAGAIN); + } else { + panic!("invalid error code conversion!"); + } + } } diff --git a/src/vhost_user/slave.rs b/src/vhost_user/slave.rs index c167dce..fb65c41 100644 --- a/src/vhost_user/slave.rs +++ b/src/vhost_user/slave.rs @@ -44,3 +44,43 @@ impl SlaveListener { self.listener.set_nonblocking(block) } } + +#[cfg(test)] +mod tests { + use std::sync::Mutex; + + use super::*; + use crate::vhost_user::dummy_slave::DummySlaveReqHandler; + + #[test] + fn test_slave_listener_set_nonblocking() { + let backend = Arc::new(Mutex::new(DummySlaveReqHandler::new())); + let listener = + Listener::new("/tmp/vhost_user_lib_unit_test_slave_nonblocking", true).unwrap(); + let slave_listener = SlaveListener::new(listener, backend).unwrap(); + + slave_listener.set_nonblocking(true).unwrap(); + slave_listener.set_nonblocking(false).unwrap(); + slave_listener.set_nonblocking(false).unwrap(); + slave_listener.set_nonblocking(true).unwrap(); + slave_listener.set_nonblocking(true).unwrap(); + } + + #[cfg(feature = "vhost-user-master")] + #[test] + fn test_slave_listener_accept() { + use super::super::Master; + + let path = "/tmp/vhost_user_lib_unit_test_slave_accept"; + let backend = Arc::new(Mutex::new(DummySlaveReqHandler::new())); + let listener = Listener::new(path, true).unwrap(); + let mut slave_listener = SlaveListener::new(listener, backend).unwrap(); + + slave_listener.set_nonblocking(true).unwrap(); + assert!(slave_listener.accept().unwrap().is_none()); + assert!(slave_listener.accept().unwrap().is_none()); + + let _master = Master::connect(path, 1).unwrap(); + let _slave = slave_listener.accept().unwrap().unwrap(); + } +} diff --git a/src/vhost_user/slave_req_handler.rs b/src/vhost_user/slave_req_handler.rs index 344afff..ff07304 100644 --- a/src/vhost_user/slave_req_handler.rs +++ b/src/vhost_user/slave_req_handler.rs @@ -743,3 +743,24 @@ impl AsRawFd for SlaveReqHandler { self.main_sock.as_raw_fd() } } + +#[cfg(test)] +mod tests { + use std::os::unix::io::AsRawFd; + + use super::*; + use crate::vhost_user::dummy_slave::DummySlaveReqHandler; + + #[test] + fn test_slave_req_handler_new() { + let (p1, _p2) = UnixStream::pair().unwrap(); + let endpoint = Endpoint::::from_stream(p1); + let backend = Arc::new(Mutex::new(DummySlaveReqHandler::new())); + let mut handler = SlaveReqHandler::new(endpoint, backend); + + handler.check_state().unwrap(); + handler.set_failed(libc::EAGAIN); + handler.check_state().unwrap_err(); + assert!(handler.as_raw_fd() >= 0); + } +} -- cgit v1.2.3 From e543bf2a8defadae534cfe757829a9377663e5ad Mon Sep 17 00:00:00 2001 From: Liu Jiang Date: Mon, 22 Feb 2021 14:08:04 +0800 Subject: vhost_user: add more negative unit test cases Add more negative unit test cases to improve code coverage. Also add two helper functions to simplify code. Signed-off-by: Liu Jiang --- coverage_config_x86_64.json | 2 +- src/vhost_user/master.rs | 257 +++++++++++++++++++++++++++++++++++---- src/vhost_user/mod.rs | 11 +- src/vhost_user/slave_fs_cache.rs | 65 ++++++++-- 4 files changed, 295 insertions(+), 40 deletions(-) diff --git a/coverage_config_x86_64.json b/coverage_config_x86_64.json index 075255b..a4ed64f 100644 --- a/coverage_config_x86_64.json +++ b/coverage_config_x86_64.json @@ -1 +1 @@ -{"coverage_score": 78.9, "exclude_path": "src/vhost_kern/", "crate_features": "vhost-user-master,vhost-user-slave"} +{"coverage_score": 81.3, "exclude_path": "src/vhost_kern/", "crate_features": "vhost-user-master,vhost-user-slave"} \ No newline at end of file diff --git a/src/vhost_user/master.rs b/src/vhost_user/master.rs index cc754b5..be2892e 100644 --- a/src/vhost_user/master.rs +++ b/src/vhost_user/master.rs @@ -6,7 +6,7 @@ use std::mem; use std::os::unix::io::{AsRawFd, RawFd}; use std::os::unix::net::UnixStream; -use std::sync::{Arc, Mutex}; +use std::sync::{Arc, Mutex, MutexGuard}; use vmm_sys_util::eventfd::EventFd; @@ -78,6 +78,10 @@ impl Master { } } + fn node(&self) -> MutexGuard { + self.node.lock().unwrap() + } + /// Create a new instance from a Unix stream socket. pub fn from_stream(sock: UnixStream, max_queue_num: u64) -> Self { Self::new(Endpoint::::from_stream(sock), max_queue_num) @@ -116,7 +120,7 @@ impl Master { impl VhostBackend for Master { /// Get from the underlying vhost implementation the feature bitmask. fn get_features(&self) -> Result { - let mut node = self.node.lock().unwrap(); + let mut node = self.node(); let hdr = node.send_request_header(MasterReq::GET_FEATURES, None)?; let val = node.recv_reply::(&hdr)?; node.virtio_features = val.value; @@ -125,7 +129,7 @@ impl VhostBackend for Master { /// Enable features in the underlying vhost implementation using a bitmask. fn set_features(&self, features: u64) -> Result<()> { - let mut node = self.node.lock().unwrap(); + let mut node = self.node(); let val = VhostUserU64::new(features); let _ = node.send_request_with_body(MasterReq::SET_FEATURES, &val, None)?; // Don't wait for ACK here because the protocol feature negotiation process hasn't been @@ -138,7 +142,7 @@ impl VhostBackend for Master { fn set_owner(&self) -> Result<()> { // We unwrap() the return value to assert that we are not expecting threads to ever fail // while holding the lock. - let mut node = self.node.lock().unwrap(); + let mut node = self.node(); let _ = node.send_request_header(MasterReq::SET_OWNER, None)?; // Don't wait for ACK here because the protocol feature negotiation process hasn't been // completed yet. @@ -146,7 +150,7 @@ impl VhostBackend for Master { } fn reset_owner(&self) -> Result<()> { - let mut node = self.node.lock().unwrap(); + let mut node = self.node(); let _ = node.send_request_header(MasterReq::RESET_OWNER, None)?; // Don't wait for ACK here because the protocol feature negotiation process hasn't been // completed yet. @@ -174,7 +178,7 @@ impl VhostBackend for Master { ctx.append(®, region.mmap_handle); } - let mut node = self.node.lock().unwrap(); + let mut node = self.node(); let body = VhostUserMemory::new(ctx.regions.len() as u32); let (_, payload, _) = unsafe { ctx.regions.align_to::() }; let hdr = node.send_request_with_payload( @@ -189,7 +193,7 @@ impl VhostBackend for Master { // Clippy doesn't seem to know that if let with && is still experimental #[allow(clippy::unnecessary_unwrap)] fn set_log_base(&self, base: u64, fd: Option) -> Result<()> { - let mut node = self.node.lock().unwrap(); + let mut node = self.node(); let val = VhostUserU64::new(base); if node.acked_protocol_features & VhostUserProtocolFeatures::LOG_SHMFD.bits() != 0 @@ -204,7 +208,7 @@ impl VhostBackend for Master { } fn set_log_fd(&self, fd: RawFd) -> Result<()> { - let mut node = self.node.lock().unwrap(); + let mut node = self.node(); let fds = [fd]; node.send_request_header(MasterReq::SET_LOG_FD, Some(&fds))?; Ok(()) @@ -212,7 +216,7 @@ impl VhostBackend for Master { /// Set the size of the queue. fn set_vring_num(&self, queue_index: usize, num: u16) -> Result<()> { - let mut node = self.node.lock().unwrap(); + let mut node = self.node(); if queue_index as u64 >= node.max_queue_num { return error_code(VhostUserError::InvalidParam); } @@ -224,7 +228,7 @@ impl VhostBackend for Master { /// Sets the addresses of the different aspects of the vring. fn set_vring_addr(&self, queue_index: usize, config_data: &VringConfigData) -> Result<()> { - let mut node = self.node.lock().unwrap(); + let mut node = self.node(); if queue_index as u64 >= node.max_queue_num || config_data.flags & !(VhostUserVringAddrFlags::all().bits()) != 0 { @@ -238,7 +242,7 @@ impl VhostBackend for Master { /// Sets the base offset in the available vring. fn set_vring_base(&self, queue_index: usize, base: u16) -> Result<()> { - let mut node = self.node.lock().unwrap(); + let mut node = self.node(); if queue_index as u64 >= node.max_queue_num { return error_code(VhostUserError::InvalidParam); } @@ -249,7 +253,7 @@ impl VhostBackend for Master { } fn get_vring_base(&self, queue_index: usize) -> Result { - let mut node = self.node.lock().unwrap(); + let mut node = self.node(); if queue_index as u64 >= node.max_queue_num { return error_code(VhostUserError::InvalidParam); } @@ -265,7 +269,7 @@ impl VhostBackend for Master { /// is set when there is no file descriptor in the ancillary data. This signals that polling /// will be used instead of waiting for the call. fn set_vring_call(&self, queue_index: usize, fd: &EventFd) -> Result<()> { - let mut node = self.node.lock().unwrap(); + let mut node = self.node(); if queue_index as u64 >= node.max_queue_num { return error_code(VhostUserError::InvalidParam); } @@ -278,7 +282,7 @@ impl VhostBackend for Master { /// is set when there is no file descriptor in the ancillary data. This signals that polling /// should be used instead of waiting for a kick. fn set_vring_kick(&self, queue_index: usize, fd: &EventFd) -> Result<()> { - let mut node = self.node.lock().unwrap(); + let mut node = self.node(); if queue_index as u64 >= node.max_queue_num { return error_code(VhostUserError::InvalidParam); } @@ -290,7 +294,7 @@ impl VhostBackend for Master { /// Bits (0-7) of the payload contain the vring index. Bit 8 is the invalid FD flag. This flag /// is set when there is no file descriptor in the ancillary data. fn set_vring_err(&self, queue_index: usize, fd: &EventFd) -> Result<()> { - let mut node = self.node.lock().unwrap(); + let mut node = self.node(); if queue_index as u64 >= node.max_queue_num { return error_code(VhostUserError::InvalidParam); } @@ -301,7 +305,7 @@ impl VhostBackend for Master { impl VhostUserMaster for Master { fn get_protocol_features(&mut self) -> Result { - let mut node = self.node.lock().unwrap(); + let mut node = self.node(); let flag = VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits(); if node.virtio_features & flag == 0 || node.acked_virtio_features & flag == 0 { return error_code(VhostUserError::InvalidOperation); @@ -318,7 +322,7 @@ impl VhostUserMaster for Master { } fn set_protocol_features(&mut self, features: VhostUserProtocolFeatures) -> Result<()> { - let mut node = self.node.lock().unwrap(); + let mut node = self.node(); let flag = VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits(); if node.virtio_features & flag == 0 || node.acked_virtio_features & flag == 0 { return error_code(VhostUserError::InvalidOperation); @@ -333,7 +337,7 @@ impl VhostUserMaster for Master { } fn get_queue_num(&mut self) -> Result { - let mut node = self.node.lock().unwrap(); + let mut node = self.node(); if !node.is_feature_mq_available() { return error_code(VhostUserError::InvalidOperation); } @@ -348,7 +352,7 @@ impl VhostUserMaster for Master { } fn set_vring_enable(&mut self, queue_index: usize, enable: bool) -> Result<()> { - let mut node = self.node.lock().unwrap(); + let mut node = self.node(); // set_vring_enable() is supported only when PROTOCOL_FEATURES has been enabled. if node.acked_virtio_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() == 0 { return error_code(VhostUserError::InvalidOperation); @@ -374,7 +378,7 @@ impl VhostUserMaster for Master { return error_code(VhostUserError::InvalidParam); } - let mut node = self.node.lock().unwrap(); + let mut node = self.node(); // depends on VhostUserProtocolFeatures::CONFIG if node.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 { return error_code(VhostUserError::InvalidOperation); @@ -409,7 +413,7 @@ impl VhostUserMaster for Master { return error_code(VhostUserError::InvalidParam); } - let mut node = self.node.lock().unwrap(); + let mut node = self.node(); // depends on VhostUserProtocolFeatures::CONFIG if node.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 { return error_code(VhostUserError::InvalidOperation); @@ -420,7 +424,7 @@ impl VhostUserMaster for Master { } fn set_slave_request_fd(&mut self, fd: RawFd) -> Result<()> { - let mut node = self.node.lock().unwrap(); + let mut node = self.node(); if node.acked_protocol_features & VhostUserProtocolFeatures::SLAVE_REQ.bits() == 0 { return error_code(VhostUserError::InvalidOperation); } @@ -433,7 +437,7 @@ impl VhostUserMaster for Master { impl AsRawFd for Master { fn as_raw_fd(&self) -> RawFd { - let node = self.node.lock().unwrap(); + let node = self.node(); node.main_sock.as_raw_fd() } } @@ -783,4 +787,211 @@ mod tests { peer.send_message(&hdr, &msg, None).unwrap(); assert!(master.get_protocol_features().is_err()); } + + #[test] + fn test_master_set_config_negative() { + let path = temp_path(); + let (mut master, _peer) = create_pair(&path); + let buf = vec![0x0; MAX_MSG_SIZE + 1]; + + master + .set_config(0x100, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .unwrap_err(); + + { + let mut node = master.node(); + node.virtio_features = 0xffff_ffff; + node.acked_virtio_features = 0xffff_ffff; + node.protocol_features = 0xffff_ffff; + node.acked_protocol_features = 0xffff_ffff; + } + + master + .set_config(0x100, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .unwrap(); + master + .set_config(0x0, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .unwrap_err(); + master + .set_config(0x1000, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .unwrap_err(); + master + .set_config( + 0x100, + unsafe { VhostUserConfigFlags::from_bits_unchecked(0xffff_ffff) }, + &buf[0..4], + ) + .unwrap_err(); + master + .set_config(0x100, VhostUserConfigFlags::WRITABLE, &buf) + .unwrap_err(); + master + .set_config(0x100, VhostUserConfigFlags::WRITABLE, &[]) + .unwrap_err(); + } + + fn create_pair2() -> (Master, Endpoint) { + let path = temp_path(); + let (master, peer) = create_pair(&path); + + { + let mut node = master.node(); + node.virtio_features = 0xffff_ffff; + node.acked_virtio_features = 0xffff_ffff; + node.protocol_features = 0xffff_ffff; + node.acked_protocol_features = 0xffff_ffff; + } + + (master, peer) + } + + #[test] + fn test_master_get_config_negative0() { + let (mut master, mut peer) = create_pair2(); + let buf = vec![0x0; MAX_MSG_SIZE + 1]; + + let mut hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16); + let msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty()); + peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None) + .unwrap(); + assert!(master + .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .is_ok()); + + hdr.set_code(MasterReq::GET_FEATURES); + peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None) + .unwrap(); + assert!(master + .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .is_err()); + hdr.set_code(MasterReq::GET_CONFIG); + } + + #[test] + fn test_master_get_config_negative1() { + let (mut master, mut peer) = create_pair2(); + let buf = vec![0x0; MAX_MSG_SIZE + 1]; + + let mut hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16); + let msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty()); + peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None) + .unwrap(); + assert!(master + .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .is_ok()); + + hdr.set_reply(false); + peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None) + .unwrap(); + assert!(master + .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .is_err()); + } + + #[test] + fn test_master_get_config_negative2() { + let (mut master, mut peer) = create_pair2(); + let buf = vec![0x0; MAX_MSG_SIZE + 1]; + + let hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16); + let msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty()); + peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None) + .unwrap(); + assert!(master + .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .is_ok()); + } + + #[test] + fn test_master_get_config_negative3() { + let (mut master, mut peer) = create_pair2(); + let buf = vec![0x0; MAX_MSG_SIZE + 1]; + + let hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16); + let mut msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty()); + peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None) + .unwrap(); + assert!(master + .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .is_ok()); + + msg.offset = 0; + peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None) + .unwrap(); + assert!(master + .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .is_err()); + } + + #[test] + fn test_master_get_config_negative4() { + let (mut master, mut peer) = create_pair2(); + let buf = vec![0x0; MAX_MSG_SIZE + 1]; + + let hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16); + let mut msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty()); + peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None) + .unwrap(); + assert!(master + .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .is_ok()); + + msg.offset = 0x101; + peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None) + .unwrap(); + assert!(master + .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .is_err()); + } + + #[test] + fn test_master_get_config_negative5() { + let (mut master, mut peer) = create_pair2(); + let buf = vec![0x0; MAX_MSG_SIZE + 1]; + + let hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16); + let mut msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty()); + peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None) + .unwrap(); + assert!(master + .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .is_ok()); + + msg.offset = (MAX_MSG_SIZE + 1) as u32; + peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None) + .unwrap(); + assert!(master + .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .is_err()); + } + + #[test] + fn test_master_get_config_negative6() { + let (mut master, mut peer) = create_pair2(); + let buf = vec![0x0; MAX_MSG_SIZE + 1]; + + let hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16); + let mut msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty()); + peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None) + .unwrap(); + assert!(master + .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .is_ok()); + + msg.size = 6; + peer.send_message_with_payload(&hdr, &msg, &buf[0..6], None) + .unwrap(); + assert!(master + .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .is_err()); + } + + #[test] + fn test_maset_set_mem_table_failure() { + let (master, _peer) = create_pair2(); + + master.set_mem_table(&[]).unwrap_err(); + let tables = vec![VhostUserMemoryRegionInfo::default(); MAX_ATTACHED_FD_ENTRIES + 1]; + master.set_mem_table(&tables).unwrap_err(); + } } diff --git a/src/vhost_user/mod.rs b/src/vhost_user/mod.rs index bf0a261..bc21b44 100644 --- a/src/vhost_user/mod.rs +++ b/src/vhost_user/mod.rs @@ -210,7 +210,7 @@ mod tests { #[test] fn create_dummy_slave() { - let mut slave = DummySlaveReqHandler::new(); + let slave = Arc::new(Mutex::new(DummySlaveReqHandler::new())); slave.set_owner().unwrap(); assert!(slave.set_owner().is_err()); @@ -319,11 +319,9 @@ mod tests { // set_vring_enable slave.handle_request().unwrap(); - /* // set_log_base,set_log_fd() - slave.handle_request().unwrap(); - slave.handle_request().unwrap(); - */ + slave.handle_request().unwrap_err(); + slave.handle_request().unwrap_err(); // set_vring_xxx slave.handle_request().unwrap(); @@ -375,10 +373,9 @@ mod tests { master.set_slave_request_fd(eventfd.as_raw_fd()).unwrap(); master.set_vring_enable(0, true).unwrap(); - /* + // unimplemented yet master.set_log_base(0, Some(eventfd.as_raw_fd())).unwrap(); master.set_log_fd(eventfd.as_raw_fd()).unwrap(); - */ master.set_vring_num(0, 256).unwrap(); master.set_vring_base(0, 0).unwrap(); diff --git a/src/vhost_user/slave_fs_cache.rs b/src/vhost_user/slave_fs_cache.rs index 32b2b8e..1e2ef61 100644 --- a/src/vhost_user/slave_fs_cache.rs +++ b/src/vhost_user/slave_fs_cache.rs @@ -5,7 +5,7 @@ use std::io; use std::mem; use std::os::unix::io::RawFd; use std::os::unix::net::UnixStream; -use std::sync::{Arc, Mutex}; +use std::sync::{Arc, Mutex, MutexGuard}; use super::connection::Endpoint; use super::message::*; @@ -92,15 +92,17 @@ impl SlaveFsCacheReq { } } + fn node(&self) -> MutexGuard { + self.node.lock().unwrap() + } + fn send_message( &self, request: SlaveReq, fs: &VhostUserFSSlaveMsg, fds: Option<&[RawFd]>, ) -> io::Result { - self.node - .lock() - .unwrap() + self.node() .send_message(request, fs, fds) .or_else(|e| Err(io::Error::new(io::ErrorKind::Other, format!("{}", e)))) } @@ -116,12 +118,12 @@ impl SlaveFsCacheReq { /// the "REPLY_ACK" flag will be set in the message header for every slave to master request /// message. pub fn set_reply_ack_flag(&self, enable: bool) { - self.node.lock().unwrap().reply_ack_negotiated = enable; + self.node().reply_ack_negotiated = enable; } /// Mark endpoint as failed with specified error code. pub fn set_failed(&self, error: i32) { - self.node.lock().unwrap().error = Some(error); + self.node().error = Some(error); } } @@ -148,9 +150,9 @@ mod tests { let (p1, _p2) = UnixStream::pair().unwrap(); let fs_cache = SlaveFsCacheReq::from_stream(p1); - assert!(fs_cache.node.lock().unwrap().error.is_none()); + assert!(fs_cache.node().error.is_none()); fs_cache.set_failed(libc::EAGAIN); - assert_eq!(fs_cache.node.lock().unwrap().error, Some(libc::EAGAIN)); + assert_eq!(fs_cache.node().error, Some(libc::EAGAIN)); } #[test] @@ -166,7 +168,7 @@ mod tests { fs_cache .fs_slave_unmap(&VhostUserFSSlaveMsg::default()) .unwrap_err(); - fs_cache.node.lock().unwrap().error = None; + fs_cache.node().error = None; drop(p2); fs_cache @@ -176,4 +178,49 @@ mod tests { .fs_slave_unmap(&VhostUserFSSlaveMsg::default()) .unwrap_err(); } + + #[test] + fn test_slave_fs_cache_recv_negative() { + let (p1, p2) = UnixStream::pair().unwrap(); + let fd = p2.as_raw_fd(); + let fs_cache = SlaveFsCacheReq::from_stream(p1); + let mut master = Endpoint::::from_stream(p2); + + let len = mem::size_of::(); + let mut hdr = VhostUserMsgHeader::new( + SlaveReq::FS_MAP, + VhostUserHeaderFlag::REPLY.bits(), + len as u32, + ); + let body = VhostUserU64::new(0); + + master.send_message(&hdr, &body, Some(&[fd])).unwrap(); + fs_cache + .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd) + .unwrap(); + + fs_cache.set_reply_ack_flag(true); + fs_cache + .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd) + .unwrap_err(); + + hdr.set_code(SlaveReq::FS_UNMAP); + master.send_message(&hdr, &body, None).unwrap(); + fs_cache + .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd) + .unwrap_err(); + hdr.set_code(SlaveReq::FS_MAP); + + let body = VhostUserU64::new(1); + master.send_message(&hdr, &body, None).unwrap(); + fs_cache + .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd) + .unwrap_err(); + + let body = VhostUserU64::new(0); + master.send_message(&hdr, &body, None).unwrap(); + fs_cache + .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd) + .unwrap(); + } } -- cgit v1.2.3 From 576694bcfb09e7c78d812ed07dbc5377d283852a Mon Sep 17 00:00:00 2001 From: Liu Jiang Date: Mon, 22 Feb 2021 13:25:59 +0800 Subject: Prepare for publishing to crates.io Prepare for publishing to crates.io, 1) update README.md 2) update Cargo.toml 3) set code owners It should be ready for publishing now. Signed-off-by: Liu Jiang --- CODEOWNERS | 2 +- Cargo.toml | 4 + README.md | 12 ++- docs/vhost_architecture.drawio | 171 +++++++++++++++++++++++++++++++++++++++++ docs/vhost_architecture.png | Bin 0 -> 146074 bytes 5 files changed, 186 insertions(+), 3 deletions(-) create mode 100644 docs/vhost_architecture.drawio create mode 100644 docs/vhost_architecture.png diff --git a/CODEOWNERS b/CODEOWNERS index 4d96c3f..7174a1b 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -1,2 +1,2 @@ # Add the list of code owners here (using their GitHub username) -* gatekeeper-PullAssigner +* gatekeeper-PullAssigner @jiangliu @eryugey @sboeuf @slp diff --git a/Cargo.toml b/Cargo.toml index b8609f5..0cd15f7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,8 +1,12 @@ [package] name = "vhost" version = "0.1.0" +keywords = ["vhost", "vhost-user", "virtio", "vdpa"] +description = "a pure rust library for vdpa, vhost and vhost-user" authors = ["Liu Jiang "] repository = "https://github.com/rust-vmm/vhost" +documentation = "https://docs.rs/vhost" +readme = "README.md" license = "Apache-2.0 or BSD-3-Clause" edition = "2018" diff --git a/README.md b/README.md index c1c2ab6..b0f4dfa 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,14 @@ # vHost -A crate to support vhost backend drivers for virtio devices. +A pure rust library for vDPA, vhost and vhost-user. +The `vhost` crate aims to help implementing dataplane for virtio backend drivers. It supports three different types of dataplane drivers: +- vhost: the dataplane is implemented by linux kernel +- vhost-user: the dataplane is implemented by dedicated vhost-user servers +- vDPA(vhost DataPath Accelerator): the dataplane is implemented by hardwares + +The main relationship among Traits and Structs exported by the `vhost` crate is as below: + +![vhost Architecture](/docs/vhost_architecture.png) ## Kernel-based vHost Backend Drivers The vhost drivers in Linux provide in-kernel virtio device emulation. Normally the hypervisor userspace process emulates I/O accesses from the guest. @@ -11,7 +19,7 @@ The hypervisor relies on ioctl based interfaces to control those in-kernel vhost drivers, such as vhost-net, vhost-scsi and vhost-vsock etc. ## vHost-user Backend Drivers -The vhost-user protocol is aiming to implement vhost backend drivers in +The [vhost-user protocol](https://qemu.readthedocs.io/en/latest/interop/vhost-user.html#communication) aims to implement vhost backend drivers in userspace, which complements the ioctl interface used to control the vhost implementation in the Linux kernel. It implements the control plane needed to establish virtqueue sharing with a user space process on the same host. diff --git a/docs/vhost_architecture.drawio b/docs/vhost_architecture.drawio new file mode 100644 index 0000000..5008d28 --- /dev/null +++ b/docs/vhost_architecture.drawio @@ -0,0 +1,171 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/docs/vhost_architecture.png b/docs/vhost_architecture.png new file mode 100644 index 0000000..4d1e2bc Binary files /dev/null and b/docs/vhost_architecture.png differ -- cgit v1.2.3 From 62fd4ec5a47d1b9cd1ea2c8702228ea216b92686 Mon Sep 17 00:00:00 2001 From: Keiichi Watanabe Date: Wed, 24 Feb 2021 21:21:32 +0900 Subject: Fix clippy erros and warnings Make `cargo clippy --all-features --all-targets` pass. Signed-off-by: Keiichi Watanabe --- src/vhost_kern/mod.rs | 16 +++++++++------- src/vhost_kern/vhost_binding.rs | 1 + src/vhost_kern/vsock.rs | 1 - src/vhost_user/dummy_slave.rs | 10 +++------- src/vhost_user/master.rs | 7 ++++--- src/vhost_user/master_req_handler.rs | 3 +-- src/vhost_user/message.rs | 2 +- src/vhost_user/mod.rs | 2 +- src/vhost_user/slave_fs_cache.rs | 2 +- src/vhost_user/slave_req_handler.rs | 5 +---- 10 files changed, 22 insertions(+), 27 deletions(-) diff --git a/src/vhost_kern/mod.rs b/src/vhost_kern/mod.rs index f263a39..f82cbfc 100644 --- a/src/vhost_kern/mod.rs +++ b/src/vhost_kern/mod.rs @@ -63,20 +63,22 @@ pub trait VhostKernBackend: AsRawFd { .checked_add(desc_table_size) .map_or(true, |v| !m.address_in_range(v)) { - false - } else if GuestAddress(config_data.avail_ring_addr) + return false; + } + if GuestAddress(config_data.avail_ring_addr) .checked_add(avail_ring_size) .map_or(true, |v| !m.address_in_range(v)) { - false - } else if GuestAddress(config_data.used_ring_addr) + return false; + } + if GuestAddress(config_data.used_ring_addr) .checked_add(used_ring_size) .map_or(true, |v| !m.address_in_range(v)) { - false - } else { - config_data.is_log_addr_valid() + return false; } + + config_data.is_log_addr_valid() } } diff --git a/src/vhost_kern/vhost_binding.rs b/src/vhost_kern/vhost_binding.rs index fdc5225..57ae698 100644 --- a/src/vhost_kern/vhost_binding.rs +++ b/src/vhost_kern/vhost_binding.rs @@ -13,6 +13,7 @@ #![allow(non_camel_case_types)] #![allow(non_snake_case)] #![allow(missing_docs)] +#![allow(clippy::missing_safety_doc)] use crate::{Error, Result}; use std::os::raw; diff --git a/src/vhost_kern/vsock.rs b/src/vhost_kern/vsock.rs index 7cc1cf5..65f89e4 100644 --- a/src/vhost_kern/vsock.rs +++ b/src/vhost_kern/vsock.rs @@ -11,7 +11,6 @@ use std::fs::{File, OpenOptions}; use std::os::unix::fs::OpenOptionsExt; use std::os::unix::io::{AsRawFd, RawFd}; -use libc; use vm_memory::GuestAddressSpace; use vmm_sys_util::ioctl::ioctl_with_ref; diff --git a/src/vhost_user/dummy_slave.rs b/src/vhost_user/dummy_slave.rs index 99f08e7..9eedcbb 100644 --- a/src/vhost_user/dummy_slave.rs +++ b/src/vhost_user/dummy_slave.rs @@ -57,9 +57,7 @@ impl VhostUserSlaveReqHandlerMut for DummySlaveReqHandler { } fn set_features(&mut self, features: u64) -> Result<()> { - if !self.owned { - return Err(Error::InvalidOperation); - } else if self.features_acked { + if !self.owned || self.features_acked { return Err(Error::InvalidOperation); } else if (features & !VIRTIO_FEATURES) != 0 { return Err(Error::InvalidParam); @@ -224,8 +222,7 @@ impl VhostUserSlaveReqHandlerMut for DummySlaveReqHandler { ) -> Result> { if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 { return Err(Error::InvalidOperation); - } else if offset < VHOST_USER_CONFIG_OFFSET - || offset >= VHOST_USER_CONFIG_SIZE + } else if !(VHOST_USER_CONFIG_OFFSET..VHOST_USER_CONFIG_SIZE).contains(&offset) || size > VHOST_USER_CONFIG_SIZE - VHOST_USER_CONFIG_OFFSET || size + offset > VHOST_USER_CONFIG_SIZE { @@ -238,8 +235,7 @@ impl VhostUserSlaveReqHandlerMut for DummySlaveReqHandler { let size = buf.len() as u32; if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 { return Err(Error::InvalidOperation); - } else if offset < VHOST_USER_CONFIG_OFFSET - || offset >= VHOST_USER_CONFIG_SIZE + } else if !(VHOST_USER_CONFIG_OFFSET..VHOST_USER_CONFIG_SIZE).contains(&offset) || size > VHOST_USER_CONFIG_SIZE - VHOST_USER_CONFIG_OFFSET || size + offset > VHOST_USER_CONFIG_SIZE { diff --git a/src/vhost_user/master.rs b/src/vhost_user/master.rs index be2892e..35ca471 100644 --- a/src/vhost_user/master.rs +++ b/src/vhost_user/master.rs @@ -395,9 +395,10 @@ impl VhostUserMaster for Master { return error_code(VhostUserError::InvalidMessage); } else if body_reply.size == 0 { return error_code(VhostUserError::SlaveInternalError); - } else if body_reply.size != body.size || body_reply.size as usize != buf.len() { - return error_code(VhostUserError::InvalidMessage); - } else if body_reply.offset != body.offset { + } else if body_reply.size != body.size + || body_reply.size as usize != buf.len() + || body_reply.offset != body.offset + { return error_code(VhostUserError::InvalidMessage); } diff --git a/src/vhost_user/master_req_handler.rs b/src/vhost_user/master_req_handler.rs index fb33f15..8cba188 100644 --- a/src/vhost_user/master_req_handler.rs +++ b/src/vhost_user/master_req_handler.rs @@ -1,7 +1,6 @@ // Copyright (C) 2019-2021 Alibaba Cloud. All rights reserved. // SPDX-License-Identifier: Apache-2.0 -use libc; use std::mem; use std::os::unix::io::{AsRawFd, RawFd}; use std::os::unix::net::UnixStream; @@ -310,7 +309,7 @@ impl MasterReqHandler { _ => { if rfds.is_some() { Endpoint::::close_rfds(rfds); - return Err(Error::InvalidMessage); + Err(Error::InvalidMessage) } else { Ok(rfds) } diff --git a/src/vhost_user/message.rs b/src/vhost_user/message.rs index 0f6e1a2..8600410 100644 --- a/src/vhost_user/message.rs +++ b/src/vhost_user/message.rs @@ -809,7 +809,7 @@ mod tests { msg.guest_phys_addr = 0xFFFFFFFFFFFF0000; msg.memory_size = 0; assert!(!msg.is_valid()); - let a = msg.clone().guest_phys_addr; + let a = msg.guest_phys_addr; let b = msg.guest_phys_addr; assert_eq!(a, b); diff --git a/src/vhost_user/mod.rs b/src/vhost_user/mod.rs index bc21b44..6a5b6a1 100644 --- a/src/vhost_user/mod.rs +++ b/src/vhost_user/mod.rs @@ -417,7 +417,7 @@ mod tests { #[test] fn test_error_from_sys_util_error() { - let e: Error = vmm_sys_util::errno::Error::new(libc::EAGAIN.into()).into(); + let e: Error = vmm_sys_util::errno::Error::new(libc::EAGAIN).into(); if let Error::SocketRetry(e1) = e { assert_eq!(e1.raw_os_error().unwrap(), libc::EAGAIN); } else { diff --git a/src/vhost_user/slave_fs_cache.rs b/src/vhost_user/slave_fs_cache.rs index 1e2ef61..a9c4ed2 100644 --- a/src/vhost_user/slave_fs_cache.rs +++ b/src/vhost_user/slave_fs_cache.rs @@ -104,7 +104,7 @@ impl SlaveFsCacheReq { ) -> io::Result { self.node() .send_message(request, fs, fds) - .or_else(|e| Err(io::Error::new(io::ErrorKind::Other, format!("{}", e)))) + .map_err(|e| io::Error::new(io::ErrorKind::Other, format!("{}", e))) } /// Create a new instance from a `UnixStream` object. diff --git a/src/vhost_user/slave_req_handler.rs b/src/vhost_user/slave_req_handler.rs index ff07304..3b44e4c 100644 --- a/src/vhost_user/slave_req_handler.rs +++ b/src/vhost_user/slave_req_handler.rs @@ -579,10 +579,7 @@ impl SlaveReqHandler { // invalid FD flag. This flag is set when there is no file descriptor // in the ancillary data. This signals that polling will be used // instead of waiting for the call. - let nofd = match msg.value & 0x100u64 { - 0x100u64 => true, - _ => false, - }; + let nofd = (msg.value & 0x100u64) == 0x100u64; let mut rfd = None; match rfds { -- cgit v1.2.3