summaryrefslogtreecommitdiff
path: root/src/vhost_user/connection.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/vhost_user/connection.rs')
-rw-r--r--src/vhost_user/connection.rs152
1 files changed, 74 insertions, 78 deletions
diff --git a/src/vhost_user/connection.rs b/src/vhost_user/connection.rs
index f92db45..ea8461a 100644
--- a/src/vhost_user/connection.rs
+++ b/src/vhost_user/connection.rs
@@ -5,9 +5,10 @@
#![allow(dead_code)]
+use std::fs::File;
use std::io::ErrorKind;
use std::marker::PhantomData;
-use std::os::unix::io::{AsRawFd, RawFd};
+use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
use std::os::unix::net::{UnixListener, UnixStream};
use std::path::{Path, PathBuf};
use std::{mem, slice};
@@ -301,7 +302,7 @@ impl<R: Req> Endpoint<R> {
}
/// Reads bytes from the socket into the given scatter/gather vectors with optional attached
- /// file descriptors.
+ /// file.
///
/// The underlying communication channel is a Unix domain socket in STREAM mode. It's a little
/// tricky to pass file descriptors through such a communication channel. Let's assume that a
@@ -311,29 +312,37 @@ impl<R: Req> Endpoint<R> {
/// 2) message(packet) boundaries must be respected on the receive side.
/// In other words, recvmsg() operations must not cross the packet boundary, otherwise the
/// attached file descriptors will get lost.
+ /// Note that this function wraps received file descriptors as `File`.
///
/// # Return:
- /// * - (number of bytes received, [received fds]) on success
+ /// * - (number of bytes received, [received files]) on success
/// * - SocketRetry: temporary error caused by signals or short of resources.
/// * - SocketBroken: the underline socket is broken.
/// * - SocketError: other socket related errors.
- pub fn recv_into_iovec(&mut self, iovs: &mut [iovec]) -> Result<(usize, Option<Vec<RawFd>>)> {
+ pub fn recv_into_iovec(&mut self, iovs: &mut [iovec]) -> Result<(usize, Option<Vec<File>>)> {
let mut fd_array = vec![0; MAX_ATTACHED_FD_ENTRIES];
let (bytes, fds) = self.sock.recv_iovecs_with_fds(iovs, &mut fd_array)?;
- let rfds = match fds {
+
+ let files = match fds {
0 => None,
n => {
- let mut fds = Vec::with_capacity(n);
- fds.extend_from_slice(&fd_array[0..n]);
- Some(fds)
+ let files = fd_array
+ .iter()
+ .take(n)
+ .map(|fd| {
+ // Safe because we have the ownership of `fd`.
+ unsafe { File::from_raw_fd(*fd) }
+ })
+ .collect();
+ Some(files)
}
};
- Ok((bytes, rfds))
+ Ok((bytes, files))
}
/// Reads all bytes from the socket into the given scatter/gather vectors with optional
- /// attached file descriptors. Will loop until all data has been transfered.
+ /// attached files. Will loop until all data has been transferred.
///
/// The underlying communication channel is a Unix domain socket in STREAM mode. It's a little
/// tricky to pass file descriptors through such a communication channel. Let's assume that a
@@ -343,6 +352,7 @@ impl<R: Req> Endpoint<R> {
/// 2) message(packet) boundaries must be respected on the receive side.
/// In other words, recvmsg() operations must not cross the packet boundary, otherwise the
/// attached file descriptors will get lost.
+ /// Note that this function wraps received file descriptors as `File`.
///
/// # Return:
/// * - (number of bytes received, [received fds]) on success
@@ -351,7 +361,7 @@ impl<R: Req> Endpoint<R> {
pub fn recv_into_iovec_all(
&mut self,
iovs: &mut [iovec],
- ) -> Result<(usize, Option<Vec<RawFd>>)> {
+ ) -> Result<(usize, Option<Vec<File>>)> {
let mut data_read = 0;
let mut data_total = 0;
let mut rfds = None;
@@ -392,46 +402,46 @@ impl<R: Req> Endpoint<R> {
}
/// Reads bytes from the socket into a new buffer with optional attached
- /// file descriptors. Received file descriptors are set close-on-exec.
+ /// files. Received file descriptors are set close-on-exec and converted to `File`.
///
/// # Return:
- /// * - (number of bytes received, buf, [received fds]) on success.
+ /// * - (number of bytes received, buf, [received files]) on success.
/// * - SocketRetry: temporary error caused by signals or short of resources.
/// * - SocketBroken: the underline socket is broken.
/// * - SocketError: other socket related errors.
pub fn recv_into_buf(
&mut self,
buf_size: usize,
- ) -> Result<(usize, Vec<u8>, Option<Vec<RawFd>>)> {
+ ) -> Result<(usize, Vec<u8>, Option<Vec<File>>)> {
let mut buf = vec![0u8; buf_size];
- let (bytes, rfds) = {
+ let (bytes, files) = {
let mut iovs = [iovec {
iov_base: buf.as_mut_ptr() as *mut c_void,
iov_len: buf_size,
}];
self.recv_into_iovec(&mut iovs)?
};
- Ok((bytes, buf, rfds))
+ Ok((bytes, buf, files))
}
- /// Receive a header-only message with optional attached file descriptors.
+ /// Receive a header-only message with optional attached files.
/// Note, only the first MAX_ATTACHED_FD_ENTRIES file descriptors will be
/// accepted and all other file descriptor will be discard silently.
///
/// # Return:
- /// * - (message header, [received fds]) on success.
+ /// * - (message header, [received files]) on success.
/// * - SocketRetry: temporary error caused by signals or short of resources.
/// * - SocketBroken: the underline socket is broken.
/// * - SocketError: other socket related errors.
/// * - PartialMessage: received a partial message.
/// * - InvalidMessage: received a invalid message.
- pub fn recv_header(&mut self) -> Result<(VhostUserMsgHeader<R>, Option<Vec<RawFd>>)> {
+ pub fn recv_header(&mut self) -> Result<(VhostUserMsgHeader<R>, Option<Vec<File>>)> {
let mut hdr = VhostUserMsgHeader::default();
let mut iovs = [iovec {
iov_base: (&mut hdr as *mut VhostUserMsgHeader<R>) as *mut c_void,
iov_len: mem::size_of::<VhostUserMsgHeader<R>>(),
}];
- let (bytes, rfds) = self.recv_into_iovec_all(&mut iovs[..])?;
+ let (bytes, files) = self.recv_into_iovec_all(&mut iovs[..])?;
if bytes != mem::size_of::<VhostUserMsgHeader<R>>() {
return Err(Error::PartialMessage);
@@ -439,7 +449,7 @@ impl<R: Req> Endpoint<R> {
return Err(Error::InvalidMessage);
}
- Ok((hdr, rfds))
+ Ok((hdr, files))
}
/// Receive a message with optional attached file descriptors.
@@ -447,7 +457,7 @@ impl<R: Req> Endpoint<R> {
/// accepted and all other file descriptor will be discard silently.
///
/// # Return:
- /// * - (message header, message body, [received fds]) on success.
+ /// * - (message header, message body, [received files]) on success.
/// * - SocketRetry: temporary error caused by signals or short of resources.
/// * - SocketBroken: the underline socket is broken.
/// * - SocketError: other socket related errors.
@@ -455,7 +465,7 @@ impl<R: Req> Endpoint<R> {
/// * - InvalidMessage: received a invalid message.
pub fn recv_body<T: Sized + Default + VhostUserMsgValidator>(
&mut self,
- ) -> Result<(VhostUserMsgHeader<R>, T, Option<Vec<RawFd>>)> {
+ ) -> Result<(VhostUserMsgHeader<R>, T, Option<Vec<File>>)> {
let mut hdr = VhostUserMsgHeader::default();
let mut body: T = Default::default();
let mut iovs = [
@@ -468,7 +478,7 @@ impl<R: Req> Endpoint<R> {
iov_len: mem::size_of::<T>(),
},
];
- let (bytes, rfds) = self.recv_into_iovec_all(&mut iovs[..])?;
+ let (bytes, files) = self.recv_into_iovec_all(&mut iovs[..])?;
let total = mem::size_of::<VhostUserMsgHeader<R>>() + mem::size_of::<T>();
if bytes != total {
@@ -477,7 +487,7 @@ impl<R: Req> Endpoint<R> {
return Err(Error::InvalidMessage);
}
- Ok((hdr, body, rfds))
+ Ok((hdr, body, files))
}
/// Receive a message with header and optional content. Callers need to
@@ -488,7 +498,7 @@ impl<R: Req> Endpoint<R> {
/// silently.
///
/// # Return:
- /// * - (message header, message size, [received fds]) on success.
+ /// * - (message header, message size, [received files]) on success.
/// * - SocketRetry: temporary error caused by signals or short of resources.
/// * - SocketBroken: the underline socket is broken.
/// * - SocketError: other socket related errors.
@@ -497,7 +507,7 @@ impl<R: Req> Endpoint<R> {
pub fn recv_body_into_buf(
&mut self,
buf: &mut [u8],
- ) -> Result<(VhostUserMsgHeader<R>, usize, Option<Vec<RawFd>>)> {
+ ) -> Result<(VhostUserMsgHeader<R>, usize, Option<Vec<File>>)> {
let mut hdr = VhostUserMsgHeader::default();
let mut iovs = [
iovec {
@@ -509,7 +519,7 @@ impl<R: Req> Endpoint<R> {
iov_len: buf.len(),
},
];
- let (bytes, rfds) = self.recv_into_iovec_all(&mut iovs[..])?;
+ let (bytes, files) = self.recv_into_iovec_all(&mut iovs[..])?;
if bytes < mem::size_of::<VhostUserMsgHeader<R>>() {
return Err(Error::PartialMessage);
@@ -517,7 +527,7 @@ impl<R: Req> Endpoint<R> {
return Err(Error::InvalidMessage);
}
- Ok((hdr, bytes - mem::size_of::<VhostUserMsgHeader<R>>(), rfds))
+ Ok((hdr, bytes - mem::size_of::<VhostUserMsgHeader<R>>(), files))
}
/// Receive a message with optional payload and attached file descriptors.
@@ -525,7 +535,7 @@ impl<R: Req> Endpoint<R> {
/// accepted and all other file descriptor will be discard silently.
///
/// # Return:
- /// * - (message header, message body, size of payload, [received fds]) on success.
+ /// * - (message header, message body, size of payload, [received files]) on success.
/// * - SocketRetry: temporary error caused by signals or short of resources.
/// * - SocketBroken: the underline socket is broken.
/// * - SocketError: other socket related errors.
@@ -535,7 +545,7 @@ impl<R: Req> Endpoint<R> {
pub fn recv_payload_into_buf<T: Sized + Default + VhostUserMsgValidator>(
&mut self,
buf: &mut [u8],
- ) -> Result<(VhostUserMsgHeader<R>, T, usize, Option<Vec<RawFd>>)> {
+ ) -> Result<(VhostUserMsgHeader<R>, T, usize, Option<Vec<File>>)> {
let mut hdr = VhostUserMsgHeader::default();
let mut body: T = Default::default();
let mut iovs = [
@@ -552,7 +562,7 @@ impl<R: Req> Endpoint<R> {
iov_len: buf.len(),
},
];
- let (bytes, rfds) = self.recv_into_iovec_all(&mut iovs[..])?;
+ let (bytes, files) = self.recv_into_iovec_all(&mut iovs[..])?;
let total = mem::size_of::<VhostUserMsgHeader<R>>() + mem::size_of::<T>();
if bytes < total {
@@ -561,17 +571,7 @@ impl<R: Req> Endpoint<R> {
return Err(Error::InvalidMessage);
}
- Ok((hdr, body, bytes - total, rfds))
- }
-
- /// Close all raw file descriptors.
- pub fn close_rfds(rfds: Option<Vec<RawFd>>) {
- if let Some(fds) = rfds {
- for fd in fds {
- // safe because the rawfds are valid and we don't care about the result.
- let _ = unsafe { libc::close(fd) };
- }
- }
+ Ok((hdr, body, bytes - total, files))
}
}
@@ -604,7 +604,6 @@ fn get_sub_iovs_offset(iov_lens: &[usize], skip_size: usize) -> (usize, usize) {
#[cfg(test)]
mod tests {
use super::*;
- use std::fs::File;
use std::io::{Read, Seek, SeekFrom, Write};
use std::os::unix::io::FromRawFd;
use tempfile::{tempfile, Builder, TempDir};
@@ -685,14 +684,14 @@ mod tests {
.unwrap();
assert_eq!(len, 4);
- let (bytes, buf2, rfds) = slave.recv_into_buf(4).unwrap();
+ let (bytes, buf2, files) = slave.recv_into_buf(4).unwrap();
assert_eq!(bytes, 4);
assert_eq!(&buf1[..], &buf2[..]);
- assert!(rfds.is_some());
- let fds = rfds.unwrap();
+ assert!(files.is_some());
+ let files = files.unwrap();
{
- assert_eq!(fds.len(), 1);
- let mut file = unsafe { File::from_raw_fd(fds[0]) };
+ assert_eq!(files.len(), 1);
+ let mut file = &files[0];
let mut content = String::new();
file.seek(SeekFrom::Start(0)).unwrap();
file.read_to_string(&mut content).unwrap();
@@ -710,23 +709,23 @@ mod tests {
.unwrap();
assert_eq!(len, 4);
- let (bytes, buf2, rfds) = slave.recv_into_buf(0x2).unwrap();
+ let (bytes, buf2, files) = slave.recv_into_buf(0x2).unwrap();
assert_eq!(bytes, 2);
assert_eq!(&buf1[..2], &buf2[..]);
- assert!(rfds.is_some());
- let fds = rfds.unwrap();
+ assert!(files.is_some());
+ let files = files.unwrap();
{
- assert_eq!(fds.len(), 3);
- let mut file = unsafe { File::from_raw_fd(fds[1]) };
+ assert_eq!(files.len(), 3);
+ let mut file = &files[1];
let mut content = String::new();
file.seek(SeekFrom::Start(0)).unwrap();
file.read_to_string(&mut content).unwrap();
assert_eq!(content, "test");
}
- let (bytes, buf2, rfds) = slave.recv_into_buf(0x2).unwrap();
+ let (bytes, buf2, files) = slave.recv_into_buf(0x2).unwrap();
assert_eq!(bytes, 2);
assert_eq!(&buf1[2..], &buf2[..]);
- assert!(rfds.is_none());
+ assert!(files.is_none());
// Following communication pattern should not work:
// Sending side: data(header, body) with fds
@@ -742,10 +741,10 @@ mod tests {
let (bytes, buf4) = slave.recv_data(2).unwrap();
assert_eq!(bytes, 2);
assert_eq!(&buf1[..2], &buf4[..]);
- let (bytes, buf2, rfds) = slave.recv_into_buf(0x2).unwrap();
+ let (bytes, buf2, files) = slave.recv_into_buf(0x2).unwrap();
assert_eq!(bytes, 2);
assert_eq!(&buf1[2..], &buf2[..]);
- assert!(rfds.is_none());
+ assert!(files.is_none());
// Following communication pattern should work:
// Sending side: data, data with fds
@@ -760,28 +759,28 @@ mod tests {
.unwrap();
assert_eq!(len, 4);
- let (bytes, buf2, rfds) = slave.recv_into_buf(0x4).unwrap();
+ let (bytes, buf2, files) = slave.recv_into_buf(0x4).unwrap();
assert_eq!(bytes, 4);
assert_eq!(&buf1[..], &buf2[..]);
- assert!(rfds.is_none());
+ assert!(files.is_none());
- let (bytes, buf2, rfds) = slave.recv_into_buf(0x2).unwrap();
+ let (bytes, buf2, files) = slave.recv_into_buf(0x2).unwrap();
assert_eq!(bytes, 2);
assert_eq!(&buf1[..2], &buf2[..]);
- assert!(rfds.is_some());
- let fds = rfds.unwrap();
+ assert!(files.is_some());
+ let files = files.unwrap();
{
- assert_eq!(fds.len(), 3);
- let mut file = unsafe { File::from_raw_fd(fds[1]) };
+ assert_eq!(files.len(), 3);
+ let mut file = &files[1];
let mut content = String::new();
file.seek(SeekFrom::Start(0)).unwrap();
file.read_to_string(&mut content).unwrap();
assert_eq!(content, "test");
}
- let (bytes, buf2, rfds) = slave.recv_into_buf(0x2).unwrap();
+ let (bytes, buf2, files) = slave.recv_into_buf(0x2).unwrap();
assert_eq!(bytes, 2);
assert_eq!(&buf1[2..], &buf2[..]);
- assert!(rfds.is_none());
+ assert!(files.is_none());
// Following communication pattern should not work:
// Sending side: data1, data2 with fds
@@ -799,9 +798,9 @@ mod tests {
let (bytes, _) = slave.recv_data(5).unwrap();
assert_eq!(bytes, 5);
- let (bytes, _, rfds) = slave.recv_into_buf(0x4).unwrap();
+ let (bytes, _, files) = slave.recv_into_buf(0x4).unwrap();
assert_eq!(bytes, 3);
- assert!(rfds.is_none());
+ assert!(files.is_none());
// If the target fd array is too small, extra file descriptors will get lost.
let len = master
@@ -812,12 +811,9 @@ mod tests {
.unwrap();
assert_eq!(len, 4);
- let (bytes, _, rfds) = slave.recv_into_buf(0x4).unwrap();
+ let (bytes, _, files) = slave.recv_into_buf(0x4).unwrap();
assert_eq!(bytes, 4);
- assert!(rfds.is_some());
-
- Endpoint::<MasterReq>::close_rfds(rfds);
- Endpoint::<MasterReq>::close_rfds(None);
+ assert!(files.is_some());
}
#[test]
@@ -844,15 +840,15 @@ mod tests {
mem::size_of::<u64>(),
)
};
- let (hdr2, bytes, rfds) = slave.recv_body_into_buf(slice).unwrap();
+ let (hdr2, bytes, files) = slave.recv_body_into_buf(slice).unwrap();
assert_eq!(hdr1, hdr2);
assert_eq!(bytes, 8);
assert_eq!(features1, features2);
- assert!(rfds.is_none());
+ assert!(files.is_none());
master.send_header(&hdr1, None).unwrap();
- let (hdr2, rfds) = slave.recv_header().unwrap();
+ let (hdr2, files) = slave.recv_header().unwrap();
assert_eq!(hdr1, hdr2);
- assert!(rfds.is_none());
+ assert!(files.is_none());
}
}