summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndroid Build Coastguard Worker <android-build-coastguard-worker@google.com>2021-11-06 01:06:09 +0000
committerAndroid Build Coastguard Worker <android-build-coastguard-worker@google.com>2021-11-06 01:06:09 +0000
commitf3150451588772687f16a18fc4c8aabb13f523da (patch)
treed61c753f3a2deb9f1c334c0a593d7a81c76709fa
parentc1e5548579e09a0de2905ea16395fb0f70624c22 (diff)
parent09ae401741a8178327a1bc5b6188efef18a25600 (diff)
downloadvmm_vhost-android12L-d2-release.tar.gz
Change-Id: I0f9098e79735efbd80cb27b634a46ae571881ad1
-rw-r--r--.buildkite/pipeline.yml17
-rw-r--r--.github/dependabot.yml7
-rw-r--r--.gitmodules3
-rw-r--r--Android.bp6
-rw-r--r--Cargo.toml9
-rw-r--r--METADATA6
-rw-r--r--coverage_config_x86_64.json2
m---------rust-vmm-ci0
-rw-r--r--src/vhost_user/connection.rs152
-rw-r--r--src/vhost_user/dummy_slave.rs52
-rw-r--r--src/vhost_user/master.rs150
-rw-r--r--src/vhost_user/master_req_handler.rs78
-rw-r--r--src/vhost_user/message.rs210
-rw-r--r--src/vhost_user/mod.rs31
-rw-r--r--src/vhost_user/slave_fs_cache.rs33
-rw-r--r--src/vhost_user/slave_req_handler.rs262
16 files changed, 649 insertions, 369 deletions
diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml
new file mode 100644
index 0000000..0e77e1f
--- /dev/null
+++ b/.buildkite/pipeline.yml
@@ -0,0 +1,17 @@
+# Copyright 2021 The Chromium OS Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE-BSD-Google file.
+
+steps:
+ - label: "clippy-x86-custom"
+ commands:
+ - cargo clippy --all-features --all-targets --workspace -- -D warnings
+ retry:
+ automatic: false
+ agents:
+ platform: x86_64.metal
+ os: linux
+ plugins:
+ - docker#v3.0.1:
+ image: "rustvmm/dev:v12"
+ always-pull: true
diff --git a/.github/dependabot.yml b/.github/dependabot.yml
new file mode 100644
index 0000000..4fcd556
--- /dev/null
+++ b/.github/dependabot.yml
@@ -0,0 +1,7 @@
+version: 2
+updates:
+- package-ecosystem: gitsubmodule
+ directory: "/"
+ schedule:
+ interval: daily
+ open-pull-requests-limit: 10
diff --git a/.gitmodules b/.gitmodules
deleted file mode 100644
index bda97eb..0000000
--- a/.gitmodules
+++ /dev/null
@@ -1,3 +0,0 @@
-[submodule "rust-vmm-ci"]
- path = rust-vmm-ci
- url = https://github.com/rust-vmm/rust-vmm-ci.git
diff --git a/Android.bp b/Android.bp
index 3ca35d8..2f38b2b 100644
--- a/Android.bp
+++ b/Android.bp
@@ -1,4 +1,4 @@
-// This file is generated by cargo2android.py --run.
+// This file is generated by cargo2android.py --run --device --features default,vhost-user,vhost-user-master,vhost-user-slave --global_defaults crosvm_defaults.
// Do not modify this file as changes will be overridden on upgrade.
package {
@@ -40,8 +40,9 @@ license {
rust_library {
name: "libvmm_vhost",
- crate_name: "vmm_vhost",
+ defaults: ["crosvm_defaults"],
host_supported: true,
+ crate_name: "vmm_vhost",
srcs: ["src/lib.rs"],
edition: "2018",
features: [
@@ -56,5 +57,4 @@ rust_library {
"libsys_util",
"libtempfile",
],
- defaults: ["crosvm_defaults"],
}
diff --git a/Cargo.toml b/Cargo.toml
index 64bfb5b..94a7f45 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -21,10 +21,13 @@ vhost-user-slave = ["vhost-user"]
[dependencies]
bitflags = ">=1.0.1"
libc = ">=0.2.39"
-
-sys_util = { path = "../../../external/crosvm/sys_util" } # provided by ebuild
-tempfile = { path = "../../../external/crosvm/tempfile" } # provided by ebuild
+sys_util = "*"
+tempfile = "*"
vm-memory = { version = "0.2.0", optional = true }
[dev-dependencies]
vm-memory = { version = "0.2.0", features=["backend-mmap"] }
+
+[patch.crates-io]
+sys_util = { path = "../../../external/crosvm/sys_util" } # ignored by ebuild
+tempfile = { path = "../../../external/crosvm/tempfile" } # ignored by ebuild
diff --git a/METADATA b/METADATA
index 798081c..f939dcd 100644
--- a/METADATA
+++ b/METADATA
@@ -9,11 +9,11 @@ third_party {
type: GIT
value: "https://chromium.googlesource.com/chromiumos/third_party/rust-vmm/vhost"
}
- version: "eaca5d36a2701c99b354ab5bc0954a78dfc9ff4f"
+ version: "d65bd280d9f4e192a884f1761e4b097c11aae6de"
license_type: NOTICE
last_upgrade_date {
year: 2021
- month: 5
- day: 19
+ month: 9
+ day: 22
}
}
diff --git a/coverage_config_x86_64.json b/coverage_config_x86_64.json
index 2b2c164..c3e6939 100644
--- a/coverage_config_x86_64.json
+++ b/coverage_config_x86_64.json
@@ -1 +1 @@
-{"coverage_score": 81.2, "exclude_path": "src/vhost_kern/", "crate_features": "vhost-user-master,vhost-user-slave"}
+{"coverage_score": 82.3, "exclude_path": "src/vhost_kern/", "crate_features": "vhost-user-master,vhost-user-slave"}
diff --git a/rust-vmm-ci b/rust-vmm-ci
-Subproject 24d66cdae63d4aa7f8de01b616c015b97604a11
+Subproject d2ab3c090833aec72eee7da1e3884032206b00e
diff --git a/src/vhost_user/connection.rs b/src/vhost_user/connection.rs
index f92db45..ea8461a 100644
--- a/src/vhost_user/connection.rs
+++ b/src/vhost_user/connection.rs
@@ -5,9 +5,10 @@
#![allow(dead_code)]
+use std::fs::File;
use std::io::ErrorKind;
use std::marker::PhantomData;
-use std::os::unix::io::{AsRawFd, RawFd};
+use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
use std::os::unix::net::{UnixListener, UnixStream};
use std::path::{Path, PathBuf};
use std::{mem, slice};
@@ -301,7 +302,7 @@ impl<R: Req> Endpoint<R> {
}
/// Reads bytes from the socket into the given scatter/gather vectors with optional attached
- /// file descriptors.
+ /// file.
///
/// The underlying communication channel is a Unix domain socket in STREAM mode. It's a little
/// tricky to pass file descriptors through such a communication channel. Let's assume that a
@@ -311,29 +312,37 @@ impl<R: Req> Endpoint<R> {
/// 2) message(packet) boundaries must be respected on the receive side.
/// In other words, recvmsg() operations must not cross the packet boundary, otherwise the
/// attached file descriptors will get lost.
+ /// Note that this function wraps received file descriptors as `File`.
///
/// # Return:
- /// * - (number of bytes received, [received fds]) on success
+ /// * - (number of bytes received, [received files]) on success
/// * - SocketRetry: temporary error caused by signals or short of resources.
/// * - SocketBroken: the underline socket is broken.
/// * - SocketError: other socket related errors.
- pub fn recv_into_iovec(&mut self, iovs: &mut [iovec]) -> Result<(usize, Option<Vec<RawFd>>)> {
+ pub fn recv_into_iovec(&mut self, iovs: &mut [iovec]) -> Result<(usize, Option<Vec<File>>)> {
let mut fd_array = vec![0; MAX_ATTACHED_FD_ENTRIES];
let (bytes, fds) = self.sock.recv_iovecs_with_fds(iovs, &mut fd_array)?;
- let rfds = match fds {
+
+ let files = match fds {
0 => None,
n => {
- let mut fds = Vec::with_capacity(n);
- fds.extend_from_slice(&fd_array[0..n]);
- Some(fds)
+ let files = fd_array
+ .iter()
+ .take(n)
+ .map(|fd| {
+ // Safe because we have the ownership of `fd`.
+ unsafe { File::from_raw_fd(*fd) }
+ })
+ .collect();
+ Some(files)
}
};
- Ok((bytes, rfds))
+ Ok((bytes, files))
}
/// Reads all bytes from the socket into the given scatter/gather vectors with optional
- /// attached file descriptors. Will loop until all data has been transfered.
+ /// attached files. Will loop until all data has been transferred.
///
/// The underlying communication channel is a Unix domain socket in STREAM mode. It's a little
/// tricky to pass file descriptors through such a communication channel. Let's assume that a
@@ -343,6 +352,7 @@ impl<R: Req> Endpoint<R> {
/// 2) message(packet) boundaries must be respected on the receive side.
/// In other words, recvmsg() operations must not cross the packet boundary, otherwise the
/// attached file descriptors will get lost.
+ /// Note that this function wraps received file descriptors as `File`.
///
/// # Return:
/// * - (number of bytes received, [received fds]) on success
@@ -351,7 +361,7 @@ impl<R: Req> Endpoint<R> {
pub fn recv_into_iovec_all(
&mut self,
iovs: &mut [iovec],
- ) -> Result<(usize, Option<Vec<RawFd>>)> {
+ ) -> Result<(usize, Option<Vec<File>>)> {
let mut data_read = 0;
let mut data_total = 0;
let mut rfds = None;
@@ -392,46 +402,46 @@ impl<R: Req> Endpoint<R> {
}
/// Reads bytes from the socket into a new buffer with optional attached
- /// file descriptors. Received file descriptors are set close-on-exec.
+ /// files. Received file descriptors are set close-on-exec and converted to `File`.
///
/// # Return:
- /// * - (number of bytes received, buf, [received fds]) on success.
+ /// * - (number of bytes received, buf, [received files]) on success.
/// * - SocketRetry: temporary error caused by signals or short of resources.
/// * - SocketBroken: the underline socket is broken.
/// * - SocketError: other socket related errors.
pub fn recv_into_buf(
&mut self,
buf_size: usize,
- ) -> Result<(usize, Vec<u8>, Option<Vec<RawFd>>)> {
+ ) -> Result<(usize, Vec<u8>, Option<Vec<File>>)> {
let mut buf = vec![0u8; buf_size];
- let (bytes, rfds) = {
+ let (bytes, files) = {
let mut iovs = [iovec {
iov_base: buf.as_mut_ptr() as *mut c_void,
iov_len: buf_size,
}];
self.recv_into_iovec(&mut iovs)?
};
- Ok((bytes, buf, rfds))
+ Ok((bytes, buf, files))
}
- /// Receive a header-only message with optional attached file descriptors.
+ /// Receive a header-only message with optional attached files.
/// Note, only the first MAX_ATTACHED_FD_ENTRIES file descriptors will be
/// accepted and all other file descriptor will be discard silently.
///
/// # Return:
- /// * - (message header, [received fds]) on success.
+ /// * - (message header, [received files]) on success.
/// * - SocketRetry: temporary error caused by signals or short of resources.
/// * - SocketBroken: the underline socket is broken.
/// * - SocketError: other socket related errors.
/// * - PartialMessage: received a partial message.
/// * - InvalidMessage: received a invalid message.
- pub fn recv_header(&mut self) -> Result<(VhostUserMsgHeader<R>, Option<Vec<RawFd>>)> {
+ pub fn recv_header(&mut self) -> Result<(VhostUserMsgHeader<R>, Option<Vec<File>>)> {
let mut hdr = VhostUserMsgHeader::default();
let mut iovs = [iovec {
iov_base: (&mut hdr as *mut VhostUserMsgHeader<R>) as *mut c_void,
iov_len: mem::size_of::<VhostUserMsgHeader<R>>(),
}];
- let (bytes, rfds) = self.recv_into_iovec_all(&mut iovs[..])?;
+ let (bytes, files) = self.recv_into_iovec_all(&mut iovs[..])?;
if bytes != mem::size_of::<VhostUserMsgHeader<R>>() {
return Err(Error::PartialMessage);
@@ -439,7 +449,7 @@ impl<R: Req> Endpoint<R> {
return Err(Error::InvalidMessage);
}
- Ok((hdr, rfds))
+ Ok((hdr, files))
}
/// Receive a message with optional attached file descriptors.
@@ -447,7 +457,7 @@ impl<R: Req> Endpoint<R> {
/// accepted and all other file descriptor will be discard silently.
///
/// # Return:
- /// * - (message header, message body, [received fds]) on success.
+ /// * - (message header, message body, [received files]) on success.
/// * - SocketRetry: temporary error caused by signals or short of resources.
/// * - SocketBroken: the underline socket is broken.
/// * - SocketError: other socket related errors.
@@ -455,7 +465,7 @@ impl<R: Req> Endpoint<R> {
/// * - InvalidMessage: received a invalid message.
pub fn recv_body<T: Sized + Default + VhostUserMsgValidator>(
&mut self,
- ) -> Result<(VhostUserMsgHeader<R>, T, Option<Vec<RawFd>>)> {
+ ) -> Result<(VhostUserMsgHeader<R>, T, Option<Vec<File>>)> {
let mut hdr = VhostUserMsgHeader::default();
let mut body: T = Default::default();
let mut iovs = [
@@ -468,7 +478,7 @@ impl<R: Req> Endpoint<R> {
iov_len: mem::size_of::<T>(),
},
];
- let (bytes, rfds) = self.recv_into_iovec_all(&mut iovs[..])?;
+ let (bytes, files) = self.recv_into_iovec_all(&mut iovs[..])?;
let total = mem::size_of::<VhostUserMsgHeader<R>>() + mem::size_of::<T>();
if bytes != total {
@@ -477,7 +487,7 @@ impl<R: Req> Endpoint<R> {
return Err(Error::InvalidMessage);
}
- Ok((hdr, body, rfds))
+ Ok((hdr, body, files))
}
/// Receive a message with header and optional content. Callers need to
@@ -488,7 +498,7 @@ impl<R: Req> Endpoint<R> {
/// silently.
///
/// # Return:
- /// * - (message header, message size, [received fds]) on success.
+ /// * - (message header, message size, [received files]) on success.
/// * - SocketRetry: temporary error caused by signals or short of resources.
/// * - SocketBroken: the underline socket is broken.
/// * - SocketError: other socket related errors.
@@ -497,7 +507,7 @@ impl<R: Req> Endpoint<R> {
pub fn recv_body_into_buf(
&mut self,
buf: &mut [u8],
- ) -> Result<(VhostUserMsgHeader<R>, usize, Option<Vec<RawFd>>)> {
+ ) -> Result<(VhostUserMsgHeader<R>, usize, Option<Vec<File>>)> {
let mut hdr = VhostUserMsgHeader::default();
let mut iovs = [
iovec {
@@ -509,7 +519,7 @@ impl<R: Req> Endpoint<R> {
iov_len: buf.len(),
},
];
- let (bytes, rfds) = self.recv_into_iovec_all(&mut iovs[..])?;
+ let (bytes, files) = self.recv_into_iovec_all(&mut iovs[..])?;
if bytes < mem::size_of::<VhostUserMsgHeader<R>>() {
return Err(Error::PartialMessage);
@@ -517,7 +527,7 @@ impl<R: Req> Endpoint<R> {
return Err(Error::InvalidMessage);
}
- Ok((hdr, bytes - mem::size_of::<VhostUserMsgHeader<R>>(), rfds))
+ Ok((hdr, bytes - mem::size_of::<VhostUserMsgHeader<R>>(), files))
}
/// Receive a message with optional payload and attached file descriptors.
@@ -525,7 +535,7 @@ impl<R: Req> Endpoint<R> {
/// accepted and all other file descriptor will be discard silently.
///
/// # Return:
- /// * - (message header, message body, size of payload, [received fds]) on success.
+ /// * - (message header, message body, size of payload, [received files]) on success.
/// * - SocketRetry: temporary error caused by signals or short of resources.
/// * - SocketBroken: the underline socket is broken.
/// * - SocketError: other socket related errors.
@@ -535,7 +545,7 @@ impl<R: Req> Endpoint<R> {
pub fn recv_payload_into_buf<T: Sized + Default + VhostUserMsgValidator>(
&mut self,
buf: &mut [u8],
- ) -> Result<(VhostUserMsgHeader<R>, T, usize, Option<Vec<RawFd>>)> {
+ ) -> Result<(VhostUserMsgHeader<R>, T, usize, Option<Vec<File>>)> {
let mut hdr = VhostUserMsgHeader::default();
let mut body: T = Default::default();
let mut iovs = [
@@ -552,7 +562,7 @@ impl<R: Req> Endpoint<R> {
iov_len: buf.len(),
},
];
- let (bytes, rfds) = self.recv_into_iovec_all(&mut iovs[..])?;
+ let (bytes, files) = self.recv_into_iovec_all(&mut iovs[..])?;
let total = mem::size_of::<VhostUserMsgHeader<R>>() + mem::size_of::<T>();
if bytes < total {
@@ -561,17 +571,7 @@ impl<R: Req> Endpoint<R> {
return Err(Error::InvalidMessage);
}
- Ok((hdr, body, bytes - total, rfds))
- }
-
- /// Close all raw file descriptors.
- pub fn close_rfds(rfds: Option<Vec<RawFd>>) {
- if let Some(fds) = rfds {
- for fd in fds {
- // safe because the rawfds are valid and we don't care about the result.
- let _ = unsafe { libc::close(fd) };
- }
- }
+ Ok((hdr, body, bytes - total, files))
}
}
@@ -604,7 +604,6 @@ fn get_sub_iovs_offset(iov_lens: &[usize], skip_size: usize) -> (usize, usize) {
#[cfg(test)]
mod tests {
use super::*;
- use std::fs::File;
use std::io::{Read, Seek, SeekFrom, Write};
use std::os::unix::io::FromRawFd;
use tempfile::{tempfile, Builder, TempDir};
@@ -685,14 +684,14 @@ mod tests {
.unwrap();
assert_eq!(len, 4);
- let (bytes, buf2, rfds) = slave.recv_into_buf(4).unwrap();
+ let (bytes, buf2, files) = slave.recv_into_buf(4).unwrap();
assert_eq!(bytes, 4);
assert_eq!(&buf1[..], &buf2[..]);
- assert!(rfds.is_some());
- let fds = rfds.unwrap();
+ assert!(files.is_some());
+ let files = files.unwrap();
{
- assert_eq!(fds.len(), 1);
- let mut file = unsafe { File::from_raw_fd(fds[0]) };
+ assert_eq!(files.len(), 1);
+ let mut file = &files[0];
let mut content = String::new();
file.seek(SeekFrom::Start(0)).unwrap();
file.read_to_string(&mut content).unwrap();
@@ -710,23 +709,23 @@ mod tests {
.unwrap();
assert_eq!(len, 4);
- let (bytes, buf2, rfds) = slave.recv_into_buf(0x2).unwrap();
+ let (bytes, buf2, files) = slave.recv_into_buf(0x2).unwrap();
assert_eq!(bytes, 2);
assert_eq!(&buf1[..2], &buf2[..]);
- assert!(rfds.is_some());
- let fds = rfds.unwrap();
+ assert!(files.is_some());
+ let files = files.unwrap();
{
- assert_eq!(fds.len(), 3);
- let mut file = unsafe { File::from_raw_fd(fds[1]) };
+ assert_eq!(files.len(), 3);
+ let mut file = &files[1];
let mut content = String::new();
file.seek(SeekFrom::Start(0)).unwrap();
file.read_to_string(&mut content).unwrap();
assert_eq!(content, "test");
}
- let (bytes, buf2, rfds) = slave.recv_into_buf(0x2).unwrap();
+ let (bytes, buf2, files) = slave.recv_into_buf(0x2).unwrap();
assert_eq!(bytes, 2);
assert_eq!(&buf1[2..], &buf2[..]);
- assert!(rfds.is_none());
+ assert!(files.is_none());
// Following communication pattern should not work:
// Sending side: data(header, body) with fds
@@ -742,10 +741,10 @@ mod tests {
let (bytes, buf4) = slave.recv_data(2).unwrap();
assert_eq!(bytes, 2);
assert_eq!(&buf1[..2], &buf4[..]);
- let (bytes, buf2, rfds) = slave.recv_into_buf(0x2).unwrap();
+ let (bytes, buf2, files) = slave.recv_into_buf(0x2).unwrap();
assert_eq!(bytes, 2);
assert_eq!(&buf1[2..], &buf2[..]);
- assert!(rfds.is_none());
+ assert!(files.is_none());
// Following communication pattern should work:
// Sending side: data, data with fds
@@ -760,28 +759,28 @@ mod tests {
.unwrap();
assert_eq!(len, 4);
- let (bytes, buf2, rfds) = slave.recv_into_buf(0x4).unwrap();
+ let (bytes, buf2, files) = slave.recv_into_buf(0x4).unwrap();
assert_eq!(bytes, 4);
assert_eq!(&buf1[..], &buf2[..]);
- assert!(rfds.is_none());
+ assert!(files.is_none());
- let (bytes, buf2, rfds) = slave.recv_into_buf(0x2).unwrap();
+ let (bytes, buf2, files) = slave.recv_into_buf(0x2).unwrap();
assert_eq!(bytes, 2);
assert_eq!(&buf1[..2], &buf2[..]);
- assert!(rfds.is_some());
- let fds = rfds.unwrap();
+ assert!(files.is_some());
+ let files = files.unwrap();
{
- assert_eq!(fds.len(), 3);
- let mut file = unsafe { File::from_raw_fd(fds[1]) };
+ assert_eq!(files.len(), 3);
+ let mut file = &files[1];
let mut content = String::new();
file.seek(SeekFrom::Start(0)).unwrap();
file.read_to_string(&mut content).unwrap();
assert_eq!(content, "test");
}
- let (bytes, buf2, rfds) = slave.recv_into_buf(0x2).unwrap();
+ let (bytes, buf2, files) = slave.recv_into_buf(0x2).unwrap();
assert_eq!(bytes, 2);
assert_eq!(&buf1[2..], &buf2[..]);
- assert!(rfds.is_none());
+ assert!(files.is_none());
// Following communication pattern should not work:
// Sending side: data1, data2 with fds
@@ -799,9 +798,9 @@ mod tests {
let (bytes, _) = slave.recv_data(5).unwrap();
assert_eq!(bytes, 5);
- let (bytes, _, rfds) = slave.recv_into_buf(0x4).unwrap();
+ let (bytes, _, files) = slave.recv_into_buf(0x4).unwrap();
assert_eq!(bytes, 3);
- assert!(rfds.is_none());
+ assert!(files.is_none());
// If the target fd array is too small, extra file descriptors will get lost.
let len = master
@@ -812,12 +811,9 @@ mod tests {
.unwrap();
assert_eq!(len, 4);
- let (bytes, _, rfds) = slave.recv_into_buf(0x4).unwrap();
+ let (bytes, _, files) = slave.recv_into_buf(0x4).unwrap();
assert_eq!(bytes, 4);
- assert!(rfds.is_some());
-
- Endpoint::<MasterReq>::close_rfds(rfds);
- Endpoint::<MasterReq>::close_rfds(None);
+ assert!(files.is_some());
}
#[test]
@@ -844,15 +840,15 @@ mod tests {
mem::size_of::<u64>(),
)
};
- let (hdr2, bytes, rfds) = slave.recv_body_into_buf(slice).unwrap();
+ let (hdr2, bytes, files) = slave.recv_body_into_buf(slice).unwrap();
assert_eq!(hdr1, hdr2);
assert_eq!(bytes, 8);
assert_eq!(features1, features2);
- assert!(rfds.is_none());
+ assert!(files.is_none());
master.send_header(&hdr1, None).unwrap();
- let (hdr2, rfds) = slave.recv_header().unwrap();
+ let (hdr2, files) = slave.recv_header().unwrap();
assert_eq!(hdr1, hdr2);
- assert!(rfds.is_none());
+ assert!(files.is_none());
}
}
diff --git a/src/vhost_user/dummy_slave.rs b/src/vhost_user/dummy_slave.rs
index b2b83d2..cc9a9fb 100644
--- a/src/vhost_user/dummy_slave.rs
+++ b/src/vhost_user/dummy_slave.rs
@@ -1,7 +1,7 @@
// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
-use std::os::unix::io::RawFd;
+use std::fs::File;
use super::message::*;
use super::*;
@@ -20,11 +20,12 @@ pub struct DummySlaveReqHandler {
pub queue_num: usize,
pub vring_num: [u32; MAX_QUEUE_NUM],
pub vring_base: [u32; MAX_QUEUE_NUM],
- pub call_fd: [Option<RawFd>; MAX_QUEUE_NUM],
- pub kick_fd: [Option<RawFd>; MAX_QUEUE_NUM],
- pub err_fd: [Option<RawFd>; MAX_QUEUE_NUM],
+ pub call_fd: [Option<File>; MAX_QUEUE_NUM],
+ pub kick_fd: [Option<File>; MAX_QUEUE_NUM],
+ pub err_fd: [Option<File>; MAX_QUEUE_NUM],
pub vring_started: [bool; MAX_QUEUE_NUM],
pub vring_enabled: [bool; MAX_QUEUE_NUM],
+ pub inflight_file: Option<File>,
}
impl DummySlaveReqHandler {
@@ -83,7 +84,7 @@ impl VhostUserSlaveReqHandlerMut for DummySlaveReqHandler {
Ok(())
}
- fn set_mem_table(&mut self, _ctx: &[VhostUserMemoryRegion], _fds: &[RawFd]) -> Result<()> {
+ fn set_mem_table(&mut self, _ctx: &[VhostUserMemoryRegion], _files: Vec<File>) -> Result<()> {
Ok(())
}
@@ -134,14 +135,10 @@ impl VhostUserSlaveReqHandlerMut for DummySlaveReqHandler {
))
}
- fn set_vring_kick(&mut self, index: u8, fd: Option<RawFd>) -> Result<()> {
+ fn set_vring_kick(&mut self, index: u8, fd: Option<File>) -> Result<()> {
if index as usize >= self.queue_num || index as usize > self.queue_num {
return Err(Error::InvalidParam);
}
- if self.kick_fd[index as usize].is_some() {
- // Close file descriptor set by previous operations.
- let _ = unsafe { libc::close(self.kick_fd[index as usize].unwrap()) };
- }
self.kick_fd[index as usize] = fd;
// Quotation from vhost-user spec:
@@ -155,26 +152,18 @@ impl VhostUserSlaveReqHandlerMut for DummySlaveReqHandler {
Ok(())
}
- fn set_vring_call(&mut self, index: u8, fd: Option<RawFd>) -> Result<()> {
+ fn set_vring_call(&mut self, index: u8, fd: Option<File>) -> Result<()> {
if index as usize >= self.queue_num || index as usize > self.queue_num {
return Err(Error::InvalidParam);
}
- if self.call_fd[index as usize].is_some() {
- // Close file descriptor set by previous operations.
- let _ = unsafe { libc::close(self.call_fd[index as usize].unwrap()) };
- }
self.call_fd[index as usize] = fd;
Ok(())
}
- fn set_vring_err(&mut self, index: u8, fd: Option<RawFd>) -> Result<()> {
+ fn set_vring_err(&mut self, index: u8, fd: Option<File>) -> Result<()> {
if index as usize >= self.queue_num || index as usize > self.queue_num {
return Err(Error::InvalidParam);
}
- if self.err_fd[index as usize].is_some() {
- // Close file descriptor set by previous operations.
- let _ = unsafe { libc::close(self.err_fd[index as usize].unwrap()) };
- }
self.err_fd[index as usize] = fd;
Ok(())
}
@@ -245,11 +234,32 @@ impl VhostUserSlaveReqHandlerMut for DummySlaveReqHandler {
Ok(())
}
+ fn get_inflight_fd(
+ &mut self,
+ inflight: &VhostUserInflight,
+ ) -> Result<(VhostUserInflight, File)> {
+ let file = tempfile::tempfile().unwrap();
+ self.inflight_file = Some(file.try_clone().unwrap());
+ Ok((
+ VhostUserInflight {
+ mmap_size: 0x1000,
+ mmap_offset: 0,
+ num_queues: inflight.num_queues,
+ queue_size: inflight.queue_size,
+ },
+ file,
+ ))
+ }
+
+ fn set_inflight_fd(&mut self, _inflight: &VhostUserInflight, _file: File) -> Result<()> {
+ Ok(())
+ }
+
fn get_max_mem_slots(&mut self) -> Result<u64> {
Ok(MAX_MEM_SLOTS as u64)
}
- fn add_mem_region(&mut self, _region: &VhostUserSingleMemoryRegion, _fd: RawFd) -> Result<()> {
+ fn add_mem_region(&mut self, _region: &VhostUserSingleMemoryRegion, _fd: File) -> Result<()> {
Ok(())
}
diff --git a/src/vhost_user/master.rs b/src/vhost_user/master.rs
index 16f0e02..9a65fbe 100644
--- a/src/vhost_user/master.rs
+++ b/src/vhost_user/master.rs
@@ -3,6 +3,7 @@
//! Traits and Struct for vhost-user master.
+use std::fs::File;
use std::mem;
use std::os::unix::io::{AsRawFd, RawFd};
use std::os::unix::net::UnixStream;
@@ -13,7 +14,7 @@ use sys_util::EventFd;
use super::connection::Endpoint;
use super::message::*;
-use super::{Error as VhostUserError, Result as VhostUserResult};
+use super::{take_single_file, Error as VhostUserError, Result as VhostUserResult};
use crate::backend::{VhostBackend, VhostUserMemoryRegionInfo, VringConfigData};
use crate::{Error, Result};
@@ -49,7 +50,16 @@ pub trait VhostUserMaster: VhostBackend {
fn set_config(&mut self, offset: u32, flags: VhostUserConfigFlags, buf: &[u8]) -> Result<()>;
/// Setup slave communication channel.
- fn set_slave_request_fd(&mut self, fd: RawFd) -> Result<()>;
+ fn set_slave_request_fd(&mut self, fd: &dyn AsRawFd) -> Result<()>;
+
+ /// Retrieve shared buffer for inflight I/O tracking.
+ fn get_inflight_fd(
+ &mut self,
+ inflight: &VhostUserInflight,
+ ) -> Result<(VhostUserInflight, File)>;
+
+ /// Set shared buffer for inflight I/O tracking.
+ fn set_inflight_fd(&mut self, inflight: &VhostUserInflight, fd: RawFd) -> Result<()>;
/// Query the maximum amount of memory slots supported by the backend.
fn get_max_mem_slots(&mut self) -> Result<u64>;
@@ -84,6 +94,7 @@ impl Master {
protocol_features_ready: false,
max_queue_num,
error: None,
+ hdr_flags: VhostUserHeaderFlag::empty(),
})),
}
}
@@ -125,6 +136,12 @@ impl Master {
Ok(Self::new(endpoint, max_queue_num))
}
+
+ /// Set the header flags that should be applied to all following messages.
+ pub fn set_hdr_flags(&self, flags: VhostUserHeaderFlag) {
+ let mut node = self.node();
+ node.hdr_flags = flags;
+ }
}
impl VhostBackend for Master {
@@ -141,11 +158,9 @@ impl VhostBackend for Master {
fn set_features(&self, features: u64) -> Result<()> {
let mut node = self.node();
let val = VhostUserU64::new(features);
- let _ = node.send_request_with_body(MasterReq::SET_FEATURES, &val, None)?;
- // Don't wait for ACK here because the protocol feature negotiation process hasn't been
- // completed yet.
+ let hdr = node.send_request_with_body(MasterReq::SET_FEATURES, &val, None)?;
node.acked_virtio_features = features & node.virtio_features;
- Ok(())
+ node.wait_for_ack(&hdr).map_err(|e| e.into())
}
/// Set the current Master as an owner of the session.
@@ -153,18 +168,14 @@ impl VhostBackend for Master {
// We unwrap() the return value to assert that we are not expecting threads to ever fail
// while holding the lock.
let mut node = self.node();
- let _ = node.send_request_header(MasterReq::SET_OWNER, None)?;
- // Don't wait for ACK here because the protocol feature negotiation process hasn't been
- // completed yet.
- Ok(())
+ let hdr = node.send_request_header(MasterReq::SET_OWNER, None)?;
+ node.wait_for_ack(&hdr).map_err(|e| e.into())
}
fn reset_owner(&self) -> Result<()> {
let mut node = self.node();
- let _ = node.send_request_header(MasterReq::RESET_OWNER, None)?;
- // Don't wait for ACK here because the protocol feature negotiation process hasn't been
- // completed yet.
- Ok(())
+ let hdr = node.send_request_header(MasterReq::RESET_OWNER, None)?;
+ node.wait_for_ack(&hdr).map_err(|e| e.into())
}
/// Set the memory map regions on the slave so it can translate the vring
@@ -220,8 +231,8 @@ impl VhostBackend for Master {
fn set_log_fd(&self, fd: RawFd) -> Result<()> {
let mut node = self.node();
let fds = [fd];
- node.send_request_header(MasterReq::SET_LOG_FD, Some(&fds))?;
- Ok(())
+ let hdr = node.send_request_header(MasterReq::SET_LOG_FD, Some(&fds))?;
+ node.wait_for_ack(&hdr).map_err(|e| e.into())
}
/// Set the size of the queue.
@@ -283,8 +294,8 @@ impl VhostBackend for Master {
if queue_index as u64 >= node.max_queue_num {
return error_code(VhostUserError::InvalidParam);
}
- node.send_fd_for_vring(MasterReq::SET_VRING_CALL, queue_index, fd.as_raw_fd())?;
- Ok(())
+ let hdr = node.send_fd_for_vring(MasterReq::SET_VRING_CALL, queue_index, fd.as_raw_fd())?;
+ node.wait_for_ack(&hdr).map_err(|e| e.into())
}
/// Set the event file descriptor for adding buffers to the vring.
@@ -296,8 +307,8 @@ impl VhostBackend for Master {
if queue_index as u64 >= node.max_queue_num {
return error_code(VhostUserError::InvalidParam);
}
- node.send_fd_for_vring(MasterReq::SET_VRING_KICK, queue_index, fd.as_raw_fd())?;
- Ok(())
+ let hdr = node.send_fd_for_vring(MasterReq::SET_VRING_KICK, queue_index, fd.as_raw_fd())?;
+ node.wait_for_ack(&hdr).map_err(|e| e.into())
}
/// Set the event file descriptor to signal when error occurs.
@@ -308,8 +319,8 @@ impl VhostBackend for Master {
if queue_index as u64 >= node.max_queue_num {
return error_code(VhostUserError::InvalidParam);
}
- node.send_fd_for_vring(MasterReq::SET_VRING_ERR, queue_index, fd.as_raw_fd())?;
- Ok(())
+ let hdr = node.send_fd_for_vring(MasterReq::SET_VRING_ERR, queue_index, fd.as_raw_fd())?;
+ node.wait_for_ack(&hdr).map_err(|e| e.into())
}
}
@@ -317,7 +328,7 @@ impl VhostUserMaster for Master {
fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures> {
let mut node = self.node();
let flag = VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits();
- if node.virtio_features & flag == 0 || node.acked_virtio_features & flag == 0 {
+ if node.virtio_features & flag == 0 {
return error_code(VhostUserError::InvalidOperation);
}
let hdr = node.send_request_header(MasterReq::GET_PROTOCOL_FEATURES, None)?;
@@ -334,16 +345,16 @@ impl VhostUserMaster for Master {
fn set_protocol_features(&mut self, features: VhostUserProtocolFeatures) -> Result<()> {
let mut node = self.node();
let flag = VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits();
- if node.virtio_features & flag == 0 || node.acked_virtio_features & flag == 0 {
+ if node.virtio_features & flag == 0 {
return error_code(VhostUserError::InvalidOperation);
}
let val = VhostUserU64::new(features.bits());
- let _ = node.send_request_with_body(MasterReq::SET_PROTOCOL_FEATURES, &val, None)?;
+ let hdr = node.send_request_with_body(MasterReq::SET_PROTOCOL_FEATURES, &val, None)?;
// Don't wait for ACK here because the protocol feature negotiation process hasn't been
// completed yet.
node.acked_protocol_features = features.bits();
node.protocol_features_ready = true;
- Ok(())
+ node.wait_for_ack(&hdr).map_err(|e| e.into())
}
fn get_queue_num(&mut self) -> Result<u64> {
@@ -401,7 +412,6 @@ impl VhostUserMaster for Master {
let (body_reply, buf_reply, rfds) =
node.recv_reply_with_payload::<VhostUserConfig>(&hdr)?;
if rfds.is_some() {
- Endpoint::<MasterReq>::close_rfds(rfds);
return error_code(VhostUserError::InvalidMessage);
} else if body_reply.size == 0 {
return error_code(VhostUserError::SlaveInternalError);
@@ -434,15 +444,47 @@ impl VhostUserMaster for Master {
node.wait_for_ack(&hdr).map_err(|e| e.into())
}
- fn set_slave_request_fd(&mut self, fd: RawFd) -> Result<()> {
+ fn set_slave_request_fd(&mut self, fd: &dyn AsRawFd) -> Result<()> {
let mut node = self.node();
if node.acked_protocol_features & VhostUserProtocolFeatures::SLAVE_REQ.bits() == 0 {
return error_code(VhostUserError::InvalidOperation);
}
+ let fds = [fd.as_raw_fd()];
+ let hdr = node.send_request_header(MasterReq::SET_SLAVE_REQ_FD, Some(&fds))?;
+ node.wait_for_ack(&hdr).map_err(|e| e.into())
+ }
- let fds = [fd];
- node.send_request_header(MasterReq::SET_SLAVE_REQ_FD, Some(&fds))?;
- Ok(())
+ fn get_inflight_fd(
+ &mut self,
+ inflight: &VhostUserInflight,
+ ) -> Result<(VhostUserInflight, File)> {
+ let mut node = self.node();
+ if node.acked_protocol_features & VhostUserProtocolFeatures::INFLIGHT_SHMFD.bits() == 0 {
+ return error_code(VhostUserError::InvalidOperation);
+ }
+
+ let hdr = node.send_request_with_body(MasterReq::GET_INFLIGHT_FD, inflight, None)?;
+ let (inflight, files) = node.recv_reply_with_files::<VhostUserInflight>(&hdr)?;
+
+ match take_single_file(files) {
+ Some(file) => Ok((inflight, file)),
+ None => error_code(VhostUserError::IncorrectFds),
+ }
+ }
+
+ fn set_inflight_fd(&mut self, inflight: &VhostUserInflight, fd: RawFd) -> Result<()> {
+ let mut node = self.node();
+ if node.acked_protocol_features & VhostUserProtocolFeatures::INFLIGHT_SHMFD.bits() == 0 {
+ return error_code(VhostUserError::InvalidOperation);
+ }
+
+ if inflight.mmap_size == 0 || inflight.num_queues == 0 || inflight.queue_size == 0 || fd < 0
+ {
+ return error_code(VhostUserError::InvalidParam);
+ }
+
+ let hdr = node.send_request_with_body(MasterReq::SET_INFLIGHT_FD, inflight, Some(&[fd]))?;
+ node.wait_for_ack(&hdr).map_err(|e| e.into())
}
fn get_max_mem_slots(&mut self) -> Result<u64> {
@@ -546,6 +588,8 @@ struct MasterInternal {
max_queue_num: u64,
// Internal flag to mark failure state.
error: Option<i32>,
+ // List of header flags.
+ hdr_flags: VhostUserHeaderFlag,
}
impl MasterInternal {
@@ -555,7 +599,7 @@ impl MasterInternal {
fds: Option<&[RawFd]>,
) -> VhostUserResult<VhostUserMsgHeader<MasterReq>> {
self.check_state()?;
- let hdr = Self::new_request_header(code, 0);
+ let hdr = self.new_request_header(code, 0);
self.main_sock.send_header(&hdr, fds)?;
Ok(hdr)
}
@@ -571,7 +615,7 @@ impl MasterInternal {
}
self.check_state()?;
- let hdr = Self::new_request_header(code, mem::size_of::<T>() as u32);
+ let hdr = self.new_request_header(code, mem::size_of::<T>() as u32);
self.main_sock.send_message(&hdr, msg, fds)?;
Ok(hdr)
}
@@ -594,7 +638,7 @@ impl MasterInternal {
}
self.check_state()?;
- let hdr = Self::new_request_header(code, len as u32);
+ let hdr = self.new_request_header(code, len as u32);
self.main_sock
.send_message_with_payload(&hdr, msg, payload, fds)?;
Ok(hdr)
@@ -615,7 +659,7 @@ impl MasterInternal {
// This flag is set when there is no file descriptor in the ancillary data. This signals
// that polling will be used instead of waiting for the call.
let msg = VhostUserU64::new(queue_index as u64);
- let hdr = Self::new_request_header(code, mem::size_of::<VhostUserU64>() as u32);
+ let hdr = self.new_request_header(code, mem::size_of::<VhostUserU64>() as u32);
self.main_sock.send_message(&hdr, &msg, Some(&[fd]))?;
Ok(hdr)
}
@@ -631,16 +675,31 @@ impl MasterInternal {
let (reply, body, rfds) = self.main_sock.recv_body::<T>()?;
if !reply.is_reply_for(&hdr) || rfds.is_some() || !body.is_valid() {
- Endpoint::<MasterReq>::close_rfds(rfds);
return Err(VhostUserError::InvalidMessage);
}
Ok(body)
}
+ fn recv_reply_with_files<T: Sized + Default + VhostUserMsgValidator>(
+ &mut self,
+ hdr: &VhostUserMsgHeader<MasterReq>,
+ ) -> VhostUserResult<(T, Option<Vec<File>>)> {
+ if mem::size_of::<T>() > MAX_MSG_SIZE || hdr.is_reply() {
+ return Err(VhostUserError::InvalidParam);
+ }
+ self.check_state()?;
+
+ let (reply, body, files) = self.main_sock.recv_body::<T>()?;
+ if !reply.is_reply_for(&hdr) || files.is_none() || !body.is_valid() {
+ return Err(VhostUserError::InvalidMessage);
+ }
+ Ok((body, files))
+ }
+
fn recv_reply_with_payload<T: Sized + Default + VhostUserMsgValidator>(
&mut self,
hdr: &VhostUserMsgHeader<MasterReq>,
- ) -> VhostUserResult<(T, Vec<u8>, Option<Vec<RawFd>>)> {
+ ) -> VhostUserResult<(T, Vec<u8>, Option<Vec<File>>)> {
if mem::size_of::<T>() > MAX_MSG_SIZE
|| hdr.get_size() as usize <= mem::size_of::<T>()
|| hdr.get_size() as usize > MAX_MSG_SIZE
@@ -651,18 +710,17 @@ impl MasterInternal {
self.check_state()?;
let mut buf: Vec<u8> = vec![0; hdr.get_size() as usize - mem::size_of::<T>()];
- let (reply, body, bytes, rfds) = self.main_sock.recv_payload_into_buf::<T>(&mut buf)?;
+ let (reply, body, bytes, files) = self.main_sock.recv_payload_into_buf::<T>(&mut buf)?;
if !reply.is_reply_for(hdr)
|| reply.get_size() as usize != mem::size_of::<T>() + bytes
- || rfds.is_some()
+ || files.is_some()
|| !body.is_valid()
+ || bytes != buf.len()
{
- Endpoint::<MasterReq>::close_rfds(rfds);
- return Err(VhostUserError::InvalidMessage);
- } else if bytes != buf.len() {
return Err(VhostUserError::InvalidMessage);
}
- Ok((body, buf, rfds))
+
+ Ok((body, buf, files))
}
fn wait_for_ack(&mut self, hdr: &VhostUserMsgHeader<MasterReq>) -> VhostUserResult<()> {
@@ -675,7 +733,6 @@ impl MasterInternal {
let (reply, body, rfds) = self.main_sock.recv_body::<VhostUserU64>()?;
if !reply.is_reply_for(&hdr) || rfds.is_some() || !body.is_valid() {
- Endpoint::<MasterReq>::close_rfds(rfds);
return Err(VhostUserError::InvalidMessage);
}
if body.value != 0 {
@@ -698,9 +755,8 @@ impl MasterInternal {
}
#[inline]
- fn new_request_header(request: MasterReq, size: u32) -> VhostUserMsgHeader<MasterReq> {
- // TODO: handle NEED_REPLY flag
- VhostUserMsgHeader::new(request, 0x1, size)
+ fn new_request_header(&self, request: MasterReq, size: u32) -> VhostUserMsgHeader<MasterReq> {
+ VhostUserMsgHeader::new(request, self.hdr_flags.bits() | 0x1, size)
}
}
diff --git a/src/vhost_user/master_req_handler.rs b/src/vhost_user/master_req_handler.rs
index 8cba188..0ecda4e 100644
--- a/src/vhost_user/master_req_handler.rs
+++ b/src/vhost_user/master_req_handler.rs
@@ -1,6 +1,7 @@
// Copyright (C) 2019-2021 Alibaba Cloud. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
+use std::fs::File;
use std::mem;
use std::os::unix::io::{AsRawFd, RawFd};
use std::os::unix::net::UnixStream;
@@ -33,9 +34,7 @@ pub trait VhostUserMasterReqHandler {
}
/// Handle virtio-fs map file requests.
- fn fs_slave_map(&self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> {
- // Safe because we have just received the rawfd from kernel.
- unsafe { libc::close(fd) };
+ fn fs_slave_map(&self, _fs: &VhostUserFSSlaveMsg, _fd: &dyn AsRawFd) -> HandlerResult<u64> {
Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
}
@@ -50,14 +49,12 @@ pub trait VhostUserMasterReqHandler {
}
/// Handle virtio-fs file IO requests.
- fn fs_slave_io(&self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> {
- // Safe because we have just received the rawfd from kernel.
- unsafe { libc::close(fd) };
+ fn fs_slave_io(&self, _fs: &VhostUserFSSlaveMsg, _fd: &dyn AsRawFd) -> HandlerResult<u64> {
Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
}
// fn handle_iotlb_msg(&mut self, iotlb: VhostUserIotlb);
- // fn handle_vring_host_notifier(&mut self, area: VhostUserVringArea, fd: RawFd);
+ // fn handle_vring_host_notifier(&mut self, area: VhostUserVringArea, fd: &dyn AsRawFd);
}
/// A helper trait mirroring [VhostUserMasterReqHandler] but without interior mutability.
@@ -70,9 +67,7 @@ pub trait VhostUserMasterReqHandlerMut {
}
/// Handle virtio-fs map file requests.
- fn fs_slave_map(&mut self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> {
- // Safe because we have just received the rawfd from kernel.
- unsafe { libc::close(fd) };
+ fn fs_slave_map(&mut self, _fs: &VhostUserFSSlaveMsg, _fd: &dyn AsRawFd) -> HandlerResult<u64> {
Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
}
@@ -87,9 +82,7 @@ pub trait VhostUserMasterReqHandlerMut {
}
/// Handle virtio-fs file IO requests.
- fn fs_slave_io(&mut self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> {
- // Safe because we have just received the rawfd from kernel.
- unsafe { libc::close(fd) };
+ fn fs_slave_io(&mut self, _fs: &VhostUserFSSlaveMsg, _fd: &dyn AsRawFd) -> HandlerResult<u64> {
Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
}
@@ -102,7 +95,7 @@ impl<S: VhostUserMasterReqHandlerMut> VhostUserMasterReqHandler for Mutex<S> {
self.lock().unwrap().handle_config_change()
}
- fn fs_slave_map(&self, fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> {
+ fn fs_slave_map(&self, fs: &VhostUserFSSlaveMsg, fd: &dyn AsRawFd) -> HandlerResult<u64> {
self.lock().unwrap().fs_slave_map(fs, fd)
}
@@ -114,7 +107,7 @@ impl<S: VhostUserMasterReqHandlerMut> VhostUserMasterReqHandler for Mutex<S> {
self.lock().unwrap().fs_slave_sync(fs)
}
- fn fs_slave_io(&self, fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> {
+ fn fs_slave_io(&self, fs: &VhostUserFSSlaveMsg, fd: &dyn AsRawFd) -> HandlerResult<u64> {
self.lock().unwrap().fs_slave_io(fs, fd)
}
}
@@ -206,8 +199,8 @@ impl<S: VhostUserMasterReqHandler> MasterReqHandler<S> {
// . recv optional message body and payload according size field in
// message header
// . validate message body and optional payload
- let (hdr, rfds) = self.sub_sock.recv_header()?;
- let rfds = self.check_attached_rfds(&hdr, rfds)?;
+ let (hdr, files) = self.sub_sock.recv_header()?;
+ self.check_attached_files(&hdr, &files)?;
let (size, buf) = match hdr.get_size() {
0 => (0, vec![0u8; 0]),
len => {
@@ -231,9 +224,9 @@ impl<S: VhostUserMasterReqHandler> MasterReqHandler<S> {
}
SlaveReq::FS_MAP => {
let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?;
- // check_attached_rfds() has validated rfds
+ // check_attached_files() has validated files
self.backend
- .fs_slave_map(&msg, rfds.unwrap()[0])
+ .fs_slave_map(&msg, &files.unwrap()[0])
.map_err(Error::ReqHandlerError)
}
SlaveReq::FS_UNMAP => {
@@ -250,9 +243,9 @@ impl<S: VhostUserMasterReqHandler> MasterReqHandler<S> {
}
SlaveReq::FS_IO => {
let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?;
- // check_attached_rfds() has validated rfds
+ // check_attached_files() has validated files
self.backend
- .fs_slave_io(&msg, rfds.unwrap()[0])
+ .fs_slave_io(&msg, &files.unwrap()[0])
.map_err(Error::ReqHandlerError)
}
_ => Err(Error::InvalidMessage),
@@ -286,34 +279,21 @@ impl<S: VhostUserMasterReqHandler> MasterReqHandler<S> {
Ok(())
}
- fn check_attached_rfds(
+ fn check_attached_files(
&self,
hdr: &VhostUserMsgHeader<SlaveReq>,
- rfds: Option<Vec<RawFd>>,
- ) -> Result<Option<Vec<RawFd>>> {
+ files: &Option<Vec<File>>,
+ ) -> Result<()> {
match hdr.get_code() {
SlaveReq::FS_MAP | SlaveReq::FS_IO => {
- // Expect an fd set with a single fd.
- match rfds {
- None => Err(Error::InvalidMessage),
- Some(fds) => {
- if fds.len() != 1 {
- Endpoint::<SlaveReq>::close_rfds(Some(fds));
- Err(Error::InvalidMessage)
- } else {
- Ok(Some(fds))
- }
- }
- }
- }
- _ => {
- if rfds.is_some() {
- Endpoint::<SlaveReq>::close_rfds(rfds);
- Err(Error::InvalidMessage)
- } else {
- Ok(rfds)
+ // Expect a single file is passed.
+ match files {
+ Some(files) if files.len() == 1 => Ok(()),
+ _ => Err(Error::InvalidMessage),
}
}
+ _ if files.is_some() => Err(Error::InvalidMessage),
+ _ => Ok(()),
}
}
@@ -390,9 +370,11 @@ mod tests {
impl VhostUserMasterReqHandlerMut for MockMasterReqHandler {
/// Handle virtio-fs map file requests from the slave.
- fn fs_slave_map(&mut self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> {
- // Safe because we have just received the rawfd from kernel.
- unsafe { libc::close(fd) };
+ fn fs_slave_map(
+ &mut self,
+ _fs: &VhostUserFSSlaveMsg,
+ _fd: &dyn AsRawFd,
+ ) -> HandlerResult<u64> {
Ok(0)
}
@@ -437,7 +419,7 @@ mod tests {
});
fs_cache
- .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd)
+ .fs_slave_map(&VhostUserFSSlaveMsg::default(), &fd)
.unwrap();
// When REPLY_ACK has not been negotiated, the master has no way to detect failure from
// slave side.
@@ -468,7 +450,7 @@ mod tests {
fs_cache.set_reply_ack_flag(true);
fs_cache
- .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd)
+ .fs_slave_map(&VhostUserFSSlaveMsg::default(), &fd)
.unwrap();
fs_cache
.fs_slave_unmap(&VhostUserFSSlaveMsg::default())
diff --git a/src/vhost_user/message.rs b/src/vhost_user/message.rs
index 32b2f8c..fc33e1b 100644
--- a/src/vhost_user/message.rs
+++ b/src/vhost_user/message.rs
@@ -7,6 +7,7 @@
#![allow(dead_code)]
#![allow(non_camel_case_types)]
+#![allow(clippy::upper_case_acronyms)]
use std::fmt::Debug;
use std::marker::PhantomData;
@@ -140,9 +141,9 @@ pub enum MasterReq {
MAX_CMD = 41,
}
-impl Into<u32> for MasterReq {
- fn into(self) -> u32 {
- self as u32
+impl From<MasterReq> for u32 {
+ fn from(req: MasterReq) -> u32 {
+ req as u32
}
}
@@ -180,9 +181,9 @@ pub enum SlaveReq {
MAX_CMD = 10,
}
-impl Into<u32> for SlaveReq {
- fn into(self) -> u32 {
- self as u32
+impl From<SlaveReq> for u32 {
+ fn from(req: SlaveReq) -> u32 {
+ req as u32
}
}
@@ -222,9 +223,8 @@ bitflags! {
/// Common message header for vhost-user requests and replies.
/// A vhost-user message consists of 3 header fields and an optional payload. All numbers are in the
/// machine native byte order.
-#[allow(safe_packed_borrows)]
#[repr(packed)]
-#[derive(Debug, Clone, Copy, PartialEq)]
+#[derive(Copy)]
pub(super) struct VhostUserMsgHeader<R: Req> {
request: u32,
flags: u32,
@@ -232,6 +232,28 @@ pub(super) struct VhostUserMsgHeader<R: Req> {
_r: PhantomData<R>,
}
+impl<R: Req> Debug for VhostUserMsgHeader<R> {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ f.debug_struct("Point")
+ .field("request", &{ self.request })
+ .field("flags", &{ self.flags })
+ .field("size", &{ self.size })
+ .finish()
+ }
+}
+
+impl<R: Req> Clone for VhostUserMsgHeader<R> {
+ fn clone(&self) -> VhostUserMsgHeader<R> {
+ *self
+ }
+}
+
+impl<R: Req> PartialEq for VhostUserMsgHeader<R> {
+ fn eq(&self, other: &Self) -> bool {
+ self.request == other.request && self.flags == other.flags && self.size == other.size
+ }
+}
+
impl<R: Req> VhostUserMsgHeader<R> {
/// Create a new instance of `VhostUserMsgHeader`.
pub fn new(request: R, flags: u32, size: u32) -> Self {
@@ -248,7 +270,7 @@ impl<R: Req> VhostUserMsgHeader<R> {
/// Get message type.
pub fn get_code(&self) -> R {
// It's safe because R is marked as repr(u32).
- unsafe { std::mem::transmute_copy::<u32, R>(&self.request) }
+ unsafe { std::mem::transmute_copy::<u32, R>(&{ self.request }) }
}
/// Set message type.
@@ -673,6 +695,42 @@ impl VhostUserMsgValidator for VhostUserConfig {
/// Payload for the VhostUserConfig message.
pub type VhostUserConfigPayload = Vec<u8>;
+/// Single memory region descriptor as payload for ADD_MEM_REG and REM_MEM_REG
+/// requests.
+#[repr(C)]
+#[derive(Default, Clone)]
+pub struct VhostUserInflight {
+ /// Size of the area to track inflight I/O.
+ pub mmap_size: u64,
+ /// Offset of this area from the start of the supplied file descriptor.
+ pub mmap_offset: u64,
+ /// Number of virtqueues.
+ pub num_queues: u16,
+ /// Size of virtqueues.
+ pub queue_size: u16,
+}
+
+impl VhostUserInflight {
+ /// Create a new instance.
+ pub fn new(mmap_size: u64, mmap_offset: u64, num_queues: u16, queue_size: u16) -> Self {
+ VhostUserInflight {
+ mmap_size,
+ mmap_offset,
+ num_queues,
+ queue_size,
+ }
+ }
+}
+
+impl VhostUserMsgValidator for VhostUserInflight {
+ fn is_valid(&self) -> bool {
+ if self.num_queues == 0 || self.queue_size == 0 {
+ return false;
+ }
+ true
+ }
+}
+
/*
* TODO: support dirty log, live migration and IOTLB operations.
#[repr(packed)]
@@ -744,6 +802,137 @@ impl VhostUserMsgValidator for VhostUserFSSlaveMsg {
}
}
+/// Inflight I/O descriptor state for split virtqueues
+#[repr(packed)]
+#[derive(Clone, Copy, Default)]
+pub struct DescStateSplit {
+ /// Indicate whether this descriptor (only head) is inflight or not.
+ pub inflight: u8,
+ /// Padding
+ padding: [u8; 5],
+ /// List of last batch of used descriptors, only when batching is used for submitting
+ pub next: u16,
+ /// Preserve order of fetching available descriptors, only for head descriptor
+ pub counter: u64,
+}
+
+impl DescStateSplit {
+ /// New instance of DescStateSplit struct
+ pub fn new() -> Self {
+ Self::default()
+ }
+}
+
+/// Inflight I/O queue region for split virtqueues
+#[repr(packed)]
+pub struct QueueRegionSplit {
+ /// Features flags of this region
+ pub features: u64,
+ /// Version of this region
+ pub version: u16,
+ /// Number of DescStateSplit entries
+ pub desc_num: u16,
+ /// List to track last batch of used descriptors
+ pub last_batch_head: u16,
+ /// Idx value of used ring
+ pub used_idx: u16,
+ /// Pointer to an array of DescStateSplit entries
+ pub desc: u64,
+}
+
+impl QueueRegionSplit {
+ /// New instance of QueueRegionSplit struct
+ pub fn new(features: u64, queue_size: u16) -> Self {
+ QueueRegionSplit {
+ features,
+ version: 1,
+ desc_num: queue_size,
+ last_batch_head: 0,
+ used_idx: 0,
+ desc: 0,
+ }
+ }
+}
+
+/// Inflight I/O descriptor state for packed virtqueues
+#[repr(packed)]
+#[derive(Clone, Copy, Default)]
+pub struct DescStatePacked {
+ /// Indicate whether this descriptor (only head) is inflight or not.
+ pub inflight: u8,
+ /// Padding
+ padding: u8,
+ /// Link to next free entry
+ pub next: u16,
+ /// Link to last entry of descriptor list, only for head
+ pub last: u16,
+ /// Length of descriptor list, only for head
+ pub num: u16,
+ /// Preserve order of fetching avail descriptors, only for head
+ pub counter: u64,
+ /// Buffer ID
+ pub id: u16,
+ /// Descriptor flags
+ pub flags: u16,
+ /// Buffer length
+ pub len: u32,
+ /// Buffer address
+ pub addr: u64,
+}
+
+impl DescStatePacked {
+ /// New instance of DescStatePacked struct
+ pub fn new() -> Self {
+ Self::default()
+ }
+}
+
+/// Inflight I/O queue region for packed virtqueues
+#[repr(packed)]
+pub struct QueueRegionPacked {
+ /// Features flags of this region
+ pub features: u64,
+ /// version of this region
+ pub version: u16,
+ /// size of descriptor state array
+ pub desc_num: u16,
+ /// head of free DescStatePacked entry list
+ pub free_head: u16,
+ /// old head of free DescStatePacked entry list
+ pub old_free_head: u16,
+ /// used idx of descriptor ring
+ pub used_idx: u16,
+ /// old used idx of descriptor ring
+ pub old_used_idx: u16,
+ /// device ring wrap counter
+ pub used_wrap_counter: u8,
+ /// old device ring wrap counter
+ pub old_used_wrap_counter: u8,
+ /// Padding
+ padding: [u8; 7],
+ /// Pointer to array tracking state of each descriptor from descriptor ring
+ pub desc: u64,
+}
+
+impl QueueRegionPacked {
+ /// New instance of QueueRegionPacked struct
+ pub fn new(features: u64, queue_size: u16) -> Self {
+ QueueRegionPacked {
+ features,
+ version: 1,
+ desc_num: queue_size,
+ free_head: 0,
+ old_free_head: 0,
+ used_idx: 0,
+ old_used_idx: 0,
+ used_wrap_counter: 0,
+ old_used_wrap_counter: 0,
+ padding: [0; 7],
+ desc: 0,
+ }
+ }
+}
+
#[cfg(test)]
mod tests {
use super::*;
@@ -824,7 +1013,10 @@ mod tests {
hdr.set_version(0x1);
assert!(hdr.is_valid());
+ // Test Debug, Clone, PartiaEq trait
assert_eq!(hdr, hdr.clone());
+ assert_eq!(hdr.clone().get_code(), hdr.get_code());
+ assert_eq!(format!("{:?}", hdr.clone()), format!("{:?}", hdr));
}
#[test]
diff --git a/src/vhost_user/mod.rs b/src/vhost_user/mod.rs
index 9ef6453..5d8ce31 100644
--- a/src/vhost_user/mod.rs
+++ b/src/vhost_user/mod.rs
@@ -18,6 +18,7 @@
//! Most messages that can be sent via the Unix domain socket implementing vhost-user have an
//! equivalent ioctl to the kernel implementation.
+use std::fs::File;
use std::io::Error as IOError;
pub mod message;
@@ -175,6 +176,16 @@ pub type Result<T> = std::result::Result<T, Error>;
/// Result of request handler.
pub type HandlerResult<T> = std::result::Result<T, IOError>;
+/// Utility function to take the first element from option of a vector of files.
+/// Returns `None` if the vector contains no file or more than one file.
+pub(crate) fn take_single_file(files: Option<Vec<File>>) -> Option<File> {
+ let mut files = files?;
+ if files.len() != 1 {
+ return None;
+ }
+ Some(files.swap_remove(0))
+}
+
#[cfg(all(test, feature = "vhost-user-slave"))]
mod dummy_slave;
@@ -308,6 +319,11 @@ mod tests {
VhostUserProtocolFeatures::all().bits()
);
+ // get_inflight_fd()
+ slave.handle_request().unwrap();
+ // set_inflight_fd()
+ slave.handle_request().unwrap();
+
// get_queue_num()
slave.handle_request().unwrap();
@@ -360,6 +376,19 @@ mod tests {
assert_eq!(features.bits(), VhostUserProtocolFeatures::all().bits());
master.set_protocol_features(features).unwrap();
+ // Retrieve inflight I/O tracking information
+ let (inflight_info, inflight_file) = master
+ .get_inflight_fd(&VhostUserInflight {
+ num_queues: 2,
+ queue_size: 256,
+ ..Default::default()
+ })
+ .unwrap();
+ // Set the buffer back to the backend
+ master
+ .set_inflight_fd(&inflight_info, inflight_file.as_raw_fd())
+ .unwrap();
+
let num = master.get_queue_num().unwrap();
assert_eq!(num, 2);
@@ -384,7 +413,7 @@ mod tests {
assert_eq!(offset, 0x100);
assert_eq!(reply_payload[0], 0xa5);
- master.set_slave_request_fd(eventfd.as_raw_fd()).unwrap();
+ master.set_slave_request_fd(&eventfd).unwrap();
master.set_vring_enable(0, true).unwrap();
// unimplemented yet
diff --git a/src/vhost_user/slave_fs_cache.rs b/src/vhost_user/slave_fs_cache.rs
index a9c4ed2..ee5fd9b 100644
--- a/src/vhost_user/slave_fs_cache.rs
+++ b/src/vhost_user/slave_fs_cache.rs
@@ -3,7 +3,7 @@
use std::io;
use std::mem;
-use std::os::unix::io::RawFd;
+use std::os::unix::io::{AsRawFd, RawFd};
use std::os::unix::net::UnixStream;
use std::sync::{Arc, Mutex, MutexGuard};
@@ -55,7 +55,6 @@ impl SlaveFsCacheReqInternal {
let (reply, body, rfds) = self.sock.recv_body::<VhostUserU64>()?;
if !reply.is_reply_for(&hdr) || rfds.is_some() || !body.is_valid() {
- Endpoint::<SlaveReq>::close_rfds(rfds);
return Err(Error::InvalidMessage);
}
if body.value != 0 {
@@ -129,8 +128,8 @@ impl SlaveFsCacheReq {
impl VhostUserMasterReqHandler for SlaveFsCacheReq {
/// Forward vhost-user-fs map file requests to the slave.
- fn fs_slave_map(&self, fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> {
- self.send_message(SlaveReq::FS_MAP, fs, Some(&[fd]))
+ fn fs_slave_map(&self, fs: &VhostUserFSSlaveMsg, fd: &dyn AsRawFd) -> HandlerResult<u64> {
+ self.send_message(SlaveReq::FS_MAP, fs, Some(&[fd.as_raw_fd()]))
}
/// Forward vhost-user-fs unmap file requests to the master.
@@ -158,31 +157,21 @@ mod tests {
#[test]
fn test_slave_fs_cache_send_failure() {
let (p1, p2) = UnixStream::pair().unwrap();
- let fd = p2.as_raw_fd();
let fs_cache = SlaveFsCacheReq::from_stream(p1);
fs_cache.set_failed(libc::ECONNRESET);
fs_cache
- .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd)
+ .fs_slave_map(&VhostUserFSSlaveMsg::default(), &p2)
.unwrap_err();
fs_cache
.fs_slave_unmap(&VhostUserFSSlaveMsg::default())
.unwrap_err();
fs_cache.node().error = None;
-
- drop(p2);
- fs_cache
- .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd)
- .unwrap_err();
- fs_cache
- .fs_slave_unmap(&VhostUserFSSlaveMsg::default())
- .unwrap_err();
}
#[test]
fn test_slave_fs_cache_recv_negative() {
let (p1, p2) = UnixStream::pair().unwrap();
- let fd = p2.as_raw_fd();
let fs_cache = SlaveFsCacheReq::from_stream(p1);
let mut master = Endpoint::<SlaveReq>::from_stream(p2);
@@ -194,33 +183,35 @@ mod tests {
);
let body = VhostUserU64::new(0);
- master.send_message(&hdr, &body, Some(&[fd])).unwrap();
+ master
+ .send_message(&hdr, &body, Some(&[master.as_raw_fd()]))
+ .unwrap();
fs_cache
- .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd)
+ .fs_slave_map(&VhostUserFSSlaveMsg::default(), &master)
.unwrap();
fs_cache.set_reply_ack_flag(true);
fs_cache
- .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd)
+ .fs_slave_map(&VhostUserFSSlaveMsg::default(), &master)
.unwrap_err();
hdr.set_code(SlaveReq::FS_UNMAP);
master.send_message(&hdr, &body, None).unwrap();
fs_cache
- .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd)
+ .fs_slave_map(&VhostUserFSSlaveMsg::default(), &master)
.unwrap_err();
hdr.set_code(SlaveReq::FS_MAP);
let body = VhostUserU64::new(1);
master.send_message(&hdr, &body, None).unwrap();
fs_cache
- .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd)
+ .fs_slave_map(&VhostUserFSSlaveMsg::default(), &master)
.unwrap_err();
let body = VhostUserU64::new(0);
master.send_message(&hdr, &body, None).unwrap();
fs_cache
- .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd)
+ .fs_slave_map(&VhostUserFSSlaveMsg::default(), &master)
.unwrap();
}
}
diff --git a/src/vhost_user/slave_req_handler.rs b/src/vhost_user/slave_req_handler.rs
index 9d7ea10..402030c 100644
--- a/src/vhost_user/slave_req_handler.rs
+++ b/src/vhost_user/slave_req_handler.rs
@@ -1,16 +1,16 @@
// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
+use std::fs::File;
use std::mem;
-use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
+use std::os::unix::io::{AsRawFd, RawFd};
use std::os::unix::net::UnixStream;
use std::slice;
use std::sync::{Arc, Mutex};
use super::connection::Endpoint;
use super::message::*;
-use super::slave_fs_cache::SlaveFsCacheReq;
-use super::{Error, Result};
+use super::{take_single_file, Error, Result};
/// Services provided to the master by the slave with interior mutability.
///
@@ -38,7 +38,7 @@ pub trait VhostUserSlaveReqHandler {
fn reset_owner(&self) -> Result<()>;
fn get_features(&self) -> Result<u64>;
fn set_features(&self, features: u64) -> Result<()>;
- fn set_mem_table(&self, ctx: &[VhostUserMemoryRegion], fds: &[RawFd]) -> Result<()>;
+ fn set_mem_table(&self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()>;
fn set_vring_num(&self, index: u32, num: u32) -> Result<()>;
fn set_vring_addr(
&self,
@@ -51,9 +51,9 @@ pub trait VhostUserSlaveReqHandler {
) -> Result<()>;
fn set_vring_base(&self, index: u32, base: u32) -> Result<()>;
fn get_vring_base(&self, index: u32) -> Result<VhostUserVringState>;
- fn set_vring_kick(&self, index: u8, fd: Option<RawFd>) -> Result<()>;
- fn set_vring_call(&self, index: u8, fd: Option<RawFd>) -> Result<()>;
- fn set_vring_err(&self, index: u8, fd: Option<RawFd>) -> Result<()>;
+ fn set_vring_kick(&self, index: u8, fd: Option<File>) -> Result<()>;
+ fn set_vring_call(&self, index: u8, fd: Option<File>) -> Result<()>;
+ fn set_vring_err(&self, index: u8, fd: Option<File>) -> Result<()>;
fn get_protocol_features(&self) -> Result<VhostUserProtocolFeatures>;
fn set_protocol_features(&self, features: u64) -> Result<()>;
@@ -61,9 +61,11 @@ pub trait VhostUserSlaveReqHandler {
fn set_vring_enable(&self, index: u32, enable: bool) -> Result<()>;
fn get_config(&self, offset: u32, size: u32, flags: VhostUserConfigFlags) -> Result<Vec<u8>>;
fn set_config(&self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>;
- fn set_slave_req_fd(&self, _vu_req: SlaveFsCacheReq) {}
+ fn set_slave_req_fd(&self, _vu_req: File) {}
+ fn get_inflight_fd(&self, inflight: &VhostUserInflight) -> Result<(VhostUserInflight, File)>;
+ fn set_inflight_fd(&self, inflight: &VhostUserInflight, file: File) -> Result<()>;
fn get_max_mem_slots(&self) -> Result<u64>;
- fn add_mem_region(&self, region: &VhostUserSingleMemoryRegion, fd: RawFd) -> Result<()>;
+ fn add_mem_region(&self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()>;
fn remove_mem_region(&self, region: &VhostUserSingleMemoryRegion) -> Result<()>;
}
@@ -76,7 +78,7 @@ pub trait VhostUserSlaveReqHandlerMut {
fn reset_owner(&mut self) -> Result<()>;
fn get_features(&mut self) -> Result<u64>;
fn set_features(&mut self, features: u64) -> Result<()>;
- fn set_mem_table(&mut self, ctx: &[VhostUserMemoryRegion], fds: &[RawFd]) -> Result<()>;
+ fn set_mem_table(&mut self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()>;
fn set_vring_num(&mut self, index: u32, num: u32) -> Result<()>;
fn set_vring_addr(
&mut self,
@@ -89,9 +91,9 @@ pub trait VhostUserSlaveReqHandlerMut {
) -> Result<()>;
fn set_vring_base(&mut self, index: u32, base: u32) -> Result<()>;
fn get_vring_base(&mut self, index: u32) -> Result<VhostUserVringState>;
- fn set_vring_kick(&mut self, index: u8, fd: Option<RawFd>) -> Result<()>;
- fn set_vring_call(&mut self, index: u8, fd: Option<RawFd>) -> Result<()>;
- fn set_vring_err(&mut self, index: u8, fd: Option<RawFd>) -> Result<()>;
+ fn set_vring_kick(&mut self, index: u8, fd: Option<File>) -> Result<()>;
+ fn set_vring_call(&mut self, index: u8, fd: Option<File>) -> Result<()>;
+ fn set_vring_err(&mut self, index: u8, fd: Option<File>) -> Result<()>;
fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures>;
fn set_protocol_features(&mut self, features: u64) -> Result<()>;
@@ -104,9 +106,14 @@ pub trait VhostUserSlaveReqHandlerMut {
flags: VhostUserConfigFlags,
) -> Result<Vec<u8>>;
fn set_config(&mut self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>;
- fn set_slave_req_fd(&mut self, _vu_req: SlaveFsCacheReq) {}
+ fn set_slave_req_fd(&mut self, _vu_req: File) {}
+ fn get_inflight_fd(
+ &mut self,
+ inflight: &VhostUserInflight,
+ ) -> Result<(VhostUserInflight, File)>;
+ fn set_inflight_fd(&mut self, inflight: &VhostUserInflight, file: File) -> Result<()>;
fn get_max_mem_slots(&mut self) -> Result<u64>;
- fn add_mem_region(&mut self, region: &VhostUserSingleMemoryRegion, fd: RawFd) -> Result<()>;
+ fn add_mem_region(&mut self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()>;
fn remove_mem_region(&mut self, region: &VhostUserSingleMemoryRegion) -> Result<()>;
}
@@ -127,8 +134,8 @@ impl<T: VhostUserSlaveReqHandlerMut> VhostUserSlaveReqHandler for Mutex<T> {
self.lock().unwrap().set_features(features)
}
- fn set_mem_table(&self, ctx: &[VhostUserMemoryRegion], fds: &[RawFd]) -> Result<()> {
- self.lock().unwrap().set_mem_table(ctx, fds)
+ fn set_mem_table(&self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()> {
+ self.lock().unwrap().set_mem_table(ctx, files)
}
fn set_vring_num(&self, index: u32, num: u32) -> Result<()> {
@@ -157,15 +164,15 @@ impl<T: VhostUserSlaveReqHandlerMut> VhostUserSlaveReqHandler for Mutex<T> {
self.lock().unwrap().get_vring_base(index)
}
- fn set_vring_kick(&self, index: u8, fd: Option<RawFd>) -> Result<()> {
+ fn set_vring_kick(&self, index: u8, fd: Option<File>) -> Result<()> {
self.lock().unwrap().set_vring_kick(index, fd)
}
- fn set_vring_call(&self, index: u8, fd: Option<RawFd>) -> Result<()> {
+ fn set_vring_call(&self, index: u8, fd: Option<File>) -> Result<()> {
self.lock().unwrap().set_vring_call(index, fd)
}
- fn set_vring_err(&self, index: u8, fd: Option<RawFd>) -> Result<()> {
+ fn set_vring_err(&self, index: u8, fd: Option<File>) -> Result<()> {
self.lock().unwrap().set_vring_err(index, fd)
}
@@ -193,15 +200,23 @@ impl<T: VhostUserSlaveReqHandlerMut> VhostUserSlaveReqHandler for Mutex<T> {
self.lock().unwrap().set_config(offset, buf, flags)
}
- fn set_slave_req_fd(&self, vu_req: SlaveFsCacheReq) {
+ fn set_slave_req_fd(&self, vu_req: File) {
self.lock().unwrap().set_slave_req_fd(vu_req)
}
+ fn get_inflight_fd(&self, inflight: &VhostUserInflight) -> Result<(VhostUserInflight, File)> {
+ self.lock().unwrap().get_inflight_fd(inflight)
+ }
+
+ fn set_inflight_fd(&self, inflight: &VhostUserInflight, file: File) -> Result<()> {
+ self.lock().unwrap().set_inflight_fd(inflight, file)
+ }
+
fn get_max_mem_slots(&self) -> Result<u64> {
self.lock().unwrap().get_max_mem_slots()
}
- fn add_mem_region(&self, region: &VhostUserSingleMemoryRegion, fd: RawFd) -> Result<()> {
+ fn add_mem_region(&self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()> {
self.lock().unwrap().add_mem_region(region, fd)
}
@@ -253,6 +268,11 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
}
}
+ /// Create a vhost-user slave endpoint from a connected socket.
+ pub fn from_stream(socket: UnixStream, backend: Arc<S>) -> Self {
+ Self::new(Endpoint::from_stream(socket), backend)
+ }
+
/// Create a new vhost-user slave endpoint.
///
/// # Arguments
@@ -286,8 +306,9 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
// . recv optional message body and payload according size field in
// message header
// . validate message body and optional payload
- let (hdr, rfds) = self.main_sock.recv_header()?;
- let rfds = self.check_attached_rfds(&hdr, rfds)?;
+ let (hdr, files) = self.main_sock.recv_header()?;
+ self.check_attached_files(&hdr, &files)?;
+
let (size, buf) = match hdr.get_size() {
0 => (0, vec![0u8; 0]),
len => {
@@ -302,11 +323,13 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
match hdr.get_code() {
MasterReq::SET_OWNER => {
self.check_request_size(&hdr, size, 0)?;
- self.backend.set_owner()?;
+ let res = self.backend.set_owner();
+ self.send_ack_message(&hdr, res)?;
}
MasterReq::RESET_OWNER => {
self.check_request_size(&hdr, size, 0)?;
- self.backend.reset_owner()?;
+ let res = self.backend.reset_owner();
+ self.send_ack_message(&hdr, res)?;
}
MasterReq::GET_FEATURES => {
self.check_request_size(&hdr, size, 0)?;
@@ -318,12 +341,13 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
}
MasterReq::SET_FEATURES => {
let msg = self.extract_request_body::<VhostUserU64>(&hdr, size, &buf)?;
- self.backend.set_features(msg.value)?;
+ let res = self.backend.set_features(msg.value);
self.acked_virtio_features = msg.value;
self.update_reply_ack_flag();
+ self.send_ack_message(&hdr, res)?;
}
MasterReq::SET_MEM_TABLE => {
- let res = self.set_mem_table(&hdr, size, &buf, rfds);
+ let res = self.set_mem_table(&hdr, size, &buf, files);
self.send_ack_message(&hdr, res)?;
}
MasterReq::SET_VRING_NUM => {
@@ -359,20 +383,20 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
}
MasterReq::SET_VRING_CALL => {
self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?;
- let (index, rfds) = self.handle_vring_fd_request(&buf, rfds)?;
- let res = self.backend.set_vring_call(index, rfds);
+ let (index, file) = self.handle_vring_fd_request(&buf, files)?;
+ let res = self.backend.set_vring_call(index, file);
self.send_ack_message(&hdr, res)?;
}
MasterReq::SET_VRING_KICK => {
self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?;
- let (index, rfds) = self.handle_vring_fd_request(&buf, rfds)?;
- let res = self.backend.set_vring_kick(index, rfds);
+ let (index, file) = self.handle_vring_fd_request(&buf, files)?;
+ let res = self.backend.set_vring_kick(index, file);
self.send_ack_message(&hdr, res)?;
}
MasterReq::SET_VRING_ERR => {
self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?;
- let (index, rfds) = self.handle_vring_fd_request(&buf, rfds)?;
- let res = self.backend.set_vring_err(index, rfds);
+ let (index, file) = self.handle_vring_fd_request(&buf, files)?;
+ let res = self.backend.set_vring_err(index, file);
self.send_ack_message(&hdr, res)?;
}
MasterReq::GET_PROTOCOL_FEATURES => {
@@ -385,9 +409,10 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
}
MasterReq::SET_PROTOCOL_FEATURES => {
let msg = self.extract_request_body::<VhostUserU64>(&hdr, size, &buf)?;
- self.backend.set_protocol_features(msg.value)?;
+ let res = self.backend.set_protocol_features(msg.value);
self.acked_protocol_features = msg.value;
self.update_reply_ack_flag();
+ self.send_ack_message(&hdr, res)?;
}
MasterReq::GET_QUEUE_NUM => {
if self.acked_protocol_features & VhostUserProtocolFeatures::MQ.bits() == 0 {
@@ -426,14 +451,40 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
return Err(Error::InvalidOperation);
}
self.check_request_size(&hdr, size, hdr.get_size() as usize)?;
- self.set_config(&hdr, size, &buf)?;
+ let res = self.set_config(size, &buf);
+ self.send_ack_message(&hdr, res)?;
}
MasterReq::SET_SLAVE_REQ_FD => {
if self.acked_protocol_features & VhostUserProtocolFeatures::SLAVE_REQ.bits() == 0 {
return Err(Error::InvalidOperation);
}
self.check_request_size(&hdr, size, hdr.get_size() as usize)?;
- self.set_slave_req_fd(&hdr, rfds)?;
+ let res = self.set_slave_req_fd(files);
+ self.send_ack_message(&hdr, res)?;
+ }
+ MasterReq::GET_INFLIGHT_FD => {
+ if self.acked_protocol_features & VhostUserProtocolFeatures::INFLIGHT_SHMFD.bits()
+ == 0
+ {
+ return Err(Error::InvalidOperation);
+ }
+
+ let msg = self.extract_request_body::<VhostUserInflight>(&hdr, size, &buf)?;
+ let (inflight, file) = self.backend.get_inflight_fd(&msg)?;
+ let reply_hdr = self.new_reply_header::<VhostUserInflight>(&hdr, 0)?;
+ self.main_sock
+ .send_message(&reply_hdr, &inflight, Some(&[file.as_raw_fd()]))?;
+ }
+ MasterReq::SET_INFLIGHT_FD => {
+ if self.acked_protocol_features & VhostUserProtocolFeatures::INFLIGHT_SHMFD.bits()
+ == 0
+ {
+ return Err(Error::InvalidOperation);
+ }
+ let file = take_single_file(files).ok_or(Error::IncorrectFds)?;
+ let msg = self.extract_request_body::<VhostUserInflight>(&hdr, size, &buf)?;
+ let res = self.backend.set_inflight_fd(&msg, file);
+ self.send_ack_message(&hdr, res)?;
}
MasterReq::GET_MAX_MEM_SLOTS => {
if self.acked_protocol_features
@@ -454,18 +505,13 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
{
return Err(Error::InvalidOperation);
}
- let fd = if let Some(fds) = &rfds {
- if fds.len() != 1 {
- return Err(Error::InvalidParam);
- }
- fds[0]
- } else {
+ let mut files = files.ok_or(Error::InvalidParam)?;
+ if files.len() != 1 {
return Err(Error::InvalidParam);
- };
-
+ }
let msg =
self.extract_request_body::<VhostUserSingleMemoryRegion>(&hdr, size, &buf)?;
- let res = self.backend.add_mem_region(&msg, fd);
+ let res = self.backend.add_mem_region(&msg, files.swap_remove(0));
self.send_ack_message(&hdr, res)?;
}
MasterReq::REM_MEM_REG => {
@@ -493,37 +539,28 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
hdr: &VhostUserMsgHeader<MasterReq>,
size: usize,
buf: &[u8],
- rfds: Option<Vec<RawFd>>,
+ files: Option<Vec<File>>,
) -> Result<()> {
self.check_request_size(&hdr, size, hdr.get_size() as usize)?;
// check message size is consistent
let hdrsize = mem::size_of::<VhostUserMemory>();
if size < hdrsize {
- Endpoint::<MasterReq>::close_rfds(rfds);
return Err(Error::InvalidMessage);
}
let msg = unsafe { &*(buf.as_ptr() as *const VhostUserMemory) };
if !msg.is_valid() {
- Endpoint::<MasterReq>::close_rfds(rfds);
return Err(Error::InvalidMessage);
}
if size != hdrsize + msg.num_regions as usize * mem::size_of::<VhostUserMemoryRegion>() {
- Endpoint::<MasterReq>::close_rfds(rfds);
return Err(Error::InvalidMessage);
}
// validate number of fds matching number of memory regions
- let fds = match rfds {
- None => return Err(Error::InvalidMessage),
- Some(fds) => {
- if fds.len() != msg.num_regions as usize {
- Endpoint::<MasterReq>::close_rfds(Some(fds));
- return Err(Error::InvalidMessage);
- }
- fds
- }
- };
+ let files = files.ok_or(Error::InvalidMessage)?;
+ if files.len() != msg.num_regions as usize {
+ return Err(Error::InvalidMessage);
+ }
// Validate memory regions
let regions = unsafe {
@@ -534,12 +571,11 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
};
for region in regions.iter() {
if !region.is_valid() {
- Endpoint::<MasterReq>::close_rfds(Some(fds));
return Err(Error::InvalidMessage);
}
}
- self.backend.set_mem_table(&regions, &fds)
+ self.backend.set_mem_table(&regions, files)
}
fn get_config(&mut self, hdr: &VhostUserMsgHeader<MasterReq>, buf: &[u8]) -> Result<()> {
@@ -580,12 +616,7 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
Ok(())
}
- fn set_config(
- &mut self,
- hdr: &VhostUserMsgHeader<MasterReq>,
- size: usize,
- buf: &[u8],
- ) -> Result<()> {
+ fn set_config(&mut self, size: usize, buf: &[u8]) -> Result<()> {
if size > MAX_MSG_SIZE || size < mem::size_of::<VhostUserConfig>() {
return Err(Error::InvalidMessage);
}
@@ -602,35 +633,20 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
None => return Err(Error::InvalidMessage),
}
- let res = self.backend.set_config(msg.offset, buf, flags);
- self.send_ack_message(&hdr, res)?;
- Ok(())
+ self.backend.set_config(msg.offset, buf, flags)
}
- fn set_slave_req_fd(
- &mut self,
- hdr: &VhostUserMsgHeader<MasterReq>,
- rfds: Option<Vec<RawFd>>,
- ) -> Result<()> {
- if let Some(fds) = rfds {
- if fds.len() == 1 {
- let sock = unsafe { UnixStream::from_raw_fd(fds[0]) };
- let vu_req = SlaveFsCacheReq::from_stream(sock);
- self.backend.set_slave_req_fd(vu_req);
- self.send_ack_message(&hdr, Ok(()))
- } else {
- Err(Error::InvalidMessage)
- }
- } else {
- Err(Error::InvalidMessage)
- }
+ fn set_slave_req_fd(&mut self, files: Option<Vec<File>>) -> Result<()> {
+ let file = take_single_file(files).ok_or(Error::InvalidMessage)?;
+ self.backend.set_slave_req_fd(file);
+ Ok(())
}
fn handle_vring_fd_request(
&mut self,
buf: &[u8],
- rfds: Option<Vec<RawFd>>,
- ) -> Result<(u8, Option<RawFd>)> {
+ files: Option<Vec<File>>,
+ ) -> Result<(u8, Option<File>)> {
if buf.len() > MAX_MSG_SIZE || buf.len() < mem::size_of::<VhostUserU64>() {
return Err(Error::InvalidMessage);
}
@@ -640,28 +656,19 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
}
// Bits (0-7) of the payload contain the vring index. Bit 8 is the
- // invalid FD flag. This flag is set when there is no file descriptor
+ // invalid FD flag. This bit is set when there is no file descriptor
// in the ancillary data. This signals that polling will be used
// instead of waiting for the call.
- let nofd = (msg.value & 0x100u64) == 0x100u64;
-
- let mut rfd = None;
- match rfds {
- Some(fds) => {
- if !nofd && fds.len() == 1 {
- rfd = Some(fds[0]);
- } else if (nofd && !fds.is_empty()) || (!nofd && fds.len() != 1) {
- Endpoint::<MasterReq>::close_rfds(Some(fds));
- return Err(Error::InvalidMessage);
- }
- }
- None => {
- if !nofd {
- return Err(Error::InvalidMessage);
- }
- }
+ // If Bit 8 is unset, the data must contain a file descriptor.
+ let has_fd = (msg.value & 0x100u64) == 0;
+
+ let file = take_single_file(files);
+
+ if has_fd && file.is_none() || !has_fd && file.is_some() {
+ return Err(Error::InvalidMessage);
}
- Ok((msg.value as u8, rfd))
+
+ Ok((msg.value as u8, file))
}
fn check_state(&self) -> Result<()> {
@@ -687,29 +694,23 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
Ok(())
}
- fn check_attached_rfds(
+ fn check_attached_files(
&self,
hdr: &VhostUserMsgHeader<MasterReq>,
- rfds: Option<Vec<RawFd>>,
- ) -> Result<Option<Vec<RawFd>>> {
+ files: &Option<Vec<File>>,
+ ) -> Result<()> {
match hdr.get_code() {
- MasterReq::SET_MEM_TABLE => Ok(rfds),
- MasterReq::SET_VRING_CALL => Ok(rfds),
- MasterReq::SET_VRING_KICK => Ok(rfds),
- MasterReq::SET_VRING_ERR => Ok(rfds),
- MasterReq::SET_LOG_BASE => Ok(rfds),
- MasterReq::SET_LOG_FD => Ok(rfds),
- MasterReq::SET_SLAVE_REQ_FD => Ok(rfds),
- MasterReq::SET_INFLIGHT_FD => Ok(rfds),
- MasterReq::ADD_MEM_REG => Ok(rfds),
- _ => {
- if rfds.is_some() {
- Endpoint::<MasterReq>::close_rfds(rfds);
- Err(Error::InvalidMessage)
- } else {
- Ok(rfds)
- }
- }
+ MasterReq::SET_MEM_TABLE
+ | MasterReq::SET_VRING_CALL
+ | MasterReq::SET_VRING_KICK
+ | MasterReq::SET_VRING_ERR
+ | MasterReq::SET_LOG_BASE
+ | MasterReq::SET_LOG_FD
+ | MasterReq::SET_SLAVE_REQ_FD
+ | MasterReq::SET_INFLIGHT_FD
+ | MasterReq::ADD_MEM_REG => Ok(()),
+ _ if files.is_some() => Err(Error::InvalidMessage),
+ _ => Ok(()),
}
}
@@ -731,7 +732,6 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
let vflag = VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits();
let pflag = VhostUserProtocolFeatures::REPLY_ACK;
if (self.virtio_features & vflag) != 0
- && (self.acked_virtio_features & vflag) != 0
&& self.protocol_features.contains(pflag)
&& (self.acked_protocol_features & pflag.bits()) != 0
{
@@ -774,7 +774,7 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
let msg = VhostUserU64::new(val);
self.main_sock.send_message(&hdr, &msg, None)?;
}
- Ok(())
+ res
}
fn send_reply_message<T>(