diff options
Diffstat (limited to 'src/vhost_user/connection.rs')
-rw-r--r-- | src/vhost_user/connection.rs | 152 |
1 files changed, 74 insertions, 78 deletions
diff --git a/src/vhost_user/connection.rs b/src/vhost_user/connection.rs index f92db45..ea8461a 100644 --- a/src/vhost_user/connection.rs +++ b/src/vhost_user/connection.rs @@ -5,9 +5,10 @@ #![allow(dead_code)] +use std::fs::File; use std::io::ErrorKind; use std::marker::PhantomData; -use std::os::unix::io::{AsRawFd, RawFd}; +use std::os::unix::io::{AsRawFd, FromRawFd, RawFd}; use std::os::unix::net::{UnixListener, UnixStream}; use std::path::{Path, PathBuf}; use std::{mem, slice}; @@ -301,7 +302,7 @@ impl<R: Req> Endpoint<R> { } /// Reads bytes from the socket into the given scatter/gather vectors with optional attached - /// file descriptors. + /// file. /// /// The underlying communication channel is a Unix domain socket in STREAM mode. It's a little /// tricky to pass file descriptors through such a communication channel. Let's assume that a @@ -311,29 +312,37 @@ impl<R: Req> Endpoint<R> { /// 2) message(packet) boundaries must be respected on the receive side. /// In other words, recvmsg() operations must not cross the packet boundary, otherwise the /// attached file descriptors will get lost. + /// Note that this function wraps received file descriptors as `File`. /// /// # Return: - /// * - (number of bytes received, [received fds]) on success + /// * - (number of bytes received, [received files]) on success /// * - SocketRetry: temporary error caused by signals or short of resources. /// * - SocketBroken: the underline socket is broken. /// * - SocketError: other socket related errors. - pub fn recv_into_iovec(&mut self, iovs: &mut [iovec]) -> Result<(usize, Option<Vec<RawFd>>)> { + pub fn recv_into_iovec(&mut self, iovs: &mut [iovec]) -> Result<(usize, Option<Vec<File>>)> { let mut fd_array = vec![0; MAX_ATTACHED_FD_ENTRIES]; let (bytes, fds) = self.sock.recv_iovecs_with_fds(iovs, &mut fd_array)?; - let rfds = match fds { + + let files = match fds { 0 => None, n => { - let mut fds = Vec::with_capacity(n); - fds.extend_from_slice(&fd_array[0..n]); - Some(fds) + let files = fd_array + .iter() + .take(n) + .map(|fd| { + // Safe because we have the ownership of `fd`. + unsafe { File::from_raw_fd(*fd) } + }) + .collect(); + Some(files) } }; - Ok((bytes, rfds)) + Ok((bytes, files)) } /// Reads all bytes from the socket into the given scatter/gather vectors with optional - /// attached file descriptors. Will loop until all data has been transfered. + /// attached files. Will loop until all data has been transferred. /// /// The underlying communication channel is a Unix domain socket in STREAM mode. It's a little /// tricky to pass file descriptors through such a communication channel. Let's assume that a @@ -343,6 +352,7 @@ impl<R: Req> Endpoint<R> { /// 2) message(packet) boundaries must be respected on the receive side. /// In other words, recvmsg() operations must not cross the packet boundary, otherwise the /// attached file descriptors will get lost. + /// Note that this function wraps received file descriptors as `File`. /// /// # Return: /// * - (number of bytes received, [received fds]) on success @@ -351,7 +361,7 @@ impl<R: Req> Endpoint<R> { pub fn recv_into_iovec_all( &mut self, iovs: &mut [iovec], - ) -> Result<(usize, Option<Vec<RawFd>>)> { + ) -> Result<(usize, Option<Vec<File>>)> { let mut data_read = 0; let mut data_total = 0; let mut rfds = None; @@ -392,46 +402,46 @@ impl<R: Req> Endpoint<R> { } /// Reads bytes from the socket into a new buffer with optional attached - /// file descriptors. Received file descriptors are set close-on-exec. + /// files. Received file descriptors are set close-on-exec and converted to `File`. /// /// # Return: - /// * - (number of bytes received, buf, [received fds]) on success. + /// * - (number of bytes received, buf, [received files]) on success. /// * - SocketRetry: temporary error caused by signals or short of resources. /// * - SocketBroken: the underline socket is broken. /// * - SocketError: other socket related errors. pub fn recv_into_buf( &mut self, buf_size: usize, - ) -> Result<(usize, Vec<u8>, Option<Vec<RawFd>>)> { + ) -> Result<(usize, Vec<u8>, Option<Vec<File>>)> { let mut buf = vec![0u8; buf_size]; - let (bytes, rfds) = { + let (bytes, files) = { let mut iovs = [iovec { iov_base: buf.as_mut_ptr() as *mut c_void, iov_len: buf_size, }]; self.recv_into_iovec(&mut iovs)? }; - Ok((bytes, buf, rfds)) + Ok((bytes, buf, files)) } - /// Receive a header-only message with optional attached file descriptors. + /// Receive a header-only message with optional attached files. /// Note, only the first MAX_ATTACHED_FD_ENTRIES file descriptors will be /// accepted and all other file descriptor will be discard silently. /// /// # Return: - /// * - (message header, [received fds]) on success. + /// * - (message header, [received files]) on success. /// * - SocketRetry: temporary error caused by signals or short of resources. /// * - SocketBroken: the underline socket is broken. /// * - SocketError: other socket related errors. /// * - PartialMessage: received a partial message. /// * - InvalidMessage: received a invalid message. - pub fn recv_header(&mut self) -> Result<(VhostUserMsgHeader<R>, Option<Vec<RawFd>>)> { + pub fn recv_header(&mut self) -> Result<(VhostUserMsgHeader<R>, Option<Vec<File>>)> { let mut hdr = VhostUserMsgHeader::default(); let mut iovs = [iovec { iov_base: (&mut hdr as *mut VhostUserMsgHeader<R>) as *mut c_void, iov_len: mem::size_of::<VhostUserMsgHeader<R>>(), }]; - let (bytes, rfds) = self.recv_into_iovec_all(&mut iovs[..])?; + let (bytes, files) = self.recv_into_iovec_all(&mut iovs[..])?; if bytes != mem::size_of::<VhostUserMsgHeader<R>>() { return Err(Error::PartialMessage); @@ -439,7 +449,7 @@ impl<R: Req> Endpoint<R> { return Err(Error::InvalidMessage); } - Ok((hdr, rfds)) + Ok((hdr, files)) } /// Receive a message with optional attached file descriptors. @@ -447,7 +457,7 @@ impl<R: Req> Endpoint<R> { /// accepted and all other file descriptor will be discard silently. /// /// # Return: - /// * - (message header, message body, [received fds]) on success. + /// * - (message header, message body, [received files]) on success. /// * - SocketRetry: temporary error caused by signals or short of resources. /// * - SocketBroken: the underline socket is broken. /// * - SocketError: other socket related errors. @@ -455,7 +465,7 @@ impl<R: Req> Endpoint<R> { /// * - InvalidMessage: received a invalid message. pub fn recv_body<T: Sized + Default + VhostUserMsgValidator>( &mut self, - ) -> Result<(VhostUserMsgHeader<R>, T, Option<Vec<RawFd>>)> { + ) -> Result<(VhostUserMsgHeader<R>, T, Option<Vec<File>>)> { let mut hdr = VhostUserMsgHeader::default(); let mut body: T = Default::default(); let mut iovs = [ @@ -468,7 +478,7 @@ impl<R: Req> Endpoint<R> { iov_len: mem::size_of::<T>(), }, ]; - let (bytes, rfds) = self.recv_into_iovec_all(&mut iovs[..])?; + let (bytes, files) = self.recv_into_iovec_all(&mut iovs[..])?; let total = mem::size_of::<VhostUserMsgHeader<R>>() + mem::size_of::<T>(); if bytes != total { @@ -477,7 +487,7 @@ impl<R: Req> Endpoint<R> { return Err(Error::InvalidMessage); } - Ok((hdr, body, rfds)) + Ok((hdr, body, files)) } /// Receive a message with header and optional content. Callers need to @@ -488,7 +498,7 @@ impl<R: Req> Endpoint<R> { /// silently. /// /// # Return: - /// * - (message header, message size, [received fds]) on success. + /// * - (message header, message size, [received files]) on success. /// * - SocketRetry: temporary error caused by signals or short of resources. /// * - SocketBroken: the underline socket is broken. /// * - SocketError: other socket related errors. @@ -497,7 +507,7 @@ impl<R: Req> Endpoint<R> { pub fn recv_body_into_buf( &mut self, buf: &mut [u8], - ) -> Result<(VhostUserMsgHeader<R>, usize, Option<Vec<RawFd>>)> { + ) -> Result<(VhostUserMsgHeader<R>, usize, Option<Vec<File>>)> { let mut hdr = VhostUserMsgHeader::default(); let mut iovs = [ iovec { @@ -509,7 +519,7 @@ impl<R: Req> Endpoint<R> { iov_len: buf.len(), }, ]; - let (bytes, rfds) = self.recv_into_iovec_all(&mut iovs[..])?; + let (bytes, files) = self.recv_into_iovec_all(&mut iovs[..])?; if bytes < mem::size_of::<VhostUserMsgHeader<R>>() { return Err(Error::PartialMessage); @@ -517,7 +527,7 @@ impl<R: Req> Endpoint<R> { return Err(Error::InvalidMessage); } - Ok((hdr, bytes - mem::size_of::<VhostUserMsgHeader<R>>(), rfds)) + Ok((hdr, bytes - mem::size_of::<VhostUserMsgHeader<R>>(), files)) } /// Receive a message with optional payload and attached file descriptors. @@ -525,7 +535,7 @@ impl<R: Req> Endpoint<R> { /// accepted and all other file descriptor will be discard silently. /// /// # Return: - /// * - (message header, message body, size of payload, [received fds]) on success. + /// * - (message header, message body, size of payload, [received files]) on success. /// * - SocketRetry: temporary error caused by signals or short of resources. /// * - SocketBroken: the underline socket is broken. /// * - SocketError: other socket related errors. @@ -535,7 +545,7 @@ impl<R: Req> Endpoint<R> { pub fn recv_payload_into_buf<T: Sized + Default + VhostUserMsgValidator>( &mut self, buf: &mut [u8], - ) -> Result<(VhostUserMsgHeader<R>, T, usize, Option<Vec<RawFd>>)> { + ) -> Result<(VhostUserMsgHeader<R>, T, usize, Option<Vec<File>>)> { let mut hdr = VhostUserMsgHeader::default(); let mut body: T = Default::default(); let mut iovs = [ @@ -552,7 +562,7 @@ impl<R: Req> Endpoint<R> { iov_len: buf.len(), }, ]; - let (bytes, rfds) = self.recv_into_iovec_all(&mut iovs[..])?; + let (bytes, files) = self.recv_into_iovec_all(&mut iovs[..])?; let total = mem::size_of::<VhostUserMsgHeader<R>>() + mem::size_of::<T>(); if bytes < total { @@ -561,17 +571,7 @@ impl<R: Req> Endpoint<R> { return Err(Error::InvalidMessage); } - Ok((hdr, body, bytes - total, rfds)) - } - - /// Close all raw file descriptors. - pub fn close_rfds(rfds: Option<Vec<RawFd>>) { - if let Some(fds) = rfds { - for fd in fds { - // safe because the rawfds are valid and we don't care about the result. - let _ = unsafe { libc::close(fd) }; - } - } + Ok((hdr, body, bytes - total, files)) } } @@ -604,7 +604,6 @@ fn get_sub_iovs_offset(iov_lens: &[usize], skip_size: usize) -> (usize, usize) { #[cfg(test)] mod tests { use super::*; - use std::fs::File; use std::io::{Read, Seek, SeekFrom, Write}; use std::os::unix::io::FromRawFd; use tempfile::{tempfile, Builder, TempDir}; @@ -685,14 +684,14 @@ mod tests { .unwrap(); assert_eq!(len, 4); - let (bytes, buf2, rfds) = slave.recv_into_buf(4).unwrap(); + let (bytes, buf2, files) = slave.recv_into_buf(4).unwrap(); assert_eq!(bytes, 4); assert_eq!(&buf1[..], &buf2[..]); - assert!(rfds.is_some()); - let fds = rfds.unwrap(); + assert!(files.is_some()); + let files = files.unwrap(); { - assert_eq!(fds.len(), 1); - let mut file = unsafe { File::from_raw_fd(fds[0]) }; + assert_eq!(files.len(), 1); + let mut file = &files[0]; let mut content = String::new(); file.seek(SeekFrom::Start(0)).unwrap(); file.read_to_string(&mut content).unwrap(); @@ -710,23 +709,23 @@ mod tests { .unwrap(); assert_eq!(len, 4); - let (bytes, buf2, rfds) = slave.recv_into_buf(0x2).unwrap(); + let (bytes, buf2, files) = slave.recv_into_buf(0x2).unwrap(); assert_eq!(bytes, 2); assert_eq!(&buf1[..2], &buf2[..]); - assert!(rfds.is_some()); - let fds = rfds.unwrap(); + assert!(files.is_some()); + let files = files.unwrap(); { - assert_eq!(fds.len(), 3); - let mut file = unsafe { File::from_raw_fd(fds[1]) }; + assert_eq!(files.len(), 3); + let mut file = &files[1]; let mut content = String::new(); file.seek(SeekFrom::Start(0)).unwrap(); file.read_to_string(&mut content).unwrap(); assert_eq!(content, "test"); } - let (bytes, buf2, rfds) = slave.recv_into_buf(0x2).unwrap(); + let (bytes, buf2, files) = slave.recv_into_buf(0x2).unwrap(); assert_eq!(bytes, 2); assert_eq!(&buf1[2..], &buf2[..]); - assert!(rfds.is_none()); + assert!(files.is_none()); // Following communication pattern should not work: // Sending side: data(header, body) with fds @@ -742,10 +741,10 @@ mod tests { let (bytes, buf4) = slave.recv_data(2).unwrap(); assert_eq!(bytes, 2); assert_eq!(&buf1[..2], &buf4[..]); - let (bytes, buf2, rfds) = slave.recv_into_buf(0x2).unwrap(); + let (bytes, buf2, files) = slave.recv_into_buf(0x2).unwrap(); assert_eq!(bytes, 2); assert_eq!(&buf1[2..], &buf2[..]); - assert!(rfds.is_none()); + assert!(files.is_none()); // Following communication pattern should work: // Sending side: data, data with fds @@ -760,28 +759,28 @@ mod tests { .unwrap(); assert_eq!(len, 4); - let (bytes, buf2, rfds) = slave.recv_into_buf(0x4).unwrap(); + let (bytes, buf2, files) = slave.recv_into_buf(0x4).unwrap(); assert_eq!(bytes, 4); assert_eq!(&buf1[..], &buf2[..]); - assert!(rfds.is_none()); + assert!(files.is_none()); - let (bytes, buf2, rfds) = slave.recv_into_buf(0x2).unwrap(); + let (bytes, buf2, files) = slave.recv_into_buf(0x2).unwrap(); assert_eq!(bytes, 2); assert_eq!(&buf1[..2], &buf2[..]); - assert!(rfds.is_some()); - let fds = rfds.unwrap(); + assert!(files.is_some()); + let files = files.unwrap(); { - assert_eq!(fds.len(), 3); - let mut file = unsafe { File::from_raw_fd(fds[1]) }; + assert_eq!(files.len(), 3); + let mut file = &files[1]; let mut content = String::new(); file.seek(SeekFrom::Start(0)).unwrap(); file.read_to_string(&mut content).unwrap(); assert_eq!(content, "test"); } - let (bytes, buf2, rfds) = slave.recv_into_buf(0x2).unwrap(); + let (bytes, buf2, files) = slave.recv_into_buf(0x2).unwrap(); assert_eq!(bytes, 2); assert_eq!(&buf1[2..], &buf2[..]); - assert!(rfds.is_none()); + assert!(files.is_none()); // Following communication pattern should not work: // Sending side: data1, data2 with fds @@ -799,9 +798,9 @@ mod tests { let (bytes, _) = slave.recv_data(5).unwrap(); assert_eq!(bytes, 5); - let (bytes, _, rfds) = slave.recv_into_buf(0x4).unwrap(); + let (bytes, _, files) = slave.recv_into_buf(0x4).unwrap(); assert_eq!(bytes, 3); - assert!(rfds.is_none()); + assert!(files.is_none()); // If the target fd array is too small, extra file descriptors will get lost. let len = master @@ -812,12 +811,9 @@ mod tests { .unwrap(); assert_eq!(len, 4); - let (bytes, _, rfds) = slave.recv_into_buf(0x4).unwrap(); + let (bytes, _, files) = slave.recv_into_buf(0x4).unwrap(); assert_eq!(bytes, 4); - assert!(rfds.is_some()); - - Endpoint::<MasterReq>::close_rfds(rfds); - Endpoint::<MasterReq>::close_rfds(None); + assert!(files.is_some()); } #[test] @@ -844,15 +840,15 @@ mod tests { mem::size_of::<u64>(), ) }; - let (hdr2, bytes, rfds) = slave.recv_body_into_buf(slice).unwrap(); + let (hdr2, bytes, files) = slave.recv_body_into_buf(slice).unwrap(); assert_eq!(hdr1, hdr2); assert_eq!(bytes, 8); assert_eq!(features1, features2); - assert!(rfds.is_none()); + assert!(files.is_none()); master.send_header(&hdr1, None).unwrap(); - let (hdr2, rfds) = slave.recv_header().unwrap(); + let (hdr2, files) = slave.recv_header().unwrap(); assert_eq!(hdr1, hdr2); - assert!(rfds.is_none()); + assert!(files.is_none()); } } |