diff options
Diffstat (limited to 'src/lib.rs')
-rw-r--r-- | src/lib.rs | 287 |
1 files changed, 69 insertions, 218 deletions
@@ -17,9 +17,19 @@ //! Virtio socket support for Rust. -use libc::*; -use nix::ioctl_read_bad; -use std::ffi::c_void; +use libc::{ + accept4, ioctl, sa_family_t, sockaddr, sockaddr_vm, socklen_t, suseconds_t, timeval, AF_VSOCK, + FIONBIO, SOCK_CLOEXEC, +}; +use nix::{ + ioctl_read_bad, + sys::socket::{ + self, bind, connect, getpeername, getsockname, listen, recv, send, shutdown, socket, + sockopt::{ReceiveTimeout, SendTimeout, SocketError}, + AddressFamily, GetSockOpt, MsgFlags, SetSockOpt, SockFlag, SockType, + }, + unistd::close, +}; use std::fs::File; use std::io::{Error, ErrorKind, Read, Result, Write}; use std::mem::{self, size_of}; @@ -28,10 +38,15 @@ use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd}; use std::time::Duration; pub use libc::{VMADDR_CID_ANY, VMADDR_CID_HOST, VMADDR_CID_HYPERVISOR, VMADDR_CID_LOCAL}; -pub use nix::sys::socket::{SockAddr, VsockAddr}; - -fn new_socket() -> libc::c_int { - unsafe { socket(AF_VSOCK, SOCK_STREAM | SOCK_CLOEXEC, 0) } +pub use nix::sys::socket::{SockaddrLike, VsockAddr}; + +fn new_socket() -> Result<RawFd> { + Ok(socket( + AddressFamily::Vsock, + SockType::Stream, + SockFlag::SOCK_CLOEXEC, + None, + )?) } /// An iterator that infinitely accepts connections on a VsockListener. @@ -56,68 +71,32 @@ pub struct VsockListener { impl VsockListener { /// Create a new VsockListener which is bound and listening on the socket address. - pub fn bind(addr: &SockAddr) -> Result<VsockListener> { - let mut vsock_addr = if let SockAddr::Vsock(addr) = addr { - addr.0 - } else { + pub fn bind(addr: &impl SockaddrLike) -> Result<Self> { + if addr.family() != Some(AddressFamily::Vsock) { return Err(Error::new( ErrorKind::Other, "requires a virtio socket address", )); - }; - - let socket = new_socket(); - if socket < 0 { - return Err(Error::last_os_error()); } - let res = unsafe { - bind( - socket, - &mut vsock_addr as *mut _ as *mut sockaddr, - size_of::<sockaddr_vm>() as socklen_t, - ) - }; - if res < 0 { - return Err(Error::last_os_error()); - } + let socket = new_socket()?; + + bind(socket, addr)?; // rust stdlib uses a 128 connection backlog - let res = unsafe { listen(socket, 128) }; - if res < 0 { - return Err(Error::last_os_error()); - } + listen(socket, 128)?; Ok(Self { socket }) } /// Create a new VsockListener with specified cid and port. pub fn bind_with_cid_port(cid: u32, port: u32) -> Result<VsockListener> { - Self::bind(&SockAddr::Vsock(VsockAddr::new(cid, port))) + Self::bind(&VsockAddr::new(cid, port)) } /// The local socket address of the listener. - pub fn local_addr(&self) -> Result<SockAddr> { - let mut vsock_addr = sockaddr_vm { - svm_family: AF_VSOCK as sa_family_t, - svm_reserved1: 0, - svm_port: 0, - svm_cid: 0, - svm_zero: [0u8; 4], - }; - let mut vsock_addr_len = size_of::<sockaddr_vm>() as socklen_t; - if unsafe { - getsockname( - self.socket, - &mut vsock_addr as *mut _ as *mut sockaddr, - &mut vsock_addr_len, - ) - } < 0 - { - Err(Error::last_os_error()) - } else { - Ok(SockAddr::Vsock(VsockAddr(vsock_addr))) - } + pub fn local_addr(&self) -> Result<VsockAddr> { + Ok(getsockname(self.socket)?) } /// Create a new independently owned handle to the underlying socket. @@ -126,7 +105,7 @@ impl VsockListener { } /// Accept a new incoming connection from this listener. - pub fn accept(&self) -> Result<(VsockStream, SockAddr)> { + pub fn accept(&self) -> Result<(VsockStream, VsockAddr)> { let mut vsock_addr = sockaddr_vm { svm_family: AF_VSOCK as sa_family_t, svm_reserved1: 0, @@ -148,7 +127,7 @@ impl VsockListener { } else { Ok(( unsafe { VsockStream::from_raw_fd(socket as RawFd) }, - SockAddr::Vsock(VsockAddr::new(vsock_addr.svm_cid, vsock_addr.svm_port)), + VsockAddr::new(vsock_addr.svm_cid, vsock_addr.svm_port), )) } } @@ -160,26 +139,12 @@ impl VsockListener { /// Retrieve the latest error associated with the underlying socket. pub fn take_error(&self) -> Result<Option<Error>> { - let mut error: i32 = 0; - let mut error_len: socklen_t = 0; - if unsafe { - getsockopt( - self.socket, - SOL_SOCKET, - SO_ERROR, - &mut error as *mut _ as *mut c_void, - &mut error_len, - ) - } < 0 - { - Err(Error::last_os_error()) + let error = SocketError.get(self.socket)?; + Ok(if error == 0 { + None } else { - Ok(if error == 0 { - None - } else { - Some(Error::from_raw_os_error(error)) - }) - } + Some(Error::from_raw_os_error(error)) + }) } /// Move this stream in and out of nonblocking mode. @@ -215,7 +180,7 @@ impl IntoRawFd for VsockListener { impl Drop for VsockListener { fn drop(&mut self) { - unsafe { close(self.socket) }; + let _ = close(self.socket); } } @@ -227,99 +192,42 @@ pub struct VsockStream { impl VsockStream { /// Open a connection to a remote host. - pub fn connect(addr: &SockAddr) -> Result<Self> { - let vsock_addr = if let SockAddr::Vsock(addr) = addr { - addr.0 - } else { + pub fn connect(addr: &impl SockaddrLike) -> Result<Self> { + if addr.family() != Some(AddressFamily::Vsock) { return Err(Error::new( ErrorKind::Other, "requires a virtio socket address", )); - }; - - let sock = new_socket(); - if sock < 0 { - return Err(Error::last_os_error()); - } - if unsafe { - connect( - sock, - &vsock_addr as *const _ as *const sockaddr, - size_of::<sockaddr_vm>() as socklen_t, - ) - } < 0 - { - Err(Error::last_os_error()) - } else { - Ok(unsafe { VsockStream::from_raw_fd(sock) }) } + + let sock = new_socket()?; + connect(sock, addr)?; + Ok(unsafe { Self::from_raw_fd(sock) }) } /// Open a connection to a remote host with specified cid and port. pub fn connect_with_cid_port(cid: u32, port: u32) -> Result<Self> { - Self::connect(&SockAddr::Vsock(VsockAddr::new(cid, port))) + Self::connect(&VsockAddr::new(cid, port)) } /// Virtio socket address of the remote peer associated with this connection. - pub fn peer_addr(&self) -> Result<SockAddr> { - let mut vsock_addr = sockaddr_vm { - svm_family: AF_VSOCK as sa_family_t, - svm_reserved1: 0, - svm_port: 0, - svm_cid: 0, - svm_zero: [0u8; 4], - }; - let mut vsock_addr_len = size_of::<sockaddr_vm>() as socklen_t; - if unsafe { - getpeername( - self.socket, - &mut vsock_addr as *mut _ as *mut sockaddr, - &mut vsock_addr_len, - ) - } < 0 - { - Err(Error::last_os_error()) - } else { - Ok(SockAddr::Vsock(VsockAddr(vsock_addr))) - } + pub fn peer_addr(&self) -> Result<VsockAddr> { + Ok(getpeername(self.socket)?) } /// Virtio socket address of the local address associated with this connection. - pub fn local_addr(&self) -> Result<SockAddr> { - let mut vsock_addr = sockaddr_vm { - svm_family: AF_VSOCK as sa_family_t, - svm_reserved1: 0, - svm_port: 0, - svm_cid: 0, - svm_zero: [0u8; 4], - }; - let mut vsock_addr_len = size_of::<sockaddr_vm>() as socklen_t; - if unsafe { - getsockname( - self.socket, - &mut vsock_addr as *mut _ as *mut sockaddr, - &mut vsock_addr_len, - ) - } < 0 - { - Err(Error::last_os_error()) - } else { - Ok(SockAddr::Vsock(VsockAddr(vsock_addr))) - } + pub fn local_addr(&self) -> Result<VsockAddr> { + Ok(getsockname(self.socket)?) } /// Shutdown the read, write, or both halves of this connection. pub fn shutdown(&self, how: Shutdown) -> Result<()> { let how = match how { - Shutdown::Write => SHUT_WR, - Shutdown::Read => SHUT_RD, - Shutdown::Both => SHUT_RDWR, + Shutdown::Write => socket::Shutdown::Write, + Shutdown::Read => socket::Shutdown::Read, + Shutdown::Both => socket::Shutdown::Both, }; - if unsafe { shutdown(self.socket, how) } < 0 { - Err(Error::last_os_error()) - } else { - Ok(()) - } + Ok(shutdown(self.socket, how)?) } /// Create a new independently owned handle to the underlying socket. @@ -329,64 +237,24 @@ impl VsockStream { /// Set the timeout on read operations. pub fn set_read_timeout(&self, dur: Option<Duration>) -> Result<()> { - let timeout = Self::timeval_from_duration(dur)?; - if unsafe { - setsockopt( - self.socket, - SOL_SOCKET, - SO_SNDTIMEO, - &timeout as *const _ as *const c_void, - size_of::<timeval>() as socklen_t, - ) - } < 0 - { - Err(Error::last_os_error()) - } else { - Ok(()) - } + let timeout = Self::timeval_from_duration(dur)?.into(); + Ok(SendTimeout.set(self.socket, &timeout)?) } /// Set the timeout on write operations. pub fn set_write_timeout(&self, dur: Option<Duration>) -> Result<()> { - let timeout = Self::timeval_from_duration(dur)?; - if unsafe { - setsockopt( - self.socket, - SOL_SOCKET, - SO_RCVTIMEO, - &timeout as *const _ as *const c_void, - size_of::<timeval>() as socklen_t, - ) - } < 0 - { - Err(Error::last_os_error()) - } else { - Ok(()) - } + let timeout = Self::timeval_from_duration(dur)?.into(); + Ok(ReceiveTimeout.set(self.socket, &timeout)?) } /// Retrieve the latest error associated with the underlying socket. pub fn take_error(&self) -> Result<Option<Error>> { - let mut error: i32 = 0; - let mut error_len: socklen_t = 0; - if unsafe { - getsockopt( - self.socket, - SOL_SOCKET, - SO_ERROR, - &mut error as *mut _ as *mut c_void, - &mut error_len, - ) - } < 0 - { - Err(Error::last_os_error()) + let error = SocketError.get(self.socket)?; + Ok(if error == 0 { + None } else { - Ok(if error == 0 { - None - } else { - Some(Error::from_raw_os_error(error)) - }) - } + Some(Error::from_raw_os_error(error)) + }) } /// Move this stream in and out of nonblocking mode. @@ -411,10 +279,10 @@ impl VsockStream { // https://github.com/rust-lang/libc/issues/1848 #[cfg_attr(target_env = "musl", allow(deprecated))] - let secs = if dur.as_secs() > time_t::max_value() as u64 { - time_t::max_value() + let secs = if dur.as_secs() > libc::time_t::max_value() as u64 { + libc::time_t::max_value() } else { - dur.as_secs() as time_t + dur.as_secs() as libc::time_t }; let mut timeout = timeval { tv_sec: secs, @@ -451,30 +319,13 @@ impl Write for VsockStream { impl Read for &VsockStream { fn read(&mut self, buf: &mut [u8]) -> Result<usize> { - let ret = unsafe { recv(self.socket, buf.as_mut_ptr() as *mut c_void, buf.len(), 0) }; - if ret < 0 { - Err(Error::last_os_error()) - } else { - Ok(ret as usize) - } + Ok(recv(self.socket, buf, MsgFlags::empty())?) } } impl Write for &VsockStream { fn write(&mut self, buf: &[u8]) -> Result<usize> { - let ret = unsafe { - send( - self.socket, - buf.as_ptr() as *const c_void, - buf.len(), - MSG_NOSIGNAL, - ) - }; - if ret < 0 { - Err(Error::last_os_error()) - } else { - Ok(ret as usize) - } + Ok(send(self.socket, buf, MsgFlags::MSG_NOSIGNAL)?) } fn flush(&mut self) -> Result<()> { @@ -504,7 +355,7 @@ impl IntoRawFd for VsockStream { impl Drop for VsockStream { fn drop(&mut self) { - unsafe { close(self.socket) }; + let _ = close(self.socket); } } |