summaryrefslogtreecommitdiff
path: root/src/vhost_user/connection.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/vhost_user/connection.rs')
-rw-r--r--src/vhost_user/connection.rs56
1 files changed, 34 insertions, 22 deletions
diff --git a/src/vhost_user/connection.rs b/src/vhost_user/connection.rs
index deafdeb..01bf124 100644
--- a/src/vhost_user/connection.rs
+++ b/src/vhost_user/connection.rs
@@ -5,15 +5,16 @@
#![allow(dead_code)]
-use libc::{c_void, iovec};
use std::io::ErrorKind;
use std::marker::PhantomData;
use std::os::unix::io::{AsRawFd, RawFd};
use std::os::unix::net::{UnixListener, UnixStream};
use std::{mem, slice};
+use libc::{c_void, iovec};
+use vmm_sys_util::sock_ctrl_msg::ScmSocket;
+
use super::message::*;
-use super::sock_ctrl_msg::ScmSocket;
use super::{Error, Result};
/// Unix domain socket listener for accepting incoming connections.
@@ -215,6 +216,9 @@ impl<R: Req> Endpoint<R> {
body: &T,
fds: Option<&[RawFd]>,
) -> Result<()> {
+ if mem::size_of::<T>() > MAX_MSG_SIZE {
+ return Err(Error::OversizedMsg);
+ }
// Safe because there can't be other mutable referance to hdr and body.
let iovs = unsafe {
[
@@ -243,14 +247,17 @@ impl<R: Req> Endpoint<R> {
/// * - OversizedMsg: message size is too big.
/// * - PartialMessage: received a partial message.
/// * - IncorrectFds: wrong number of attached fds.
- pub fn send_message_with_payload<T: Sized, P: Sized>(
+ pub fn send_message_with_payload<T: Sized>(
&mut self,
hdr: &VhostUserMsgHeader<R>,
body: &T,
- payload: &[P],
+ payload: &[u8],
fds: Option<&[RawFd]>,
) -> Result<()> {
- let len = payload.len() * mem::size_of::<P>();
+ let len = payload.len();
+ if mem::size_of::<T>() > MAX_MSG_SIZE {
+ return Err(Error::OversizedMsg);
+ }
if len > MAX_MSG_SIZE - mem::size_of::<T>() {
return Err(Error::OversizedMsg);
}
@@ -599,27 +606,32 @@ fn get_sub_iovs_offset(iov_lens: &[usize], skip_size: usize) -> (usize, usize) {
#[cfg(test)]
mod tests {
-
use super::*;
use std::fs::File;
use std::io::{Read, Seek, SeekFrom, Write};
use std::os::unix::io::FromRawFd;
+ use vmm_sys_util::rand::rand_alphanumerics;
use vmm_sys_util::tempfile::TempFile;
- const UNIX_SOCKET_LISTENER: &'static str = "/tmp/vhost_user_test_rust_listener";
- const UNIX_SOCKET_CONNECTION: &'static str = "/tmp/vhost_user_test_rust_connection";
- const UNIX_SOCKET_DATA: &'static str = "/tmp/vhost_user_test_rust_data";
- const UNIX_SOCKET_FD: &'static str = "/tmp/vhost_user_test_rust_fd";
- const UNIX_SOCKET_SEND: &'static str = "/tmp/vhost_user_test_rust_send";
+ fn temp_path() -> String {
+ format!(
+ "/tmp/vhost_test_{}",
+ rand_alphanumerics(8).to_str().unwrap()
+ )
+ }
#[test]
fn create_listener() {
- let _ = Listener::new(UNIX_SOCKET_LISTENER, true).unwrap();
+ let path = temp_path();
+ let listener = Listener::new(&path, true).unwrap();
+
+ assert!(listener.as_raw_fd() > 0);
}
#[test]
fn accept_connection() {
- let listener = Listener::new(UNIX_SOCKET_CONNECTION, true).unwrap();
+ let path = temp_path();
+ let listener = Listener::new(&path, true).unwrap();
listener.set_nonblocking(true).unwrap();
// accept on a fd without incoming connection
@@ -628,11 +640,11 @@ mod tests {
}
#[test]
- #[ignore]
fn send_data() {
- let listener = Listener::new(UNIX_SOCKET_DATA, true).unwrap();
+ let path = temp_path();
+ let listener = Listener::new(&path, true).unwrap();
listener.set_nonblocking(true).unwrap();
- let mut master = Endpoint::<MasterReq>::connect(UNIX_SOCKET_DATA).unwrap();
+ let mut master = Endpoint::<MasterReq>::connect(&path).unwrap();
let sock = listener.accept().unwrap().unwrap();
let mut slave = Endpoint::<MasterReq>::from_stream(sock);
@@ -654,11 +666,11 @@ mod tests {
}
#[test]
- #[ignore]
fn send_fd() {
- let listener = Listener::new(UNIX_SOCKET_FD, true).unwrap();
+ let path = temp_path();
+ let listener = Listener::new(&path, true).unwrap();
listener.set_nonblocking(true).unwrap();
- let mut master = Endpoint::<MasterReq>::connect(UNIX_SOCKET_FD).unwrap();
+ let mut master = Endpoint::<MasterReq>::connect(&path).unwrap();
let sock = listener.accept().unwrap().unwrap();
let mut slave = Endpoint::<MasterReq>::from_stream(sock);
@@ -808,11 +820,11 @@ mod tests {
}
#[test]
- #[ignore]
fn send_recv() {
- let listener = Listener::new(UNIX_SOCKET_SEND, true).unwrap();
+ let path = temp_path();
+ let listener = Listener::new(&path, true).unwrap();
listener.set_nonblocking(true).unwrap();
- let mut master = Endpoint::<MasterReq>::connect(UNIX_SOCKET_SEND).unwrap();
+ let mut master = Endpoint::<MasterReq>::connect(&path).unwrap();
let sock = listener.accept().unwrap().unwrap();
let mut slave = Endpoint::<MasterReq>::from_stream(sock);