aboutsummaryrefslogtreecommitdiff
path: root/src/lib.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/lib.rs')
-rw-r--r--src/lib.rs287
1 files changed, 69 insertions, 218 deletions
diff --git a/src/lib.rs b/src/lib.rs
index aaeb393..e482e9f 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -17,9 +17,19 @@
//! Virtio socket support for Rust.
-use libc::*;
-use nix::ioctl_read_bad;
-use std::ffi::c_void;
+use libc::{
+ accept4, ioctl, sa_family_t, sockaddr, sockaddr_vm, socklen_t, suseconds_t, timeval, AF_VSOCK,
+ FIONBIO, SOCK_CLOEXEC,
+};
+use nix::{
+ ioctl_read_bad,
+ sys::socket::{
+ self, bind, connect, getpeername, getsockname, listen, recv, send, shutdown, socket,
+ sockopt::{ReceiveTimeout, SendTimeout, SocketError},
+ AddressFamily, GetSockOpt, MsgFlags, SetSockOpt, SockFlag, SockType,
+ },
+ unistd::close,
+};
use std::fs::File;
use std::io::{Error, ErrorKind, Read, Result, Write};
use std::mem::{self, size_of};
@@ -28,10 +38,15 @@ use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
use std::time::Duration;
pub use libc::{VMADDR_CID_ANY, VMADDR_CID_HOST, VMADDR_CID_HYPERVISOR, VMADDR_CID_LOCAL};
-pub use nix::sys::socket::{SockAddr, VsockAddr};
-
-fn new_socket() -> libc::c_int {
- unsafe { socket(AF_VSOCK, SOCK_STREAM | SOCK_CLOEXEC, 0) }
+pub use nix::sys::socket::{SockaddrLike, VsockAddr};
+
+fn new_socket() -> Result<RawFd> {
+ Ok(socket(
+ AddressFamily::Vsock,
+ SockType::Stream,
+ SockFlag::SOCK_CLOEXEC,
+ None,
+ )?)
}
/// An iterator that infinitely accepts connections on a VsockListener.
@@ -56,68 +71,32 @@ pub struct VsockListener {
impl VsockListener {
/// Create a new VsockListener which is bound and listening on the socket address.
- pub fn bind(addr: &SockAddr) -> Result<VsockListener> {
- let mut vsock_addr = if let SockAddr::Vsock(addr) = addr {
- addr.0
- } else {
+ pub fn bind(addr: &impl SockaddrLike) -> Result<Self> {
+ if addr.family() != Some(AddressFamily::Vsock) {
return Err(Error::new(
ErrorKind::Other,
"requires a virtio socket address",
));
- };
-
- let socket = new_socket();
- if socket < 0 {
- return Err(Error::last_os_error());
}
- let res = unsafe {
- bind(
- socket,
- &mut vsock_addr as *mut _ as *mut sockaddr,
- size_of::<sockaddr_vm>() as socklen_t,
- )
- };
- if res < 0 {
- return Err(Error::last_os_error());
- }
+ let socket = new_socket()?;
+
+ bind(socket, addr)?;
// rust stdlib uses a 128 connection backlog
- let res = unsafe { listen(socket, 128) };
- if res < 0 {
- return Err(Error::last_os_error());
- }
+ listen(socket, 128)?;
Ok(Self { socket })
}
/// Create a new VsockListener with specified cid and port.
pub fn bind_with_cid_port(cid: u32, port: u32) -> Result<VsockListener> {
- Self::bind(&SockAddr::Vsock(VsockAddr::new(cid, port)))
+ Self::bind(&VsockAddr::new(cid, port))
}
/// The local socket address of the listener.
- pub fn local_addr(&self) -> Result<SockAddr> {
- let mut vsock_addr = sockaddr_vm {
- svm_family: AF_VSOCK as sa_family_t,
- svm_reserved1: 0,
- svm_port: 0,
- svm_cid: 0,
- svm_zero: [0u8; 4],
- };
- let mut vsock_addr_len = size_of::<sockaddr_vm>() as socklen_t;
- if unsafe {
- getsockname(
- self.socket,
- &mut vsock_addr as *mut _ as *mut sockaddr,
- &mut vsock_addr_len,
- )
- } < 0
- {
- Err(Error::last_os_error())
- } else {
- Ok(SockAddr::Vsock(VsockAddr(vsock_addr)))
- }
+ pub fn local_addr(&self) -> Result<VsockAddr> {
+ Ok(getsockname(self.socket)?)
}
/// Create a new independently owned handle to the underlying socket.
@@ -126,7 +105,7 @@ impl VsockListener {
}
/// Accept a new incoming connection from this listener.
- pub fn accept(&self) -> Result<(VsockStream, SockAddr)> {
+ pub fn accept(&self) -> Result<(VsockStream, VsockAddr)> {
let mut vsock_addr = sockaddr_vm {
svm_family: AF_VSOCK as sa_family_t,
svm_reserved1: 0,
@@ -148,7 +127,7 @@ impl VsockListener {
} else {
Ok((
unsafe { VsockStream::from_raw_fd(socket as RawFd) },
- SockAddr::Vsock(VsockAddr::new(vsock_addr.svm_cid, vsock_addr.svm_port)),
+ VsockAddr::new(vsock_addr.svm_cid, vsock_addr.svm_port),
))
}
}
@@ -160,26 +139,12 @@ impl VsockListener {
/// Retrieve the latest error associated with the underlying socket.
pub fn take_error(&self) -> Result<Option<Error>> {
- let mut error: i32 = 0;
- let mut error_len: socklen_t = 0;
- if unsafe {
- getsockopt(
- self.socket,
- SOL_SOCKET,
- SO_ERROR,
- &mut error as *mut _ as *mut c_void,
- &mut error_len,
- )
- } < 0
- {
- Err(Error::last_os_error())
+ let error = SocketError.get(self.socket)?;
+ Ok(if error == 0 {
+ None
} else {
- Ok(if error == 0 {
- None
- } else {
- Some(Error::from_raw_os_error(error))
- })
- }
+ Some(Error::from_raw_os_error(error))
+ })
}
/// Move this stream in and out of nonblocking mode.
@@ -215,7 +180,7 @@ impl IntoRawFd for VsockListener {
impl Drop for VsockListener {
fn drop(&mut self) {
- unsafe { close(self.socket) };
+ let _ = close(self.socket);
}
}
@@ -227,99 +192,42 @@ pub struct VsockStream {
impl VsockStream {
/// Open a connection to a remote host.
- pub fn connect(addr: &SockAddr) -> Result<Self> {
- let vsock_addr = if let SockAddr::Vsock(addr) = addr {
- addr.0
- } else {
+ pub fn connect(addr: &impl SockaddrLike) -> Result<Self> {
+ if addr.family() != Some(AddressFamily::Vsock) {
return Err(Error::new(
ErrorKind::Other,
"requires a virtio socket address",
));
- };
-
- let sock = new_socket();
- if sock < 0 {
- return Err(Error::last_os_error());
- }
- if unsafe {
- connect(
- sock,
- &vsock_addr as *const _ as *const sockaddr,
- size_of::<sockaddr_vm>() as socklen_t,
- )
- } < 0
- {
- Err(Error::last_os_error())
- } else {
- Ok(unsafe { VsockStream::from_raw_fd(sock) })
}
+
+ let sock = new_socket()?;
+ connect(sock, addr)?;
+ Ok(unsafe { Self::from_raw_fd(sock) })
}
/// Open a connection to a remote host with specified cid and port.
pub fn connect_with_cid_port(cid: u32, port: u32) -> Result<Self> {
- Self::connect(&SockAddr::Vsock(VsockAddr::new(cid, port)))
+ Self::connect(&VsockAddr::new(cid, port))
}
/// Virtio socket address of the remote peer associated with this connection.
- pub fn peer_addr(&self) -> Result<SockAddr> {
- let mut vsock_addr = sockaddr_vm {
- svm_family: AF_VSOCK as sa_family_t,
- svm_reserved1: 0,
- svm_port: 0,
- svm_cid: 0,
- svm_zero: [0u8; 4],
- };
- let mut vsock_addr_len = size_of::<sockaddr_vm>() as socklen_t;
- if unsafe {
- getpeername(
- self.socket,
- &mut vsock_addr as *mut _ as *mut sockaddr,
- &mut vsock_addr_len,
- )
- } < 0
- {
- Err(Error::last_os_error())
- } else {
- Ok(SockAddr::Vsock(VsockAddr(vsock_addr)))
- }
+ pub fn peer_addr(&self) -> Result<VsockAddr> {
+ Ok(getpeername(self.socket)?)
}
/// Virtio socket address of the local address associated with this connection.
- pub fn local_addr(&self) -> Result<SockAddr> {
- let mut vsock_addr = sockaddr_vm {
- svm_family: AF_VSOCK as sa_family_t,
- svm_reserved1: 0,
- svm_port: 0,
- svm_cid: 0,
- svm_zero: [0u8; 4],
- };
- let mut vsock_addr_len = size_of::<sockaddr_vm>() as socklen_t;
- if unsafe {
- getsockname(
- self.socket,
- &mut vsock_addr as *mut _ as *mut sockaddr,
- &mut vsock_addr_len,
- )
- } < 0
- {
- Err(Error::last_os_error())
- } else {
- Ok(SockAddr::Vsock(VsockAddr(vsock_addr)))
- }
+ pub fn local_addr(&self) -> Result<VsockAddr> {
+ Ok(getsockname(self.socket)?)
}
/// Shutdown the read, write, or both halves of this connection.
pub fn shutdown(&self, how: Shutdown) -> Result<()> {
let how = match how {
- Shutdown::Write => SHUT_WR,
- Shutdown::Read => SHUT_RD,
- Shutdown::Both => SHUT_RDWR,
+ Shutdown::Write => socket::Shutdown::Write,
+ Shutdown::Read => socket::Shutdown::Read,
+ Shutdown::Both => socket::Shutdown::Both,
};
- if unsafe { shutdown(self.socket, how) } < 0 {
- Err(Error::last_os_error())
- } else {
- Ok(())
- }
+ Ok(shutdown(self.socket, how)?)
}
/// Create a new independently owned handle to the underlying socket.
@@ -329,64 +237,24 @@ impl VsockStream {
/// Set the timeout on read operations.
pub fn set_read_timeout(&self, dur: Option<Duration>) -> Result<()> {
- let timeout = Self::timeval_from_duration(dur)?;
- if unsafe {
- setsockopt(
- self.socket,
- SOL_SOCKET,
- SO_SNDTIMEO,
- &timeout as *const _ as *const c_void,
- size_of::<timeval>() as socklen_t,
- )
- } < 0
- {
- Err(Error::last_os_error())
- } else {
- Ok(())
- }
+ let timeout = Self::timeval_from_duration(dur)?.into();
+ Ok(SendTimeout.set(self.socket, &timeout)?)
}
/// Set the timeout on write operations.
pub fn set_write_timeout(&self, dur: Option<Duration>) -> Result<()> {
- let timeout = Self::timeval_from_duration(dur)?;
- if unsafe {
- setsockopt(
- self.socket,
- SOL_SOCKET,
- SO_RCVTIMEO,
- &timeout as *const _ as *const c_void,
- size_of::<timeval>() as socklen_t,
- )
- } < 0
- {
- Err(Error::last_os_error())
- } else {
- Ok(())
- }
+ let timeout = Self::timeval_from_duration(dur)?.into();
+ Ok(ReceiveTimeout.set(self.socket, &timeout)?)
}
/// Retrieve the latest error associated with the underlying socket.
pub fn take_error(&self) -> Result<Option<Error>> {
- let mut error: i32 = 0;
- let mut error_len: socklen_t = 0;
- if unsafe {
- getsockopt(
- self.socket,
- SOL_SOCKET,
- SO_ERROR,
- &mut error as *mut _ as *mut c_void,
- &mut error_len,
- )
- } < 0
- {
- Err(Error::last_os_error())
+ let error = SocketError.get(self.socket)?;
+ Ok(if error == 0 {
+ None
} else {
- Ok(if error == 0 {
- None
- } else {
- Some(Error::from_raw_os_error(error))
- })
- }
+ Some(Error::from_raw_os_error(error))
+ })
}
/// Move this stream in and out of nonblocking mode.
@@ -411,10 +279,10 @@ impl VsockStream {
// https://github.com/rust-lang/libc/issues/1848
#[cfg_attr(target_env = "musl", allow(deprecated))]
- let secs = if dur.as_secs() > time_t::max_value() as u64 {
- time_t::max_value()
+ let secs = if dur.as_secs() > libc::time_t::max_value() as u64 {
+ libc::time_t::max_value()
} else {
- dur.as_secs() as time_t
+ dur.as_secs() as libc::time_t
};
let mut timeout = timeval {
tv_sec: secs,
@@ -451,30 +319,13 @@ impl Write for VsockStream {
impl Read for &VsockStream {
fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
- let ret = unsafe { recv(self.socket, buf.as_mut_ptr() as *mut c_void, buf.len(), 0) };
- if ret < 0 {
- Err(Error::last_os_error())
- } else {
- Ok(ret as usize)
- }
+ Ok(recv(self.socket, buf, MsgFlags::empty())?)
}
}
impl Write for &VsockStream {
fn write(&mut self, buf: &[u8]) -> Result<usize> {
- let ret = unsafe {
- send(
- self.socket,
- buf.as_ptr() as *const c_void,
- buf.len(),
- MSG_NOSIGNAL,
- )
- };
- if ret < 0 {
- Err(Error::last_os_error())
- } else {
- Ok(ret as usize)
- }
+ Ok(send(self.socket, buf, MsgFlags::MSG_NOSIGNAL)?)
}
fn flush(&mut self) -> Result<()> {
@@ -504,7 +355,7 @@ impl IntoRawFd for VsockStream {
impl Drop for VsockStream {
fn drop(&mut self) {
- unsafe { close(self.socket) };
+ let _ = close(self.socket);
}
}