aboutsummaryrefslogtreecommitdiff
path: root/src/net/tcp/stream.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/net/tcp/stream.rs')
-rw-r--r--src/net/tcp/stream.rs156
1 files changed, 139 insertions, 17 deletions
diff --git a/src/net/tcp/stream.rs b/src/net/tcp/stream.rs
index cdbd46a..a7a9aa1 100644
--- a/src/net/tcp/stream.rs
+++ b/src/net/tcp/stream.rs
@@ -3,11 +3,14 @@ use std::io::{self, IoSlice, IoSliceMut, Read, Write};
use std::net::{self, Shutdown, SocketAddr};
#[cfg(unix)]
use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
+#[cfg(target_os = "wasi")]
+use std::os::wasi::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
#[cfg(windows)]
use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket};
use crate::io_source::IoSource;
-use crate::net::TcpSocket;
+#[cfg(not(target_os = "wasi"))]
+use crate::sys::tcp::{connect, new_for_addr};
use crate::{event, Interest, Registry, Token};
/// A non-blocking TCP stream between a local socket and a remote socket.
@@ -49,9 +52,43 @@ pub struct TcpStream {
impl TcpStream {
/// Create a new TCP stream and issue a non-blocking connect to the
/// specified address.
+ ///
+ /// # Notes
+ ///
+ /// The returned `TcpStream` may not be connected (and thus usable), unlike
+ /// the API found in `std::net::TcpStream`. Because Mio issues a
+ /// *non-blocking* connect it will not block the thread and instead return
+ /// an unconnected `TcpStream`.
+ ///
+ /// Ensuring the returned stream is connected is surprisingly complex when
+ /// considering cross-platform support. Doing this properly should follow
+ /// the steps below, an example implementation can be found
+ /// [here](https://github.com/Thomasdezeeuw/heph/blob/0c4f1ab3eaf08bea1d65776528bfd6114c9f8374/src/net/tcp/stream.rs#L560-L622).
+ ///
+ /// 1. Call `TcpStream::connect`
+ /// 2. Register the returned stream with at least [write interest].
+ /// 3. Wait for a (writable) event.
+ /// 4. Check `TcpStream::peer_addr`. If it returns `libc::EINPROGRESS` or
+ /// `ErrorKind::NotConnected` it means the stream is not yet connected,
+ /// go back to step 3. If it returns an address it means the stream is
+ /// connected, go to step 5. If another error is returned something
+ /// whent wrong.
+ /// 5. Now the stream can be used.
+ ///
+ /// This may return a `WouldBlock` in which case the socket connection
+ /// cannot be completed immediately, it usually means there are insufficient
+ /// entries in the routing cache.
+ ///
+ /// [write interest]: Interest::WRITABLE
+ #[cfg(not(target_os = "wasi"))]
pub fn connect(addr: SocketAddr) -> io::Result<TcpStream> {
- let socket = TcpSocket::new_for_addr(addr)?;
- socket.connect(addr)
+ let socket = new_for_addr(addr)?;
+ #[cfg(unix)]
+ let stream = unsafe { TcpStream::from_raw_fd(socket) };
+ #[cfg(windows)]
+ let stream = unsafe { TcpStream::from_raw_socket(socket as _) };
+ connect(&stream.inner, addr)?;
+ Ok(stream)
}
/// Creates a new `TcpStream` from a standard `net::TcpStream`.
@@ -103,7 +140,7 @@ impl TcpStream {
///
/// On Windows make sure the stream is connected before calling this method,
/// by receiving an (writable) event. Trying to set `nodelay` on an
- /// unconnected `TcpStream` is undefined behavior.
+ /// unconnected `TcpStream` is unspecified behavior.
pub fn set_nodelay(&self, nodelay: bool) -> io::Result<()> {
self.inner.set_nodelay(nodelay)
}
@@ -118,7 +155,7 @@ impl TcpStream {
///
/// On Windows make sure the stream is connected before calling this method,
/// by receiving an (writable) event. Trying to get `nodelay` on an
- /// unconnected `TcpStream` is undefined behavior.
+ /// unconnected `TcpStream` is unspecified behavior.
pub fn nodelay(&self) -> io::Result<bool> {
self.inner.nodelay()
}
@@ -132,7 +169,7 @@ impl TcpStream {
///
/// On Windows make sure the stream is connected before calling this method,
/// by receiving an (writable) event. Trying to set `ttl` on an
- /// unconnected `TcpStream` is undefined behavior.
+ /// unconnected `TcpStream` is unspecified behavior.
pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
self.inner.set_ttl(ttl)
}
@@ -145,7 +182,7 @@ impl TcpStream {
///
/// On Windows make sure the stream is connected before calling this method,
/// by receiving an (writable) event. Trying to get `ttl` on an
- /// unconnected `TcpStream` is undefined behavior.
+ /// unconnected `TcpStream` is unspecified behavior.
///
/// [link]: #method.set_ttl
pub fn ttl(&self) -> io::Result<u32> {
@@ -170,53 +207,111 @@ impl TcpStream {
pub fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
self.inner.peek(buf)
}
+
+ /// Execute an I/O operation ensuring that the socket receives more events
+ /// if it hits a [`WouldBlock`] error.
+ ///
+ /// # Notes
+ ///
+ /// This method is required to be called for **all** I/O operations to
+ /// ensure the user will receive events once the socket is ready again after
+ /// returning a [`WouldBlock`] error.
+ ///
+ /// [`WouldBlock`]: io::ErrorKind::WouldBlock
+ ///
+ /// # Examples
+ ///
+ #[cfg_attr(unix, doc = "```no_run")]
+ #[cfg_attr(windows, doc = "```ignore")]
+ /// # use std::error::Error;
+ /// #
+ /// # fn main() -> Result<(), Box<dyn Error>> {
+ /// use std::io;
+ /// #[cfg(unix)]
+ /// use std::os::unix::io::AsRawFd;
+ /// #[cfg(windows)]
+ /// use std::os::windows::io::AsRawSocket;
+ /// use mio::net::TcpStream;
+ ///
+ /// let address = "127.0.0.1:8080".parse().unwrap();
+ /// let stream = TcpStream::connect(address)?;
+ ///
+ /// // Wait until the stream is readable...
+ ///
+ /// // Read from the stream using a direct libc call, of course the
+ /// // `io::Read` implementation would be easier to use.
+ /// let mut buf = [0; 512];
+ /// let n = stream.try_io(|| {
+ /// let buf_ptr = &mut buf as *mut _ as *mut _;
+ /// #[cfg(unix)]
+ /// let res = unsafe { libc::recv(stream.as_raw_fd(), buf_ptr, buf.len(), 0) };
+ /// #[cfg(windows)]
+ /// let res = unsafe { libc::recvfrom(stream.as_raw_socket() as usize, buf_ptr, buf.len() as i32, 0, std::ptr::null_mut(), std::ptr::null_mut()) };
+ /// if res != -1 {
+ /// Ok(res as usize)
+ /// } else {
+ /// // If EAGAIN or EWOULDBLOCK is set by libc::recv, the closure
+ /// // should return `WouldBlock` error.
+ /// Err(io::Error::last_os_error())
+ /// }
+ /// })?;
+ /// eprintln!("read {} bytes", n);
+ /// # Ok(())
+ /// # }
+ /// ```
+ pub fn try_io<F, T>(&self, f: F) -> io::Result<T>
+ where
+ F: FnOnce() -> io::Result<T>,
+ {
+ self.inner.do_io(|_| f())
+ }
}
impl Read for TcpStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
- self.inner.do_io(|inner| (&*inner).read(buf))
+ self.inner.do_io(|mut inner| inner.read(buf))
}
fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result<usize> {
- self.inner.do_io(|inner| (&*inner).read_vectored(bufs))
+ self.inner.do_io(|mut inner| inner.read_vectored(bufs))
}
}
impl<'a> Read for &'a TcpStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
- self.inner.do_io(|inner| (&*inner).read(buf))
+ self.inner.do_io(|mut inner| inner.read(buf))
}
fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result<usize> {
- self.inner.do_io(|inner| (&*inner).read_vectored(bufs))
+ self.inner.do_io(|mut inner| inner.read_vectored(bufs))
}
}
impl Write for TcpStream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
- self.inner.do_io(|inner| (&*inner).write(buf))
+ self.inner.do_io(|mut inner| inner.write(buf))
}
fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
- self.inner.do_io(|inner| (&*inner).write_vectored(bufs))
+ self.inner.do_io(|mut inner| inner.write_vectored(bufs))
}
fn flush(&mut self) -> io::Result<()> {
- self.inner.do_io(|inner| (&*inner).flush())
+ self.inner.do_io(|mut inner| inner.flush())
}
}
impl<'a> Write for &'a TcpStream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
- self.inner.do_io(|inner| (&*inner).write(buf))
+ self.inner.do_io(|mut inner| inner.write(buf))
}
fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
- self.inner.do_io(|inner| (&*inner).write_vectored(bufs))
+ self.inner.do_io(|mut inner| inner.write_vectored(bufs))
}
fn flush(&mut self) -> io::Result<()> {
- self.inner.do_io(|inner| (&*inner).flush())
+ self.inner.do_io(|mut inner| inner.flush())
}
}
@@ -303,3 +398,30 @@ impl FromRawSocket for TcpStream {
TcpStream::from_std(FromRawSocket::from_raw_socket(socket))
}
}
+
+#[cfg(target_os = "wasi")]
+impl IntoRawFd for TcpStream {
+ fn into_raw_fd(self) -> RawFd {
+ self.inner.into_inner().into_raw_fd()
+ }
+}
+
+#[cfg(target_os = "wasi")]
+impl AsRawFd for TcpStream {
+ fn as_raw_fd(&self) -> RawFd {
+ self.inner.as_raw_fd()
+ }
+}
+
+#[cfg(target_os = "wasi")]
+impl FromRawFd for TcpStream {
+ /// Converts a `RawFd` to a `TcpStream`.
+ ///
+ /// # Notes
+ ///
+ /// The caller is responsible for ensuring that the socket is in
+ /// non-blocking mode.
+ unsafe fn from_raw_fd(fd: RawFd) -> TcpStream {
+ TcpStream::from_std(FromRawFd::from_raw_fd(fd))
+ }
+}