diff options
Diffstat (limited to 'src/vhost_user/connection.rs')
-rw-r--r-- | src/vhost_user/connection.rs | 56 |
1 files changed, 34 insertions, 22 deletions
diff --git a/src/vhost_user/connection.rs b/src/vhost_user/connection.rs index deafdeb..01bf124 100644 --- a/src/vhost_user/connection.rs +++ b/src/vhost_user/connection.rs @@ -5,15 +5,16 @@ #![allow(dead_code)] -use libc::{c_void, iovec}; use std::io::ErrorKind; use std::marker::PhantomData; use std::os::unix::io::{AsRawFd, RawFd}; use std::os::unix::net::{UnixListener, UnixStream}; use std::{mem, slice}; +use libc::{c_void, iovec}; +use vmm_sys_util::sock_ctrl_msg::ScmSocket; + use super::message::*; -use super::sock_ctrl_msg::ScmSocket; use super::{Error, Result}; /// Unix domain socket listener for accepting incoming connections. @@ -215,6 +216,9 @@ impl<R: Req> Endpoint<R> { body: &T, fds: Option<&[RawFd]>, ) -> Result<()> { + if mem::size_of::<T>() > MAX_MSG_SIZE { + return Err(Error::OversizedMsg); + } // Safe because there can't be other mutable referance to hdr and body. let iovs = unsafe { [ @@ -243,14 +247,17 @@ impl<R: Req> Endpoint<R> { /// * - OversizedMsg: message size is too big. /// * - PartialMessage: received a partial message. /// * - IncorrectFds: wrong number of attached fds. - pub fn send_message_with_payload<T: Sized, P: Sized>( + pub fn send_message_with_payload<T: Sized>( &mut self, hdr: &VhostUserMsgHeader<R>, body: &T, - payload: &[P], + payload: &[u8], fds: Option<&[RawFd]>, ) -> Result<()> { - let len = payload.len() * mem::size_of::<P>(); + let len = payload.len(); + if mem::size_of::<T>() > MAX_MSG_SIZE { + return Err(Error::OversizedMsg); + } if len > MAX_MSG_SIZE - mem::size_of::<T>() { return Err(Error::OversizedMsg); } @@ -599,27 +606,32 @@ fn get_sub_iovs_offset(iov_lens: &[usize], skip_size: usize) -> (usize, usize) { #[cfg(test)] mod tests { - use super::*; use std::fs::File; use std::io::{Read, Seek, SeekFrom, Write}; use std::os::unix::io::FromRawFd; + use vmm_sys_util::rand::rand_alphanumerics; use vmm_sys_util::tempfile::TempFile; - const UNIX_SOCKET_LISTENER: &'static str = "/tmp/vhost_user_test_rust_listener"; - const UNIX_SOCKET_CONNECTION: &'static str = "/tmp/vhost_user_test_rust_connection"; - const UNIX_SOCKET_DATA: &'static str = "/tmp/vhost_user_test_rust_data"; - const UNIX_SOCKET_FD: &'static str = "/tmp/vhost_user_test_rust_fd"; - const UNIX_SOCKET_SEND: &'static str = "/tmp/vhost_user_test_rust_send"; + fn temp_path() -> String { + format!( + "/tmp/vhost_test_{}", + rand_alphanumerics(8).to_str().unwrap() + ) + } #[test] fn create_listener() { - let _ = Listener::new(UNIX_SOCKET_LISTENER, true).unwrap(); + let path = temp_path(); + let listener = Listener::new(&path, true).unwrap(); + + assert!(listener.as_raw_fd() > 0); } #[test] fn accept_connection() { - let listener = Listener::new(UNIX_SOCKET_CONNECTION, true).unwrap(); + let path = temp_path(); + let listener = Listener::new(&path, true).unwrap(); listener.set_nonblocking(true).unwrap(); // accept on a fd without incoming connection @@ -628,11 +640,11 @@ mod tests { } #[test] - #[ignore] fn send_data() { - let listener = Listener::new(UNIX_SOCKET_DATA, true).unwrap(); + let path = temp_path(); + let listener = Listener::new(&path, true).unwrap(); listener.set_nonblocking(true).unwrap(); - let mut master = Endpoint::<MasterReq>::connect(UNIX_SOCKET_DATA).unwrap(); + let mut master = Endpoint::<MasterReq>::connect(&path).unwrap(); let sock = listener.accept().unwrap().unwrap(); let mut slave = Endpoint::<MasterReq>::from_stream(sock); @@ -654,11 +666,11 @@ mod tests { } #[test] - #[ignore] fn send_fd() { - let listener = Listener::new(UNIX_SOCKET_FD, true).unwrap(); + let path = temp_path(); + let listener = Listener::new(&path, true).unwrap(); listener.set_nonblocking(true).unwrap(); - let mut master = Endpoint::<MasterReq>::connect(UNIX_SOCKET_FD).unwrap(); + let mut master = Endpoint::<MasterReq>::connect(&path).unwrap(); let sock = listener.accept().unwrap().unwrap(); let mut slave = Endpoint::<MasterReq>::from_stream(sock); @@ -808,11 +820,11 @@ mod tests { } #[test] - #[ignore] fn send_recv() { - let listener = Listener::new(UNIX_SOCKET_SEND, true).unwrap(); + let path = temp_path(); + let listener = Listener::new(&path, true).unwrap(); listener.set_nonblocking(true).unwrap(); - let mut master = Endpoint::<MasterReq>::connect(UNIX_SOCKET_SEND).unwrap(); + let mut master = Endpoint::<MasterReq>::connect(&path).unwrap(); let sock = listener.accept().unwrap().unwrap(); let mut slave = Endpoint::<MasterReq>::from_stream(sock); |