diff options
Diffstat (limited to 'src/vsock.rs')
-rw-r--r-- | src/vsock.rs | 491 |
1 files changed, 491 insertions, 0 deletions
diff --git a/src/vsock.rs b/src/vsock.rs new file mode 100644 index 0000000..634591e --- /dev/null +++ b/src/vsock.rs @@ -0,0 +1,491 @@ +// Copyright 2018 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +/// Support for virtual sockets. +use std::fmt; +use std::io; +use std::mem::{self, size_of}; +use std::num::ParseIntError; +use std::os::raw::{c_uchar, c_uint, c_ushort}; +use std::os::unix::io::{AsRawFd, IntoRawFd, RawFd}; +use std::result; +use std::str::FromStr; + +use libc::{ + self, c_void, sa_family_t, size_t, sockaddr, socklen_t, F_GETFL, F_SETFL, O_NONBLOCK, + VMADDR_CID_ANY, VMADDR_CID_HOST, VMADDR_CID_HYPERVISOR, +}; + +// The domain for vsock sockets. +const AF_VSOCK: sa_family_t = 40; + +// Vsock loopback address. +const VMADDR_CID_LOCAL: c_uint = 1; + +/// Vsock equivalent of binding on port 0. Binds to a random port. +pub const VMADDR_PORT_ANY: c_uint = c_uint::max_value(); + +// The number of bytes of padding to be added to the sockaddr_vm struct. Taken directly +// from linux/vm_sockets.h. +const PADDING: usize = size_of::<sockaddr>() + - size_of::<sa_family_t>() + - size_of::<c_ushort>() + - (2 * size_of::<c_uint>()); + +#[repr(C)] +#[derive(Default)] +struct sockaddr_vm { + svm_family: sa_family_t, + svm_reserved1: c_ushort, + svm_port: c_uint, + svm_cid: c_uint, + svm_zero: [c_uchar; PADDING], +} + +#[derive(Debug)] +pub struct AddrParseError; + +impl fmt::Display for AddrParseError { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + write!(fmt, "failed to parse vsock address") + } +} + +/// The vsock equivalent of an IP address. +#[derive(Debug, Copy, Clone, Hash, Eq, PartialEq)] +pub enum VsockCid { + /// Vsock equivalent of INADDR_ANY. Indicates the context id of the current endpoint. + Any, + /// An address that refers to the bare-metal machine that serves as the hypervisor. + Hypervisor, + /// The loopback address. + Local, + /// The parent machine. It may not be the hypervisor for nested VMs. + Host, + /// An assigned CID that serves as the address for VSOCK. + Cid(c_uint), +} + +impl fmt::Display for VsockCid { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + match &self { + VsockCid::Any => write!(fmt, "Any"), + VsockCid::Hypervisor => write!(fmt, "Hypervisor"), + VsockCid::Local => write!(fmt, "Local"), + VsockCid::Host => write!(fmt, "Host"), + VsockCid::Cid(c) => write!(fmt, "'{}'", c), + } + } +} + +impl From<c_uint> for VsockCid { + fn from(c: c_uint) -> Self { + match c { + VMADDR_CID_ANY => VsockCid::Any, + VMADDR_CID_HYPERVISOR => VsockCid::Hypervisor, + VMADDR_CID_LOCAL => VsockCid::Local, + VMADDR_CID_HOST => VsockCid::Host, + _ => VsockCid::Cid(c), + } + } +} + +impl FromStr for VsockCid { + type Err = ParseIntError; + + fn from_str(s: &str) -> Result<Self, Self::Err> { + let c: c_uint = s.parse()?; + Ok(c.into()) + } +} + +impl Into<c_uint> for VsockCid { + fn into(self) -> c_uint { + match self { + VsockCid::Any => VMADDR_CID_ANY, + VsockCid::Hypervisor => VMADDR_CID_HYPERVISOR, + VsockCid::Local => VMADDR_CID_LOCAL, + VsockCid::Host => VMADDR_CID_HOST, + VsockCid::Cid(c) => c, + } + } +} + +/// An address associated with a virtual socket. +#[derive(Debug, Copy, Clone, Hash, Eq, PartialEq)] +pub struct SocketAddr { + pub cid: VsockCid, + pub port: c_uint, +} + +pub trait ToSocketAddr { + fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError>; +} + +impl ToSocketAddr for SocketAddr { + fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> { + Ok(*self) + } +} + +impl ToSocketAddr for str { + fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> { + self.parse() + } +} + +impl ToSocketAddr for (VsockCid, c_uint) { + fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> { + let (cid, port) = *self; + Ok(SocketAddr { cid, port }) + } +} + +impl<'a, T: ToSocketAddr + ?Sized> ToSocketAddr for &'a T { + fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> { + (**self).to_socket_addr() + } +} + +impl FromStr for SocketAddr { + type Err = AddrParseError; + + /// Parse a vsock SocketAddr from a string. vsock socket addresses are of the form + /// "vsock:cid:port". + fn from_str(s: &str) -> Result<SocketAddr, AddrParseError> { + let components: Vec<&str> = s.split(':').collect(); + if components.len() != 3 || components[0] != "vsock" { + return Err(AddrParseError); + } + + Ok(SocketAddr { + cid: components[1].parse().map_err(|_| AddrParseError)?, + port: components[2].parse().map_err(|_| AddrParseError)?, + }) + } +} + +impl fmt::Display for SocketAddr { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + write!(fmt, "{}:{}", self.cid, self.port) + } +} + +/// Sets `fd` to be blocking or nonblocking. `fd` must be a valid fd of a type that accepts the +/// `O_NONBLOCK` flag. This includes regular files, pipes, and sockets. +unsafe fn set_nonblocking(fd: RawFd, nonblocking: bool) -> io::Result<()> { + let flags = libc::fcntl(fd, F_GETFL, 0); + if flags < 0 { + return Err(io::Error::last_os_error()); + } + + let flags = if nonblocking { + flags | O_NONBLOCK + } else { + flags & !O_NONBLOCK + }; + + let ret = libc::fcntl(fd, F_SETFL, flags); + if ret < 0 { + return Err(io::Error::last_os_error()); + } + + Ok(()) +} + +/// A virtual socket. +/// +/// Do not use this class unless you need to change socket options or query the +/// state of the socket prior to calling listen or connect. Instead use either VsockStream or +/// VsockListener. +#[derive(Debug)] +pub struct VsockSocket { + fd: RawFd, +} + +impl VsockSocket { + pub fn new() -> io::Result<Self> { + let fd = unsafe { libc::socket(libc::AF_VSOCK, libc::SOCK_STREAM | libc::SOCK_CLOEXEC, 0) }; + if fd < 0 { + Err(io::Error::last_os_error()) + } else { + Ok(VsockSocket { fd }) + } + } + + pub fn bind<A: ToSocketAddr>(&mut self, addr: A) -> io::Result<()> { + let sockaddr = addr + .to_socket_addr() + .map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?; + + // The compiler should optimize this out since these are both compile-time constants. + assert_eq!(size_of::<sockaddr_vm>(), size_of::<sockaddr>()); + + let mut svm: sockaddr_vm = Default::default(); + svm.svm_family = AF_VSOCK; + svm.svm_cid = sockaddr.cid.into(); + svm.svm_port = sockaddr.port; + + // Safe because this doesn't modify any memory and we check the return value. + let ret = unsafe { + libc::bind( + self.fd, + &svm as *const sockaddr_vm as *const sockaddr, + size_of::<sockaddr_vm>() as socklen_t, + ) + }; + if ret < 0 { + let bind_err = io::Error::last_os_error(); + Err(bind_err) + } else { + Ok(()) + } + } + + pub fn connect<A: ToSocketAddr>(self, addr: A) -> io::Result<VsockStream> { + let sockaddr = addr + .to_socket_addr() + .map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?; + + let mut svm: sockaddr_vm = Default::default(); + svm.svm_family = AF_VSOCK; + svm.svm_cid = sockaddr.cid.into(); + svm.svm_port = sockaddr.port; + + // Safe because this just connects a vsock socket, and the return value is checked. + let ret = unsafe { + libc::connect( + self.fd, + &svm as *const sockaddr_vm as *const sockaddr, + size_of::<sockaddr_vm>() as socklen_t, + ) + }; + if ret < 0 { + let connect_err = io::Error::last_os_error(); + Err(connect_err) + } else { + Ok(VsockStream { sock: self }) + } + } + + pub fn listen(self) -> io::Result<VsockListener> { + // Safe because this doesn't modify any memory and we check the return value. + let ret = unsafe { libc::listen(self.fd, 1) }; + if ret < 0 { + let listen_err = io::Error::last_os_error(); + return Err(listen_err); + } + Ok(VsockListener { sock: self }) + } + + /// Returns the port that this socket is bound to. This can only succeed after bind is called. + pub fn local_port(&self) -> io::Result<u32> { + let mut svm: sockaddr_vm = Default::default(); + + // Safe because we give a valid pointer for addrlen and check the length. + let mut addrlen = size_of::<sockaddr_vm>() as socklen_t; + let ret = unsafe { + // Get the socket address that was actually bound. + libc::getsockname( + self.fd, + &mut svm as *mut sockaddr_vm as *mut sockaddr, + &mut addrlen as *mut socklen_t, + ) + }; + if ret < 0 { + let getsockname_err = io::Error::last_os_error(); + Err(getsockname_err) + } else { + // If this doesn't match, it's not safe to get the port out of the sockaddr. + assert_eq!(addrlen as usize, size_of::<sockaddr_vm>()); + + Ok(svm.svm_port) + } + } + + pub fn try_clone(&self) -> io::Result<Self> { + // Safe because this doesn't modify any memory and we check the return value. + let dup_fd = unsafe { libc::fcntl(self.fd, libc::F_DUPFD_CLOEXEC, 0) }; + if dup_fd < 0 { + Err(io::Error::last_os_error()) + } else { + Ok(Self { fd: dup_fd }) + } + } + + pub fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> { + // Safe because the fd is valid and owned by this stream. + unsafe { set_nonblocking(self.fd, nonblocking) } + } +} + +impl IntoRawFd for VsockSocket { + fn into_raw_fd(self) -> RawFd { + let fd = self.fd; + mem::forget(self); + fd + } +} + +impl AsRawFd for VsockSocket { + fn as_raw_fd(&self) -> RawFd { + self.fd + } +} + +impl Drop for VsockSocket { + fn drop(&mut self) { + // Safe because this doesn't modify any memory and we are the only + // owner of the file descriptor. + unsafe { libc::close(self.fd) }; + } +} + +/// A virtual stream socket. +#[derive(Debug)] +pub struct VsockStream { + sock: VsockSocket, +} + +impl VsockStream { + pub fn connect<A: ToSocketAddr>(addr: A) -> io::Result<VsockStream> { + let sock = VsockSocket::new()?; + sock.connect(addr) + } + + /// Returns the port that this stream is bound to. + pub fn local_port(&self) -> io::Result<u32> { + self.sock.local_port() + } + + pub fn try_clone(&self) -> io::Result<VsockStream> { + self.sock.try_clone().map(|f| VsockStream { sock: f }) + } + + pub fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> { + self.sock.set_nonblocking(nonblocking) + } +} + +impl io::Read for VsockStream { + fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { + // Safe because this will only modify the contents of |buf| and we check the return value. + let ret = unsafe { + libc::read( + self.sock.as_raw_fd(), + buf as *mut [u8] as *mut c_void, + buf.len() as size_t, + ) + }; + if ret < 0 { + return Err(io::Error::last_os_error()); + } + + Ok(ret as usize) + } +} + +impl io::Write for VsockStream { + fn write(&mut self, buf: &[u8]) -> io::Result<usize> { + // Safe because this doesn't modify any memory and we check the return value. + let ret = unsafe { + libc::write( + self.sock.as_raw_fd(), + buf as *const [u8] as *const c_void, + buf.len() as size_t, + ) + }; + if ret < 0 { + return Err(io::Error::last_os_error()); + } + + Ok(ret as usize) + } + + fn flush(&mut self) -> io::Result<()> { + // No buffered data so nothing to do. + Ok(()) + } +} + +impl AsRawFd for VsockStream { + fn as_raw_fd(&self) -> RawFd { + self.sock.as_raw_fd() + } +} + +impl IntoRawFd for VsockStream { + fn into_raw_fd(self) -> RawFd { + self.sock.into_raw_fd() + } +} + +/// Represents a virtual socket server. +#[derive(Debug)] +pub struct VsockListener { + sock: VsockSocket, +} + +impl VsockListener { + /// Creates a new `VsockListener` bound to the specified port on the current virtual socket + /// endpoint. + pub fn bind<A: ToSocketAddr>(addr: A) -> io::Result<VsockListener> { + let mut sock = VsockSocket::new()?; + sock.bind(addr)?; + sock.listen() + } + + /// Returns the port that this listener is bound to. + pub fn local_port(&self) -> io::Result<u32> { + self.sock.local_port() + } + + /// Accepts a new incoming connection on this listener. Blocks the calling thread until a + /// new connection is established. When established, returns the corresponding `VsockStream` + /// and the remote peer's address. + pub fn accept(&self) -> io::Result<(VsockStream, SocketAddr)> { + let mut svm: sockaddr_vm = Default::default(); + + // Safe because this will only modify |svm| and we check the return value. + let mut socklen: socklen_t = size_of::<sockaddr_vm>() as socklen_t; + let fd = unsafe { + libc::accept4( + self.sock.as_raw_fd(), + &mut svm as *mut sockaddr_vm as *mut sockaddr, + &mut socklen as *mut socklen_t, + libc::SOCK_CLOEXEC, + ) + }; + if fd < 0 { + return Err(io::Error::last_os_error()); + } + + if svm.svm_family != AF_VSOCK { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("unexpected address family: {}", svm.svm_family), + )); + } + + Ok(( + VsockStream { + sock: VsockSocket { fd }, + }, + SocketAddr { + cid: svm.svm_cid.into(), + port: svm.svm_port, + }, + )) + } + + pub fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> { + self.sock.set_nonblocking(nonblocking) + } +} + +impl AsRawFd for VsockListener { + fn as_raw_fd(&self) -> RawFd { + self.sock.as_raw_fd() + } +} |