aboutsummaryrefslogtreecommitdiff
path: root/src/net.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/net.rs')
-rw-r--r--src/net.rs269
1 files changed, 269 insertions, 0 deletions
diff --git a/src/net.rs b/src/net.rs
new file mode 100644
index 0000000..be1cd25
--- /dev/null
+++ b/src/net.rs
@@ -0,0 +1,269 @@
+// Copyright 2020 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+/// Structs to supplement std::net.
+use std::io;
+use std::mem::{self, size_of};
+use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6, TcpListener, TcpStream, ToSocketAddrs};
+use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
+
+use libc::{
+ c_int, in6_addr, in_addr, sa_family_t, sockaddr, sockaddr_in, sockaddr_in6, socklen_t, AF_INET,
+ AF_INET6, SOCK_CLOEXEC, SOCK_STREAM,
+};
+
+/// Assist in handling both IP version 4 and IP version 6.
+#[derive(Debug, Copy, Clone, PartialEq, Eq)]
+pub enum InetVersion {
+ V4,
+ V6,
+}
+
+impl InetVersion {
+ pub fn from_sockaddr(s: &SocketAddr) -> Self {
+ match s {
+ SocketAddr::V4(_) => InetVersion::V4,
+ SocketAddr::V6(_) => InetVersion::V6,
+ }
+ }
+}
+
+impl Into<sa_family_t> for InetVersion {
+ fn into(self) -> sa_family_t {
+ match self {
+ InetVersion::V4 => AF_INET as sa_family_t,
+ InetVersion::V6 => AF_INET6 as sa_family_t,
+ }
+ }
+}
+
+fn sockaddrv4_to_lib_c(s: &SocketAddrV4) -> sockaddr_in {
+ sockaddr_in {
+ sin_family: AF_INET as sa_family_t,
+ sin_port: s.port().to_be(),
+ sin_addr: in_addr {
+ s_addr: u32::from_ne_bytes(s.ip().octets()),
+ },
+ sin_zero: [0; 8],
+ }
+}
+
+fn sockaddrv6_to_lib_c(s: &SocketAddrV6) -> sockaddr_in6 {
+ sockaddr_in6 {
+ sin6_family: AF_INET6 as sa_family_t,
+ sin6_port: s.port().to_be(),
+ sin6_flowinfo: 0,
+ sin6_addr: in6_addr {
+ s6_addr: s.ip().octets(),
+ },
+ sin6_scope_id: 0,
+ }
+}
+
+/// A TCP socket.
+///
+/// Do not use this class unless you need to change socket options or query the
+/// state of the socket prior to calling listen or connect. Instead use either TcpStream or
+/// TcpListener.
+#[derive(Debug)]
+pub struct TcpSocket {
+ inet_version: InetVersion,
+ fd: RawFd,
+}
+
+impl TcpSocket {
+ pub fn new(inet_version: InetVersion) -> io::Result<Self> {
+ let fd = unsafe {
+ libc::socket(
+ Into::<sa_family_t>::into(inet_version) as c_int,
+ SOCK_STREAM | SOCK_CLOEXEC,
+ 0,
+ )
+ };
+ if fd < 0 {
+ Err(io::Error::last_os_error())
+ } else {
+ Ok(TcpSocket { inet_version, fd })
+ }
+ }
+
+ pub fn bind<A: ToSocketAddrs>(&mut self, addr: A) -> io::Result<()> {
+ let sockaddr = addr
+ .to_socket_addrs()
+ .map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?
+ .next()
+ .unwrap();
+
+ let ret = match sockaddr {
+ SocketAddr::V4(a) => {
+ let sin = sockaddrv4_to_lib_c(&a);
+ // Safe because this doesn't modify any memory and we check the return value.
+ unsafe {
+ libc::bind(
+ self.fd,
+ &sin as *const sockaddr_in as *const sockaddr,
+ size_of::<sockaddr_in>() as socklen_t,
+ )
+ }
+ }
+ SocketAddr::V6(a) => {
+ let sin6 = sockaddrv6_to_lib_c(&a);
+ // Safe because this doesn't modify any memory and we check the return value.
+ unsafe {
+ libc::bind(
+ self.fd,
+ &sin6 as *const sockaddr_in6 as *const sockaddr,
+ size_of::<sockaddr_in6>() as socklen_t,
+ )
+ }
+ }
+ };
+ if ret < 0 {
+ let bind_err = io::Error::last_os_error();
+ Err(bind_err)
+ } else {
+ Ok(())
+ }
+ }
+
+ pub fn connect<A: ToSocketAddrs>(self, addr: A) -> io::Result<TcpStream> {
+ let sockaddr = addr
+ .to_socket_addrs()
+ .map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?
+ .next()
+ .unwrap();
+
+ let ret = match sockaddr {
+ SocketAddr::V4(a) => {
+ let sin = sockaddrv4_to_lib_c(&a);
+ // Safe because this doesn't modify any memory and we check the return value.
+ unsafe {
+ libc::connect(
+ self.fd,
+ &sin as *const sockaddr_in as *const sockaddr,
+ size_of::<sockaddr_in>() as socklen_t,
+ )
+ }
+ }
+ SocketAddr::V6(a) => {
+ let sin6 = sockaddrv6_to_lib_c(&a);
+ // Safe because this doesn't modify any memory and we check the return value.
+ unsafe {
+ libc::connect(
+ self.fd,
+ &sin6 as *const sockaddr_in6 as *const sockaddr,
+ size_of::<sockaddr_in>() as socklen_t,
+ )
+ }
+ }
+ };
+
+ if ret < 0 {
+ let connect_err = io::Error::last_os_error();
+ Err(connect_err)
+ } else {
+ // Safe because the ownership of the raw fd is released from self and taken over by the
+ // new TcpStream.
+ Ok(unsafe { TcpStream::from_raw_fd(self.into_raw_fd()) })
+ }
+ }
+
+ pub fn listen(self) -> io::Result<TcpListener> {
+ // Safe because this doesn't modify any memory and we check the return value.
+ let ret = unsafe { libc::listen(self.fd, 1) };
+ if ret < 0 {
+ let listen_err = io::Error::last_os_error();
+ Err(listen_err)
+ } else {
+ // Safe because the ownership of the raw fd is released from self and taken over by the
+ // new TcpListener.
+ Ok(unsafe { TcpListener::from_raw_fd(self.into_raw_fd()) })
+ }
+ }
+
+ /// Returns the port that this socket is bound to. This can only succeed after bind is called.
+ pub fn local_port(&self) -> io::Result<u16> {
+ match self.inet_version {
+ InetVersion::V4 => {
+ let mut sin = sockaddr_in {
+ sin_family: 0,
+ sin_port: 0,
+ sin_addr: in_addr { s_addr: 0 },
+ sin_zero: [0; 8],
+ };
+
+ // Safe because we give a valid pointer for addrlen and check the length.
+ let mut addrlen = size_of::<sockaddr_in>() as socklen_t;
+ let ret = unsafe {
+ // Get the socket address that was actually bound.
+ libc::getsockname(
+ self.fd,
+ &mut sin as *mut sockaddr_in as *mut sockaddr,
+ &mut addrlen as *mut socklen_t,
+ )
+ };
+ if ret < 0 {
+ let getsockname_err = io::Error::last_os_error();
+ Err(getsockname_err)
+ } else {
+ // If this doesn't match, it's not safe to get the port out of the sockaddr.
+ assert_eq!(addrlen as usize, size_of::<sockaddr_in>());
+
+ Ok(sin.sin_port)
+ }
+ }
+ InetVersion::V6 => {
+ let mut sin6 = sockaddr_in6 {
+ sin6_family: 0,
+ sin6_port: 0,
+ sin6_flowinfo: 0,
+ sin6_addr: in6_addr { s6_addr: [0; 16] },
+ sin6_scope_id: 0,
+ };
+
+ // Safe because we give a valid pointer for addrlen and check the length.
+ let mut addrlen = size_of::<sockaddr_in6>() as socklen_t;
+ let ret = unsafe {
+ // Get the socket address that was actually bound.
+ libc::getsockname(
+ self.fd,
+ &mut sin6 as *mut sockaddr_in6 as *mut sockaddr,
+ &mut addrlen as *mut socklen_t,
+ )
+ };
+ if ret < 0 {
+ let getsockname_err = io::Error::last_os_error();
+ Err(getsockname_err)
+ } else {
+ // If this doesn't match, it's not safe to get the port out of the sockaddr.
+ assert_eq!(addrlen as usize, size_of::<sockaddr_in>());
+
+ Ok(sin6.sin6_port)
+ }
+ }
+ }
+ }
+}
+
+impl IntoRawFd for TcpSocket {
+ fn into_raw_fd(self) -> RawFd {
+ let fd = self.fd;
+ mem::forget(self);
+ fd
+ }
+}
+
+impl AsRawFd for TcpSocket {
+ fn as_raw_fd(&self) -> RawFd {
+ self.fd
+ }
+}
+
+impl Drop for TcpSocket {
+ fn drop(&mut self) {
+ // Safe because this doesn't modify any memory and we are the only
+ // owner of the file descriptor.
+ unsafe { libc::close(self.fd) };
+ }
+}