diff options
Diffstat (limited to 'pw_stream/socket_stream.cc')
-rw-r--r-- | pw_stream/socket_stream.cc | 277 |
1 files changed, 247 insertions, 30 deletions
diff --git a/pw_stream/socket_stream.cc b/pw_stream/socket_stream.cc index b3125439c..d7cec21f0 100644 --- a/pw_stream/socket_stream.cc +++ b/pw_stream/socket_stream.cc @@ -15,6 +15,8 @@ #include "pw_stream/socket_stream.h" #if defined(_WIN32) && _WIN32 +#include <fcntl.h> +#include <io.h> #include <winsock2.h> #include <ws2tcpip.h> #define SHUT_RDWR SD_BOTH @@ -22,6 +24,7 @@ #include <arpa/inet.h> #include <netdb.h> #include <netinet/in.h> +#include <poll.h> #include <sys/socket.h> #include <sys/types.h> #include <unistd.h> @@ -56,6 +59,25 @@ void ConfigureSocket([[maybe_unused]] int socket) { #if defined(_WIN32) && _WIN32 int close(SOCKET s) { return closesocket(s); } +ssize_t write(int fd, const void* buf, size_t count) { + return _write(fd, buf, count); +} + +int poll(struct pollfd* fds, unsigned int nfds, int timeout) { + return WSAPoll(fds, nfds, timeout); +} + +int pipe(int pipefd[2]) { return _pipe(pipefd, 256, O_BINARY); } + +int setsockopt( + int fd, int level, int optname, const void* optval, unsigned int optlen) { + return setsockopt(static_cast<SOCKET>(fd), + level, + optname, + static_cast<const char*>(optval), + static_cast<int>(optlen)); +} + class WinsockInitializer { public: WinsockInitializer() { @@ -93,37 +115,72 @@ Status SocketStream::SocketStream::Connect(const char* host, uint16_t port) { } struct addrinfo* rp; + int connection_fd; for (rp = res; rp != nullptr; rp = rp->ai_next) { - connection_fd_ = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); - if (connection_fd_ != kInvalidFd) { + connection_fd = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); + if (connection_fd != kInvalidFd) { break; } } - if (connection_fd_ == kInvalidFd) { + if (connection_fd == kInvalidFd) { PW_LOG_ERROR("Failed to create a socket: %s", std::strerror(errno)); freeaddrinfo(res); return Status::Unknown(); } - ConfigureSocket(connection_fd_); - if (connect(connection_fd_, rp->ai_addr, rp->ai_addrlen) == -1) { - close(connection_fd_); - connection_fd_ = kInvalidFd; + ConfigureSocket(connection_fd); + if (connect(connection_fd, rp->ai_addr, rp->ai_addrlen) == -1) { + close(connection_fd); PW_LOG_ERROR( "Failed to connect to %s:%d: %s", host, port, std::strerror(errno)); freeaddrinfo(res); return Status::Unknown(); } + // Mark as ready and take ownership of the connection by this object. + { + std::lock_guard lock(connection_mutex_); + connection_fd_ = connection_fd; + TakeConnectionWithLockHeld(); + ready_ = true; + } + freeaddrinfo(res); return OkStatus(); } +// Configures socket options. +int SocketStream::SetSockOpt(int level, + int optname, + const void* optval, + unsigned int optlen) { + ConnectionOwnership ownership(this); + if (ownership.fd() == kInvalidFd) { + return EBADF; + } + return setsockopt(ownership.fd(), level, optname, optval, optlen); +} + void SocketStream::Close() { - if (connection_fd_ != kInvalidFd) { - close(connection_fd_); - connection_fd_ = kInvalidFd; + ConnectionOwnership ownership(this); + { + std::lock_guard lock(connection_mutex_); + if (ready_) { + // Shutdown the connection and send tear down notification to unblock any + // waiters. + if (connection_fd_ != kInvalidFd) { + shutdown(connection_fd_, SHUT_RDWR); + } + if (connection_pipe_w_fd_ != kInvalidFd) { + write(connection_pipe_w_fd_, "T", 1); + } + + // Release ownership of the connection by this object and mark as no + // longer ready. + ReleaseConnectionWithLockHeld(); + ready_ = false; + } } } @@ -135,10 +192,17 @@ Status SocketStream::DoWrite(span<const std::byte> data) { send_flags |= MSG_NOSIGNAL; #endif // defined(__linux__) - ssize_t bytes_sent = send(connection_fd_, - reinterpret_cast<const char*>(data.data()), - data.size_bytes(), - send_flags); + ssize_t bytes_sent; + { + ConnectionOwnership ownership(this); + if (ownership.fd() == kInvalidFd) { + return Status::Unknown(); + } + bytes_sent = send(ownership.fd(), + reinterpret_cast<const char*>(data.data()), + data.size_bytes(), + send_flags); + } if (bytes_sent < 0 || static_cast<size_t>(bytes_sent) != data.size()) { if (errno == EPIPE) { @@ -153,7 +217,23 @@ Status SocketStream::DoWrite(span<const std::byte> data) { } StatusWithSize SocketStream::DoRead(ByteSpan dest) { - ssize_t bytes_rcvd = recv(connection_fd_, + ConnectionOwnership ownership(this); + if (ownership.fd() == kInvalidFd) { + return StatusWithSize::Unknown(); + } + + // Wait for data to read or a tear down notification. + pollfd fds_to_poll[2]; + fds_to_poll[0].fd = ownership.fd(); + fds_to_poll[0].events = POLLIN | POLLERR | POLLHUP; + fds_to_poll[1].fd = ownership.pipe_r_fd(); + fds_to_poll[1].events = POLLIN; + poll(fds_to_poll, 2, -1); + if (!(fds_to_poll[0].revents & POLLIN)) { + return StatusWithSize::Unknown(); + } + + ssize_t bytes_rcvd = recv(ownership.fd(), reinterpret_cast<char*>(dest.data()), dest.size_bytes(), 0); @@ -174,18 +254,67 @@ StatusWithSize SocketStream::DoRead(ByteSpan dest) { return StatusWithSize(bytes_rcvd); } +int SocketStream::TakeConnection() { + std::lock_guard lock(connection_mutex_); + return TakeConnectionWithLockHeld(); +} + +int SocketStream::TakeConnectionWithLockHeld() { + ++connection_own_count_; + + if (ready_ && (connection_fd_ != kInvalidFd) && + (connection_pipe_r_fd_ == kInvalidFd)) { + int fd_list[2]; + if (pipe(fd_list) >= 0) { + connection_pipe_r_fd_ = fd_list[0]; + connection_pipe_w_fd_ = fd_list[1]; + } + } + + if (!ready_ || (connection_pipe_r_fd_ == kInvalidFd) || + (connection_pipe_w_fd_ == kInvalidFd)) { + return kInvalidFd; + } + return connection_fd_; +} + +void SocketStream::ReleaseConnection() { + std::lock_guard lock(connection_mutex_); + ReleaseConnectionWithLockHeld(); +} + +void SocketStream::ReleaseConnectionWithLockHeld() { + --connection_own_count_; + + if (connection_own_count_ <= 0) { + ready_ = false; + if (connection_fd_ != kInvalidFd) { + close(connection_fd_); + connection_fd_ = kInvalidFd; + } + if (connection_pipe_r_fd_ != kInvalidFd) { + close(connection_pipe_r_fd_); + connection_pipe_r_fd_ = kInvalidFd; + } + if (connection_pipe_w_fd_ != kInvalidFd) { + close(connection_pipe_w_fd_); + connection_pipe_w_fd_ = kInvalidFd; + } + } +} + // Listen for connections on the given port. // If port is 0, a random unused port is chosen and can be retrieved with // port(). Status ServerSocket::Listen(uint16_t port) { - socket_fd_ = socket(AF_INET6, SOCK_STREAM, 0); - if (socket_fd_ == kInvalidFd) { + int socket_fd = socket(AF_INET6, SOCK_STREAM, 0); + if (socket_fd == kInvalidFd) { return Status::Unknown(); } // Allow binding to an address that may still be in use by a closed socket. constexpr int value = 1; - setsockopt(socket_fd_, + setsockopt(socket_fd, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast<const char*>(&value), @@ -197,27 +326,37 @@ Status ServerSocket::Listen(uint16_t port) { addr.sin6_family = AF_INET6; addr.sin6_port = htons(port); addr.sin6_addr = in6addr_any; - if (bind(socket_fd_, reinterpret_cast<sockaddr*>(&addr), addr_len) < 0) { + if (bind(socket_fd, reinterpret_cast<sockaddr*>(&addr), addr_len) < 0) { + close(socket_fd); return Status::Unknown(); } } - if (listen(socket_fd_, kServerBacklogLength) < 0) { + if (listen(socket_fd, kServerBacklogLength) < 0) { + close(socket_fd); return Status::Unknown(); } // Find out which port the socket is listening on, and fill in port_. struct sockaddr_in6 addr = {}; socklen_t addr_len = sizeof(addr); - if (getsockname(socket_fd_, reinterpret_cast<sockaddr*>(&addr), &addr_len) < + if (getsockname(socket_fd, reinterpret_cast<sockaddr*>(&addr), &addr_len) < 0 || static_cast<size_t>(addr_len) > sizeof(addr)) { - close(socket_fd_); + close(socket_fd); return Status::Unknown(); } port_ = ntohs(addr.sin6_port); + // Mark as ready and take ownership of the socket by this object. + { + std::lock_guard lock(socket_mutex_); + socket_fd_ = socket_fd; + TakeSocketWithLockHeld(); + ready_ = true; + } + return OkStatus(); } @@ -227,23 +366,101 @@ Result<SocketStream> ServerSocket::Accept() { struct sockaddr_in6 sockaddr_client_ = {}; socklen_t len = sizeof(sockaddr_client_); - int connection_fd = - accept(socket_fd_, reinterpret_cast<sockaddr*>(&sockaddr_client_), &len); + SocketOwnership ownership(this); + if (ownership.fd() == kInvalidFd) { + return Status::Unknown(); + } + + // Wait for a connection or a tear down notification. + pollfd fds_to_poll[2]; + fds_to_poll[0].fd = ownership.fd(); + fds_to_poll[0].events = POLLIN | POLLERR | POLLHUP; + fds_to_poll[1].fd = ownership.pipe_r_fd(); + fds_to_poll[1].events = POLLIN; + int rv = poll(fds_to_poll, 2, -1); + if ((rv <= 0) || !(fds_to_poll[0].revents & POLLIN)) { + return Status::Unknown(); + } + + int connection_fd = accept( + ownership.fd(), reinterpret_cast<sockaddr*>(&sockaddr_client_), &len); if (connection_fd == kInvalidFd) { return Status::Unknown(); } ConfigureSocket(connection_fd); - SocketStream client_stream; - client_stream.connection_fd_ = connection_fd; - return client_stream; + return SocketStream(connection_fd); } // Close the server socket, preventing further connections. void ServerSocket::Close() { - if (socket_fd_ != kInvalidFd) { - close(socket_fd_); - socket_fd_ = kInvalidFd; + SocketOwnership ownership(this); + { + std::lock_guard lock(socket_mutex_); + if (ready_) { + // Shutdown the socket and send tear down notification to unblock any + // waiters. + if (socket_fd_ != kInvalidFd) { + shutdown(socket_fd_, SHUT_RDWR); + } + if (socket_pipe_w_fd_ != kInvalidFd) { + write(socket_pipe_w_fd_, "T", 1); + } + + // Release ownership of the socket by this object and mark as no longer + // ready. + ReleaseSocketWithLockHeld(); + ready_ = false; + } + } +} + +int ServerSocket::TakeSocket() { + std::lock_guard lock(socket_mutex_); + return TakeSocketWithLockHeld(); +} + +int ServerSocket::TakeSocketWithLockHeld() { + ++socket_own_count_; + + if (ready_ && (socket_fd_ != kInvalidFd) && + (socket_pipe_r_fd_ == kInvalidFd)) { + int fd_list[2]; + if (pipe(fd_list) >= 0) { + socket_pipe_r_fd_ = fd_list[0]; + socket_pipe_w_fd_ = fd_list[1]; + } + } + + if (!ready_ || (socket_pipe_r_fd_ == kInvalidFd) || + (socket_pipe_w_fd_ == kInvalidFd)) { + return kInvalidFd; + } + return socket_fd_; +} + +void ServerSocket::ReleaseSocket() { + std::lock_guard lock(socket_mutex_); + ReleaseSocketWithLockHeld(); +} + +void ServerSocket::ReleaseSocketWithLockHeld() { + --socket_own_count_; + + if (socket_own_count_ <= 0) { + ready_ = false; + if (socket_fd_ != kInvalidFd) { + close(socket_fd_); + socket_fd_ = kInvalidFd; + } + if (socket_pipe_r_fd_ != kInvalidFd) { + close(socket_pipe_r_fd_); + socket_pipe_r_fd_ = kInvalidFd; + } + if (socket_pipe_w_fd_ != kInvalidFd) { + close(socket_pipe_w_fd_); + socket_pipe_w_fd_ = kInvalidFd; + } } } |