diff options
Diffstat (limited to 'src/net/tcp/stream.rs')
-rw-r--r-- | src/net/tcp/stream.rs | 156 |
1 files changed, 139 insertions, 17 deletions
diff --git a/src/net/tcp/stream.rs b/src/net/tcp/stream.rs index cdbd46a..a7a9aa1 100644 --- a/src/net/tcp/stream.rs +++ b/src/net/tcp/stream.rs @@ -3,11 +3,14 @@ use std::io::{self, IoSlice, IoSliceMut, Read, Write}; use std::net::{self, Shutdown, SocketAddr}; #[cfg(unix)] use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd}; +#[cfg(target_os = "wasi")] +use std::os::wasi::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd}; #[cfg(windows)] use std::os::windows::io::{AsRawSocket, FromRawSocket, IntoRawSocket, RawSocket}; use crate::io_source::IoSource; -use crate::net::TcpSocket; +#[cfg(not(target_os = "wasi"))] +use crate::sys::tcp::{connect, new_for_addr}; use crate::{event, Interest, Registry, Token}; /// A non-blocking TCP stream between a local socket and a remote socket. @@ -49,9 +52,43 @@ pub struct TcpStream { impl TcpStream { /// Create a new TCP stream and issue a non-blocking connect to the /// specified address. + /// + /// # Notes + /// + /// The returned `TcpStream` may not be connected (and thus usable), unlike + /// the API found in `std::net::TcpStream`. Because Mio issues a + /// *non-blocking* connect it will not block the thread and instead return + /// an unconnected `TcpStream`. + /// + /// Ensuring the returned stream is connected is surprisingly complex when + /// considering cross-platform support. Doing this properly should follow + /// the steps below, an example implementation can be found + /// [here](https://github.com/Thomasdezeeuw/heph/blob/0c4f1ab3eaf08bea1d65776528bfd6114c9f8374/src/net/tcp/stream.rs#L560-L622). + /// + /// 1. Call `TcpStream::connect` + /// 2. Register the returned stream with at least [write interest]. + /// 3. Wait for a (writable) event. + /// 4. Check `TcpStream::peer_addr`. If it returns `libc::EINPROGRESS` or + /// `ErrorKind::NotConnected` it means the stream is not yet connected, + /// go back to step 3. If it returns an address it means the stream is + /// connected, go to step 5. If another error is returned something + /// whent wrong. + /// 5. Now the stream can be used. + /// + /// This may return a `WouldBlock` in which case the socket connection + /// cannot be completed immediately, it usually means there are insufficient + /// entries in the routing cache. + /// + /// [write interest]: Interest::WRITABLE + #[cfg(not(target_os = "wasi"))] pub fn connect(addr: SocketAddr) -> io::Result<TcpStream> { - let socket = TcpSocket::new_for_addr(addr)?; - socket.connect(addr) + let socket = new_for_addr(addr)?; + #[cfg(unix)] + let stream = unsafe { TcpStream::from_raw_fd(socket) }; + #[cfg(windows)] + let stream = unsafe { TcpStream::from_raw_socket(socket as _) }; + connect(&stream.inner, addr)?; + Ok(stream) } /// Creates a new `TcpStream` from a standard `net::TcpStream`. @@ -103,7 +140,7 @@ impl TcpStream { /// /// On Windows make sure the stream is connected before calling this method, /// by receiving an (writable) event. Trying to set `nodelay` on an - /// unconnected `TcpStream` is undefined behavior. + /// unconnected `TcpStream` is unspecified behavior. pub fn set_nodelay(&self, nodelay: bool) -> io::Result<()> { self.inner.set_nodelay(nodelay) } @@ -118,7 +155,7 @@ impl TcpStream { /// /// On Windows make sure the stream is connected before calling this method, /// by receiving an (writable) event. Trying to get `nodelay` on an - /// unconnected `TcpStream` is undefined behavior. + /// unconnected `TcpStream` is unspecified behavior. pub fn nodelay(&self) -> io::Result<bool> { self.inner.nodelay() } @@ -132,7 +169,7 @@ impl TcpStream { /// /// On Windows make sure the stream is connected before calling this method, /// by receiving an (writable) event. Trying to set `ttl` on an - /// unconnected `TcpStream` is undefined behavior. + /// unconnected `TcpStream` is unspecified behavior. pub fn set_ttl(&self, ttl: u32) -> io::Result<()> { self.inner.set_ttl(ttl) } @@ -145,7 +182,7 @@ impl TcpStream { /// /// On Windows make sure the stream is connected before calling this method, /// by receiving an (writable) event. Trying to get `ttl` on an - /// unconnected `TcpStream` is undefined behavior. + /// unconnected `TcpStream` is unspecified behavior. /// /// [link]: #method.set_ttl pub fn ttl(&self) -> io::Result<u32> { @@ -170,53 +207,111 @@ impl TcpStream { pub fn peek(&self, buf: &mut [u8]) -> io::Result<usize> { self.inner.peek(buf) } + + /// Execute an I/O operation ensuring that the socket receives more events + /// if it hits a [`WouldBlock`] error. + /// + /// # Notes + /// + /// This method is required to be called for **all** I/O operations to + /// ensure the user will receive events once the socket is ready again after + /// returning a [`WouldBlock`] error. + /// + /// [`WouldBlock`]: io::ErrorKind::WouldBlock + /// + /// # Examples + /// + #[cfg_attr(unix, doc = "```no_run")] + #[cfg_attr(windows, doc = "```ignore")] + /// # use std::error::Error; + /// # + /// # fn main() -> Result<(), Box<dyn Error>> { + /// use std::io; + /// #[cfg(unix)] + /// use std::os::unix::io::AsRawFd; + /// #[cfg(windows)] + /// use std::os::windows::io::AsRawSocket; + /// use mio::net::TcpStream; + /// + /// let address = "127.0.0.1:8080".parse().unwrap(); + /// let stream = TcpStream::connect(address)?; + /// + /// // Wait until the stream is readable... + /// + /// // Read from the stream using a direct libc call, of course the + /// // `io::Read` implementation would be easier to use. + /// let mut buf = [0; 512]; + /// let n = stream.try_io(|| { + /// let buf_ptr = &mut buf as *mut _ as *mut _; + /// #[cfg(unix)] + /// let res = unsafe { libc::recv(stream.as_raw_fd(), buf_ptr, buf.len(), 0) }; + /// #[cfg(windows)] + /// let res = unsafe { libc::recvfrom(stream.as_raw_socket() as usize, buf_ptr, buf.len() as i32, 0, std::ptr::null_mut(), std::ptr::null_mut()) }; + /// if res != -1 { + /// Ok(res as usize) + /// } else { + /// // If EAGAIN or EWOULDBLOCK is set by libc::recv, the closure + /// // should return `WouldBlock` error. + /// Err(io::Error::last_os_error()) + /// } + /// })?; + /// eprintln!("read {} bytes", n); + /// # Ok(()) + /// # } + /// ``` + pub fn try_io<F, T>(&self, f: F) -> io::Result<T> + where + F: FnOnce() -> io::Result<T>, + { + self.inner.do_io(|_| f()) + } } impl Read for TcpStream { fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { - self.inner.do_io(|inner| (&*inner).read(buf)) + self.inner.do_io(|mut inner| inner.read(buf)) } fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result<usize> { - self.inner.do_io(|inner| (&*inner).read_vectored(bufs)) + self.inner.do_io(|mut inner| inner.read_vectored(bufs)) } } impl<'a> Read for &'a TcpStream { fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { - self.inner.do_io(|inner| (&*inner).read(buf)) + self.inner.do_io(|mut inner| inner.read(buf)) } fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result<usize> { - self.inner.do_io(|inner| (&*inner).read_vectored(bufs)) + self.inner.do_io(|mut inner| inner.read_vectored(bufs)) } } impl Write for TcpStream { fn write(&mut self, buf: &[u8]) -> io::Result<usize> { - self.inner.do_io(|inner| (&*inner).write(buf)) + self.inner.do_io(|mut inner| inner.write(buf)) } fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> { - self.inner.do_io(|inner| (&*inner).write_vectored(bufs)) + self.inner.do_io(|mut inner| inner.write_vectored(bufs)) } fn flush(&mut self) -> io::Result<()> { - self.inner.do_io(|inner| (&*inner).flush()) + self.inner.do_io(|mut inner| inner.flush()) } } impl<'a> Write for &'a TcpStream { fn write(&mut self, buf: &[u8]) -> io::Result<usize> { - self.inner.do_io(|inner| (&*inner).write(buf)) + self.inner.do_io(|mut inner| inner.write(buf)) } fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> { - self.inner.do_io(|inner| (&*inner).write_vectored(bufs)) + self.inner.do_io(|mut inner| inner.write_vectored(bufs)) } fn flush(&mut self) -> io::Result<()> { - self.inner.do_io(|inner| (&*inner).flush()) + self.inner.do_io(|mut inner| inner.flush()) } } @@ -303,3 +398,30 @@ impl FromRawSocket for TcpStream { TcpStream::from_std(FromRawSocket::from_raw_socket(socket)) } } + +#[cfg(target_os = "wasi")] +impl IntoRawFd for TcpStream { + fn into_raw_fd(self) -> RawFd { + self.inner.into_inner().into_raw_fd() + } +} + +#[cfg(target_os = "wasi")] +impl AsRawFd for TcpStream { + fn as_raw_fd(&self) -> RawFd { + self.inner.as_raw_fd() + } +} + +#[cfg(target_os = "wasi")] +impl FromRawFd for TcpStream { + /// Converts a `RawFd` to a `TcpStream`. + /// + /// # Notes + /// + /// The caller is responsible for ensuring that the socket is in + /// non-blocking mode. + unsafe fn from_raw_fd(fd: RawFd) -> TcpStream { + TcpStream::from_std(FromRawFd::from_raw_fd(fd)) + } +} |