diff options
author | Bernie Innocenti <codewiz@google.com> | 2018-05-18 20:50:25 +0900 |
---|---|---|
committer | Bernie Innocenti <codewiz@google.com> | 2018-05-22 18:37:36 +0900 |
commit | d8e33ae8633bfd34e744f09f93b1cad09834d4eb (patch) | |
tree | eb5d9ef00ec415bdd637a53990d1914d435b78b9 | |
parent | f8f683fea2855ffb2d13ea93cfad1dada399c072 (diff) | |
download | netd-d8e33ae8633bfd34e744f09f93b1cad09834d4eb.tar.gz |
netd: Convert DnsTlsSocket from select() to poll()
Change-Id: Ib6ef5867c5b8190c49194233d890056afbd48b09
Test: system/netd/tests/runtests.sh
Bug: 79838856
-rw-r--r-- | server/dns/DnsTlsSocket.cpp | 54 |
1 files changed, 24 insertions, 30 deletions
diff --git a/server/dns/DnsTlsSocket.cpp b/server/dns/DnsTlsSocket.cpp index a6ea6c1d..ca1cdc9e 100644 --- a/server/dns/DnsTlsSocket.cpp +++ b/server/dns/DnsTlsSocket.cpp @@ -25,7 +25,7 @@ #include <errno.h> #include <linux/tcp.h> #include <openssl/err.h> -#include <sys/select.h> +#include <sys/poll.h> #include "dns/DnsTlsSessionCache.h" #include "dns/IDnsTlsSocketObserver.h" @@ -51,20 +51,14 @@ namespace { constexpr const char kCaCertDir[] = "/system/etc/security/cacerts"; int waitForReading(int fd) { - fd_set fds; - FD_ZERO(&fds); - FD_SET(fd, &fds); - const int ret = TEMP_FAILURE_RETRY(select(fd + 1, &fds, nullptr, nullptr, nullptr)); - ALOGV_IF(ret <= 0, "select failed during read"); + struct pollfd fds = { .fd = fd, .events = POLLIN }; + const int ret = TEMP_FAILURE_RETRY(poll(&fds, 1, -1)); return ret; } int waitForWriting(int fd) { - fd_set fds; - FD_ZERO(&fds); - FD_SET(fd, &fds); - const int ret = TEMP_FAILURE_RETRY(select(fd + 1, nullptr, &fds, nullptr, nullptr)); - ALOGV_IF(ret <= 0, "select failed during write"); + struct pollfd fds = { .fd = fd, .events = POLLOUT }; + const int ret = TEMP_FAILURE_RETRY(poll(&fds, 1, -1)); return ret; } @@ -239,7 +233,7 @@ bssl::UniquePtr<SSL> DnsTlsSocket::sslConnect(int fd) { switch (ssl_err) { case SSL_ERROR_WANT_READ: if (waitForReading(fd) != 1) { - ALOGW("SSL_connect read error"); + ALOGW("SSL_connect read error: %d", errno); return nullptr; } break; @@ -350,41 +344,41 @@ void DnsTlsSocket::loop() { // Buffer at most one query. Query q; - fd_set readFds, writeFds; - FD_ZERO(&readFds); - FD_ZERO(&writeFds); - const int maxFd = std::max(mSslFd.get(), mIpcOutFd.get()); + const int timeout_msecs = DnsTlsSocket::kIdleTimeout.count() * 1000; while (true) { - timeval timeout = { .tv_sec = DnsTlsSocket::kIdleTimeout.count() }; + // poll() ignores negative fds + struct pollfd fds[2] = { { .fd = -1 }, { .fd = -1 } }; + enum { SSLFD = 0, IPCFD = 1 }; + // Always listen for a response from server. - FD_SET(mSslFd.get(), &readFds); + fds[SSLFD].fd = mSslFd.get(); + fds[SSLFD].events = POLLIN; + // If we have a pending query, also wait for space // to write it, otherwise listen for a new query. if (!q.query.empty()) { - FD_SET(mSslFd.get(), &writeFds); - FD_CLR(mIpcOutFd.get(), &readFds); + fds[SSLFD].events |= POLLOUT; } else { - FD_CLR(mSslFd.get(), &writeFds); - FD_SET(mIpcOutFd.get(), &readFds); + fds[IPCFD].fd = mIpcOutFd.get(); + fds[IPCFD].events = POLLIN; } - // Deviating from POSIX, Linux will decrement the timeout on each retry. - // Either behavior is OK here. - const int s = TEMP_FAILURE_RETRY(select(maxFd + 1, &readFds, &writeFds, nullptr, &timeout)); + + const int s = TEMP_FAILURE_RETRY(poll(fds, ARRAY_SIZE(fds), timeout_msecs)); if (s == 0) { ALOGV("Idle timeout"); break; } if (s < 0) { - ALOGV("Select failed: %d", errno); + ALOGV("Poll failed: %d", errno); break; } - if (FD_ISSET(mSslFd.get(), &readFds)) { + if (fds[SSLFD].revents & (POLLIN | POLLERR)) { if (!readResponse()) { ALOGV("SSL remote close or read error."); break; } } - if (FD_ISSET(mIpcOutFd.get(), &readFds)) { + if (fds[IPCFD].revents & (POLLIN | POLLERR)) { int res = read(mIpcOutFd.get(), &q, sizeof(q)); if (res < 0) { ALOGW("Error during IPC read"); @@ -396,7 +390,7 @@ void DnsTlsSocket::loop() { ALOGE("Struct size mismatch: %d != %zu", res, sizeof(q)); break; } - } else if (FD_ISSET(mSslFd.get(), &writeFds)) { + } else if (fds[SSLFD].revents & POLLOUT) { // query cannot be null here. if (!sendQuery(q)) { break; @@ -457,7 +451,7 @@ int DnsTlsSocket::sslRead(const Slice buffer, bool wait) { const int ssl_err = SSL_get_error(mSsl.get(), ret); if (wait && ssl_err == SSL_ERROR_WANT_READ) { if (waitForReading(mSslFd.get()) != 1) { - ALOGV("Select failed in sslRead"); + ALOGV("Poll failed in sslRead: %d", errno); return SSL_ERROR_SYSCALL; } continue; |