aboutsummaryrefslogtreecommitdiff
path: root/host/frontend/gcastv2/https/RunLoop.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'host/frontend/gcastv2/https/RunLoop.cpp')
-rw-r--r--host/frontend/gcastv2/https/RunLoop.cpp448
1 files changed, 448 insertions, 0 deletions
diff --git a/host/frontend/gcastv2/https/RunLoop.cpp b/host/frontend/gcastv2/https/RunLoop.cpp
new file mode 100644
index 000000000..f9d50c659
--- /dev/null
+++ b/host/frontend/gcastv2/https/RunLoop.cpp
@@ -0,0 +1,448 @@
+/*
+ * Copyright (C) 2019 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <https/RunLoop.h>
+
+#include <https/Support.h>
+
+#include <android-base/logging.h>
+
+#include <cstring>
+#include <fcntl.h>
+#include <iostream>
+#include <unistd.h>
+
+#include <mutex>
+#include <condition_variable>
+
+bool RunLoop::QueueElem::operator<=(const QueueElem &other) const {
+ if (mWhen) {
+ if (other.mWhen) {
+ return mWhen <= other.mWhen;
+ }
+
+ return false;
+ }
+
+ if (other.mWhen) {
+ return true;
+ }
+
+ // This ensures that two events posted without a trigger time are queued in
+ // the order they were post()ed in.
+ return true;
+}
+
+RunLoop::RunLoop()
+ : mDone(false),
+ mPThread(0),
+ mNextToken(1) {
+ int res = pipe(mControlFds);
+ CHECK_GE(res, 0);
+
+ makeFdNonblocking(mControlFds[0]);
+}
+
+RunLoop::RunLoop(std::string_view name)
+ : RunLoop() {
+ mName = name;
+
+ mThread = std::thread([this]{ run(); });
+}
+
+RunLoop::~RunLoop() {
+ stop();
+
+ close(mControlFds[1]);
+ mControlFds[1] = -1;
+
+ close(mControlFds[0]);
+ mControlFds[0] = -1;
+}
+
+void RunLoop::stop() {
+ mDone = true;
+ interrupt();
+
+ if (mThread.joinable()) {
+ mThread.join();
+ }
+}
+
+RunLoop::Token RunLoop::post(AsyncFunction fn) {
+ CHECK(fn != nullptr);
+
+ auto token = mNextToken++;
+ insert({ std::nullopt, fn, token });
+
+ return token;
+}
+
+bool RunLoop::postAndAwait(AsyncFunction fn) {
+ if (isCurrentThread()) {
+ // To wait from the runloop's thread would cause deadlock
+ post(fn);
+ return false;
+ }
+
+ std::mutex mtx;
+ bool ran = false;
+ std::condition_variable cond_var;
+
+ post([&cond_var, &mtx, &ran, fn](){
+ fn();
+ {
+ std::unique_lock<std::mutex> lock(mtx);
+ ran = true;
+ // Notify while holding the mutex, otherwise the condition variable
+ // could be destroyed before the call to notify_all.
+ cond_var.notify_all();
+ }
+ });
+
+ {
+ std::unique_lock<std::mutex> lock(mtx);
+ cond_var.wait(lock, [&ran](){ return ran;});
+ }
+ return ran;
+}
+
+RunLoop::Token RunLoop::postWithDelay(
+ std::chrono::steady_clock::duration delay, AsyncFunction fn) {
+ CHECK(fn != nullptr);
+
+ auto token = mNextToken++;
+ insert({ std::chrono::steady_clock::now() + delay, fn, token });
+
+ return token;
+}
+
+bool RunLoop::cancelToken(Token token) {
+ std::lock_guard<std::mutex> autoLock(mLock);
+
+ bool found = false;
+ for (auto it = mQueue.begin(); it != mQueue.end(); ++it) {
+ if (it->mToken == token) {
+ mQueue.erase(it);
+
+ if (it == mQueue.begin()) {
+ interrupt();
+ }
+
+ found = true;
+ break;
+ }
+ }
+
+ return found;
+}
+
+void RunLoop::postSocketRecv(int sock, AsyncFunction fn) {
+ CHECK_GE(sock, 0);
+ CHECK(fn != nullptr);
+
+ std::lock_guard<std::mutex> autoLock(mLock);
+ mAddInfos.push_back({ sock, InfoType::RECV, fn });
+ interrupt();
+}
+
+void RunLoop::postSocketSend(int sock, AsyncFunction fn) {
+ CHECK_GE(sock, 0);
+ CHECK(fn != nullptr);
+
+ std::lock_guard<std::mutex> autoLock(mLock);
+ mAddInfos.push_back({ sock, InfoType::SEND, fn });
+ interrupt();
+}
+
+void RunLoop::cancelSocket(int sock) {
+ CHECK_GE(sock, 0);
+
+ std::lock_guard<std::mutex> autoLock(mLock);
+ mAddInfos.push_back({ sock, InfoType::CANCEL, nullptr });
+ interrupt();
+}
+
+void RunLoop::insert(const QueueElem &elem) {
+ std::lock_guard<std::mutex> autoLock(mLock);
+
+ auto it = mQueue.begin();
+ while (it != mQueue.end() && *it <= elem) {
+ ++it;
+ }
+
+ if (it == mQueue.begin()) {
+ interrupt();
+ }
+
+ mQueue.insert(it, elem);
+}
+
+void RunLoop::run() {
+ mPThread = pthread_self();
+
+ std::map<int, SocketCallbacks> socketCallbacksByFd;
+ std::vector<pollfd> pollFds;
+
+ auto removePollFdAt = [&socketCallbacksByFd, &pollFds](size_t i) {
+ if (i + 1 == pollFds.size()) {
+ pollFds.pop_back();
+ } else {
+ // Instead of leaving a hole in the middle of the
+ // pollFds vector, we copy the last item into
+ // that hole and reduce the size of the vector by 1,
+ // taking are of updating the corresponding callback
+ // with the correct, new index.
+ pollFds[i] = pollFds.back();
+ pollFds.pop_back();
+ socketCallbacksByFd[pollFds[i].fd].mPollFdIndex = i;
+ }
+ };
+
+ // The control channel's pollFd will always be at index 0.
+ pollFds.push_back({ mControlFds[0], POLLIN, 0 });
+
+ for (;;) {
+ int timeoutMs = -1; // wait Forever
+
+ {
+ std::lock_guard<std::mutex> autoLock(mLock);
+
+ if (mDone) {
+ break;
+ }
+
+ for (const auto &addInfo : mAddInfos) {
+ const int sock = addInfo.mSock;
+ const auto fn = addInfo.mFn;
+
+ auto it = socketCallbacksByFd.find(sock);
+
+ switch (addInfo.mType) {
+ case InfoType::RECV:
+ {
+ if (it == socketCallbacksByFd.end()) {
+ socketCallbacksByFd[sock] = { fn, nullptr, pollFds.size() };
+ pollFds.push_back({ sock, POLLIN, 0 });
+ } else {
+ // There's already a pollFd for this socket.
+ CHECK(it->second.mSendFn != nullptr);
+
+ CHECK(it->second.mRecvFn == nullptr);
+ it->second.mRecvFn = fn;
+
+ pollFds[it->second.mPollFdIndex].events |= POLLIN;
+ }
+ break;
+ }
+
+ case InfoType::SEND:
+ {
+ if (it == socketCallbacksByFd.end()) {
+ socketCallbacksByFd[sock] = { nullptr, fn, pollFds.size() };
+ pollFds.push_back({ sock, POLLOUT, 0 });
+ } else {
+ // There's already a pollFd for this socket.
+ if (it->second.mRecvFn == nullptr) {
+ LOG(ERROR)
+ << "There's an entry but no recvFn "
+ "notification for socket "
+ << sock;
+ }
+
+ CHECK(it->second.mRecvFn != nullptr);
+
+ if (it->second.mSendFn != nullptr) {
+ LOG(ERROR)
+ << "There's already a pending send "
+ "notification for socket "
+ << sock;
+ }
+ CHECK(it->second.mSendFn == nullptr);
+ it->second.mSendFn = fn;
+
+ pollFds[it->second.mPollFdIndex].events |= POLLOUT;
+ }
+ break;
+ }
+
+ case InfoType::CANCEL:
+ {
+ if (it != socketCallbacksByFd.end()) {
+ const size_t i = it->second.mPollFdIndex;
+
+ socketCallbacksByFd.erase(it);
+ removePollFdAt(i);
+ }
+ break;
+ }
+ }
+ }
+
+ mAddInfos.clear();
+
+ if (!mQueue.empty()) {
+ timeoutMs = 0;
+
+ if (mQueue.front().mWhen) {
+ auto duration =
+ *mQueue.front().mWhen - std::chrono::steady_clock::now();
+
+ auto durationMs =
+ std::chrono::duration_cast<std::chrono::milliseconds>(duration);
+
+ if (durationMs.count() > 0) {
+ timeoutMs = static_cast<int>(durationMs.count());
+ }
+ }
+ }
+ }
+
+ int pollRes = 0;
+ if (timeoutMs != 0) {
+ // NOTE: The inequality is on purpose, we'll want to execute this
+ // code if timeoutMs == -1 (infinite) or timeoutMs > 0, but not
+ // if it's 0.
+
+ pollRes = poll(
+ pollFds.data(),
+ static_cast<nfds_t>(pollFds.size()),
+ timeoutMs);
+ }
+
+ if (pollRes < 0) {
+ if (errno != EINTR) {
+ std::cerr
+ << "poll FAILED w/ "
+ << errno
+ << " ("
+ << strerror(errno)
+ << ")"
+ << std::endl;
+ }
+
+ CHECK_EQ(errno, EINTR);
+ continue;
+ }
+
+ std::vector<AsyncFunction> fnArray;
+
+ {
+ std::lock_guard<std::mutex> autoLock(mLock);
+
+ if (pollRes > 0) {
+ if (pollFds[0].revents & POLLIN) {
+ ssize_t res;
+ do {
+ uint8_t c[32];
+ while ((res = read(mControlFds[0], c, sizeof(c))) < 0
+ && errno == EINTR) {
+ }
+ } while (res > 0);
+ CHECK(res < 0 && errno == EWOULDBLOCK);
+
+ --pollRes;
+ }
+
+ // NOTE: Skip index 0, as we already handled it above.
+ // Also, bail early if we exhausted all actionable pollFds
+ // according to pollRes.
+ for (size_t i = pollFds.size(); pollRes && i-- > 1;) {
+ pollfd &pollFd = pollFds[i];
+ const short revents = pollFd.revents;
+
+ if (revents) {
+ --pollRes;
+ }
+
+ const bool readable = (revents & POLLIN);
+ const bool writable = (revents & POLLOUT);
+ const bool dead = (revents & POLLNVAL);
+
+ bool removeCallback = dead;
+
+ if (readable || writable || dead) {
+ const int sock = pollFd.fd;
+
+ const auto &it = socketCallbacksByFd.find(sock);
+ auto &cb = it->second;
+ CHECK_EQ(cb.mPollFdIndex, i);
+
+ if (readable) {
+ CHECK(cb.mRecvFn != nullptr);
+ fnArray.push_back(cb.mRecvFn);
+ cb.mRecvFn = nullptr;
+ pollFd.events &= ~POLLIN;
+
+ removeCallback |= (cb.mSendFn == nullptr);
+ }
+
+ if (writable) {
+ CHECK(cb.mSendFn != nullptr);
+ fnArray.push_back(cb.mSendFn);
+ cb.mSendFn = nullptr;
+ pollFd.events &= ~POLLOUT;
+
+ removeCallback |= (cb.mRecvFn == nullptr);
+ }
+
+ if (removeCallback) {
+ socketCallbacksByFd.erase(it);
+ removePollFdAt(i);
+ }
+ }
+ }
+ } else {
+ // No interrupt, no socket notifications.
+ fnArray.push_back(mQueue.front().mFn);
+ mQueue.pop_front();
+ }
+ }
+
+ for (const auto &fn : fnArray) {
+ fn();
+ }
+ }
+}
+
+void RunLoop::interrupt() {
+ uint8_t c = 1;
+ ssize_t res;
+ while ((res = write(mControlFds[1], &c, 1)) < 0 && errno == EINTR) {
+ }
+
+ CHECK_EQ(res, 1);
+}
+
+struct MainRunLoop : public RunLoop {
+};
+
+static std::mutex gLock;
+static std::shared_ptr<RunLoop> gMainRunLoop;
+
+// static
+std::shared_ptr<RunLoop> RunLoop::main() {
+ std::lock_guard<std::mutex> autoLock(gLock);
+ if (!gMainRunLoop) {
+ gMainRunLoop = std::make_shared<MainRunLoop>();
+ }
+ return gMainRunLoop;
+}
+
+bool RunLoop::isCurrentThread() const {
+ return pthread_equal(pthread_self(), mPThread);
+}
+