aboutsummaryrefslogtreecommitdiff
path: root/src/vsock.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/vsock.rs')
-rw-r--r--src/vsock.rs491
1 files changed, 491 insertions, 0 deletions
diff --git a/src/vsock.rs b/src/vsock.rs
new file mode 100644
index 0000000..634591e
--- /dev/null
+++ b/src/vsock.rs
@@ -0,0 +1,491 @@
+// Copyright 2018 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+/// Support for virtual sockets.
+use std::fmt;
+use std::io;
+use std::mem::{self, size_of};
+use std::num::ParseIntError;
+use std::os::raw::{c_uchar, c_uint, c_ushort};
+use std::os::unix::io::{AsRawFd, IntoRawFd, RawFd};
+use std::result;
+use std::str::FromStr;
+
+use libc::{
+ self, c_void, sa_family_t, size_t, sockaddr, socklen_t, F_GETFL, F_SETFL, O_NONBLOCK,
+ VMADDR_CID_ANY, VMADDR_CID_HOST, VMADDR_CID_HYPERVISOR,
+};
+
+// The domain for vsock sockets.
+const AF_VSOCK: sa_family_t = 40;
+
+// Vsock loopback address.
+const VMADDR_CID_LOCAL: c_uint = 1;
+
+/// Vsock equivalent of binding on port 0. Binds to a random port.
+pub const VMADDR_PORT_ANY: c_uint = c_uint::max_value();
+
+// The number of bytes of padding to be added to the sockaddr_vm struct. Taken directly
+// from linux/vm_sockets.h.
+const PADDING: usize = size_of::<sockaddr>()
+ - size_of::<sa_family_t>()
+ - size_of::<c_ushort>()
+ - (2 * size_of::<c_uint>());
+
+#[repr(C)]
+#[derive(Default)]
+struct sockaddr_vm {
+ svm_family: sa_family_t,
+ svm_reserved1: c_ushort,
+ svm_port: c_uint,
+ svm_cid: c_uint,
+ svm_zero: [c_uchar; PADDING],
+}
+
+#[derive(Debug)]
+pub struct AddrParseError;
+
+impl fmt::Display for AddrParseError {
+ fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
+ write!(fmt, "failed to parse vsock address")
+ }
+}
+
+/// The vsock equivalent of an IP address.
+#[derive(Debug, Copy, Clone, Hash, Eq, PartialEq)]
+pub enum VsockCid {
+ /// Vsock equivalent of INADDR_ANY. Indicates the context id of the current endpoint.
+ Any,
+ /// An address that refers to the bare-metal machine that serves as the hypervisor.
+ Hypervisor,
+ /// The loopback address.
+ Local,
+ /// The parent machine. It may not be the hypervisor for nested VMs.
+ Host,
+ /// An assigned CID that serves as the address for VSOCK.
+ Cid(c_uint),
+}
+
+impl fmt::Display for VsockCid {
+ fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
+ match &self {
+ VsockCid::Any => write!(fmt, "Any"),
+ VsockCid::Hypervisor => write!(fmt, "Hypervisor"),
+ VsockCid::Local => write!(fmt, "Local"),
+ VsockCid::Host => write!(fmt, "Host"),
+ VsockCid::Cid(c) => write!(fmt, "'{}'", c),
+ }
+ }
+}
+
+impl From<c_uint> for VsockCid {
+ fn from(c: c_uint) -> Self {
+ match c {
+ VMADDR_CID_ANY => VsockCid::Any,
+ VMADDR_CID_HYPERVISOR => VsockCid::Hypervisor,
+ VMADDR_CID_LOCAL => VsockCid::Local,
+ VMADDR_CID_HOST => VsockCid::Host,
+ _ => VsockCid::Cid(c),
+ }
+ }
+}
+
+impl FromStr for VsockCid {
+ type Err = ParseIntError;
+
+ fn from_str(s: &str) -> Result<Self, Self::Err> {
+ let c: c_uint = s.parse()?;
+ Ok(c.into())
+ }
+}
+
+impl Into<c_uint> for VsockCid {
+ fn into(self) -> c_uint {
+ match self {
+ VsockCid::Any => VMADDR_CID_ANY,
+ VsockCid::Hypervisor => VMADDR_CID_HYPERVISOR,
+ VsockCid::Local => VMADDR_CID_LOCAL,
+ VsockCid::Host => VMADDR_CID_HOST,
+ VsockCid::Cid(c) => c,
+ }
+ }
+}
+
+/// An address associated with a virtual socket.
+#[derive(Debug, Copy, Clone, Hash, Eq, PartialEq)]
+pub struct SocketAddr {
+ pub cid: VsockCid,
+ pub port: c_uint,
+}
+
+pub trait ToSocketAddr {
+ fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError>;
+}
+
+impl ToSocketAddr for SocketAddr {
+ fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> {
+ Ok(*self)
+ }
+}
+
+impl ToSocketAddr for str {
+ fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> {
+ self.parse()
+ }
+}
+
+impl ToSocketAddr for (VsockCid, c_uint) {
+ fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> {
+ let (cid, port) = *self;
+ Ok(SocketAddr { cid, port })
+ }
+}
+
+impl<'a, T: ToSocketAddr + ?Sized> ToSocketAddr for &'a T {
+ fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> {
+ (**self).to_socket_addr()
+ }
+}
+
+impl FromStr for SocketAddr {
+ type Err = AddrParseError;
+
+ /// Parse a vsock SocketAddr from a string. vsock socket addresses are of the form
+ /// "vsock:cid:port".
+ fn from_str(s: &str) -> Result<SocketAddr, AddrParseError> {
+ let components: Vec<&str> = s.split(':').collect();
+ if components.len() != 3 || components[0] != "vsock" {
+ return Err(AddrParseError);
+ }
+
+ Ok(SocketAddr {
+ cid: components[1].parse().map_err(|_| AddrParseError)?,
+ port: components[2].parse().map_err(|_| AddrParseError)?,
+ })
+ }
+}
+
+impl fmt::Display for SocketAddr {
+ fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
+ write!(fmt, "{}:{}", self.cid, self.port)
+ }
+}
+
+/// Sets `fd` to be blocking or nonblocking. `fd` must be a valid fd of a type that accepts the
+/// `O_NONBLOCK` flag. This includes regular files, pipes, and sockets.
+unsafe fn set_nonblocking(fd: RawFd, nonblocking: bool) -> io::Result<()> {
+ let flags = libc::fcntl(fd, F_GETFL, 0);
+ if flags < 0 {
+ return Err(io::Error::last_os_error());
+ }
+
+ let flags = if nonblocking {
+ flags | O_NONBLOCK
+ } else {
+ flags & !O_NONBLOCK
+ };
+
+ let ret = libc::fcntl(fd, F_SETFL, flags);
+ if ret < 0 {
+ return Err(io::Error::last_os_error());
+ }
+
+ Ok(())
+}
+
+/// A virtual socket.
+///
+/// Do not use this class unless you need to change socket options or query the
+/// state of the socket prior to calling listen or connect. Instead use either VsockStream or
+/// VsockListener.
+#[derive(Debug)]
+pub struct VsockSocket {
+ fd: RawFd,
+}
+
+impl VsockSocket {
+ pub fn new() -> io::Result<Self> {
+ let fd = unsafe { libc::socket(libc::AF_VSOCK, libc::SOCK_STREAM | libc::SOCK_CLOEXEC, 0) };
+ if fd < 0 {
+ Err(io::Error::last_os_error())
+ } else {
+ Ok(VsockSocket { fd })
+ }
+ }
+
+ pub fn bind<A: ToSocketAddr>(&mut self, addr: A) -> io::Result<()> {
+ let sockaddr = addr
+ .to_socket_addr()
+ .map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?;
+
+ // The compiler should optimize this out since these are both compile-time constants.
+ assert_eq!(size_of::<sockaddr_vm>(), size_of::<sockaddr>());
+
+ let mut svm: sockaddr_vm = Default::default();
+ svm.svm_family = AF_VSOCK;
+ svm.svm_cid = sockaddr.cid.into();
+ svm.svm_port = sockaddr.port;
+
+ // Safe because this doesn't modify any memory and we check the return value.
+ let ret = unsafe {
+ libc::bind(
+ self.fd,
+ &svm as *const sockaddr_vm as *const sockaddr,
+ size_of::<sockaddr_vm>() as socklen_t,
+ )
+ };
+ if ret < 0 {
+ let bind_err = io::Error::last_os_error();
+ Err(bind_err)
+ } else {
+ Ok(())
+ }
+ }
+
+ pub fn connect<A: ToSocketAddr>(self, addr: A) -> io::Result<VsockStream> {
+ let sockaddr = addr
+ .to_socket_addr()
+ .map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?;
+
+ let mut svm: sockaddr_vm = Default::default();
+ svm.svm_family = AF_VSOCK;
+ svm.svm_cid = sockaddr.cid.into();
+ svm.svm_port = sockaddr.port;
+
+ // Safe because this just connects a vsock socket, and the return value is checked.
+ let ret = unsafe {
+ libc::connect(
+ self.fd,
+ &svm as *const sockaddr_vm as *const sockaddr,
+ size_of::<sockaddr_vm>() as socklen_t,
+ )
+ };
+ if ret < 0 {
+ let connect_err = io::Error::last_os_error();
+ Err(connect_err)
+ } else {
+ Ok(VsockStream { sock: self })
+ }
+ }
+
+ pub fn listen(self) -> io::Result<VsockListener> {
+ // Safe because this doesn't modify any memory and we check the return value.
+ let ret = unsafe { libc::listen(self.fd, 1) };
+ if ret < 0 {
+ let listen_err = io::Error::last_os_error();
+ return Err(listen_err);
+ }
+ Ok(VsockListener { sock: self })
+ }
+
+ /// Returns the port that this socket is bound to. This can only succeed after bind is called.
+ pub fn local_port(&self) -> io::Result<u32> {
+ let mut svm: sockaddr_vm = Default::default();
+
+ // Safe because we give a valid pointer for addrlen and check the length.
+ let mut addrlen = size_of::<sockaddr_vm>() as socklen_t;
+ let ret = unsafe {
+ // Get the socket address that was actually bound.
+ libc::getsockname(
+ self.fd,
+ &mut svm as *mut sockaddr_vm as *mut sockaddr,
+ &mut addrlen as *mut socklen_t,
+ )
+ };
+ if ret < 0 {
+ let getsockname_err = io::Error::last_os_error();
+ Err(getsockname_err)
+ } else {
+ // If this doesn't match, it's not safe to get the port out of the sockaddr.
+ assert_eq!(addrlen as usize, size_of::<sockaddr_vm>());
+
+ Ok(svm.svm_port)
+ }
+ }
+
+ pub fn try_clone(&self) -> io::Result<Self> {
+ // Safe because this doesn't modify any memory and we check the return value.
+ let dup_fd = unsafe { libc::fcntl(self.fd, libc::F_DUPFD_CLOEXEC, 0) };
+ if dup_fd < 0 {
+ Err(io::Error::last_os_error())
+ } else {
+ Ok(Self { fd: dup_fd })
+ }
+ }
+
+ pub fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> {
+ // Safe because the fd is valid and owned by this stream.
+ unsafe { set_nonblocking(self.fd, nonblocking) }
+ }
+}
+
+impl IntoRawFd for VsockSocket {
+ fn into_raw_fd(self) -> RawFd {
+ let fd = self.fd;
+ mem::forget(self);
+ fd
+ }
+}
+
+impl AsRawFd for VsockSocket {
+ fn as_raw_fd(&self) -> RawFd {
+ self.fd
+ }
+}
+
+impl Drop for VsockSocket {
+ fn drop(&mut self) {
+ // Safe because this doesn't modify any memory and we are the only
+ // owner of the file descriptor.
+ unsafe { libc::close(self.fd) };
+ }
+}
+
+/// A virtual stream socket.
+#[derive(Debug)]
+pub struct VsockStream {
+ sock: VsockSocket,
+}
+
+impl VsockStream {
+ pub fn connect<A: ToSocketAddr>(addr: A) -> io::Result<VsockStream> {
+ let sock = VsockSocket::new()?;
+ sock.connect(addr)
+ }
+
+ /// Returns the port that this stream is bound to.
+ pub fn local_port(&self) -> io::Result<u32> {
+ self.sock.local_port()
+ }
+
+ pub fn try_clone(&self) -> io::Result<VsockStream> {
+ self.sock.try_clone().map(|f| VsockStream { sock: f })
+ }
+
+ pub fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> {
+ self.sock.set_nonblocking(nonblocking)
+ }
+}
+
+impl io::Read for VsockStream {
+ fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
+ // Safe because this will only modify the contents of |buf| and we check the return value.
+ let ret = unsafe {
+ libc::read(
+ self.sock.as_raw_fd(),
+ buf as *mut [u8] as *mut c_void,
+ buf.len() as size_t,
+ )
+ };
+ if ret < 0 {
+ return Err(io::Error::last_os_error());
+ }
+
+ Ok(ret as usize)
+ }
+}
+
+impl io::Write for VsockStream {
+ fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
+ // Safe because this doesn't modify any memory and we check the return value.
+ let ret = unsafe {
+ libc::write(
+ self.sock.as_raw_fd(),
+ buf as *const [u8] as *const c_void,
+ buf.len() as size_t,
+ )
+ };
+ if ret < 0 {
+ return Err(io::Error::last_os_error());
+ }
+
+ Ok(ret as usize)
+ }
+
+ fn flush(&mut self) -> io::Result<()> {
+ // No buffered data so nothing to do.
+ Ok(())
+ }
+}
+
+impl AsRawFd for VsockStream {
+ fn as_raw_fd(&self) -> RawFd {
+ self.sock.as_raw_fd()
+ }
+}
+
+impl IntoRawFd for VsockStream {
+ fn into_raw_fd(self) -> RawFd {
+ self.sock.into_raw_fd()
+ }
+}
+
+/// Represents a virtual socket server.
+#[derive(Debug)]
+pub struct VsockListener {
+ sock: VsockSocket,
+}
+
+impl VsockListener {
+ /// Creates a new `VsockListener` bound to the specified port on the current virtual socket
+ /// endpoint.
+ pub fn bind<A: ToSocketAddr>(addr: A) -> io::Result<VsockListener> {
+ let mut sock = VsockSocket::new()?;
+ sock.bind(addr)?;
+ sock.listen()
+ }
+
+ /// Returns the port that this listener is bound to.
+ pub fn local_port(&self) -> io::Result<u32> {
+ self.sock.local_port()
+ }
+
+ /// Accepts a new incoming connection on this listener. Blocks the calling thread until a
+ /// new connection is established. When established, returns the corresponding `VsockStream`
+ /// and the remote peer's address.
+ pub fn accept(&self) -> io::Result<(VsockStream, SocketAddr)> {
+ let mut svm: sockaddr_vm = Default::default();
+
+ // Safe because this will only modify |svm| and we check the return value.
+ let mut socklen: socklen_t = size_of::<sockaddr_vm>() as socklen_t;
+ let fd = unsafe {
+ libc::accept4(
+ self.sock.as_raw_fd(),
+ &mut svm as *mut sockaddr_vm as *mut sockaddr,
+ &mut socklen as *mut socklen_t,
+ libc::SOCK_CLOEXEC,
+ )
+ };
+ if fd < 0 {
+ return Err(io::Error::last_os_error());
+ }
+
+ if svm.svm_family != AF_VSOCK {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ format!("unexpected address family: {}", svm.svm_family),
+ ));
+ }
+
+ Ok((
+ VsockStream {
+ sock: VsockSocket { fd },
+ },
+ SocketAddr {
+ cid: svm.svm_cid.into(),
+ port: svm.svm_port,
+ },
+ ))
+ }
+
+ pub fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> {
+ self.sock.set_nonblocking(nonblocking)
+ }
+}
+
+impl AsRawFd for VsockListener {
+ fn as_raw_fd(&self) -> RawFd {
+ self.sock.as_raw_fd()
+ }
+}