aboutsummaryrefslogtreecommitdiff
path: root/src/sync/blocking.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/sync/blocking.rs')
-rw-r--r--src/sync/blocking.rs192
1 files changed, 192 insertions, 0 deletions
diff --git a/src/sync/blocking.rs b/src/sync/blocking.rs
new file mode 100644
index 0000000..e95694d
--- /dev/null
+++ b/src/sync/blocking.rs
@@ -0,0 +1,192 @@
+// Copyright 2020 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+use std::future::Future;
+use std::ptr;
+use std::sync::atomic::{AtomicI32, Ordering};
+use std::sync::Arc;
+use std::task::{Context, Poll};
+
+use futures::pin_mut;
+use futures::task::{waker_ref, ArcWake};
+
+// Randomly generated values to indicate the state of the current thread.
+const WAITING: i32 = 0x25de_74d1;
+const WOKEN: i32 = 0x72d3_2c9f;
+
+const FUTEX_WAIT_PRIVATE: libc::c_int = libc::FUTEX_WAIT | libc::FUTEX_PRIVATE_FLAG;
+const FUTEX_WAKE_PRIVATE: libc::c_int = libc::FUTEX_WAKE | libc::FUTEX_PRIVATE_FLAG;
+
+thread_local!(static PER_THREAD_WAKER: Arc<Waker> = Arc::new(Waker(AtomicI32::new(WAITING))));
+
+#[repr(transparent)]
+struct Waker(AtomicI32);
+
+extern {
+ #[cfg_attr(target_os = "android", link_name = "__errno")]
+ #[cfg_attr(target_os = "linux", link_name = "__errno_location")]
+ fn errno_location() -> *mut libc::c_int;
+}
+
+impl ArcWake for Waker {
+ fn wake_by_ref(arc_self: &Arc<Self>) {
+ let state = arc_self.0.swap(WOKEN, Ordering::Release);
+ if state == WAITING {
+ // The thread hasn't already been woken up so wake it up now. Safe because this doesn't
+ // modify any memory and we check the return value.
+ let res = unsafe {
+ libc::syscall(
+ libc::SYS_futex,
+ &arc_self.0,
+ FUTEX_WAKE_PRIVATE,
+ libc::INT_MAX, // val
+ ptr::null() as *const libc::timespec, // timeout
+ ptr::null() as *const libc::c_int, // uaddr2
+ 0 as libc::c_int, // val3
+ )
+ };
+ if res < 0 {
+ panic!("unexpected error from FUTEX_WAKE_PRIVATE: {}", unsafe {
+ *errno_location()
+ });
+ }
+ }
+ }
+}
+
+/// Run a future to completion on the current thread.
+///
+/// This method will block the current thread until `f` completes. Useful when you need to call an
+/// async fn from a non-async context.
+pub fn block_on<F: Future>(f: F) -> F::Output {
+ pin_mut!(f);
+
+ PER_THREAD_WAKER.with(|thread_waker| {
+ let waker = waker_ref(thread_waker);
+ let mut cx = Context::from_waker(&waker);
+
+ loop {
+ if let Poll::Ready(t) = f.as_mut().poll(&mut cx) {
+ return t;
+ }
+
+ let state = thread_waker.0.swap(WAITING, Ordering::Acquire);
+ if state == WAITING {
+ // If we weren't already woken up then wait until we are. Safe because this doesn't
+ // modify any memory and we check the return value.
+ let res = unsafe {
+ libc::syscall(
+ libc::SYS_futex,
+ &thread_waker.0,
+ FUTEX_WAIT_PRIVATE,
+ state,
+ ptr::null() as *const libc::timespec, // timeout
+ ptr::null() as *const libc::c_int, // uaddr2
+ 0 as libc::c_int, // val3
+ )
+ };
+
+ if res < 0 {
+ // Safe because libc guarantees that this is a valid pointer.
+ match unsafe { *errno_location() } {
+ libc::EAGAIN | libc::EINTR => {}
+ e => panic!("unexpected error from FUTEX_WAIT_PRIVATE: {}", e),
+ }
+ }
+
+ // Clear the state to prevent unnecessary extra loop iterations and also to allow
+ // nested usage of `block_on`.
+ thread_waker.0.store(WAITING, Ordering::Release);
+ }
+ }
+ })
+}
+
+#[cfg(test)]
+mod test {
+ use super::*;
+
+ use std::future::Future;
+ use std::pin::Pin;
+ use std::sync::mpsc::{channel, Sender};
+ use std::sync::Arc;
+ use std::task::{Context, Poll, Waker};
+ use std::thread;
+ use std::time::Duration;
+
+ use crate::sync::SpinLock;
+
+ struct TimerState {
+ fired: bool,
+ waker: Option<Waker>,
+ }
+ struct Timer {
+ state: Arc<SpinLock<TimerState>>,
+ }
+
+ impl Future for Timer {
+ type Output = ();
+
+ fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
+ let mut state = self.state.lock();
+ if state.fired {
+ return Poll::Ready(());
+ }
+
+ state.waker = Some(cx.waker().clone());
+ Poll::Pending
+ }
+ }
+
+ fn start_timer(dur: Duration, notify: Option<Sender<()>>) -> Timer {
+ let state = Arc::new(SpinLock::new(TimerState {
+ fired: false,
+ waker: None,
+ }));
+
+ let thread_state = Arc::clone(&state);
+ thread::spawn(move || {
+ thread::sleep(dur);
+ let mut ts = thread_state.lock();
+ ts.fired = true;
+ if let Some(waker) = ts.waker.take() {
+ waker.wake();
+ }
+ drop(ts);
+
+ if let Some(tx) = notify {
+ tx.send(()).expect("Failed to send completion notification");
+ }
+ });
+
+ Timer { state }
+ }
+
+ #[test]
+ fn it_works() {
+ block_on(start_timer(Duration::from_millis(100), None));
+ }
+
+ #[test]
+ fn nested() {
+ async fn inner() {
+ block_on(start_timer(Duration::from_millis(100), None));
+ }
+
+ block_on(inner());
+ }
+
+ #[test]
+ fn ready_before_poll() {
+ let (tx, rx) = channel();
+
+ let timer = start_timer(Duration::from_millis(50), Some(tx));
+
+ rx.recv()
+ .expect("Failed to receive completion notification");
+
+ // We know the timer has already fired so the poll should complete immediately.
+ block_on(timer);
+ }
+}