aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorVictor Hsieh <victorhsieh@google.com>2020-12-15 13:59:26 -0800
committerVictor Hsieh <victorhsieh@google.com>2021-01-08 08:25:09 -0800
commit5cc8a92ff386267e0c69f7dad909f5cc3132ebb9 (patch)
treeadf0ed37ac7d7272ca7ba5c1ed11a742e0f3bed6
parent4061c0fb78928db1e424afbaf808d36800fbbfe7 (diff)
downloadlibchromeos-rs-5cc8a92ff386267e0c69f7dad909f5cc3132ebb9.tar.gz
Import libchromeos-rs
This change is imported from the libchromeos-rs directory of https://chromium.googlesource.com/chromiumos/platform2 Extra files added: - Android.bp - patches/Android.bp.patch File modified: - OWNERS Bug: 174797066 Test: mma Change-Id: I473686bb8dd9633759e6fb58266c9750322d52f2
-rw-r--r--Android.bp92
-rw-r--r--Cargo.lock284
-rw-r--r--Cargo.toml19
-rw-r--r--LICENSE27
-rw-r--r--METADATA14
-rw-r--r--MODULE_LICENSE_BSD0
-rw-r--r--OWNERS5
-rw-r--r--README.md8
-rw-r--r--TEST_MAPPING13
-rw-r--r--patches/Android.bp.patch31
-rw-r--r--src/lib.rs25
-rw-r--r--src/linux.rs18
-rw-r--r--src/net.rs269
-rw-r--r--src/read_dir.rs150
-rw-r--r--src/scoped_path.rs115
-rw-r--r--src/sync.rs14
-rw-r--r--src/sync/blocking.rs192
-rw-r--r--src/sync/cv.rs1251
-rw-r--r--src/sync/mu.rs2400
-rw-r--r--src/sync/spin.rs271
-rw-r--r--src/sync/waiter.rs317
-rw-r--r--src/syslog.rs73
-rw-r--r--src/vsock.rs491
23 files changed, 6079 insertions, 0 deletions
diff --git a/Android.bp b/Android.bp
new file mode 100644
index 0000000..63b9fdd
--- /dev/null
+++ b/Android.bp
@@ -0,0 +1,92 @@
+// This file is generated by cargo2android.py --run --device --tests --dependencies --patch=patches/Android.bp.patch.
+
+rust_defaults {
+ name: "libchromeos-rs_defaults",
+ crate_name: "libchromeos",
+ srcs: ["src/lib.rs"],
+ test_suites: ["general-tests"],
+ auto_gen_config: true,
+ edition: "2018",
+ rustlibs: [
+ "libdata_model",
+ "libfutures",
+ "libfutures_executor",
+ "libfutures_util",
+ "libintrusive_collections",
+ "liblibc",
+ "liblog_rust",
+ "libprotobuf",
+ ],
+}
+
+rust_test_host {
+ name: "libchromeos-rs_host_test_src_lib",
+ defaults: ["libchromeos-rs_defaults"],
+}
+
+rust_test {
+ name: "libchromeos-rs_device_test_src_lib",
+ defaults: ["libchromeos-rs_defaults"],
+ // Manually limit to 64-bit to avoid depending on non-existing 32-bit build
+ // of libdata_model currently.
+ compile_multilib: "64",
+}
+
+rust_library {
+ name: "liblibchromeos",
+ host_supported: true,
+ crate_name: "libchromeos",
+ srcs: ["src/lib.rs"],
+ edition: "2018",
+ rustlibs: [
+ "libdata_model",
+ "libfutures",
+ "libintrusive_collections",
+ "liblibc",
+ "liblog_rust",
+ "libprotobuf",
+ ],
+ apex_available: [
+ "//apex_available:platform",
+ "com.android.virt",
+ ],
+ // This library depends on libdata_model that is is part of crosvm project.
+ // Projects within crosvm on Android have only 64-bit target build enabled.
+ // As a result, we need to manually limit this build to 64-bit only, too.
+ // This is fine because this library is only used by crosvm now (thus 64-bit
+ // only).
+ compile_multilib: "64",
+}
+
+// dependent_library ["feature_list"]
+// ../crosvm/assertions/src/lib.rs
+// ../crosvm/data_model/src/lib.rs
+// autocfg-1.0.1
+// cfg-if-0.1.10
+// futures-0.3.8 "alloc,async-await,default,executor,futures-executor,std"
+// futures-channel-0.3.8 "alloc,futures-sink,sink,std"
+// futures-core-0.3.8 "alloc,std"
+// futures-executor-0.3.8 "default,num_cpus,std,thread-pool"
+// futures-io-0.3.8 "std"
+// futures-macro-0.3.8
+// futures-sink-0.3.8 "alloc,std"
+// futures-task-0.3.8 "alloc,once_cell,std"
+// futures-util-0.3.8 "alloc,async-await,async-await-macro,channel,default,futures-channel,futures-io,futures-macro,futures-sink,io,memchr,proc-macro-hack,proc-macro-nested,sink,slab,std"
+// intrusive-collections-0.9.0 "alloc,default"
+// libc-0.2.82 "default,std"
+// log-0.4.11
+// memchr-2.3.4 "default,std"
+// memoffset-0.5.6 "default"
+// num_cpus-1.13.0
+// once_cell-1.5.2 "alloc,std"
+// pin-project-1.0.3
+// pin-project-internal-1.0.3
+// pin-utils-0.1.0
+// proc-macro-hack-0.5.19
+// proc-macro-nested-0.1.6
+// proc-macro2-1.0.24 "default,proc-macro"
+// protobuf-2.20.0
+// quote-1.0.8 "default,proc-macro"
+// slab-0.4.2
+// syn-1.0.58 "clone-impls,default,derive,full,parsing,printing,proc-macro,quote,visit-mut"
+// unicode-xid-0.2.1 "default"
diff --git a/Cargo.lock b/Cargo.lock
new file mode 100644
index 0000000..4a3acbe
--- /dev/null
+++ b/Cargo.lock
@@ -0,0 +1,284 @@
+# This file is automatically @generated by Cargo.
+# It is not intended for manual editing.
+[[package]]
+name = "assertions"
+version = "0.1.0"
+
+[[package]]
+name = "autocfg"
+version = "1.0.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "cdb031dd78e28731d87d56cc8ffef4a8f36ca26c38fe2de700543e627f8a464a"
+
+[[package]]
+name = "cfg-if"
+version = "0.1.10"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "4785bdd1c96b2a846b2bd7cc02e86b6b3dbf14e7e53446c4f54c92a361040822"
+
+[[package]]
+name = "data_model"
+version = "0.1.0"
+dependencies = [
+ "assertions",
+ "libc",
+]
+
+[[package]]
+name = "futures"
+version = "0.3.8"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "9b3b0c040a1fe6529d30b3c5944b280c7f0dcb2930d2c3062bca967b602583d0"
+dependencies = [
+ "futures-channel",
+ "futures-core",
+ "futures-executor",
+ "futures-io",
+ "futures-sink",
+ "futures-task",
+ "futures-util",
+]
+
+[[package]]
+name = "futures-channel"
+version = "0.3.8"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "4b7109687aa4e177ef6fe84553af6280ef2778bdb7783ba44c9dc3399110fe64"
+dependencies = [
+ "futures-core",
+ "futures-sink",
+]
+
+[[package]]
+name = "futures-core"
+version = "0.3.8"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "847ce131b72ffb13b6109a221da9ad97a64cbe48feb1028356b836b47b8f1748"
+
+[[package]]
+name = "futures-executor"
+version = "0.3.8"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "4caa2b2b68b880003057c1dd49f1ed937e38f22fcf6c212188a121f08cf40a65"
+dependencies = [
+ "futures-core",
+ "futures-task",
+ "futures-util",
+ "num_cpus",
+]
+
+[[package]]
+name = "futures-io"
+version = "0.3.8"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "611834ce18aaa1bd13c4b374f5d653e1027cf99b6b502584ff8c9a64413b30bb"
+
+[[package]]
+name = "futures-macro"
+version = "0.3.8"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "77408a692f1f97bcc61dc001d752e00643408fbc922e4d634c655df50d595556"
+dependencies = [
+ "proc-macro-hack",
+ "proc-macro2",
+ "quote",
+ "syn",
+]
+
+[[package]]
+name = "futures-sink"
+version = "0.3.8"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "f878195a49cee50e006b02b93cf7e0a95a38ac7b776b4c4d9cc1207cd20fcb3d"
+
+[[package]]
+name = "futures-task"
+version = "0.3.8"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "7c554eb5bf48b2426c4771ab68c6b14468b6e76cc90996f528c3338d761a4d0d"
+dependencies = [
+ "once_cell",
+]
+
+[[package]]
+name = "futures-util"
+version = "0.3.8"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "d304cff4a7b99cfb7986f7d43fbe93d175e72e704a8860787cc95e9ffd85cbd2"
+dependencies = [
+ "futures-channel",
+ "futures-core",
+ "futures-io",
+ "futures-macro",
+ "futures-sink",
+ "futures-task",
+ "memchr",
+ "pin-project",
+ "pin-utils",
+ "proc-macro-hack",
+ "proc-macro-nested",
+ "slab",
+]
+
+[[package]]
+name = "hermit-abi"
+version = "0.1.17"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "5aca5565f760fb5b220e499d72710ed156fdb74e631659e99377d9ebfbd13ae8"
+dependencies = [
+ "libc",
+]
+
+[[package]]
+name = "intrusive-collections"
+version = "0.9.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "4bca8c0bb831cd60d4dda79a58e3705ca6eb47efb65d665651a8d672213ec3db"
+dependencies = [
+ "memoffset",
+]
+
+[[package]]
+name = "libc"
+version = "0.2.80"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "4d58d1b70b004888f764dfbf6a26a3b0342a1632d33968e4a179d8011c760614"
+
+[[package]]
+name = "libchromeos"
+version = "0.1.0"
+dependencies = [
+ "data_model",
+ "futures",
+ "futures-executor",
+ "futures-util",
+ "intrusive-collections",
+ "libc",
+ "log",
+ "protobuf",
+]
+
+[[package]]
+name = "log"
+version = "0.4.11"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "4fabed175da42fed1fa0746b0ea71f412aa9d35e76e95e59b192c64b9dc2bf8b"
+dependencies = [
+ "cfg-if",
+]
+
+[[package]]
+name = "memchr"
+version = "2.3.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "0ee1c47aaa256ecabcaea351eae4a9b01ef39ed810004e298d2511ed284b1525"
+
+[[package]]
+name = "memoffset"
+version = "0.5.6"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "043175f069eda7b85febe4a74abbaeff828d9f8b448515d3151a14a3542811aa"
+dependencies = [
+ "autocfg",
+]
+
+[[package]]
+name = "num_cpus"
+version = "1.13.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "05499f3756671c15885fee9034446956fff3f243d6077b91e5767df161f766b3"
+dependencies = [
+ "hermit-abi",
+ "libc",
+]
+
+[[package]]
+name = "once_cell"
+version = "1.5.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "13bd41f508810a131401606d54ac32a467c97172d74ba7662562ebba5ad07fa0"
+
+[[package]]
+name = "pin-project"
+version = "1.0.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "9ccc2237c2c489783abd8c4c80e5450fc0e98644555b1364da68cc29aa151ca7"
+dependencies = [
+ "pin-project-internal",
+]
+
+[[package]]
+name = "pin-project-internal"
+version = "1.0.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "f8e8d2bf0b23038a4424865103a4df472855692821aab4e4f5c3312d461d9e5f"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "syn",
+]
+
+[[package]]
+name = "pin-utils"
+version = "0.1.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184"
+
+[[package]]
+name = "proc-macro-hack"
+version = "0.5.19"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "dbf0c48bc1d91375ae5c3cd81e3722dff1abcf81a30960240640d223f59fe0e5"
+
+[[package]]
+name = "proc-macro-nested"
+version = "0.1.6"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "eba180dafb9038b050a4c280019bbedf9f2467b61e5d892dcad585bb57aadc5a"
+
+[[package]]
+name = "proc-macro2"
+version = "1.0.24"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "1e0704ee1a7e00d7bb417d0770ea303c1bccbabf0ef1667dae92b5967f5f8a71"
+dependencies = [
+ "unicode-xid",
+]
+
+[[package]]
+name = "protobuf"
+version = "2.18.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "da78e04bc0e40f36df43ecc6575e4f4b180e8156c4efd73f13d5619479b05696"
+
+[[package]]
+name = "quote"
+version = "1.0.7"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "aa563d17ecb180e500da1cfd2b028310ac758de548efdd203e18f283af693f37"
+dependencies = [
+ "proc-macro2",
+]
+
+[[package]]
+name = "slab"
+version = "0.4.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "c111b5bd5695e56cffe5129854aa230b39c93a305372fdbb2668ca2394eea9f8"
+
+[[package]]
+name = "syn"
+version = "1.0.50"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "443b4178719c5a851e1bde36ce12da21d74a0e60b4d982ec3385a933c812f0f6"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "unicode-xid",
+]
+
+[[package]]
+name = "unicode-xid"
+version = "0.2.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "f7fe0bb3479651439c9112f72b6c505038574c9fbb575ed1bf3b797fa39dd564"
diff --git a/Cargo.toml b/Cargo.toml
new file mode 100644
index 0000000..d0de4ce
--- /dev/null
+++ b/Cargo.toml
@@ -0,0 +1,19 @@
+[package]
+name = "libchromeos"
+version = "0.1.0"
+authors = ["The Chromium OS Authors"]
+edition = "2018"
+
+[dependencies]
+data_model = { path = "../crosvm/data_model" } # provided by ebuild
+libc = "0.2"
+log = "0.4"
+protobuf = "2.1"
+intrusive-collections = "0.9"
+futures = { version = "0.3", default-features = false, features = ["alloc"] }
+
+[dev-dependencies]
+futures = { version = "0.3", features = ["async-await"] }
+futures-executor = { version = "0.3", features = ["thread-pool"] }
+futures-util = "0.3"
+
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..b9e779f
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,27 @@
+// Copyright 2014 The Chromium OS Authors. All rights reserved.
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/METADATA b/METADATA
new file mode 100644
index 0000000..c4cea7d
--- /dev/null
+++ b/METADATA
@@ -0,0 +1,14 @@
+name: "libchromeos-rs"
+description:
+ "libchromeos-rs contains Rust code that can be reused across any Chrome OS "
+ "project. The crate is copied from a bigger platform2 repo."
+
+third_party {
+ url {
+ type: GIT
+ value: "https://chromium.googlesource.com/chromiumos/platform2"
+ }
+ version: "a39ed76e487e191a009cba4676bdd9a33006cbc0"
+ last_upgrade_date { year: 2020 month: 12 day: 8 }
+ license_type: NOTICE
+}
diff --git a/MODULE_LICENSE_BSD b/MODULE_LICENSE_BSD
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/MODULE_LICENSE_BSD
diff --git a/OWNERS b/OWNERS
new file mode 100644
index 0000000..2a801a4
--- /dev/null
+++ b/OWNERS
@@ -0,0 +1,5 @@
+# ANDROID: Remove chromium owners to not confuse Gerrit with unrecognized
+# owners.
+#chirantan@chromium.org
+#smbarber@chromium.org
+#zachr@chromium.org
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..02d55d2
--- /dev/null
+++ b/README.md
@@ -0,0 +1,8 @@
+# libchromeos-rs - The Rust crate for common Chrome OS code
+
+`libchromeos-rs` contains Rust code that can be reused across any Chrome OS
+project. It's the Rust equivalent of [libbrillo](../libbrillo/).
+
+Current modules include:
+* `syslog` - an adaptor for using the generic `log` crate with syslog
+* `vsock` - wrappers for dealing with AF_VSOCK sockets
diff --git a/TEST_MAPPING b/TEST_MAPPING
new file mode 100644
index 0000000..9882312
--- /dev/null
+++ b/TEST_MAPPING
@@ -0,0 +1,13 @@
+// Generated by cargo2android.py for tests in Android.bp
+{
+ "presubmit": [
+ {
+ "host": true,
+ "name": "libchromeos-rs_host_test_src_lib"
+// Presubmit tries to run x86, but we only support 64-bit builds.
+// },
+// {
+// "name": "libchromeos-rs_device_test_src_lib"
+ }
+ ]
+}
diff --git a/patches/Android.bp.patch b/patches/Android.bp.patch
new file mode 100644
index 0000000..32234ed
--- /dev/null
+++ b/patches/Android.bp.patch
@@ -0,0 +1,31 @@
+diff --git a/Android.bp b/Android.bp
+index fc2c8b8..c4e3c5b 100644
+--- a/Android.bp
++++ b/Android.bp
+@@ -27,6 +27,9 @@ rust_test_host {
+ rust_test {
+ name: "libchromeos-rs_device_test_src_lib",
+ defaults: ["libchromeos-rs_defaults"],
++ // Manually limit to 64-bit to avoid depending on non-existing 32-bit build
++ // of libdata_model currently.
++ compile_multilib: "64",
+ }
+
+ rust_library {
+@@ -43,6 +46,16 @@ rust_library {
+ "liblog_rust",
+ "libprotobuf",
+ ],
++ apex_available: [
++ "//apex_available:platform",
++ "com.android.virt",
++ ],
++ // This library depends on libdata_model that is is part of crosvm project.
++ // Projects within crosvm on Android have only 64-bit target build enabled.
++ // As a result, we need to manually limit this build to 64-bit only, too.
++ // This is fine because this library is only used by crosvm now (thus 64-bit
++ // only).
++ compile_multilib: "64",
+ }
+
+ // dependent_library ["feature_list"]
diff --git a/src/lib.rs b/src/lib.rs
new file mode 100644
index 0000000..9d34580
--- /dev/null
+++ b/src/lib.rs
@@ -0,0 +1,25 @@
+// Copyright 2019 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+pub mod linux;
+pub mod net;
+mod read_dir;
+pub mod scoped_path;
+pub mod sync;
+pub mod syslog;
+pub mod vsock;
+
+pub use read_dir::*;
+
+#[macro_export]
+macro_rules! syscall {
+ ($e:expr) => {{
+ let res = $e;
+ if res < 0 {
+ Err(::std::io::Error::last_os_error())
+ } else {
+ Ok(res)
+ }
+ }};
+}
diff --git a/src/linux.rs b/src/linux.rs
new file mode 100644
index 0000000..4826730
--- /dev/null
+++ b/src/linux.rs
@@ -0,0 +1,18 @@
+// Copyright 2020 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+/// Provides safe implementations of common low level functions that assume a Linux environment.
+use libc::{syscall, SYS_gettid};
+
+pub type Pid = libc::pid_t;
+
+pub fn getpid() -> Pid {
+ // Calling getpid() is always safe.
+ unsafe { libc::getpid() }
+}
+
+pub fn gettid() -> Pid {
+ // Calling the gettid() sycall is always safe.
+ unsafe { syscall(SYS_gettid) as Pid }
+}
diff --git a/src/net.rs b/src/net.rs
new file mode 100644
index 0000000..be1cd25
--- /dev/null
+++ b/src/net.rs
@@ -0,0 +1,269 @@
+// Copyright 2020 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+/// Structs to supplement std::net.
+use std::io;
+use std::mem::{self, size_of};
+use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6, TcpListener, TcpStream, ToSocketAddrs};
+use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
+
+use libc::{
+ c_int, in6_addr, in_addr, sa_family_t, sockaddr, sockaddr_in, sockaddr_in6, socklen_t, AF_INET,
+ AF_INET6, SOCK_CLOEXEC, SOCK_STREAM,
+};
+
+/// Assist in handling both IP version 4 and IP version 6.
+#[derive(Debug, Copy, Clone, PartialEq, Eq)]
+pub enum InetVersion {
+ V4,
+ V6,
+}
+
+impl InetVersion {
+ pub fn from_sockaddr(s: &SocketAddr) -> Self {
+ match s {
+ SocketAddr::V4(_) => InetVersion::V4,
+ SocketAddr::V6(_) => InetVersion::V6,
+ }
+ }
+}
+
+impl Into<sa_family_t> for InetVersion {
+ fn into(self) -> sa_family_t {
+ match self {
+ InetVersion::V4 => AF_INET as sa_family_t,
+ InetVersion::V6 => AF_INET6 as sa_family_t,
+ }
+ }
+}
+
+fn sockaddrv4_to_lib_c(s: &SocketAddrV4) -> sockaddr_in {
+ sockaddr_in {
+ sin_family: AF_INET as sa_family_t,
+ sin_port: s.port().to_be(),
+ sin_addr: in_addr {
+ s_addr: u32::from_ne_bytes(s.ip().octets()),
+ },
+ sin_zero: [0; 8],
+ }
+}
+
+fn sockaddrv6_to_lib_c(s: &SocketAddrV6) -> sockaddr_in6 {
+ sockaddr_in6 {
+ sin6_family: AF_INET6 as sa_family_t,
+ sin6_port: s.port().to_be(),
+ sin6_flowinfo: 0,
+ sin6_addr: in6_addr {
+ s6_addr: s.ip().octets(),
+ },
+ sin6_scope_id: 0,
+ }
+}
+
+/// A TCP socket.
+///
+/// Do not use this class unless you need to change socket options or query the
+/// state of the socket prior to calling listen or connect. Instead use either TcpStream or
+/// TcpListener.
+#[derive(Debug)]
+pub struct TcpSocket {
+ inet_version: InetVersion,
+ fd: RawFd,
+}
+
+impl TcpSocket {
+ pub fn new(inet_version: InetVersion) -> io::Result<Self> {
+ let fd = unsafe {
+ libc::socket(
+ Into::<sa_family_t>::into(inet_version) as c_int,
+ SOCK_STREAM | SOCK_CLOEXEC,
+ 0,
+ )
+ };
+ if fd < 0 {
+ Err(io::Error::last_os_error())
+ } else {
+ Ok(TcpSocket { inet_version, fd })
+ }
+ }
+
+ pub fn bind<A: ToSocketAddrs>(&mut self, addr: A) -> io::Result<()> {
+ let sockaddr = addr
+ .to_socket_addrs()
+ .map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?
+ .next()
+ .unwrap();
+
+ let ret = match sockaddr {
+ SocketAddr::V4(a) => {
+ let sin = sockaddrv4_to_lib_c(&a);
+ // Safe because this doesn't modify any memory and we check the return value.
+ unsafe {
+ libc::bind(
+ self.fd,
+ &sin as *const sockaddr_in as *const sockaddr,
+ size_of::<sockaddr_in>() as socklen_t,
+ )
+ }
+ }
+ SocketAddr::V6(a) => {
+ let sin6 = sockaddrv6_to_lib_c(&a);
+ // Safe because this doesn't modify any memory and we check the return value.
+ unsafe {
+ libc::bind(
+ self.fd,
+ &sin6 as *const sockaddr_in6 as *const sockaddr,
+ size_of::<sockaddr_in6>() as socklen_t,
+ )
+ }
+ }
+ };
+ if ret < 0 {
+ let bind_err = io::Error::last_os_error();
+ Err(bind_err)
+ } else {
+ Ok(())
+ }
+ }
+
+ pub fn connect<A: ToSocketAddrs>(self, addr: A) -> io::Result<TcpStream> {
+ let sockaddr = addr
+ .to_socket_addrs()
+ .map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?
+ .next()
+ .unwrap();
+
+ let ret = match sockaddr {
+ SocketAddr::V4(a) => {
+ let sin = sockaddrv4_to_lib_c(&a);
+ // Safe because this doesn't modify any memory and we check the return value.
+ unsafe {
+ libc::connect(
+ self.fd,
+ &sin as *const sockaddr_in as *const sockaddr,
+ size_of::<sockaddr_in>() as socklen_t,
+ )
+ }
+ }
+ SocketAddr::V6(a) => {
+ let sin6 = sockaddrv6_to_lib_c(&a);
+ // Safe because this doesn't modify any memory and we check the return value.
+ unsafe {
+ libc::connect(
+ self.fd,
+ &sin6 as *const sockaddr_in6 as *const sockaddr,
+ size_of::<sockaddr_in>() as socklen_t,
+ )
+ }
+ }
+ };
+
+ if ret < 0 {
+ let connect_err = io::Error::last_os_error();
+ Err(connect_err)
+ } else {
+ // Safe because the ownership of the raw fd is released from self and taken over by the
+ // new TcpStream.
+ Ok(unsafe { TcpStream::from_raw_fd(self.into_raw_fd()) })
+ }
+ }
+
+ pub fn listen(self) -> io::Result<TcpListener> {
+ // Safe because this doesn't modify any memory and we check the return value.
+ let ret = unsafe { libc::listen(self.fd, 1) };
+ if ret < 0 {
+ let listen_err = io::Error::last_os_error();
+ Err(listen_err)
+ } else {
+ // Safe because the ownership of the raw fd is released from self and taken over by the
+ // new TcpListener.
+ Ok(unsafe { TcpListener::from_raw_fd(self.into_raw_fd()) })
+ }
+ }
+
+ /// Returns the port that this socket is bound to. This can only succeed after bind is called.
+ pub fn local_port(&self) -> io::Result<u16> {
+ match self.inet_version {
+ InetVersion::V4 => {
+ let mut sin = sockaddr_in {
+ sin_family: 0,
+ sin_port: 0,
+ sin_addr: in_addr { s_addr: 0 },
+ sin_zero: [0; 8],
+ };
+
+ // Safe because we give a valid pointer for addrlen and check the length.
+ let mut addrlen = size_of::<sockaddr_in>() as socklen_t;
+ let ret = unsafe {
+ // Get the socket address that was actually bound.
+ libc::getsockname(
+ self.fd,
+ &mut sin as *mut sockaddr_in as *mut sockaddr,
+ &mut addrlen as *mut socklen_t,
+ )
+ };
+ if ret < 0 {
+ let getsockname_err = io::Error::last_os_error();
+ Err(getsockname_err)
+ } else {
+ // If this doesn't match, it's not safe to get the port out of the sockaddr.
+ assert_eq!(addrlen as usize, size_of::<sockaddr_in>());
+
+ Ok(sin.sin_port)
+ }
+ }
+ InetVersion::V6 => {
+ let mut sin6 = sockaddr_in6 {
+ sin6_family: 0,
+ sin6_port: 0,
+ sin6_flowinfo: 0,
+ sin6_addr: in6_addr { s6_addr: [0; 16] },
+ sin6_scope_id: 0,
+ };
+
+ // Safe because we give a valid pointer for addrlen and check the length.
+ let mut addrlen = size_of::<sockaddr_in6>() as socklen_t;
+ let ret = unsafe {
+ // Get the socket address that was actually bound.
+ libc::getsockname(
+ self.fd,
+ &mut sin6 as *mut sockaddr_in6 as *mut sockaddr,
+ &mut addrlen as *mut socklen_t,
+ )
+ };
+ if ret < 0 {
+ let getsockname_err = io::Error::last_os_error();
+ Err(getsockname_err)
+ } else {
+ // If this doesn't match, it's not safe to get the port out of the sockaddr.
+ assert_eq!(addrlen as usize, size_of::<sockaddr_in>());
+
+ Ok(sin6.sin6_port)
+ }
+ }
+ }
+ }
+}
+
+impl IntoRawFd for TcpSocket {
+ fn into_raw_fd(self) -> RawFd {
+ let fd = self.fd;
+ mem::forget(self);
+ fd
+ }
+}
+
+impl AsRawFd for TcpSocket {
+ fn as_raw_fd(&self) -> RawFd {
+ self.fd
+ }
+}
+
+impl Drop for TcpSocket {
+ fn drop(&mut self) {
+ // Safe because this doesn't modify any memory and we are the only
+ // owner of the file descriptor.
+ unsafe { libc::close(self.fd) };
+ }
+}
diff --git a/src/read_dir.rs b/src/read_dir.rs
new file mode 100644
index 0000000..eb8a549
--- /dev/null
+++ b/src/read_dir.rs
@@ -0,0 +1,150 @@
+// Copyright 2020 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+use std::ffi::CStr;
+use std::io;
+use std::mem::size_of;
+use std::os::unix::io::AsRawFd;
+
+use data_model::DataInit;
+
+use crate::syscall;
+
+#[repr(C, packed)]
+#[derive(Clone, Copy)]
+struct LinuxDirent64 {
+ d_ino: libc::ino64_t,
+ d_off: libc::off64_t,
+ d_reclen: libc::c_ushort,
+ d_ty: libc::c_uchar,
+}
+unsafe impl DataInit for LinuxDirent64 {}
+
+pub struct DirEntry<'r> {
+ pub ino: libc::ino64_t,
+ pub offset: u64,
+ pub type_: u8,
+ pub name: &'r CStr,
+}
+
+pub struct ReadDir<'d, D> {
+ buf: [u8; 256],
+ dir: &'d mut D,
+ current: usize,
+ end: usize,
+}
+
+impl<'d, D: AsRawFd> ReadDir<'d, D> {
+ /// Return the next directory entry. This is implemented as a separate method rather than via
+ /// the `Iterator` trait because rust doesn't currently support generic associated types.
+ #[allow(clippy::should_implement_trait)]
+ pub fn next(&mut self) -> Option<io::Result<DirEntry>> {
+ if self.current >= self.end {
+ let res = syscall!(unsafe {
+ libc::syscall(
+ libc::SYS_getdents64,
+ self.dir.as_raw_fd(),
+ self.buf.as_mut_ptr() as *mut LinuxDirent64,
+ self.buf.len() as libc::c_int,
+ )
+ });
+ match res {
+ Ok(end) => {
+ self.current = 0;
+ self.end = end as usize;
+ }
+ Err(e) => return Some(Err(e)),
+ }
+ }
+
+ let rem = &self.buf[self.current..self.end];
+ if rem.is_empty() {
+ return None;
+ }
+
+ // We only use debug asserts here because these values are coming from the kernel and we
+ // trust them implicitly.
+ debug_assert!(
+ rem.len() >= size_of::<LinuxDirent64>(),
+ "not enough space left in `rem`"
+ );
+
+ let (front, back) = rem.split_at(size_of::<LinuxDirent64>());
+
+ let dirent64 =
+ LinuxDirent64::from_slice(front).expect("unable to get LinuxDirent64 from slice");
+
+ let namelen = dirent64.d_reclen as usize - size_of::<LinuxDirent64>();
+ debug_assert!(namelen <= back.len(), "back is smaller than `namelen`");
+
+ // The kernel will pad the name with additional nul bytes until it is 8-byte aligned so
+ // we need to strip those off here.
+ let name = strip_padding(&back[..namelen]);
+ let entry = DirEntry {
+ ino: dirent64.d_ino,
+ offset: dirent64.d_off as u64,
+ type_: dirent64.d_ty,
+ name,
+ };
+
+ debug_assert!(
+ rem.len() >= dirent64.d_reclen as usize,
+ "rem is smaller than `d_reclen`"
+ );
+ self.current += dirent64.d_reclen as usize;
+ Some(Ok(entry))
+ }
+}
+
+pub fn read_dir<D: AsRawFd>(dir: &mut D, offset: libc::off64_t) -> io::Result<ReadDir<D>> {
+ // Safe because this doesn't modify any memory and we check the return value.
+ syscall!(unsafe { libc::lseek64(dir.as_raw_fd(), offset, libc::SEEK_SET) })?;
+
+ Ok(ReadDir {
+ buf: [0u8; 256],
+ dir,
+ current: 0,
+ end: 0,
+ })
+}
+
+// Like `CStr::from_bytes_with_nul` but strips any bytes after the first '\0'-byte. Panics if `b`
+// doesn't contain any '\0' bytes.
+fn strip_padding(b: &[u8]) -> &CStr {
+ // It would be nice if we could use memchr here but that's locked behind an unstable gate.
+ let pos = b
+ .iter()
+ .position(|&c| c == 0)
+ .expect("`b` doesn't contain any nul bytes");
+
+ // Safe because we are creating this string with the first nul-byte we found so we can
+ // guarantee that it is nul-terminated and doesn't contain any interior nuls.
+ unsafe { CStr::from_bytes_with_nul_unchecked(&b[..pos + 1]) }
+}
+
+#[cfg(test)]
+mod test {
+ use super::*;
+
+ #[test]
+ fn padded_cstrings() {
+ assert_eq!(strip_padding(b".\0\0\0\0\0\0\0").to_bytes(), b".");
+ assert_eq!(strip_padding(b"..\0\0\0\0\0\0").to_bytes(), b"..");
+ assert_eq!(
+ strip_padding(b"normal cstring\0").to_bytes(),
+ b"normal cstring"
+ );
+ assert_eq!(strip_padding(b"\0\0\0\0").to_bytes(), b"");
+ assert_eq!(
+ strip_padding(b"interior\0nul bytes\0\0\0").to_bytes(),
+ b"interior"
+ );
+ }
+
+ #[test]
+ #[should_panic(expected = "`b` doesn't contain any nul bytes")]
+ fn no_nul_byte() {
+ strip_padding(b"no nul bytes in string");
+ }
+}
diff --git a/src/scoped_path.rs b/src/scoped_path.rs
new file mode 100644
index 0000000..22d0ad9
--- /dev/null
+++ b/src/scoped_path.rs
@@ -0,0 +1,115 @@
+// Copyright 2020 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+use std::env::{current_exe, temp_dir};
+use std::fs::{create_dir_all, remove_dir_all};
+use std::ops::Deref;
+use std::path::{Path, PathBuf};
+use std::thread::panicking;
+
+use super::linux::{getpid, gettid};
+
+/// Returns a stable path based on the label, pid, and tid. If the label isn't provided the
+/// current_exe is used instead.
+pub fn get_temp_path(label: Option<&str>) -> PathBuf {
+ if let Some(label) = label {
+ temp_dir().join(format!("{}-{}-{}", label, getpid(), gettid()))
+ } else {
+ get_temp_path(Some(current_exe().unwrap().to_str().unwrap()))
+ }
+}
+
+/// Automatically deletes the path it contains when it goes out of scope unless it is a test and
+/// drop is called after a panic!.
+///
+/// This is particularly useful for creating temporary directories for use with tests.
+pub struct ScopedPath<P: AsRef<Path>>(P);
+
+impl<P: AsRef<Path>> ScopedPath<P> {
+ pub fn create(p: P) -> Result<Self, std::io::Error> {
+ create_dir_all(p.as_ref())?;
+ Ok(ScopedPath(p))
+ }
+}
+
+impl<P: AsRef<Path>> AsRef<Path> for ScopedPath<P> {
+ fn as_ref(&self) -> &Path {
+ self.0.as_ref()
+ }
+}
+
+impl<P: AsRef<Path>> Deref for ScopedPath<P> {
+ type Target = Path;
+
+ fn deref(&self) -> &Self::Target {
+ self.0.as_ref()
+ }
+}
+
+impl<P: AsRef<Path>> Drop for ScopedPath<P> {
+ fn drop(&mut self) {
+ // Leave the files on a failed test run for debugging.
+ if panicking() && cfg!(test) {
+ eprintln!("NOTE: Not removing {}", self.display());
+ return;
+ }
+ if let Err(e) = remove_dir_all(&**self) {
+ eprintln!("Failed to remove {}: {}", self.display(), e);
+ }
+ }
+}
+
+#[cfg(test)]
+pub(crate) mod tests {
+ use super::*;
+
+ use std::panic::catch_unwind;
+
+ #[test]
+ fn gettemppath() {
+ assert_ne!("", get_temp_path(None).to_string_lossy());
+ assert_eq!(
+ get_temp_path(None),
+ get_temp_path(Some(current_exe().unwrap().to_str().unwrap()))
+ );
+ assert_ne!(
+ get_temp_path(Some("label")),
+ get_temp_path(Some(current_exe().unwrap().to_str().unwrap()))
+ );
+ }
+
+ #[test]
+ fn scopedpath_exists() {
+ let tmp_path = get_temp_path(None);
+ {
+ let scoped_path = ScopedPath::create(&tmp_path).unwrap();
+ assert!(scoped_path.exists());
+ }
+ assert!(!tmp_path.exists());
+ }
+
+ #[test]
+ fn scopedpath_notexists() {
+ let tmp_path = get_temp_path(None);
+ {
+ let _scoped_path = ScopedPath(&tmp_path);
+ }
+ assert!(!tmp_path.exists());
+ }
+
+ #[test]
+ fn scopedpath_panic() {
+ let tmp_path = get_temp_path(None);
+ assert!(catch_unwind(|| {
+ {
+ let scoped_path = ScopedPath::create(&tmp_path).unwrap();
+ assert!(scoped_path.exists());
+ panic!()
+ }
+ })
+ .is_err());
+ assert!(tmp_path.exists());
+ remove_dir_all(&tmp_path).unwrap();
+ }
+}
diff --git a/src/sync.rs b/src/sync.rs
new file mode 100644
index 0000000..80dec38
--- /dev/null
+++ b/src/sync.rs
@@ -0,0 +1,14 @@
+// Copyright 2020 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+mod blocking;
+mod cv;
+mod mu;
+mod spin;
+mod waiter;
+
+pub use blocking::block_on;
+pub use cv::Condvar;
+pub use mu::Mutex;
+pub use spin::SpinLock;
diff --git a/src/sync/blocking.rs b/src/sync/blocking.rs
new file mode 100644
index 0000000..e95694d
--- /dev/null
+++ b/src/sync/blocking.rs
@@ -0,0 +1,192 @@
+// Copyright 2020 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+use std::future::Future;
+use std::ptr;
+use std::sync::atomic::{AtomicI32, Ordering};
+use std::sync::Arc;
+use std::task::{Context, Poll};
+
+use futures::pin_mut;
+use futures::task::{waker_ref, ArcWake};
+
+// Randomly generated values to indicate the state of the current thread.
+const WAITING: i32 = 0x25de_74d1;
+const WOKEN: i32 = 0x72d3_2c9f;
+
+const FUTEX_WAIT_PRIVATE: libc::c_int = libc::FUTEX_WAIT | libc::FUTEX_PRIVATE_FLAG;
+const FUTEX_WAKE_PRIVATE: libc::c_int = libc::FUTEX_WAKE | libc::FUTEX_PRIVATE_FLAG;
+
+thread_local!(static PER_THREAD_WAKER: Arc<Waker> = Arc::new(Waker(AtomicI32::new(WAITING))));
+
+#[repr(transparent)]
+struct Waker(AtomicI32);
+
+extern {
+ #[cfg_attr(target_os = "android", link_name = "__errno")]
+ #[cfg_attr(target_os = "linux", link_name = "__errno_location")]
+ fn errno_location() -> *mut libc::c_int;
+}
+
+impl ArcWake for Waker {
+ fn wake_by_ref(arc_self: &Arc<Self>) {
+ let state = arc_self.0.swap(WOKEN, Ordering::Release);
+ if state == WAITING {
+ // The thread hasn't already been woken up so wake it up now. Safe because this doesn't
+ // modify any memory and we check the return value.
+ let res = unsafe {
+ libc::syscall(
+ libc::SYS_futex,
+ &arc_self.0,
+ FUTEX_WAKE_PRIVATE,
+ libc::INT_MAX, // val
+ ptr::null() as *const libc::timespec, // timeout
+ ptr::null() as *const libc::c_int, // uaddr2
+ 0 as libc::c_int, // val3
+ )
+ };
+ if res < 0 {
+ panic!("unexpected error from FUTEX_WAKE_PRIVATE: {}", unsafe {
+ *errno_location()
+ });
+ }
+ }
+ }
+}
+
+/// Run a future to completion on the current thread.
+///
+/// This method will block the current thread until `f` completes. Useful when you need to call an
+/// async fn from a non-async context.
+pub fn block_on<F: Future>(f: F) -> F::Output {
+ pin_mut!(f);
+
+ PER_THREAD_WAKER.with(|thread_waker| {
+ let waker = waker_ref(thread_waker);
+ let mut cx = Context::from_waker(&waker);
+
+ loop {
+ if let Poll::Ready(t) = f.as_mut().poll(&mut cx) {
+ return t;
+ }
+
+ let state = thread_waker.0.swap(WAITING, Ordering::Acquire);
+ if state == WAITING {
+ // If we weren't already woken up then wait until we are. Safe because this doesn't
+ // modify any memory and we check the return value.
+ let res = unsafe {
+ libc::syscall(
+ libc::SYS_futex,
+ &thread_waker.0,
+ FUTEX_WAIT_PRIVATE,
+ state,
+ ptr::null() as *const libc::timespec, // timeout
+ ptr::null() as *const libc::c_int, // uaddr2
+ 0 as libc::c_int, // val3
+ )
+ };
+
+ if res < 0 {
+ // Safe because libc guarantees that this is a valid pointer.
+ match unsafe { *errno_location() } {
+ libc::EAGAIN | libc::EINTR => {}
+ e => panic!("unexpected error from FUTEX_WAIT_PRIVATE: {}", e),
+ }
+ }
+
+ // Clear the state to prevent unnecessary extra loop iterations and also to allow
+ // nested usage of `block_on`.
+ thread_waker.0.store(WAITING, Ordering::Release);
+ }
+ }
+ })
+}
+
+#[cfg(test)]
+mod test {
+ use super::*;
+
+ use std::future::Future;
+ use std::pin::Pin;
+ use std::sync::mpsc::{channel, Sender};
+ use std::sync::Arc;
+ use std::task::{Context, Poll, Waker};
+ use std::thread;
+ use std::time::Duration;
+
+ use crate::sync::SpinLock;
+
+ struct TimerState {
+ fired: bool,
+ waker: Option<Waker>,
+ }
+ struct Timer {
+ state: Arc<SpinLock<TimerState>>,
+ }
+
+ impl Future for Timer {
+ type Output = ();
+
+ fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
+ let mut state = self.state.lock();
+ if state.fired {
+ return Poll::Ready(());
+ }
+
+ state.waker = Some(cx.waker().clone());
+ Poll::Pending
+ }
+ }
+
+ fn start_timer(dur: Duration, notify: Option<Sender<()>>) -> Timer {
+ let state = Arc::new(SpinLock::new(TimerState {
+ fired: false,
+ waker: None,
+ }));
+
+ let thread_state = Arc::clone(&state);
+ thread::spawn(move || {
+ thread::sleep(dur);
+ let mut ts = thread_state.lock();
+ ts.fired = true;
+ if let Some(waker) = ts.waker.take() {
+ waker.wake();
+ }
+ drop(ts);
+
+ if let Some(tx) = notify {
+ tx.send(()).expect("Failed to send completion notification");
+ }
+ });
+
+ Timer { state }
+ }
+
+ #[test]
+ fn it_works() {
+ block_on(start_timer(Duration::from_millis(100), None));
+ }
+
+ #[test]
+ fn nested() {
+ async fn inner() {
+ block_on(start_timer(Duration::from_millis(100), None));
+ }
+
+ block_on(inner());
+ }
+
+ #[test]
+ fn ready_before_poll() {
+ let (tx, rx) = channel();
+
+ let timer = start_timer(Duration::from_millis(50), Some(tx));
+
+ rx.recv()
+ .expect("Failed to receive completion notification");
+
+ // We know the timer has already fired so the poll should complete immediately.
+ block_on(timer);
+ }
+}
diff --git a/src/sync/cv.rs b/src/sync/cv.rs
new file mode 100644
index 0000000..714c6d6
--- /dev/null
+++ b/src/sync/cv.rs
@@ -0,0 +1,1251 @@
+// Copyright 2020 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+use std::cell::UnsafeCell;
+use std::mem;
+use std::sync::atomic::{spin_loop_hint, AtomicUsize, Ordering};
+use std::sync::Arc;
+
+use crate::sync::mu::{MutexGuard, MutexReadGuard, RawMutex};
+use crate::sync::waiter::{Kind as WaiterKind, Waiter, WaiterAdapter, WaiterList, WaitingFor};
+
+const SPINLOCK: usize = 1 << 0;
+const HAS_WAITERS: usize = 1 << 1;
+
+/// A primitive to wait for an event to occur without consuming CPU time.
+///
+/// Condition variables are used in combination with a `Mutex` when a thread wants to wait for some
+/// condition to become true. The condition must always be verified while holding the `Mutex` lock.
+/// It is an error to use a `Condvar` with more than one `Mutex` while there are threads waiting on
+/// the `Condvar`.
+///
+/// # Examples
+///
+/// ```edition2018
+/// use std::sync::Arc;
+/// use std::thread;
+/// use std::sync::mpsc::channel;
+///
+/// use libchromeos::sync::{block_on, Condvar, Mutex};
+///
+/// const N: usize = 13;
+///
+/// // Spawn a few threads to increment a shared variable (non-atomically), and
+/// // let all threads waiting on the Condvar know once the increments are done.
+/// let data = Arc::new(Mutex::new(0));
+/// let cv = Arc::new(Condvar::new());
+///
+/// for _ in 0..N {
+/// let (data, cv) = (data.clone(), cv.clone());
+/// thread::spawn(move || {
+/// let mut data = block_on(data.lock());
+/// *data += 1;
+/// if *data == N {
+/// cv.notify_all();
+/// }
+/// });
+/// }
+///
+/// let mut val = block_on(data.lock());
+/// while *val != N {
+/// val = block_on(cv.wait(val));
+/// }
+/// ```
+pub struct Condvar {
+ state: AtomicUsize,
+ waiters: UnsafeCell<WaiterList>,
+ mu: UnsafeCell<usize>,
+}
+
+impl Condvar {
+ /// Creates a new condition variable ready to be waited on and notified.
+ pub fn new() -> Condvar {
+ Condvar {
+ state: AtomicUsize::new(0),
+ waiters: UnsafeCell::new(WaiterList::new(WaiterAdapter::new())),
+ mu: UnsafeCell::new(0),
+ }
+ }
+
+ /// Block the current thread until this `Condvar` is notified by another thread.
+ ///
+ /// This method will atomically unlock the `Mutex` held by `guard` and then block the current
+ /// thread. Any call to `notify_one` or `notify_all` after the `Mutex` is unlocked may wake up
+ /// the thread.
+ ///
+ /// To allow for more efficient scheduling, this call may return even when the programmer
+ /// doesn't expect the thread to be woken. Therefore, calls to `wait()` should be used inside a
+ /// loop that checks the predicate before continuing.
+ ///
+ /// Callers that are not in an async context may wish to use the `block_on` method to block the
+ /// thread until the `Condvar` is notified.
+ ///
+ /// # Panics
+ ///
+ /// This method will panic if used with more than one `Mutex` at the same time.
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// # use std::sync::Arc;
+ /// # use std::thread;
+ ///
+ /// # use libchromeos::sync::{block_on, Condvar, Mutex};
+ ///
+ /// # let mu = Arc::new(Mutex::new(false));
+ /// # let cv = Arc::new(Condvar::new());
+ /// # let (mu2, cv2) = (mu.clone(), cv.clone());
+ ///
+ /// # let t = thread::spawn(move || {
+ /// # *block_on(mu2.lock()) = true;
+ /// # cv2.notify_all();
+ /// # });
+ ///
+ /// let mut ready = block_on(mu.lock());
+ /// while !*ready {
+ /// ready = block_on(cv.wait(ready));
+ /// }
+ ///
+ /// # t.join().expect("failed to join thread");
+ /// ```
+ // Clippy doesn't like the lifetime parameters here but doing what it suggests leads to code
+ // that doesn't compile.
+ #[allow(clippy::needless_lifetimes)]
+ pub async fn wait<'g, T>(&self, guard: MutexGuard<'g, T>) -> MutexGuard<'g, T> {
+ let waiter = Arc::new(Waiter::new(
+ WaiterKind::Exclusive,
+ cancel_waiter,
+ self as *const Condvar as usize,
+ WaitingFor::Condvar,
+ ));
+
+ self.add_waiter(waiter.clone(), guard.as_raw_mutex());
+
+ // Get a reference to the mutex and then drop the lock.
+ let mu = guard.into_inner();
+
+ // Wait to be woken up.
+ waiter.wait().await;
+
+ // Now re-acquire the lock.
+ mu.lock_from_cv().await
+ }
+
+ /// Like `wait()` but takes and returns a `MutexReadGuard` instead.
+ // Clippy doesn't like the lifetime parameters here but doing what it suggests leads to code
+ // that doesn't compile.
+ #[allow(clippy::needless_lifetimes)]
+ pub async fn wait_read<'g, T>(&self, guard: MutexReadGuard<'g, T>) -> MutexReadGuard<'g, T> {
+ let waiter = Arc::new(Waiter::new(
+ WaiterKind::Shared,
+ cancel_waiter,
+ self as *const Condvar as usize,
+ WaitingFor::Condvar,
+ ));
+
+ self.add_waiter(waiter.clone(), guard.as_raw_mutex());
+
+ // Get a reference to the mutex and then drop the lock.
+ let mu = guard.into_inner();
+
+ // Wait to be woken up.
+ waiter.wait().await;
+
+ // Now re-acquire the lock.
+ mu.read_lock_from_cv().await
+ }
+
+ fn add_waiter(&self, waiter: Arc<Waiter>, raw_mutex: &RawMutex) {
+ // Acquire the spin lock.
+ let mut oldstate = self.state.load(Ordering::Relaxed);
+ while (oldstate & SPINLOCK) != 0
+ || self.state.compare_and_swap(
+ oldstate,
+ oldstate | SPINLOCK | HAS_WAITERS,
+ Ordering::Acquire,
+ ) != oldstate
+ {
+ spin_loop_hint();
+ oldstate = self.state.load(Ordering::Relaxed);
+ }
+
+ // Safe because the spin lock guarantees exclusive access and the reference does not escape
+ // this function.
+ let mu = unsafe { &mut *self.mu.get() };
+ let muptr = raw_mutex as *const RawMutex as usize;
+
+ match *mu {
+ 0 => *mu = muptr,
+ p if p == muptr => {}
+ _ => panic!("Attempting to use Condvar with more than one Mutex at the same time"),
+ }
+
+ // Safe because the spin lock guarantees exclusive access.
+ unsafe { (*self.waiters.get()).push_back(waiter) };
+
+ // Release the spin lock. Use a direct store here because no other thread can modify
+ // `self.state` while we hold the spin lock. Keep the `HAS_WAITERS` bit that we set earlier
+ // because we just added a waiter.
+ self.state.store(HAS_WAITERS, Ordering::Release);
+ }
+
+ /// Notify at most one thread currently waiting on the `Condvar`.
+ ///
+ /// If there is a thread currently waiting on the `Condvar` it will be woken up from its call to
+ /// `wait`.
+ ///
+ /// Unlike more traditional condition variable interfaces, this method requires a reference to
+ /// the `Mutex` associated with this `Condvar`. This is because it is inherently racy to call
+ /// `notify_one` or `notify_all` without first acquiring the `Mutex` lock. Additionally, taking
+ /// a reference to the `Mutex` here allows us to make some optimizations that can improve
+ /// performance by reducing unnecessary wakeups.
+ pub fn notify_one(&self) {
+ let mut oldstate = self.state.load(Ordering::Relaxed);
+ if (oldstate & HAS_WAITERS) == 0 {
+ // No waiters.
+ return;
+ }
+
+ while (oldstate & SPINLOCK) != 0
+ || self
+ .state
+ .compare_and_swap(oldstate, oldstate | SPINLOCK, Ordering::Acquire)
+ != oldstate
+ {
+ spin_loop_hint();
+ oldstate = self.state.load(Ordering::Relaxed);
+ }
+
+ // Safe because the spin lock guarantees exclusive access and the reference does not escape
+ // this function.
+ let waiters = unsafe { &mut *self.waiters.get() };
+ let (mut wake_list, all_readers) = get_wake_list(waiters);
+
+ // Safe because the spin lock guarantees exclusive access.
+ let muptr = unsafe { (*self.mu.get()) as *const RawMutex };
+
+ let newstate = if waiters.is_empty() {
+ // Also clear the mutex associated with this Condvar since there are no longer any
+ // waiters. Safe because the spin lock guarantees exclusive access.
+ unsafe { *self.mu.get() = 0 };
+
+ // We are releasing the spin lock and there are no more waiters so we can clear all bits
+ // in `self.state`.
+ 0
+ } else {
+ // There are still waiters so we need to keep the HAS_WAITERS bit in the state.
+ HAS_WAITERS
+ };
+
+ // Try to transfer waiters before releasing the spin lock.
+ if !wake_list.is_empty() {
+ // Safe because there was a waiter in the queue and the thread that owns the waiter also
+ // owns a reference to the Mutex, guaranteeing that the pointer is valid.
+ unsafe { (*muptr).transfer_waiters(&mut wake_list, all_readers) };
+ }
+
+ // Release the spin lock.
+ self.state.store(newstate, Ordering::Release);
+
+ // Now wake any waiters still left in the wake list.
+ for w in wake_list {
+ w.wake();
+ }
+ }
+
+ /// Notify all threads currently waiting on the `Condvar`.
+ ///
+ /// All threads currently waiting on the `Condvar` will be woken up from their call to `wait`.
+ ///
+ /// Unlike more traditional condition variable interfaces, this method requires a reference to
+ /// the `Mutex` associated with this `Condvar`. This is because it is inherently racy to call
+ /// `notify_one` or `notify_all` without first acquiring the `Mutex` lock. Additionally, taking
+ /// a reference to the `Mutex` here allows us to make some optimizations that can improve
+ /// performance by reducing unnecessary wakeups.
+ pub fn notify_all(&self) {
+ let mut oldstate = self.state.load(Ordering::Relaxed);
+ if (oldstate & HAS_WAITERS) == 0 {
+ // No waiters.
+ return;
+ }
+
+ while (oldstate & SPINLOCK) != 0
+ || self
+ .state
+ .compare_and_swap(oldstate, oldstate | SPINLOCK, Ordering::Acquire)
+ != oldstate
+ {
+ spin_loop_hint();
+ oldstate = self.state.load(Ordering::Relaxed);
+ }
+
+ // Safe because the spin lock guarantees exclusive access to `self.waiters`.
+ let mut wake_list = unsafe { (*self.waiters.get()).take() };
+
+ // Safe because the spin lock guarantees exclusive access.
+ let muptr = unsafe { (*self.mu.get()) as *const RawMutex };
+
+ // Clear the mutex associated with this Condvar since there are no longer any waiters. Safe
+ // because we the spin lock guarantees exclusive access.
+ unsafe { *self.mu.get() = 0 };
+
+ // Try to transfer waiters before releasing the spin lock.
+ if !wake_list.is_empty() {
+ // Safe because there was a waiter in the queue and the thread that owns the waiter also
+ // owns a reference to the Mutex, guaranteeing that the pointer is valid.
+ unsafe { (*muptr).transfer_waiters(&mut wake_list, false) };
+ }
+
+ // Mark any waiters left as no longer waiting for the Condvar.
+ for w in &wake_list {
+ w.set_waiting_for(WaitingFor::None);
+ }
+
+ // Release the spin lock. We can clear all bits in the state since we took all the waiters.
+ self.state.store(0, Ordering::Release);
+
+ // Now wake any waiters still left in the wake list.
+ for w in wake_list {
+ w.wake();
+ }
+ }
+
+ fn cancel_waiter(&self, waiter: &Waiter, wake_next: bool) -> bool {
+ let mut oldstate = self.state.load(Ordering::Relaxed);
+ while oldstate & SPINLOCK != 0
+ || self
+ .state
+ .compare_exchange_weak(
+ oldstate,
+ oldstate | SPINLOCK,
+ Ordering::Acquire,
+ Ordering::Relaxed,
+ )
+ .is_err()
+ {
+ spin_loop_hint();
+ oldstate = self.state.load(Ordering::Relaxed);
+ }
+
+ // Safe because the spin lock provides exclusive access and the reference does not escape
+ // this function.
+ let waiters = unsafe { &mut *self.waiters.get() };
+
+ let waiting_for = waiter.is_waiting_for();
+ if waiting_for == WaitingFor::Mutex {
+ // The waiter was moved to the mutex's list. Retry the cancel.
+ let set_on_release = if waiters.is_empty() {
+ // Clear the mutex associated with this Condvar since there are no longer any waiters. Safe
+ // because we the spin lock guarantees exclusive access.
+ unsafe { *self.mu.get() = 0 };
+
+ 0
+ } else {
+ HAS_WAITERS
+ };
+
+ self.state.store(set_on_release, Ordering::Release);
+
+ false
+ } else {
+ // Don't drop the old waiter now as we're still holding the spin lock.
+ let old_waiter = if waiter.is_linked() && waiting_for == WaitingFor::Condvar {
+ // Safe because we know that the waiter is still linked and is waiting for the Condvar,
+ // which guarantees that it is still in `self.waiters`.
+ let mut cursor = unsafe { waiters.cursor_mut_from_ptr(waiter as *const Waiter) };
+ cursor.remove()
+ } else {
+ None
+ };
+
+ let (mut wake_list, all_readers) = if wake_next || waiting_for == WaitingFor::None {
+ // Either the waiter was already woken or it's been removed from the condvar's waiter
+ // list and is going to be woken. Either way, we need to wake up another thread.
+ get_wake_list(waiters)
+ } else {
+ (WaiterList::new(WaiterAdapter::new()), false)
+ };
+
+ // Safe because the spin lock guarantees exclusive access.
+ let muptr = unsafe { (*self.mu.get()) as *const RawMutex };
+
+ // Try to transfer waiters before releasing the spin lock.
+ if !wake_list.is_empty() {
+ // Safe because there was a waiter in the queue and the thread that owns the waiter also
+ // owns a reference to the Mutex, guaranteeing that the pointer is valid.
+ unsafe { (*muptr).transfer_waiters(&mut wake_list, all_readers) };
+ }
+
+ let set_on_release = if waiters.is_empty() {
+ // Clear the mutex associated with this Condvar since there are no longer any waiters. Safe
+ // because we the spin lock guarantees exclusive access.
+ unsafe { *self.mu.get() = 0 };
+
+ 0
+ } else {
+ HAS_WAITERS
+ };
+
+ self.state.store(set_on_release, Ordering::Release);
+
+ // Now wake any waiters still left in the wake list.
+ for w in wake_list {
+ w.wake();
+ }
+
+ mem::drop(old_waiter);
+ true
+ }
+ }
+}
+
+unsafe impl Send for Condvar {}
+unsafe impl Sync for Condvar {}
+
+impl Default for Condvar {
+ fn default() -> Self {
+ Self::new()
+ }
+}
+
+// Scan `waiters` and return all waiters that should be woken up. If all waiters in the returned
+// wait list are readers then the returned bool will be true.
+//
+// If the first waiter is trying to acquire a shared lock, then all waiters in the list that are
+// waiting for a shared lock are also woken up. In addition one writer is woken up, if possible.
+//
+// If the first waiter is trying to acquire an exclusive lock, then only that waiter is returned and
+// the rest of the list is not scanned.
+fn get_wake_list(waiters: &mut WaiterList) -> (WaiterList, bool) {
+ let mut to_wake = WaiterList::new(WaiterAdapter::new());
+ let mut cursor = waiters.front_mut();
+
+ let mut waking_readers = false;
+ let mut all_readers = true;
+ while let Some(w) = cursor.get() {
+ match w.kind() {
+ WaiterKind::Exclusive if !waking_readers => {
+ // This is the first waiter and it's a writer. No need to check the other waiters.
+ // Also mark the waiter as having been removed from the Condvar's waiter list.
+ let waiter = cursor.remove().unwrap();
+ waiter.set_waiting_for(WaitingFor::None);
+ to_wake.push_back(waiter);
+ all_readers = false;
+ break;
+ }
+
+ WaiterKind::Shared => {
+ // This is a reader and the first waiter in the list was not a writer so wake up all
+ // the readers in the wait list.
+ let waiter = cursor.remove().unwrap();
+ waiter.set_waiting_for(WaitingFor::None);
+ to_wake.push_back(waiter);
+ waking_readers = true;
+ }
+
+ WaiterKind::Exclusive => {
+ debug_assert!(waking_readers);
+ if all_readers {
+ // We are waking readers but we need to ensure that at least one writer is woken
+ // up. Since we haven't yet woken up a writer, wake up this one.
+ let waiter = cursor.remove().unwrap();
+ waiter.set_waiting_for(WaitingFor::None);
+ to_wake.push_back(waiter);
+ all_readers = false;
+ } else {
+ // We are waking readers and have already woken one writer. Skip this one.
+ cursor.move_next();
+ }
+ }
+ }
+ }
+
+ (to_wake, all_readers)
+}
+
+fn cancel_waiter(cv: usize, waiter: &Waiter, wake_next: bool) -> bool {
+ let condvar = cv as *const Condvar;
+
+ // Safe because the thread that owns the waiter being canceled must also own a reference to the
+ // Condvar, which guarantees that this pointer is valid.
+ unsafe { (*condvar).cancel_waiter(waiter, wake_next) }
+}
+
+#[cfg(test)]
+mod test {
+ use super::*;
+
+ use std::future::Future;
+ use std::mem;
+ use std::ptr;
+ use std::rc::Rc;
+ use std::sync::mpsc::{channel, Sender};
+ use std::sync::Arc;
+ use std::task::{Context, Poll};
+ use std::thread::{self, JoinHandle};
+ use std::time::Duration;
+
+ use futures::channel::oneshot;
+ use futures::task::{waker_ref, ArcWake};
+ use futures::{select, FutureExt};
+ use futures_executor::{LocalPool, LocalSpawner, ThreadPool};
+ use futures_util::task::LocalSpawnExt;
+
+ use crate::sync::{block_on, Mutex};
+
+ // Dummy waker used when we want to manually drive futures.
+ struct TestWaker;
+ impl ArcWake for TestWaker {
+ fn wake_by_ref(_arc_self: &Arc<Self>) {}
+ }
+
+ #[test]
+ fn smoke() {
+ let cv = Condvar::new();
+ cv.notify_one();
+ cv.notify_all();
+ }
+
+ #[test]
+ fn notify_one() {
+ let mu = Arc::new(Mutex::new(()));
+ let cv = Arc::new(Condvar::new());
+
+ let mu2 = mu.clone();
+ let cv2 = cv.clone();
+
+ let guard = block_on(mu.lock());
+ thread::spawn(move || {
+ let _g = block_on(mu2.lock());
+ cv2.notify_one();
+ });
+
+ let guard = block_on(cv.wait(guard));
+ mem::drop(guard);
+ }
+
+ #[test]
+ fn multi_mutex() {
+ const NUM_THREADS: usize = 5;
+
+ let mu = Arc::new(Mutex::new(false));
+ let cv = Arc::new(Condvar::new());
+
+ let mut threads = Vec::with_capacity(NUM_THREADS);
+ for _ in 0..NUM_THREADS {
+ let mu = mu.clone();
+ let cv = cv.clone();
+
+ threads.push(thread::spawn(move || {
+ let mut ready = block_on(mu.lock());
+ while !*ready {
+ ready = block_on(cv.wait(ready));
+ }
+ }));
+ }
+
+ let mut g = block_on(mu.lock());
+ *g = true;
+ mem::drop(g);
+ cv.notify_all();
+
+ threads
+ .into_iter()
+ .map(JoinHandle::join)
+ .collect::<thread::Result<()>>()
+ .expect("Failed to join threads");
+
+ // Now use the Condvar with a different mutex.
+ let alt_mu = Arc::new(Mutex::new(None));
+ let alt_mu2 = alt_mu.clone();
+ let cv2 = cv.clone();
+ let handle = thread::spawn(move || {
+ let mut g = block_on(alt_mu2.lock());
+ while g.is_none() {
+ g = block_on(cv2.wait(g));
+ }
+ });
+
+ let mut alt_g = block_on(alt_mu.lock());
+ *alt_g = Some(());
+ mem::drop(alt_g);
+ cv.notify_all();
+
+ handle
+ .join()
+ .expect("Failed to join thread alternate mutex");
+ }
+
+ #[test]
+ fn notify_one_single_thread_async() {
+ async fn notify(mu: Rc<Mutex<()>>, cv: Rc<Condvar>) {
+ let _g = mu.lock().await;
+ cv.notify_one();
+ }
+
+ async fn wait(mu: Rc<Mutex<()>>, cv: Rc<Condvar>, spawner: LocalSpawner) {
+ let mu2 = Rc::clone(&mu);
+ let cv2 = Rc::clone(&cv);
+
+ let g = mu.lock().await;
+ // Has to be spawned _after_ acquiring the lock to prevent a race
+ // where the notify happens before the waiter has acquired the lock.
+ spawner
+ .spawn_local(notify(mu2, cv2))
+ .expect("Failed to spawn `notify` task");
+ let _g = cv.wait(g).await;
+ }
+
+ let mut ex = LocalPool::new();
+ let spawner = ex.spawner();
+
+ let mu = Rc::new(Mutex::new(()));
+ let cv = Rc::new(Condvar::new());
+
+ spawner
+ .spawn_local(wait(mu, cv, spawner.clone()))
+ .expect("Failed to spawn `wait` task");
+
+ ex.run();
+ }
+
+ #[test]
+ fn notify_one_multi_thread_async() {
+ async fn notify(mu: Arc<Mutex<()>>, cv: Arc<Condvar>) {
+ let _g = mu.lock().await;
+ cv.notify_one();
+ }
+
+ async fn wait(mu: Arc<Mutex<()>>, cv: Arc<Condvar>, tx: Sender<()>, pool: ThreadPool) {
+ let mu2 = Arc::clone(&mu);
+ let cv2 = Arc::clone(&cv);
+
+ let g = mu.lock().await;
+ // Has to be spawned _after_ acquiring the lock to prevent a race
+ // where the notify happens before the waiter has acquired the lock.
+ pool.spawn_ok(notify(mu2, cv2));
+ let _g = cv.wait(g).await;
+
+ tx.send(()).expect("Failed to send completion notification");
+ }
+
+ let ex = ThreadPool::new().expect("Failed to create ThreadPool");
+
+ let mu = Arc::new(Mutex::new(()));
+ let cv = Arc::new(Condvar::new());
+
+ let (tx, rx) = channel();
+ ex.spawn_ok(wait(mu, cv, tx, ex.clone()));
+
+ rx.recv_timeout(Duration::from_secs(5))
+ .expect("Failed to receive completion notification");
+ }
+
+ #[test]
+ fn notify_one_with_cancel() {
+ const TASKS: usize = 17;
+ const OBSERVERS: usize = 7;
+ const ITERATIONS: usize = 103;
+
+ async fn observe(mu: &Arc<Mutex<usize>>, cv: &Arc<Condvar>) {
+ let mut count = mu.read_lock().await;
+ while *count == 0 {
+ count = cv.wait_read(count).await;
+ }
+ let _ = unsafe { ptr::read_volatile(&*count as *const usize) };
+ }
+
+ async fn decrement(mu: &Arc<Mutex<usize>>, cv: &Arc<Condvar>) {
+ let mut count = mu.lock().await;
+ while *count == 0 {
+ count = cv.wait(count).await;
+ }
+ *count -= 1;
+ }
+
+ async fn increment(mu: Arc<Mutex<usize>>, cv: Arc<Condvar>, done: Sender<()>) {
+ for _ in 0..TASKS * OBSERVERS * ITERATIONS {
+ *mu.lock().await += 1;
+ cv.notify_one();
+ }
+
+ done.send(()).expect("Failed to send completion message");
+ }
+
+ async fn observe_either(
+ mu: Arc<Mutex<usize>>,
+ cv: Arc<Condvar>,
+ alt_mu: Arc<Mutex<usize>>,
+ alt_cv: Arc<Condvar>,
+ done: Sender<()>,
+ ) {
+ for _ in 0..ITERATIONS {
+ select! {
+ () = observe(&mu, &cv).fuse() => {},
+ () = observe(&alt_mu, &alt_cv).fuse() => {},
+ }
+ }
+
+ done.send(()).expect("Failed to send completion message");
+ }
+
+ async fn decrement_either(
+ mu: Arc<Mutex<usize>>,
+ cv: Arc<Condvar>,
+ alt_mu: Arc<Mutex<usize>>,
+ alt_cv: Arc<Condvar>,
+ done: Sender<()>,
+ ) {
+ for _ in 0..ITERATIONS {
+ select! {
+ () = decrement(&mu, &cv).fuse() => {},
+ () = decrement(&alt_mu, &alt_cv).fuse() => {},
+ }
+ }
+
+ done.send(()).expect("Failed to send completion message");
+ }
+
+ let ex = ThreadPool::new().expect("Failed to create ThreadPool");
+
+ let mu = Arc::new(Mutex::new(0usize));
+ let alt_mu = Arc::new(Mutex::new(0usize));
+
+ let cv = Arc::new(Condvar::new());
+ let alt_cv = Arc::new(Condvar::new());
+
+ let (tx, rx) = channel();
+ for _ in 0..TASKS {
+ ex.spawn_ok(decrement_either(
+ Arc::clone(&mu),
+ Arc::clone(&cv),
+ Arc::clone(&alt_mu),
+ Arc::clone(&alt_cv),
+ tx.clone(),
+ ));
+ }
+
+ for _ in 0..OBSERVERS {
+ ex.spawn_ok(observe_either(
+ Arc::clone(&mu),
+ Arc::clone(&cv),
+ Arc::clone(&alt_mu),
+ Arc::clone(&alt_cv),
+ tx.clone(),
+ ));
+ }
+
+ ex.spawn_ok(increment(Arc::clone(&mu), Arc::clone(&cv), tx.clone()));
+ ex.spawn_ok(increment(Arc::clone(&alt_mu), Arc::clone(&alt_cv), tx));
+
+ for _ in 0..TASKS + OBSERVERS + 2 {
+ if let Err(e) = rx.recv_timeout(Duration::from_secs(10)) {
+ panic!("Error while waiting for threads to complete: {}", e);
+ }
+ }
+
+ assert_eq!(
+ *block_on(mu.read_lock()) + *block_on(alt_mu.read_lock()),
+ (TASKS * OBSERVERS * ITERATIONS * 2) - (TASKS * ITERATIONS)
+ );
+ assert_eq!(cv.state.load(Ordering::Relaxed), 0);
+ assert_eq!(alt_cv.state.load(Ordering::Relaxed), 0);
+ }
+
+ #[test]
+ fn notify_all_with_cancel() {
+ const TASKS: usize = 17;
+ const ITERATIONS: usize = 103;
+
+ async fn decrement(mu: &Arc<Mutex<usize>>, cv: &Arc<Condvar>) {
+ let mut count = mu.lock().await;
+ while *count == 0 {
+ count = cv.wait(count).await;
+ }
+ *count -= 1;
+ }
+
+ async fn increment(mu: Arc<Mutex<usize>>, cv: Arc<Condvar>, done: Sender<()>) {
+ for _ in 0..TASKS * ITERATIONS {
+ *mu.lock().await += 1;
+ cv.notify_all();
+ }
+
+ done.send(()).expect("Failed to send completion message");
+ }
+
+ async fn decrement_either(
+ mu: Arc<Mutex<usize>>,
+ cv: Arc<Condvar>,
+ alt_mu: Arc<Mutex<usize>>,
+ alt_cv: Arc<Condvar>,
+ done: Sender<()>,
+ ) {
+ for _ in 0..ITERATIONS {
+ select! {
+ () = decrement(&mu, &cv).fuse() => {},
+ () = decrement(&alt_mu, &alt_cv).fuse() => {},
+ }
+ }
+
+ done.send(()).expect("Failed to send completion message");
+ }
+
+ let ex = ThreadPool::new().expect("Failed to create ThreadPool");
+
+ let mu = Arc::new(Mutex::new(0usize));
+ let alt_mu = Arc::new(Mutex::new(0usize));
+
+ let cv = Arc::new(Condvar::new());
+ let alt_cv = Arc::new(Condvar::new());
+
+ let (tx, rx) = channel();
+ for _ in 0..TASKS {
+ ex.spawn_ok(decrement_either(
+ Arc::clone(&mu),
+ Arc::clone(&cv),
+ Arc::clone(&alt_mu),
+ Arc::clone(&alt_cv),
+ tx.clone(),
+ ));
+ }
+
+ ex.spawn_ok(increment(Arc::clone(&mu), Arc::clone(&cv), tx.clone()));
+ ex.spawn_ok(increment(Arc::clone(&alt_mu), Arc::clone(&alt_cv), tx));
+
+ for _ in 0..TASKS + 2 {
+ if let Err(e) = rx.recv_timeout(Duration::from_secs(10)) {
+ panic!("Error while waiting for threads to complete: {}", e);
+ }
+ }
+
+ assert_eq!(
+ *block_on(mu.read_lock()) + *block_on(alt_mu.read_lock()),
+ TASKS * ITERATIONS
+ );
+ assert_eq!(cv.state.load(Ordering::Relaxed), 0);
+ assert_eq!(alt_cv.state.load(Ordering::Relaxed), 0);
+ }
+ #[test]
+ fn notify_all() {
+ const THREADS: usize = 13;
+
+ let mu = Arc::new(Mutex::new(0));
+ let cv = Arc::new(Condvar::new());
+ let (tx, rx) = channel();
+
+ let mut threads = Vec::with_capacity(THREADS);
+ for _ in 0..THREADS {
+ let mu2 = mu.clone();
+ let cv2 = cv.clone();
+ let tx2 = tx.clone();
+
+ threads.push(thread::spawn(move || {
+ let mut count = block_on(mu2.lock());
+ *count += 1;
+ if *count == THREADS {
+ tx2.send(()).unwrap();
+ }
+
+ while *count != 0 {
+ count = block_on(cv2.wait(count));
+ }
+ }));
+ }
+
+ mem::drop(tx);
+
+ // Wait till all threads have started.
+ rx.recv_timeout(Duration::from_secs(5)).unwrap();
+
+ let mut count = block_on(mu.lock());
+ *count = 0;
+ mem::drop(count);
+ cv.notify_all();
+
+ for t in threads {
+ t.join().unwrap();
+ }
+ }
+
+ #[test]
+ fn notify_all_single_thread_async() {
+ const TASKS: usize = 13;
+
+ async fn reset(mu: Rc<Mutex<usize>>, cv: Rc<Condvar>) {
+ let mut count = mu.lock().await;
+ *count = 0;
+ cv.notify_all();
+ }
+
+ async fn watcher(mu: Rc<Mutex<usize>>, cv: Rc<Condvar>, spawner: LocalSpawner) {
+ let mut count = mu.lock().await;
+ *count += 1;
+ if *count == TASKS {
+ spawner
+ .spawn_local(reset(mu.clone(), cv.clone()))
+ .expect("Failed to spawn reset task");
+ }
+
+ while *count != 0 {
+ count = cv.wait(count).await;
+ }
+ }
+
+ let mut ex = LocalPool::new();
+ let spawner = ex.spawner();
+
+ let mu = Rc::new(Mutex::new(0));
+ let cv = Rc::new(Condvar::new());
+
+ for _ in 0..TASKS {
+ spawner
+ .spawn_local(watcher(mu.clone(), cv.clone(), spawner.clone()))
+ .expect("Failed to spawn watcher task");
+ }
+
+ ex.run();
+ }
+
+ #[test]
+ fn notify_all_multi_thread_async() {
+ const TASKS: usize = 13;
+
+ async fn reset(mu: Arc<Mutex<usize>>, cv: Arc<Condvar>) {
+ let mut count = mu.lock().await;
+ *count = 0;
+ cv.notify_all();
+ }
+
+ async fn watcher(
+ mu: Arc<Mutex<usize>>,
+ cv: Arc<Condvar>,
+ pool: ThreadPool,
+ tx: Sender<()>,
+ ) {
+ let mut count = mu.lock().await;
+ *count += 1;
+ if *count == TASKS {
+ pool.spawn_ok(reset(mu.clone(), cv.clone()));
+ }
+
+ while *count != 0 {
+ count = cv.wait(count).await;
+ }
+
+ tx.send(()).expect("Failed to send completion notification");
+ }
+
+ let pool = ThreadPool::new().expect("Failed to create ThreadPool");
+
+ let mu = Arc::new(Mutex::new(0));
+ let cv = Arc::new(Condvar::new());
+
+ let (tx, rx) = channel();
+ for _ in 0..TASKS {
+ pool.spawn_ok(watcher(mu.clone(), cv.clone(), pool.clone(), tx.clone()));
+ }
+
+ for _ in 0..TASKS {
+ rx.recv_timeout(Duration::from_secs(5))
+ .expect("Failed to receive completion notification");
+ }
+ }
+
+ #[test]
+ fn wake_all_readers() {
+ async fn read(mu: Arc<Mutex<bool>>, cv: Arc<Condvar>) {
+ let mut ready = mu.read_lock().await;
+ while !*ready {
+ ready = cv.wait_read(ready).await;
+ }
+ }
+
+ let mu = Arc::new(Mutex::new(false));
+ let cv = Arc::new(Condvar::new());
+ let mut readers = [
+ Box::pin(read(mu.clone(), cv.clone())),
+ Box::pin(read(mu.clone(), cv.clone())),
+ Box::pin(read(mu.clone(), cv.clone())),
+ Box::pin(read(mu.clone(), cv.clone())),
+ ];
+
+ let arc_waker = Arc::new(TestWaker);
+ let waker = waker_ref(&arc_waker);
+ let mut cx = Context::from_waker(&waker);
+
+ // First have all the readers wait on the Condvar.
+ for r in &mut readers {
+ if let Poll::Ready(()) = r.as_mut().poll(&mut cx) {
+ panic!("reader unexpectedly ready");
+ }
+ }
+
+ assert_eq!(cv.state.load(Ordering::Relaxed) & HAS_WAITERS, HAS_WAITERS);
+
+ // Now make the condition true and notify the condvar. Even though we will call notify_one,
+ // all the readers should be woken up.
+ *block_on(mu.lock()) = true;
+ cv.notify_one();
+
+ assert_eq!(cv.state.load(Ordering::Relaxed), 0);
+
+ // All readers should now be able to complete.
+ for r in &mut readers {
+ if let Poll::Pending = r.as_mut().poll(&mut cx) {
+ panic!("reader unable to complete");
+ }
+ }
+ }
+
+ #[test]
+ fn cancel_before_notify() {
+ async fn dec(mu: Arc<Mutex<usize>>, cv: Arc<Condvar>) {
+ let mut count = mu.lock().await;
+
+ while *count == 0 {
+ count = cv.wait(count).await;
+ }
+
+ *count -= 1;
+ }
+
+ let mu = Arc::new(Mutex::new(0));
+ let cv = Arc::new(Condvar::new());
+
+ let arc_waker = Arc::new(TestWaker);
+ let waker = waker_ref(&arc_waker);
+ let mut cx = Context::from_waker(&waker);
+
+ let mut fut1 = Box::pin(dec(mu.clone(), cv.clone()));
+ let mut fut2 = Box::pin(dec(mu.clone(), cv.clone()));
+
+ if let Poll::Ready(()) = fut1.as_mut().poll(&mut cx) {
+ panic!("future unexpectedly ready");
+ }
+ if let Poll::Ready(()) = fut2.as_mut().poll(&mut cx) {
+ panic!("future unexpectedly ready");
+ }
+ assert_eq!(cv.state.load(Ordering::Relaxed) & HAS_WAITERS, HAS_WAITERS);
+
+ *block_on(mu.lock()) = 2;
+ // Drop fut1 before notifying the cv.
+ mem::drop(fut1);
+ cv.notify_one();
+
+ // fut2 should now be ready to complete.
+ assert_eq!(cv.state.load(Ordering::Relaxed), 0);
+
+ if let Poll::Pending = fut2.as_mut().poll(&mut cx) {
+ panic!("future unable to complete");
+ }
+
+ assert_eq!(*block_on(mu.lock()), 1);
+ }
+
+ #[test]
+ fn cancel_after_notify() {
+ async fn dec(mu: Arc<Mutex<usize>>, cv: Arc<Condvar>) {
+ let mut count = mu.lock().await;
+
+ while *count == 0 {
+ count = cv.wait(count).await;
+ }
+
+ *count -= 1;
+ }
+
+ let mu = Arc::new(Mutex::new(0));
+ let cv = Arc::new(Condvar::new());
+
+ let arc_waker = Arc::new(TestWaker);
+ let waker = waker_ref(&arc_waker);
+ let mut cx = Context::from_waker(&waker);
+
+ let mut fut1 = Box::pin(dec(mu.clone(), cv.clone()));
+ let mut fut2 = Box::pin(dec(mu.clone(), cv.clone()));
+
+ if let Poll::Ready(()) = fut1.as_mut().poll(&mut cx) {
+ panic!("future unexpectedly ready");
+ }
+ if let Poll::Ready(()) = fut2.as_mut().poll(&mut cx) {
+ panic!("future unexpectedly ready");
+ }
+ assert_eq!(cv.state.load(Ordering::Relaxed) & HAS_WAITERS, HAS_WAITERS);
+
+ *block_on(mu.lock()) = 2;
+ cv.notify_one();
+
+ // fut1 should now be ready to complete. Drop it before polling. This should wake up fut2.
+ mem::drop(fut1);
+ assert_eq!(cv.state.load(Ordering::Relaxed), 0);
+
+ if let Poll::Pending = fut2.as_mut().poll(&mut cx) {
+ panic!("future unable to complete");
+ }
+
+ assert_eq!(*block_on(mu.lock()), 1);
+ }
+
+ #[test]
+ fn cancel_after_transfer() {
+ async fn dec(mu: Arc<Mutex<usize>>, cv: Arc<Condvar>) {
+ let mut count = mu.lock().await;
+
+ while *count == 0 {
+ count = cv.wait(count).await;
+ }
+
+ *count -= 1;
+ }
+
+ let mu = Arc::new(Mutex::new(0));
+ let cv = Arc::new(Condvar::new());
+
+ let arc_waker = Arc::new(TestWaker);
+ let waker = waker_ref(&arc_waker);
+ let mut cx = Context::from_waker(&waker);
+
+ let mut fut1 = Box::pin(dec(mu.clone(), cv.clone()));
+ let mut fut2 = Box::pin(dec(mu.clone(), cv.clone()));
+
+ if let Poll::Ready(()) = fut1.as_mut().poll(&mut cx) {
+ panic!("future unexpectedly ready");
+ }
+ if let Poll::Ready(()) = fut2.as_mut().poll(&mut cx) {
+ panic!("future unexpectedly ready");
+ }
+ assert_eq!(cv.state.load(Ordering::Relaxed) & HAS_WAITERS, HAS_WAITERS);
+
+ let mut count = block_on(mu.lock());
+ *count = 2;
+
+ // Notify the cv while holding the lock. Only transfer one waiter.
+ cv.notify_one();
+ assert_eq!(cv.state.load(Ordering::Relaxed) & HAS_WAITERS, HAS_WAITERS);
+
+ // Drop the lock and then the future. This should not cause fut2 to become runnable as it
+ // should still be in the Condvar's wait queue.
+ mem::drop(count);
+ mem::drop(fut1);
+
+ if let Poll::Ready(()) = fut2.as_mut().poll(&mut cx) {
+ panic!("future unexpectedly ready");
+ }
+
+ // Now wake up fut2. Since the lock isn't held, it should wake up immediately.
+ cv.notify_one();
+ if let Poll::Pending = fut2.as_mut().poll(&mut cx) {
+ panic!("future unable to complete");
+ }
+
+ assert_eq!(*block_on(mu.lock()), 1);
+ }
+
+ #[test]
+ fn cancel_after_transfer_and_wake() {
+ async fn dec(mu: Arc<Mutex<usize>>, cv: Arc<Condvar>) {
+ let mut count = mu.lock().await;
+
+ while *count == 0 {
+ count = cv.wait(count).await;
+ }
+
+ *count -= 1;
+ }
+
+ let mu = Arc::new(Mutex::new(0));
+ let cv = Arc::new(Condvar::new());
+
+ let arc_waker = Arc::new(TestWaker);
+ let waker = waker_ref(&arc_waker);
+ let mut cx = Context::from_waker(&waker);
+
+ let mut fut1 = Box::pin(dec(mu.clone(), cv.clone()));
+ let mut fut2 = Box::pin(dec(mu.clone(), cv.clone()));
+
+ if let Poll::Ready(()) = fut1.as_mut().poll(&mut cx) {
+ panic!("future unexpectedly ready");
+ }
+ if let Poll::Ready(()) = fut2.as_mut().poll(&mut cx) {
+ panic!("future unexpectedly ready");
+ }
+ assert_eq!(cv.state.load(Ordering::Relaxed) & HAS_WAITERS, HAS_WAITERS);
+
+ let mut count = block_on(mu.lock());
+ *count = 2;
+
+ // Notify the cv while holding the lock. This should transfer both waiters to the mutex's
+ // wait queue.
+ cv.notify_all();
+ assert_eq!(cv.state.load(Ordering::Relaxed), 0);
+
+ mem::drop(count);
+
+ mem::drop(fut1);
+
+ if let Poll::Pending = fut2.as_mut().poll(&mut cx) {
+ panic!("future unable to complete");
+ }
+
+ assert_eq!(*block_on(mu.lock()), 1);
+ }
+
+ #[test]
+ fn timed_wait() {
+ async fn wait_deadline(
+ mu: Arc<Mutex<usize>>,
+ cv: Arc<Condvar>,
+ timeout: oneshot::Receiver<()>,
+ ) {
+ let mut count = mu.lock().await;
+
+ if *count == 0 {
+ let mut rx = timeout.fuse();
+
+ while *count == 0 {
+ select! {
+ res = rx => {
+ if let Err(e) = res {
+ panic!("Error while receiving timeout notification: {}", e);
+ }
+
+ return;
+ },
+ c = cv.wait(count).fuse() => count = c,
+ }
+ }
+ }
+
+ *count += 1;
+ }
+
+ let mu = Arc::new(Mutex::new(0));
+ let cv = Arc::new(Condvar::new());
+
+ let arc_waker = Arc::new(TestWaker);
+ let waker = waker_ref(&arc_waker);
+ let mut cx = Context::from_waker(&waker);
+
+ let (tx, rx) = oneshot::channel();
+ let mut wait = Box::pin(wait_deadline(mu.clone(), cv.clone(), rx));
+
+ if let Poll::Ready(()) = wait.as_mut().poll(&mut cx) {
+ panic!("wait_deadline unexpectedly ready");
+ }
+
+ assert_eq!(cv.state.load(Ordering::Relaxed), HAS_WAITERS);
+
+ // Signal the channel, which should cancel the wait.
+ tx.send(()).expect("Failed to send wakeup");
+
+ // Wait for the timer to run out.
+ if let Poll::Pending = wait.as_mut().poll(&mut cx) {
+ panic!("wait_deadline unable to complete in time");
+ }
+
+ assert_eq!(cv.state.load(Ordering::Relaxed), 0);
+ assert_eq!(*block_on(mu.lock()), 0);
+ }
+}
diff --git a/src/sync/mu.rs b/src/sync/mu.rs
new file mode 100644
index 0000000..26735d1
--- /dev/null
+++ b/src/sync/mu.rs
@@ -0,0 +1,2400 @@
+// Copyright 2020 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+use std::cell::UnsafeCell;
+use std::mem;
+use std::ops::{Deref, DerefMut};
+use std::sync::atomic::{spin_loop_hint, AtomicUsize, Ordering};
+use std::sync::Arc;
+use std::thread::yield_now;
+
+use crate::sync::waiter::{Kind as WaiterKind, Waiter, WaiterAdapter, WaiterList, WaitingFor};
+
+// Set when the mutex is exclusively locked.
+const LOCKED: usize = 1 << 0;
+// Set when there are one or more threads waiting to acquire the lock.
+const HAS_WAITERS: usize = 1 << 1;
+// Set when a thread has been woken up from the wait queue. Cleared when that thread either acquires
+// the lock or adds itself back into the wait queue. Used to prevent unnecessary wake ups when a
+// thread has been removed from the wait queue but has not gotten CPU time yet.
+const DESIGNATED_WAKER: usize = 1 << 2;
+// Used to provide exclusive access to the `waiters` field in `Mutex`. Should only be held while
+// modifying the waiter list.
+const SPINLOCK: usize = 1 << 3;
+// Set when a thread that wants an exclusive lock adds itself to the wait queue. New threads
+// attempting to acquire a shared lock will be preventing from getting it when this bit is set.
+// However, this bit is ignored once a thread has gone through the wait queue at least once.
+const WRITER_WAITING: usize = 1 << 4;
+// Set when a thread has gone through the wait queue many times but has failed to acquire the lock
+// every time it is woken up. When this bit is set, all other threads are prevented from acquiring
+// the lock until the thread that set the `LONG_WAIT` bit has acquired the lock.
+const LONG_WAIT: usize = 1 << 5;
+// The bit that is added to the mutex state in order to acquire a shared lock. Since more than one
+// thread can acquire a shared lock, we cannot use a single bit. Instead we use all the remaining
+// bits in the state to track the number of threads that have acquired a shared lock.
+const READ_LOCK: usize = 1 << 8;
+// Mask used for checking if any threads currently hold a shared lock.
+const READ_MASK: usize = !0xff;
+// Mask used to check if the lock is held in either shared or exclusive mode.
+const ANY_LOCK: usize = LOCKED | READ_MASK;
+
+// The number of times the thread should just spin and attempt to re-acquire the lock.
+const SPIN_THRESHOLD: usize = 7;
+
+// The number of times the thread needs to go through the wait queue before it sets the `LONG_WAIT`
+// bit and forces all other threads to wait for it to acquire the lock. This value is set relatively
+// high so that we don't lose the benefit of having running threads unless it is absolutely
+// necessary.
+const LONG_WAIT_THRESHOLD: usize = 19;
+
+// Common methods between shared and exclusive locks.
+trait Kind {
+ // The bits that must be zero for the thread to acquire this kind of lock. If any of these bits
+ // are not zero then the thread will first spin and retry a few times before adding itself to
+ // the wait queue.
+ fn zero_to_acquire() -> usize;
+
+ // The bit that must be added in order to acquire this kind of lock. This should either be
+ // `LOCKED` or `READ_LOCK`.
+ fn add_to_acquire() -> usize;
+
+ // The bits that should be set when a thread adds itself to the wait queue while waiting to
+ // acquire this kind of lock.
+ fn set_when_waiting() -> usize;
+
+ // The bits that should be cleared when a thread acquires this kind of lock.
+ fn clear_on_acquire() -> usize;
+
+ // The waiter that a thread should use when waiting to acquire this kind of lock.
+ fn new_waiter(raw: &RawMutex) -> Arc<Waiter>;
+}
+
+// A lock type for shared read-only access to the data. More than one thread may hold this kind of
+// lock simultaneously.
+struct Shared;
+
+impl Kind for Shared {
+ fn zero_to_acquire() -> usize {
+ LOCKED | WRITER_WAITING | LONG_WAIT
+ }
+
+ fn add_to_acquire() -> usize {
+ READ_LOCK
+ }
+
+ fn set_when_waiting() -> usize {
+ 0
+ }
+
+ fn clear_on_acquire() -> usize {
+ 0
+ }
+
+ fn new_waiter(raw: &RawMutex) -> Arc<Waiter> {
+ Arc::new(Waiter::new(
+ WaiterKind::Shared,
+ cancel_waiter,
+ raw as *const RawMutex as usize,
+ WaitingFor::Mutex,
+ ))
+ }
+}
+
+// A lock type for mutually exclusive read-write access to the data. Only one thread can hold this
+// kind of lock at a time.
+struct Exclusive;
+
+impl Kind for Exclusive {
+ fn zero_to_acquire() -> usize {
+ LOCKED | READ_MASK | LONG_WAIT
+ }
+
+ fn add_to_acquire() -> usize {
+ LOCKED
+ }
+
+ fn set_when_waiting() -> usize {
+ WRITER_WAITING
+ }
+
+ fn clear_on_acquire() -> usize {
+ WRITER_WAITING
+ }
+
+ fn new_waiter(raw: &RawMutex) -> Arc<Waiter> {
+ Arc::new(Waiter::new(
+ WaiterKind::Exclusive,
+ cancel_waiter,
+ raw as *const RawMutex as usize,
+ WaitingFor::Mutex,
+ ))
+ }
+}
+
+// Scan `waiters` and return the ones that should be woken up. Also returns any bits that should be
+// set in the mutex state when the current thread releases the spin lock protecting the waiter list.
+//
+// If the first waiter is trying to acquire a shared lock, then all waiters in the list that are
+// waiting for a shared lock are also woken up. If any waiters waiting for an exclusive lock are
+// found when iterating through the list, then the returned `usize` contains the `WRITER_WAITING`
+// bit, which should be set when the thread releases the spin lock.
+//
+// If the first waiter is trying to acquire an exclusive lock, then only that waiter is returned and
+// no bits are set in the returned `usize`.
+fn get_wake_list(waiters: &mut WaiterList) -> (WaiterList, usize) {
+ let mut to_wake = WaiterList::new(WaiterAdapter::new());
+ let mut set_on_release = 0;
+ let mut cursor = waiters.front_mut();
+
+ let mut waking_readers = false;
+ while let Some(w) = cursor.get() {
+ match w.kind() {
+ WaiterKind::Exclusive if !waking_readers => {
+ // This is the first waiter and it's a writer. No need to check the other waiters.
+ let waiter = cursor.remove().unwrap();
+ waiter.set_waiting_for(WaitingFor::None);
+ to_wake.push_back(waiter);
+ break;
+ }
+
+ WaiterKind::Shared => {
+ // This is a reader and the first waiter in the list was not a writer so wake up all
+ // the readers in the wait list.
+ let waiter = cursor.remove().unwrap();
+ waiter.set_waiting_for(WaitingFor::None);
+ to_wake.push_back(waiter);
+ waking_readers = true;
+ }
+
+ WaiterKind::Exclusive => {
+ // We found a writer while looking for more readers to wake up. Set the
+ // `WRITER_WAITING` bit to prevent any new readers from acquiring the lock. All
+ // readers currently in the wait list will ignore this bit since they already waited
+ // once.
+ set_on_release |= WRITER_WAITING;
+ cursor.move_next();
+ }
+ }
+ }
+
+ (to_wake, set_on_release)
+}
+
+#[inline]
+fn cpu_relax(iterations: usize) {
+ for _ in 0..iterations {
+ spin_loop_hint();
+ }
+}
+
+pub(crate) struct RawMutex {
+ state: AtomicUsize,
+ waiters: UnsafeCell<WaiterList>,
+}
+
+impl RawMutex {
+ pub fn new() -> RawMutex {
+ RawMutex {
+ state: AtomicUsize::new(0),
+ waiters: UnsafeCell::new(WaiterList::new(WaiterAdapter::new())),
+ }
+ }
+
+ #[inline]
+ pub async fn lock(&self) {
+ match self
+ .state
+ .compare_exchange_weak(0, LOCKED, Ordering::Acquire, Ordering::Relaxed)
+ {
+ Ok(_) => {}
+ Err(oldstate) => {
+ // If any bits that should be zero are not zero or if we fail to acquire the lock
+ // with a single compare_exchange then go through the slow path.
+ if (oldstate & Exclusive::zero_to_acquire()) != 0
+ || self
+ .state
+ .compare_exchange_weak(
+ oldstate,
+ (oldstate + Exclusive::add_to_acquire())
+ & !Exclusive::clear_on_acquire(),
+ Ordering::Acquire,
+ Ordering::Relaxed,
+ )
+ .is_err()
+ {
+ self.lock_slow::<Exclusive>(0, 0).await;
+ }
+ }
+ }
+ }
+
+ #[inline]
+ pub async fn read_lock(&self) {
+ match self
+ .state
+ .compare_exchange_weak(0, READ_LOCK, Ordering::Acquire, Ordering::Relaxed)
+ {
+ Ok(_) => {}
+ Err(oldstate) => {
+ if (oldstate & Shared::zero_to_acquire()) != 0
+ || self
+ .state
+ .compare_exchange_weak(
+ oldstate,
+ (oldstate + Shared::add_to_acquire()) & !Shared::clear_on_acquire(),
+ Ordering::Acquire,
+ Ordering::Relaxed,
+ )
+ .is_err()
+ {
+ self.lock_slow::<Shared>(0, 0).await;
+ }
+ }
+ }
+ }
+
+ // Slow path for acquiring the lock. `clear` should contain any bits that need to be cleared
+ // when the lock is acquired. Any bits set in `zero_mask` are cleared from the bits returned by
+ // `K::zero_to_acquire()`.
+ #[cold]
+ async fn lock_slow<K: Kind>(&self, mut clear: usize, zero_mask: usize) {
+ let mut zero_to_acquire = K::zero_to_acquire() & !zero_mask;
+
+ let mut spin_count = 0;
+ let mut wait_count = 0;
+ let mut waiter = None;
+ loop {
+ let oldstate = self.state.load(Ordering::Relaxed);
+ // If all the bits in `zero_to_acquire` are actually zero then try to acquire the lock
+ // directly.
+ if (oldstate & zero_to_acquire) == 0 {
+ if self
+ .state
+ .compare_exchange_weak(
+ oldstate,
+ (oldstate + K::add_to_acquire()) & !(clear | K::clear_on_acquire()),
+ Ordering::Acquire,
+ Ordering::Relaxed,
+ )
+ .is_ok()
+ {
+ return;
+ }
+ } else if (oldstate & SPINLOCK) == 0 {
+ // The mutex is locked and the spin lock is available. Try to add this thread
+ // to the waiter queue.
+ let w = waiter.get_or_insert_with(|| K::new_waiter(self));
+ w.reset(WaitingFor::Mutex);
+
+ if self
+ .state
+ .compare_exchange_weak(
+ oldstate,
+ (oldstate | SPINLOCK | HAS_WAITERS | K::set_when_waiting()) & !clear,
+ Ordering::Acquire,
+ Ordering::Relaxed,
+ )
+ .is_ok()
+ {
+ let mut set_on_release = 0;
+
+ // Safe because we have acquired the spin lock and it provides exclusive
+ // access to the waiter queue.
+ if wait_count < LONG_WAIT_THRESHOLD {
+ // Add the waiter to the back of the queue.
+ unsafe { (*self.waiters.get()).push_back(w.clone()) };
+ } else {
+ // This waiter has gone through the queue too many times. Put it in the
+ // front of the queue and block all other threads from acquiring the lock
+ // until this one has acquired it at least once.
+ unsafe { (*self.waiters.get()).push_front(w.clone()) };
+
+ // Set the LONG_WAIT bit to prevent all other threads from acquiring the
+ // lock.
+ set_on_release |= LONG_WAIT;
+
+ // Make sure we clear the LONG_WAIT bit when we do finally get the lock.
+ clear |= LONG_WAIT;
+
+ // Since we set the LONG_WAIT bit we shouldn't allow that bit to prevent us
+ // from acquiring the lock.
+ zero_to_acquire &= !LONG_WAIT;
+ }
+
+ // Release the spin lock.
+ let mut state = oldstate;
+ loop {
+ match self.state.compare_exchange_weak(
+ state,
+ (state | set_on_release) & !SPINLOCK,
+ Ordering::Release,
+ Ordering::Relaxed,
+ ) {
+ Ok(_) => break,
+ Err(w) => state = w,
+ }
+ }
+
+ // Now wait until we are woken.
+ w.wait().await;
+
+ // The `DESIGNATED_WAKER` bit gets set when this thread is woken up by the
+ // thread that originally held the lock. While this bit is set, no other waiters
+ // will be woken up so it's important to clear it the next time we try to
+ // acquire the main lock or the spin lock.
+ clear |= DESIGNATED_WAKER;
+
+ // Now that the thread has waited once, we no longer care if there is a writer
+ // waiting. Only the limits of mutual exclusion can prevent us from acquiring
+ // the lock.
+ zero_to_acquire &= !WRITER_WAITING;
+
+ // Reset the spin count since we just went through the wait queue.
+ spin_count = 0;
+
+ // Increment the wait count since we went through the wait queue.
+ wait_count += 1;
+
+ // Skip the `cpu_relax` below.
+ continue;
+ }
+ }
+
+ // Both the lock and the spin lock are held by one or more other threads. First, we'll
+ // spin a few times in case we can acquire the lock or the spin lock. If that fails then
+ // we yield because we might be preventing the threads that do hold the 2 locks from
+ // getting cpu time.
+ if spin_count < SPIN_THRESHOLD {
+ cpu_relax(1 << spin_count);
+ spin_count += 1;
+ } else {
+ yield_now();
+ }
+ }
+ }
+
+ #[inline]
+ pub fn unlock(&self) {
+ // Fast path, if possible. We can directly clear the locked bit since we have exclusive
+ // access to the mutex.
+ let oldstate = self.state.fetch_sub(LOCKED, Ordering::Release);
+
+ // Panic if we just tried to unlock a mutex that wasn't held by this thread. This shouldn't
+ // really be possible since `unlock` is not a public method.
+ debug_assert_eq!(
+ oldstate & READ_MASK,
+ 0,
+ "`unlock` called on mutex held in read-mode"
+ );
+ debug_assert_ne!(
+ oldstate & LOCKED,
+ 0,
+ "`unlock` called on mutex not held in write-mode"
+ );
+
+ if (oldstate & HAS_WAITERS) != 0 && (oldstate & DESIGNATED_WAKER) == 0 {
+ // The oldstate has waiters but no designated waker has been chosen yet.
+ self.unlock_slow();
+ }
+ }
+
+ #[inline]
+ pub fn read_unlock(&self) {
+ // Fast path, if possible. We can directly subtract the READ_LOCK bit since we had
+ // previously added it.
+ let oldstate = self.state.fetch_sub(READ_LOCK, Ordering::Release);
+
+ debug_assert_eq!(
+ oldstate & LOCKED,
+ 0,
+ "`read_unlock` called on mutex held in write-mode"
+ );
+ debug_assert_ne!(
+ oldstate & READ_MASK,
+ 0,
+ "`read_unlock` called on mutex not held in read-mode"
+ );
+
+ if (oldstate & HAS_WAITERS) != 0
+ && (oldstate & DESIGNATED_WAKER) == 0
+ && (oldstate & READ_MASK) == READ_LOCK
+ {
+ // There are waiters, no designated waker has been chosen yet, and the last reader is
+ // unlocking so we have to take the slow path.
+ self.unlock_slow();
+ }
+ }
+
+ #[cold]
+ fn unlock_slow(&self) {
+ let mut spin_count = 0;
+
+ loop {
+ let oldstate = self.state.load(Ordering::Relaxed);
+ if (oldstate & HAS_WAITERS) == 0 || (oldstate & DESIGNATED_WAKER) != 0 {
+ // No more waiters or a designated waker has been chosen. Nothing left for us to do.
+ return;
+ } else if (oldstate & SPINLOCK) == 0 {
+ // The spin lock is not held by another thread. Try to acquire it. Also set the
+ // `DESIGNATED_WAKER` bit since we are likely going to wake up one or more threads.
+ if self
+ .state
+ .compare_exchange_weak(
+ oldstate,
+ oldstate | SPINLOCK | DESIGNATED_WAKER,
+ Ordering::Acquire,
+ Ordering::Relaxed,
+ )
+ .is_ok()
+ {
+ // Acquired the spinlock. Try to wake a waiter. We may also end up wanting to
+ // clear the HAS_WAITER and DESIGNATED_WAKER bits so start collecting the bits
+ // to be cleared.
+ let mut clear = SPINLOCK;
+
+ // Safe because the spinlock guarantees exclusive access to the waiter list and
+ // the reference does not escape this function.
+ let waiters = unsafe { &mut *self.waiters.get() };
+ let (wake_list, set_on_release) = get_wake_list(waiters);
+
+ // If the waiter list is now empty, clear the HAS_WAITERS bit.
+ if waiters.is_empty() {
+ clear |= HAS_WAITERS;
+ }
+
+ if wake_list.is_empty() {
+ // Since we are not going to wake any waiters clear the DESIGNATED_WAKER bit
+ // that we set when we acquired the spin lock.
+ clear |= DESIGNATED_WAKER;
+ }
+
+ // Release the spin lock and clear any other bits as necessary. Also, set any
+ // bits returned by `get_wake_list`. For now, this is just the `WRITER_WAITING`
+ // bit, which needs to be set when we are waking up a bunch of readers and there
+ // are still writers in the wait queue. This will prevent any readers that
+ // aren't in `wake_list` from acquiring the read lock.
+ let mut state = oldstate;
+ loop {
+ match self.state.compare_exchange_weak(
+ state,
+ (state | set_on_release) & !clear,
+ Ordering::Release,
+ Ordering::Relaxed,
+ ) {
+ Ok(_) => break,
+ Err(w) => state = w,
+ }
+ }
+
+ // Now wake the waiters, if any.
+ for w in wake_list {
+ w.wake();
+ }
+
+ // We're done.
+ return;
+ }
+ }
+
+ // Spin and try again. It's ok to block here as we have already released the lock.
+ if spin_count < SPIN_THRESHOLD {
+ cpu_relax(1 << spin_count);
+ spin_count += 1;
+ } else {
+ yield_now();
+ }
+ }
+ }
+
+ // Transfer waiters from the `Condvar` wait list to the `Mutex` wait list. `all_readers` may
+ // be set to true if all waiters are waiting to acquire a shared lock but should not be true if
+ // there is even one waiter waiting on an exclusive lock.
+ //
+ // This is similar to what the `FUTEX_CMP_REQUEUE` flag does on linux.
+ pub fn transfer_waiters(&self, new_waiters: &mut WaiterList, all_readers: bool) {
+ if new_waiters.is_empty() {
+ return;
+ }
+
+ let mut oldstate = self.state.load(Ordering::Relaxed);
+ let can_acquire_read_lock = (oldstate & Shared::zero_to_acquire()) == 0;
+
+ // The lock needs to be held in some mode or else the waiters we transfer now may never get
+ // woken up. Additionally, if all the new waiters are readers and can acquire the lock now
+ // then we can just wake them up.
+ if (oldstate & ANY_LOCK) == 0 || (all_readers && can_acquire_read_lock) {
+ // Nothing to do here. The Condvar will wake up all the waiters left in `new_waiters`.
+ return;
+ }
+
+ if (oldstate & SPINLOCK) == 0
+ && self
+ .state
+ .compare_exchange_weak(
+ oldstate,
+ oldstate | SPINLOCK | HAS_WAITERS,
+ Ordering::Acquire,
+ Ordering::Relaxed,
+ )
+ .is_ok()
+ {
+ let mut transferred_writer = false;
+
+ // Safe because the spin lock guarantees exclusive access and the reference does not
+ // escape this function.
+ let waiters = unsafe { &mut *self.waiters.get() };
+
+ let mut current = new_waiters.front_mut();
+ while let Some(w) = current.get() {
+ match w.kind() {
+ WaiterKind::Shared => {
+ if can_acquire_read_lock {
+ current.move_next();
+ } else {
+ // We need to update the cancellation function since we're moving this
+ // waiter into our queue. Also update the waiting to indicate that it is
+ // now in the Mutex's waiter list.
+ let w = current.remove().unwrap();
+ w.set_cancel(cancel_waiter, self as *const RawMutex as usize);
+ w.set_waiting_for(WaitingFor::Mutex);
+ waiters.push_back(w);
+ }
+ }
+ WaiterKind::Exclusive => {
+ transferred_writer = true;
+ // We need to update the cancellation function since we're moving this
+ // waiter into our queue. Also update the waiting to indicate that it is
+ // now in the Mutex's waiter list.
+ let w = current.remove().unwrap();
+ w.set_cancel(cancel_waiter, self as *const RawMutex as usize);
+ w.set_waiting_for(WaitingFor::Mutex);
+ waiters.push_back(w);
+ }
+ }
+ }
+
+ let set_on_release = if transferred_writer {
+ WRITER_WAITING
+ } else {
+ 0
+ };
+
+ // If we didn't actually transfer any waiters, clear the HAS_WAITERS bit that we set
+ // earlier when we acquired the spin lock.
+ let clear = if waiters.is_empty() {
+ SPINLOCK | HAS_WAITERS
+ } else {
+ SPINLOCK
+ };
+
+ while self
+ .state
+ .compare_exchange_weak(
+ oldstate,
+ (oldstate | set_on_release) & !clear,
+ Ordering::Release,
+ Ordering::Relaxed,
+ )
+ .is_err()
+ {
+ spin_loop_hint();
+ oldstate = self.state.load(Ordering::Relaxed);
+ }
+ }
+
+ // The Condvar will wake up any waiters still left in the queue.
+ }
+
+ fn cancel_waiter(&self, waiter: &Waiter, wake_next: bool) -> bool {
+ let mut oldstate = self.state.load(Ordering::Relaxed);
+ while oldstate & SPINLOCK != 0
+ || self
+ .state
+ .compare_exchange_weak(
+ oldstate,
+ oldstate | SPINLOCK,
+ Ordering::Acquire,
+ Ordering::Relaxed,
+ )
+ .is_err()
+ {
+ spin_loop_hint();
+ oldstate = self.state.load(Ordering::Relaxed);
+ }
+
+ // Safe because the spin lock provides exclusive access and the reference does not escape
+ // this function.
+ let waiters = unsafe { &mut *self.waiters.get() };
+
+ let mut clear = SPINLOCK;
+
+ // If we are about to remove the first waiter in the wait list, then clear the LONG_WAIT
+ // bit. Also clear the bit if we are going to be waking some other waiters. In this case the
+ // waiter that set the bit may have already been removed from the waiter list (and could be
+ // the one that is currently being dropped). If it is still in the waiter list then clearing
+ // this bit may starve it for one more iteration through the lock_slow() loop, whereas not
+ // clearing this bit could cause a deadlock if the waiter that set it is the one that is
+ // being dropped.
+ if wake_next
+ || waiters
+ .front()
+ .get()
+ .map(|front| front as *const Waiter == waiter as *const Waiter)
+ .unwrap_or(false)
+ {
+ clear |= LONG_WAIT;
+ }
+
+ // Don't drop the old waiter while holding the spin lock.
+ let old_waiter = if waiter.is_linked() && waiter.is_waiting_for() == WaitingFor::Mutex {
+ // We know that the waitir is still linked and is waiting for the mutex, which
+ // guarantees that it is still linked into `self.waiters`.
+ let mut cursor = unsafe { waiters.cursor_mut_from_ptr(waiter as *const Waiter) };
+ cursor.remove()
+ } else {
+ None
+ };
+
+ let (wake_list, set_on_release) =
+ if wake_next || waiter.is_waiting_for() == WaitingFor::None {
+ // Either the waiter was already woken or it's been removed from the mutex's waiter
+ // list and is going to be woken. Either way, we need to wake up another thread.
+ get_wake_list(waiters)
+ } else {
+ (WaiterList::new(WaiterAdapter::new()), 0)
+ };
+
+ if waiters.is_empty() {
+ clear |= HAS_WAITERS;
+ }
+
+ if wake_list.is_empty() {
+ // We're not waking any other threads so clear the DESIGNATED_WAKER bit. In the worst
+ // case this leads to an additional thread being woken up but we risk a deadlock if we
+ // don't clear it.
+ clear |= DESIGNATED_WAKER;
+ }
+
+ if let WaiterKind::Exclusive = waiter.kind() {
+ // The waiter being dropped is a writer so clear the writer waiting bit for now. If we
+ // found more writers in the list while fetching waiters to wake up then this bit will
+ // be set again via `set_on_release`.
+ clear |= WRITER_WAITING;
+ }
+
+ while self
+ .state
+ .compare_exchange_weak(
+ oldstate,
+ (oldstate & !clear) | set_on_release,
+ Ordering::Release,
+ Ordering::Relaxed,
+ )
+ .is_err()
+ {
+ spin_loop_hint();
+ oldstate = self.state.load(Ordering::Relaxed);
+ }
+
+ for w in wake_list {
+ w.wake();
+ }
+
+ mem::drop(old_waiter);
+
+ // Canceling a waker is always successful.
+ true
+ }
+}
+
+unsafe impl Send for RawMutex {}
+unsafe impl Sync for RawMutex {}
+
+fn cancel_waiter(raw: usize, waiter: &Waiter, wake_next: bool) -> bool {
+ let raw_mutex = raw as *const RawMutex;
+
+ // Safe because the thread that owns the waiter that is being canceled must
+ // also own a reference to the mutex, which ensures that this pointer is
+ // valid.
+ unsafe { (*raw_mutex).cancel_waiter(waiter, wake_next) }
+}
+
+/// A high-level primitive that provides safe, mutable access to a shared resource.
+///
+/// Unlike more traditional mutexes, `Mutex` can safely provide both shared, immutable access (via
+/// `read_lock()`) as well as exclusive, mutable access (via `lock()`) to an underlying resource
+/// with no loss of performance.
+///
+/// # Poisoning
+///
+/// `Mutex` does not support lock poisoning so if a thread panics while holding the lock, the
+/// poisoned data will be accessible by other threads in your program. If you need to guarantee that
+/// other threads cannot access poisoned data then you may wish to wrap this `Mutex` inside another
+/// type that provides the poisoning feature. See the implementation of `std::sync::Mutex` for an
+/// example of this.
+///
+///
+/// # Fairness
+///
+/// This `Mutex` implementation does not guarantee that threads will acquire the lock in the same
+/// order that they call `lock()` or `read_lock()`. However it will attempt to prevent long-term
+/// starvation: if a thread repeatedly fails to acquire the lock beyond a threshold then all other
+/// threads will fail to acquire the lock until the starved thread has acquired it.
+///
+/// Similarly, this `Mutex` will attempt to balance reader and writer threads: once there is a
+/// writer thread waiting to acquire the lock no new reader threads will be allowed to acquire it.
+/// However, any reader threads that were already waiting will still be allowed to acquire it.
+///
+/// # Examples
+///
+/// ```edition2018
+/// use std::sync::Arc;
+/// use std::thread;
+/// use std::sync::mpsc::channel;
+///
+/// use libchromeos::sync::{block_on, Mutex};
+///
+/// const N: usize = 10;
+///
+/// // Spawn a few threads to increment a shared variable (non-atomically), and
+/// // let the main thread know once all increments are done.
+/// //
+/// // Here we're using an Arc to share memory among threads, and the data inside
+/// // the Arc is protected with a mutex.
+/// let data = Arc::new(Mutex::new(0));
+///
+/// let (tx, rx) = channel();
+/// for _ in 0..N {
+/// let (data, tx) = (Arc::clone(&data), tx.clone());
+/// thread::spawn(move || {
+/// // The shared state can only be accessed once the lock is held.
+/// // Our non-atomic increment is safe because we're the only thread
+/// // which can access the shared state when the lock is held.
+/// let mut data = block_on(data.lock());
+/// *data += 1;
+/// if *data == N {
+/// tx.send(()).unwrap();
+/// }
+/// // the lock is unlocked here when `data` goes out of scope.
+/// });
+/// }
+///
+/// rx.recv().unwrap();
+/// ```
+pub struct Mutex<T: ?Sized> {
+ raw: RawMutex,
+ value: UnsafeCell<T>,
+}
+
+impl<T> Mutex<T> {
+ /// Create a new, unlocked `Mutex` ready for use.
+ pub fn new(v: T) -> Mutex<T> {
+ Mutex {
+ raw: RawMutex::new(),
+ value: UnsafeCell::new(v),
+ }
+ }
+
+ /// Consume the `Mutex` and return the contained value. This method does not perform any locking
+ /// as the compiler will guarantee that there are no other references to `self` and the caller
+ /// owns the `Mutex`.
+ pub fn into_inner(self) -> T {
+ // Don't need to acquire the lock because the compiler guarantees that there are
+ // no references to `self`.
+ self.value.into_inner()
+ }
+}
+
+impl<T: ?Sized> Mutex<T> {
+ /// Acquires exclusive, mutable access to the resource protected by the `Mutex`, blocking the
+ /// current thread until it is able to do so. Upon returning, the current thread will be the
+ /// only thread with access to the resource. The `Mutex` will be released when the returned
+ /// `MutexGuard` is dropped.
+ ///
+ /// Calling `lock()` while holding a `MutexGuard` or a `MutexReadGuard` will cause a deadlock.
+ ///
+ /// Callers that are not in an async context may wish to use the `block_on` method to block the
+ /// thread until the `Mutex` is acquired.
+ #[inline]
+ pub async fn lock(&self) -> MutexGuard<'_, T> {
+ self.raw.lock().await;
+
+ // Safe because we have exclusive access to `self.value`.
+ MutexGuard {
+ mu: self,
+ value: unsafe { &mut *self.value.get() },
+ }
+ }
+
+ /// Acquires shared, immutable access to the resource protected by the `Mutex`, blocking the
+ /// current thread until it is able to do so. Upon returning there may be other threads that
+ /// also have immutable access to the resource but there will not be any threads that have
+ /// mutable access to the resource. When the returned `MutexReadGuard` is dropped the thread
+ /// releases its access to the resource.
+ ///
+ /// Calling `read_lock()` while holding a `MutexReadGuard` may deadlock. Calling `read_lock()`
+ /// while holding a `MutexGuard` will deadlock.
+ ///
+ /// Callers that are not in an async context may wish to use the `block_on` method to block the
+ /// thread until the `Mutex` is acquired.
+ #[inline]
+ pub async fn read_lock(&self) -> MutexReadGuard<'_, T> {
+ self.raw.read_lock().await;
+
+ // Safe because we have shared read-only access to `self.value`.
+ MutexReadGuard {
+ mu: self,
+ value: unsafe { &*self.value.get() },
+ }
+ }
+
+ // Called from `Condvar::wait` when the thread wants to reacquire the lock. Since we may
+ // directly transfer waiters from the `Condvar` wait list to the `Mutex` wait list (see
+ // `transfer_all` below), we cannot call `Mutex::lock` as we also need to clear the
+ // `DESIGNATED_WAKER` bit when acquiring the lock. Not doing so will prevent us from waking up
+ // any other threads in the wait list.
+ #[inline]
+ pub(crate) async fn lock_from_cv(&self) -> MutexGuard<'_, T> {
+ self.raw.lock_slow::<Exclusive>(DESIGNATED_WAKER, 0).await;
+
+ // Safe because we have exclusive access to `self.value`.
+ MutexGuard {
+ mu: self,
+ value: unsafe { &mut *self.value.get() },
+ }
+ }
+
+ // Like `lock_from_cv` but for acquiring a shared lock.
+ #[inline]
+ pub(crate) async fn read_lock_from_cv(&self) -> MutexReadGuard<'_, T> {
+ // Threads that have waited in the Condvar's waiter list don't have to care if there is a
+ // writer waiting. This also prevents a deadlock in the following case:
+ //
+ // * Thread A holds a write lock.
+ // * Thread B is in the mutex's waiter list, also waiting on a write lock.
+ // * Threads C, D, and E are in the condvar's waiter list. C and D want a read lock; E
+ // wants a write lock.
+ // * A calls `cv.notify_all()` while still holding the lock, which transfers C, D, and E
+ // onto the mutex's wait list.
+ // * A releases the lock, which wakes up B.
+ // * B acquires the lock, does some work, and releases the lock. This wakes up C and D.
+ // However, when iterating through the waiter list we find E, which is waiting for a
+ // write lock so we set the WRITER_WAITING bit.
+ // * C and D go through this function to acquire the lock. If we didn't clear the
+ // WRITER_WAITING bit from the zero_to_acquire set then it would prevent C and D from
+ // acquiring the lock and they would add themselves back into the waiter list.
+ // * Now C, D, and E will sit in the waiter list indefinitely unless some other thread
+ // comes along and acquires the lock. On release, it would wake up E and everything would
+ // go back to normal.
+ self.raw
+ .lock_slow::<Shared>(DESIGNATED_WAKER, WRITER_WAITING)
+ .await;
+
+ // Safe because we have exclusive access to `self.value`.
+ MutexReadGuard {
+ mu: self,
+ value: unsafe { &*self.value.get() },
+ }
+ }
+
+ #[inline]
+ fn unlock(&self) {
+ self.raw.unlock();
+ }
+
+ #[inline]
+ fn read_unlock(&self) {
+ self.raw.read_unlock();
+ }
+
+ pub fn get_mut(&mut self) -> &mut T {
+ // Safe because the compiler statically guarantees that are no other references to `self`.
+ // This is also why we don't need to acquire the lock first.
+ unsafe { &mut *self.value.get() }
+ }
+}
+
+unsafe impl<T: ?Sized + Send> Send for Mutex<T> {}
+unsafe impl<T: ?Sized + Send> Sync for Mutex<T> {}
+
+impl<T: ?Sized + Default> Default for Mutex<T> {
+ fn default() -> Self {
+ Self::new(Default::default())
+ }
+}
+
+impl<T> From<T> for Mutex<T> {
+ fn from(source: T) -> Self {
+ Self::new(source)
+ }
+}
+
+/// An RAII implementation of a "scoped exclusive lock" for a `Mutex`. When this structure is
+/// dropped, the lock will be released. The resource protected by the `Mutex` can be accessed via
+/// the `Deref` and `DerefMut` implementations of this structure.
+pub struct MutexGuard<'a, T: ?Sized + 'a> {
+ mu: &'a Mutex<T>,
+ value: &'a mut T,
+}
+
+impl<'a, T: ?Sized> MutexGuard<'a, T> {
+ pub(crate) fn into_inner(self) -> &'a Mutex<T> {
+ self.mu
+ }
+
+ pub(crate) fn as_raw_mutex(&self) -> &RawMutex {
+ &self.mu.raw
+ }
+}
+
+impl<'a, T: ?Sized> Deref for MutexGuard<'a, T> {
+ type Target = T;
+
+ fn deref(&self) -> &Self::Target {
+ self.value
+ }
+}
+
+impl<'a, T: ?Sized> DerefMut for MutexGuard<'a, T> {
+ fn deref_mut(&mut self) -> &mut Self::Target {
+ self.value
+ }
+}
+
+impl<'a, T: ?Sized> Drop for MutexGuard<'a, T> {
+ fn drop(&mut self) {
+ self.mu.unlock()
+ }
+}
+
+/// An RAII implementation of a "scoped shared lock" for a `Mutex`. When this structure is dropped,
+/// the lock will be released. The resource protected by the `Mutex` can be accessed via the `Deref`
+/// implementation of this structure.
+pub struct MutexReadGuard<'a, T: ?Sized + 'a> {
+ mu: &'a Mutex<T>,
+ value: &'a T,
+}
+
+impl<'a, T: ?Sized> MutexReadGuard<'a, T> {
+ pub(crate) fn into_inner(self) -> &'a Mutex<T> {
+ self.mu
+ }
+
+ pub(crate) fn as_raw_mutex(&self) -> &RawMutex {
+ &self.mu.raw
+ }
+}
+
+impl<'a, T: ?Sized> Deref for MutexReadGuard<'a, T> {
+ type Target = T;
+
+ fn deref(&self) -> &Self::Target {
+ self.value
+ }
+}
+
+impl<'a, T: ?Sized> Drop for MutexReadGuard<'a, T> {
+ fn drop(&mut self) {
+ self.mu.read_unlock()
+ }
+}
+
+#[cfg(test)]
+mod test {
+ use super::*;
+
+ use std::future::Future;
+ use std::mem;
+ use std::pin::Pin;
+ use std::rc::Rc;
+ use std::sync::atomic::{AtomicUsize, Ordering};
+ use std::sync::mpsc::{channel, Sender};
+ use std::sync::Arc;
+ use std::task::{Context, Poll, Waker};
+ use std::thread;
+ use std::time::Duration;
+
+ use futures::channel::oneshot;
+ use futures::task::{waker_ref, ArcWake};
+ use futures::{pending, select, FutureExt};
+ use futures_executor::{LocalPool, ThreadPool};
+ use futures_util::task::LocalSpawnExt;
+
+ use crate::sync::{block_on, Condvar, SpinLock};
+
+ #[derive(Debug, Eq, PartialEq)]
+ struct NonCopy(u32);
+
+ // Dummy waker used when we want to manually drive futures.
+ struct TestWaker;
+ impl ArcWake for TestWaker {
+ fn wake_by_ref(_arc_self: &Arc<Self>) {}
+ }
+
+ #[test]
+ fn it_works() {
+ let mu = Mutex::new(NonCopy(13));
+
+ assert_eq!(*block_on(mu.lock()), NonCopy(13));
+ }
+
+ #[test]
+ fn smoke() {
+ let mu = Mutex::new(NonCopy(7));
+
+ mem::drop(block_on(mu.lock()));
+ mem::drop(block_on(mu.lock()));
+ }
+
+ #[test]
+ fn rw_smoke() {
+ let mu = Mutex::new(NonCopy(7));
+
+ mem::drop(block_on(mu.lock()));
+ mem::drop(block_on(mu.read_lock()));
+ mem::drop((block_on(mu.read_lock()), block_on(mu.read_lock())));
+ mem::drop(block_on(mu.lock()));
+ }
+
+ #[test]
+ fn async_smoke() {
+ async fn lock(mu: Rc<Mutex<NonCopy>>) {
+ mu.lock().await;
+ }
+
+ async fn read_lock(mu: Rc<Mutex<NonCopy>>) {
+ mu.read_lock().await;
+ }
+
+ async fn double_read_lock(mu: Rc<Mutex<NonCopy>>) {
+ let first = mu.read_lock().await;
+ mu.read_lock().await;
+
+ // Make sure first lives past the second read lock.
+ first.as_raw_mutex();
+ }
+
+ let mu = Rc::new(Mutex::new(NonCopy(7)));
+
+ let mut ex = LocalPool::new();
+ let spawner = ex.spawner();
+
+ spawner
+ .spawn_local(lock(Rc::clone(&mu)))
+ .expect("Failed to spawn future");
+ spawner
+ .spawn_local(read_lock(Rc::clone(&mu)))
+ .expect("Failed to spawn future");
+ spawner
+ .spawn_local(double_read_lock(Rc::clone(&mu)))
+ .expect("Failed to spawn future");
+ spawner
+ .spawn_local(lock(Rc::clone(&mu)))
+ .expect("Failed to spawn future");
+
+ ex.run();
+ }
+
+ #[test]
+ fn send() {
+ let mu = Mutex::new(NonCopy(19));
+
+ thread::spawn(move || {
+ let value = block_on(mu.lock());
+ assert_eq!(*value, NonCopy(19));
+ })
+ .join()
+ .unwrap();
+ }
+
+ #[test]
+ fn arc_nested() {
+ // Tests nested mutexes and access to underlying data.
+ let mu = Mutex::new(1);
+ let arc = Arc::new(Mutex::new(mu));
+ thread::spawn(move || {
+ let nested = block_on(arc.lock());
+ let lock2 = block_on(nested.lock());
+ assert_eq!(*lock2, 1);
+ })
+ .join()
+ .unwrap();
+ }
+
+ #[test]
+ fn arc_access_in_unwind() {
+ let arc = Arc::new(Mutex::new(1));
+ let arc2 = arc.clone();
+ thread::spawn(move || {
+ struct Unwinder {
+ i: Arc<Mutex<i32>>,
+ }
+ impl Drop for Unwinder {
+ fn drop(&mut self) {
+ *block_on(self.i.lock()) += 1;
+ }
+ }
+ let _u = Unwinder { i: arc2 };
+ panic!();
+ })
+ .join()
+ .expect_err("thread did not panic");
+ let lock = block_on(arc.lock());
+ assert_eq!(*lock, 2);
+ }
+
+ #[test]
+ fn unsized_value() {
+ let mutex: &Mutex<[i32]> = &Mutex::new([1, 2, 3]);
+ {
+ let b = &mut *block_on(mutex.lock());
+ b[0] = 4;
+ b[2] = 5;
+ }
+ let expected: &[i32] = &[4, 2, 5];
+ assert_eq!(&*block_on(mutex.lock()), expected);
+ }
+ #[test]
+ fn high_contention() {
+ const THREADS: usize = 17;
+ const ITERATIONS: usize = 103;
+
+ let mut threads = Vec::with_capacity(THREADS);
+
+ let mu = Arc::new(Mutex::new(0usize));
+ for _ in 0..THREADS {
+ let mu2 = mu.clone();
+ threads.push(thread::spawn(move || {
+ for _ in 0..ITERATIONS {
+ *block_on(mu2.lock()) += 1;
+ }
+ }));
+ }
+
+ for t in threads.into_iter() {
+ t.join().unwrap();
+ }
+
+ assert_eq!(*block_on(mu.read_lock()), THREADS * ITERATIONS);
+ assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0);
+ }
+
+ #[test]
+ fn high_contention_with_cancel() {
+ const TASKS: usize = 17;
+ const ITERATIONS: usize = 103;
+
+ async fn increment(mu: Arc<Mutex<usize>>, alt_mu: Arc<Mutex<usize>>, tx: Sender<()>) {
+ for _ in 0..ITERATIONS {
+ select! {
+ mut count = mu.lock().fuse() => *count += 1,
+ mut count = alt_mu.lock().fuse() => *count += 1,
+ }
+ }
+ tx.send(()).expect("Failed to send completion signal");
+ }
+
+ let ex = ThreadPool::new().expect("Failed to create ThreadPool");
+
+ let mu = Arc::new(Mutex::new(0usize));
+ let alt_mu = Arc::new(Mutex::new(0usize));
+
+ let (tx, rx) = channel();
+ for _ in 0..TASKS {
+ ex.spawn_ok(increment(Arc::clone(&mu), Arc::clone(&alt_mu), tx.clone()));
+ }
+
+ for _ in 0..TASKS {
+ if let Err(e) = rx.recv_timeout(Duration::from_secs(10)) {
+ panic!("Error while waiting for threads to complete: {}", e);
+ }
+ }
+
+ assert_eq!(
+ *block_on(mu.read_lock()) + *block_on(alt_mu.read_lock()),
+ TASKS * ITERATIONS
+ );
+ assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0);
+ assert_eq!(alt_mu.raw.state.load(Ordering::Relaxed), 0);
+ }
+
+ #[test]
+ fn single_thread_async() {
+ const TASKS: usize = 17;
+ const ITERATIONS: usize = 103;
+
+ // Async closures are unstable.
+ async fn increment(mu: Rc<Mutex<usize>>) {
+ for _ in 0..ITERATIONS {
+ *mu.lock().await += 1;
+ }
+ }
+
+ let mut ex = LocalPool::new();
+ let spawner = ex.spawner();
+
+ let mu = Rc::new(Mutex::new(0usize));
+ for _ in 0..TASKS {
+ spawner
+ .spawn_local(increment(Rc::clone(&mu)))
+ .expect("Failed to spawn task");
+ }
+
+ ex.run();
+
+ assert_eq!(*block_on(mu.read_lock()), TASKS * ITERATIONS);
+ assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0);
+ }
+
+ #[test]
+ fn multi_thread_async() {
+ const TASKS: usize = 17;
+ const ITERATIONS: usize = 103;
+
+ // Async closures are unstable.
+ async fn increment(mu: Arc<Mutex<usize>>, tx: Sender<()>) {
+ for _ in 0..ITERATIONS {
+ *mu.lock().await += 1;
+ }
+ tx.send(()).expect("Failed to send completion signal");
+ }
+
+ let ex = ThreadPool::new().expect("Failed to create ThreadPool");
+
+ let mu = Arc::new(Mutex::new(0usize));
+ let (tx, rx) = channel();
+ for _ in 0..TASKS {
+ ex.spawn_ok(increment(Arc::clone(&mu), tx.clone()));
+ }
+
+ for _ in 0..TASKS {
+ rx.recv_timeout(Duration::from_secs(5))
+ .expect("Failed to receive completion signal");
+ }
+ assert_eq!(*block_on(mu.read_lock()), TASKS * ITERATIONS);
+ assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0);
+ }
+
+ #[test]
+ fn get_mut() {
+ let mut mu = Mutex::new(NonCopy(13));
+ *mu.get_mut() = NonCopy(17);
+
+ assert_eq!(mu.into_inner(), NonCopy(17));
+ }
+
+ #[test]
+ fn into_inner() {
+ let mu = Mutex::new(NonCopy(29));
+ assert_eq!(mu.into_inner(), NonCopy(29));
+ }
+
+ #[test]
+ fn into_inner_drop() {
+ struct NeedsDrop(Arc<AtomicUsize>);
+ impl Drop for NeedsDrop {
+ fn drop(&mut self) {
+ self.0.fetch_add(1, Ordering::AcqRel);
+ }
+ }
+
+ let value = Arc::new(AtomicUsize::new(0));
+ let needs_drop = Mutex::new(NeedsDrop(value.clone()));
+ assert_eq!(value.load(Ordering::Acquire), 0);
+
+ {
+ let inner = needs_drop.into_inner();
+ assert_eq!(inner.0.load(Ordering::Acquire), 0);
+ }
+
+ assert_eq!(value.load(Ordering::Acquire), 1);
+ }
+
+ #[test]
+ fn rw_arc() {
+ const THREADS: isize = 7;
+ const ITERATIONS: isize = 13;
+
+ let mu = Arc::new(Mutex::new(0isize));
+ let mu2 = mu.clone();
+
+ let (tx, rx) = channel();
+ thread::spawn(move || {
+ let mut guard = block_on(mu2.lock());
+ for _ in 0..ITERATIONS {
+ let tmp = *guard;
+ *guard = -1;
+ thread::yield_now();
+ *guard = tmp + 1;
+ }
+ tx.send(()).unwrap();
+ });
+
+ let mut readers = Vec::with_capacity(10);
+ for _ in 0..THREADS {
+ let mu3 = mu.clone();
+ let handle = thread::spawn(move || {
+ let guard = block_on(mu3.read_lock());
+ assert!(*guard >= 0);
+ });
+
+ readers.push(handle);
+ }
+
+ // Wait for the readers to finish their checks.
+ for r in readers {
+ r.join().expect("One or more readers saw a negative value");
+ }
+
+ // Wait for the writer to finish.
+ rx.recv_timeout(Duration::from_secs(5)).unwrap();
+ assert_eq!(*block_on(mu.read_lock()), ITERATIONS);
+ assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0);
+ }
+
+ #[test]
+ fn rw_single_thread_async() {
+ // A Future that returns `Poll::pending` the first time it is polled and `Poll::Ready` every
+ // time after that.
+ struct TestFuture {
+ polled: bool,
+ waker: Arc<SpinLock<Option<Waker>>>,
+ }
+
+ impl Future for TestFuture {
+ type Output = ();
+
+ fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
+ if self.polled {
+ Poll::Ready(())
+ } else {
+ self.polled = true;
+ *self.waker.lock() = Some(cx.waker().clone());
+ Poll::Pending
+ }
+ }
+ }
+
+ fn wake_future(waker: Arc<SpinLock<Option<Waker>>>) {
+ loop {
+ if let Some(waker) = waker.lock().take() {
+ waker.wake();
+ return;
+ } else {
+ thread::sleep(Duration::from_millis(10));
+ }
+ }
+ }
+
+ async fn writer(mu: Rc<Mutex<isize>>) {
+ let mut guard = mu.lock().await;
+ for _ in 0..ITERATIONS {
+ let tmp = *guard;
+ *guard = -1;
+ let waker = Arc::new(SpinLock::new(None));
+ let waker2 = Arc::clone(&waker);
+ thread::spawn(move || wake_future(waker2));
+ let fut = TestFuture {
+ polled: false,
+ waker,
+ };
+ fut.await;
+ *guard = tmp + 1;
+ }
+ }
+
+ async fn reader(mu: Rc<Mutex<isize>>) {
+ let guard = mu.read_lock().await;
+ assert!(*guard >= 0);
+ }
+
+ const TASKS: isize = 7;
+ const ITERATIONS: isize = 13;
+
+ let mu = Rc::new(Mutex::new(0isize));
+ let mut ex = LocalPool::new();
+ let spawner = ex.spawner();
+
+ spawner
+ .spawn_local(writer(Rc::clone(&mu)))
+ .expect("Failed to spawn writer");
+
+ for _ in 0..TASKS {
+ spawner
+ .spawn_local(reader(Rc::clone(&mu)))
+ .expect("Failed to spawn reader");
+ }
+
+ ex.run();
+
+ assert_eq!(*block_on(mu.read_lock()), ITERATIONS);
+ assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0);
+ }
+
+ #[test]
+ fn rw_multi_thread_async() {
+ async fn writer(mu: Arc<Mutex<isize>>, tx: Sender<()>) {
+ let mut guard = mu.lock().await;
+ for _ in 0..ITERATIONS {
+ let tmp = *guard;
+ *guard = -1;
+ thread::yield_now();
+ *guard = tmp + 1;
+ }
+
+ mem::drop(guard);
+ tx.send(()).unwrap();
+ }
+
+ async fn reader(mu: Arc<Mutex<isize>>, tx: Sender<()>) {
+ let guard = mu.read_lock().await;
+ assert!(*guard >= 0);
+
+ mem::drop(guard);
+ tx.send(()).expect("Failed to send completion message");
+ }
+
+ const TASKS: isize = 7;
+ const ITERATIONS: isize = 13;
+
+ let mu = Arc::new(Mutex::new(0isize));
+ let ex = ThreadPool::new().expect("Failed to create ThreadPool");
+
+ let (txw, rxw) = channel();
+ ex.spawn_ok(writer(Arc::clone(&mu), txw));
+
+ let (txr, rxr) = channel();
+ for _ in 0..TASKS {
+ ex.spawn_ok(reader(Arc::clone(&mu), txr.clone()));
+ }
+
+ // Wait for the readers to finish their checks.
+ for _ in 0..TASKS {
+ rxr.recv_timeout(Duration::from_secs(5))
+ .expect("Failed to receive completion message from reader");
+ }
+
+ // Wait for the writer to finish.
+ rxw.recv_timeout(Duration::from_secs(5))
+ .expect("Failed to receive completion message from writer");
+
+ assert_eq!(*block_on(mu.read_lock()), ITERATIONS);
+ assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0);
+ }
+
+ #[test]
+ fn wake_all_readers() {
+ async fn read(mu: Arc<Mutex<()>>) {
+ let g = mu.read_lock().await;
+ pending!();
+ mem::drop(g);
+ }
+
+ async fn write(mu: Arc<Mutex<()>>) {
+ mu.lock().await;
+ }
+
+ let mu = Arc::new(Mutex::new(()));
+ let mut futures: [Pin<Box<dyn Future<Output = ()>>>; 5] = [
+ Box::pin(read(mu.clone())),
+ Box::pin(read(mu.clone())),
+ Box::pin(read(mu.clone())),
+ Box::pin(write(mu.clone())),
+ Box::pin(read(mu.clone())),
+ ];
+ const NUM_READERS: usize = 4;
+
+ let arc_waker = Arc::new(TestWaker);
+ let waker = waker_ref(&arc_waker);
+ let mut cx = Context::from_waker(&waker);
+
+ // Acquire the lock so that the futures cannot get it.
+ let g = block_on(mu.lock());
+
+ for r in &mut futures {
+ if let Poll::Ready(()) = r.as_mut().poll(&mut cx) {
+ panic!("future unexpectedly ready");
+ }
+ }
+
+ assert_eq!(
+ mu.raw.state.load(Ordering::Relaxed) & HAS_WAITERS,
+ HAS_WAITERS
+ );
+
+ assert_eq!(
+ mu.raw.state.load(Ordering::Relaxed) & WRITER_WAITING,
+ WRITER_WAITING
+ );
+
+ // Drop the lock. This should allow all readers to make progress. Since they already waited
+ // once they should ignore the WRITER_WAITING bit that is currently set.
+ mem::drop(g);
+ for r in &mut futures {
+ if let Poll::Ready(()) = r.as_mut().poll(&mut cx) {
+ panic!("future unexpectedly ready");
+ }
+ }
+
+ // Check that all readers were able to acquire the lock.
+ assert_eq!(
+ mu.raw.state.load(Ordering::Relaxed) & READ_MASK,
+ READ_LOCK * NUM_READERS
+ );
+ assert_eq!(
+ mu.raw.state.load(Ordering::Relaxed) & WRITER_WAITING,
+ WRITER_WAITING
+ );
+
+ let mut needs_poll = None;
+
+ // All the readers can now finish but the writer needs to be polled again.
+ for (i, r) in futures.iter_mut().enumerate() {
+ match r.as_mut().poll(&mut cx) {
+ Poll::Ready(()) => {}
+ Poll::Pending => {
+ if needs_poll.is_some() {
+ panic!("More than one future unable to complete");
+ }
+ needs_poll = Some(i);
+ }
+ }
+ }
+
+ if let Poll::Pending = futures[needs_poll.expect("Writer unexpectedly able to complete")]
+ .as_mut()
+ .poll(&mut cx)
+ {
+ panic!("Writer unable to complete");
+ }
+
+ assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0);
+ }
+
+ #[test]
+ fn long_wait() {
+ async fn tight_loop(mu: Arc<Mutex<bool>>) {
+ loop {
+ let ready = mu.lock().await;
+ if *ready {
+ break;
+ }
+ pending!();
+ }
+ }
+
+ async fn mark_ready(mu: Arc<Mutex<bool>>) {
+ *mu.lock().await = true;
+ }
+
+ let mu = Arc::new(Mutex::new(false));
+ let mut tl = Box::pin(tight_loop(mu.clone()));
+ let mut mark = Box::pin(mark_ready(mu.clone()));
+
+ let arc_waker = Arc::new(TestWaker);
+ let waker = waker_ref(&arc_waker);
+ let mut cx = Context::from_waker(&waker);
+
+ for _ in 0..=LONG_WAIT_THRESHOLD {
+ if let Poll::Ready(()) = tl.as_mut().poll(&mut cx) {
+ panic!("tight_loop unexpectedly ready");
+ }
+
+ if let Poll::Ready(()) = mark.as_mut().poll(&mut cx) {
+ panic!("mark_ready unexpectedly ready");
+ }
+ }
+
+ assert_eq!(
+ mu.raw.state.load(Ordering::Relaxed),
+ LOCKED | HAS_WAITERS | WRITER_WAITING | LONG_WAIT
+ );
+
+ // This time the tight loop will fail to acquire the lock.
+ if let Poll::Ready(()) = tl.as_mut().poll(&mut cx) {
+ panic!("tight_loop unexpectedly ready");
+ }
+
+ // Which will finally allow the mark_ready function to make progress.
+ if let Poll::Pending = mark.as_mut().poll(&mut cx) {
+ panic!("mark_ready not able to make progress");
+ }
+
+ // Now the tight loop will finish.
+ if let Poll::Pending = tl.as_mut().poll(&mut cx) {
+ panic!("tight_loop not able to finish");
+ }
+
+ assert!(*block_on(mu.lock()));
+ assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0);
+ }
+
+ #[test]
+ fn cancel_long_wait_before_wake() {
+ async fn tight_loop(mu: Arc<Mutex<bool>>) {
+ loop {
+ let ready = mu.lock().await;
+ if *ready {
+ break;
+ }
+ pending!();
+ }
+ }
+
+ async fn mark_ready(mu: Arc<Mutex<bool>>) {
+ *mu.lock().await = true;
+ }
+
+ let mu = Arc::new(Mutex::new(false));
+ let mut tl = Box::pin(tight_loop(mu.clone()));
+ let mut mark = Box::pin(mark_ready(mu.clone()));
+
+ let arc_waker = Arc::new(TestWaker);
+ let waker = waker_ref(&arc_waker);
+ let mut cx = Context::from_waker(&waker);
+
+ for _ in 0..=LONG_WAIT_THRESHOLD {
+ if let Poll::Ready(()) = tl.as_mut().poll(&mut cx) {
+ panic!("tight_loop unexpectedly ready");
+ }
+
+ if let Poll::Ready(()) = mark.as_mut().poll(&mut cx) {
+ panic!("mark_ready unexpectedly ready");
+ }
+ }
+
+ assert_eq!(
+ mu.raw.state.load(Ordering::Relaxed),
+ LOCKED | HAS_WAITERS | WRITER_WAITING | LONG_WAIT
+ );
+
+ // Now drop the mark_ready future, which should clear the LONG_WAIT bit.
+ mem::drop(mark);
+ assert_eq!(mu.raw.state.load(Ordering::Relaxed), LOCKED);
+
+ mem::drop(tl);
+ assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0);
+ }
+
+ #[test]
+ fn cancel_long_wait_after_wake() {
+ async fn tight_loop(mu: Arc<Mutex<bool>>) {
+ loop {
+ let ready = mu.lock().await;
+ if *ready {
+ break;
+ }
+ pending!();
+ }
+ }
+
+ async fn mark_ready(mu: Arc<Mutex<bool>>) {
+ *mu.lock().await = true;
+ }
+
+ let mu = Arc::new(Mutex::new(false));
+ let mut tl = Box::pin(tight_loop(mu.clone()));
+ let mut mark = Box::pin(mark_ready(mu.clone()));
+
+ let arc_waker = Arc::new(TestWaker);
+ let waker = waker_ref(&arc_waker);
+ let mut cx = Context::from_waker(&waker);
+
+ for _ in 0..=LONG_WAIT_THRESHOLD {
+ if let Poll::Ready(()) = tl.as_mut().poll(&mut cx) {
+ panic!("tight_loop unexpectedly ready");
+ }
+
+ if let Poll::Ready(()) = mark.as_mut().poll(&mut cx) {
+ panic!("mark_ready unexpectedly ready");
+ }
+ }
+
+ assert_eq!(
+ mu.raw.state.load(Ordering::Relaxed),
+ LOCKED | HAS_WAITERS | WRITER_WAITING | LONG_WAIT
+ );
+
+ // This time the tight loop will fail to acquire the lock.
+ if let Poll::Ready(()) = tl.as_mut().poll(&mut cx) {
+ panic!("tight_loop unexpectedly ready");
+ }
+
+ // Now drop the mark_ready future, which should clear the LONG_WAIT bit.
+ mem::drop(mark);
+ assert_eq!(mu.raw.state.load(Ordering::Relaxed) & LONG_WAIT, 0);
+
+ // Since the lock is not held, we should be able to spawn a future to set the ready flag.
+ block_on(mark_ready(mu.clone()));
+
+ // Now the tight loop will finish.
+ if let Poll::Pending = tl.as_mut().poll(&mut cx) {
+ panic!("tight_loop not able to finish");
+ }
+
+ assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0);
+ }
+
+ #[test]
+ fn designated_waker() {
+ async fn inc(mu: Arc<Mutex<usize>>) {
+ *mu.lock().await += 1;
+ }
+
+ let mu = Arc::new(Mutex::new(0));
+
+ let mut futures = [
+ Box::pin(inc(mu.clone())),
+ Box::pin(inc(mu.clone())),
+ Box::pin(inc(mu.clone())),
+ ];
+
+ let arc_waker = Arc::new(TestWaker);
+ let waker = waker_ref(&arc_waker);
+ let mut cx = Context::from_waker(&waker);
+
+ let count = block_on(mu.lock());
+
+ // Poll 2 futures. Since neither will be able to acquire the lock, they should get added to
+ // the waiter list.
+ if let Poll::Ready(()) = futures[0].as_mut().poll(&mut cx) {
+ panic!("future unexpectedly ready");
+ }
+ if let Poll::Ready(()) = futures[1].as_mut().poll(&mut cx) {
+ panic!("future unexpectedly ready");
+ }
+
+ assert_eq!(
+ mu.raw.state.load(Ordering::Relaxed),
+ LOCKED | HAS_WAITERS | WRITER_WAITING,
+ );
+
+ // Now drop the lock. This should set the DESIGNATED_WAKER bit and wake up the first future
+ // in the wait list.
+ mem::drop(count);
+
+ assert_eq!(
+ mu.raw.state.load(Ordering::Relaxed),
+ DESIGNATED_WAKER | HAS_WAITERS | WRITER_WAITING,
+ );
+
+ // Now poll the third future. It should be able to acquire the lock immediately.
+ if let Poll::Pending = futures[2].as_mut().poll(&mut cx) {
+ panic!("future unable to complete");
+ }
+ assert_eq!(*block_on(mu.lock()), 1);
+
+ // There should still be a waiter in the wait list and the DESIGNATED_WAKER bit should still
+ // be set.
+ assert_eq!(
+ mu.raw.state.load(Ordering::Relaxed) & DESIGNATED_WAKER,
+ DESIGNATED_WAKER
+ );
+ assert_eq!(
+ mu.raw.state.load(Ordering::Relaxed) & HAS_WAITERS,
+ HAS_WAITERS
+ );
+
+ // Now let the future that was woken up run.
+ if let Poll::Pending = futures[0].as_mut().poll(&mut cx) {
+ panic!("future unable to complete");
+ }
+ assert_eq!(*block_on(mu.lock()), 2);
+
+ if let Poll::Pending = futures[1].as_mut().poll(&mut cx) {
+ panic!("future unable to complete");
+ }
+ assert_eq!(*block_on(mu.lock()), 3);
+
+ assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0);
+ }
+
+ #[test]
+ fn cancel_designated_waker() {
+ async fn inc(mu: Arc<Mutex<usize>>) {
+ *mu.lock().await += 1;
+ }
+
+ let mu = Arc::new(Mutex::new(0));
+
+ let mut fut = Box::pin(inc(mu.clone()));
+
+ let arc_waker = Arc::new(TestWaker);
+ let waker = waker_ref(&arc_waker);
+ let mut cx = Context::from_waker(&waker);
+
+ let count = block_on(mu.lock());
+
+ if let Poll::Ready(()) = fut.as_mut().poll(&mut cx) {
+ panic!("Future unexpectedly ready when lock is held");
+ }
+
+ // Drop the lock. This will wake up the future.
+ mem::drop(count);
+
+ // Now drop the future without polling. This should clear all the state in the mutex.
+ mem::drop(fut);
+
+ assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0);
+ }
+
+ #[test]
+ fn cancel_before_wake() {
+ async fn inc(mu: Arc<Mutex<usize>>) {
+ *mu.lock().await += 1;
+ }
+
+ let mu = Arc::new(Mutex::new(0));
+
+ let mut fut1 = Box::pin(inc(mu.clone()));
+
+ let mut fut2 = Box::pin(inc(mu.clone()));
+
+ let arc_waker = Arc::new(TestWaker);
+ let waker = waker_ref(&arc_waker);
+ let mut cx = Context::from_waker(&waker);
+
+ // First acquire the lock.
+ let count = block_on(mu.lock());
+
+ // Now poll the futures. Since the lock is acquired they will both get queued in the waiter
+ // list.
+ match fut1.as_mut().poll(&mut cx) {
+ Poll::Pending => {}
+ Poll::Ready(()) => panic!("Future is unexpectedly ready"),
+ }
+
+ match fut2.as_mut().poll(&mut cx) {
+ Poll::Pending => {}
+ Poll::Ready(()) => panic!("Future is unexpectedly ready"),
+ }
+
+ assert_eq!(
+ mu.raw.state.load(Ordering::Relaxed) & WRITER_WAITING,
+ WRITER_WAITING
+ );
+
+ // Drop fut1. This should remove it from the waiter list but shouldn't wake fut2.
+ mem::drop(fut1);
+
+ // There should be no designated waker.
+ assert_eq!(mu.raw.state.load(Ordering::Relaxed) & DESIGNATED_WAKER, 0);
+
+ // Since the waiter was a writer, we should clear the WRITER_WAITING bit.
+ assert_eq!(mu.raw.state.load(Ordering::Relaxed) & WRITER_WAITING, 0);
+
+ match fut2.as_mut().poll(&mut cx) {
+ Poll::Pending => {}
+ Poll::Ready(()) => panic!("Future is unexpectedly ready"),
+ }
+
+ // Now drop the lock. This should mark fut2 as ready to make progress.
+ mem::drop(count);
+
+ match fut2.as_mut().poll(&mut cx) {
+ Poll::Pending => panic!("Future is not ready to make progress"),
+ Poll::Ready(()) => {}
+ }
+
+ // Verify that we only incremented the count once.
+ assert_eq!(*block_on(mu.lock()), 1);
+ assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0);
+ }
+
+ #[test]
+ fn cancel_after_wake() {
+ async fn inc(mu: Arc<Mutex<usize>>) {
+ *mu.lock().await += 1;
+ }
+
+ let mu = Arc::new(Mutex::new(0));
+
+ let mut fut1 = Box::pin(inc(mu.clone()));
+
+ let mut fut2 = Box::pin(inc(mu.clone()));
+
+ let arc_waker = Arc::new(TestWaker);
+ let waker = waker_ref(&arc_waker);
+ let mut cx = Context::from_waker(&waker);
+
+ // First acquire the lock.
+ let count = block_on(mu.lock());
+
+ // Now poll the futures. Since the lock is acquired they will both get queued in the waiter
+ // list.
+ match fut1.as_mut().poll(&mut cx) {
+ Poll::Pending => {}
+ Poll::Ready(()) => panic!("Future is unexpectedly ready"),
+ }
+
+ match fut2.as_mut().poll(&mut cx) {
+ Poll::Pending => {}
+ Poll::Ready(()) => panic!("Future is unexpectedly ready"),
+ }
+
+ assert_eq!(
+ mu.raw.state.load(Ordering::Relaxed) & WRITER_WAITING,
+ WRITER_WAITING
+ );
+
+ // Drop the lock. This should mark fut1 as ready to make progress.
+ mem::drop(count);
+
+ // Now drop fut1. This should make fut2 ready to make progress.
+ mem::drop(fut1);
+
+ // Since there was still another waiter in the list we shouldn't have cleared the
+ // DESIGNATED_WAKER bit.
+ assert_eq!(
+ mu.raw.state.load(Ordering::Relaxed) & DESIGNATED_WAKER,
+ DESIGNATED_WAKER
+ );
+
+ // Since the waiter was a writer, we should clear the WRITER_WAITING bit.
+ assert_eq!(mu.raw.state.load(Ordering::Relaxed) & WRITER_WAITING, 0);
+
+ match fut2.as_mut().poll(&mut cx) {
+ Poll::Pending => panic!("Future is not ready to make progress"),
+ Poll::Ready(()) => {}
+ }
+
+ // Verify that we only incremented the count once.
+ assert_eq!(*block_on(mu.lock()), 1);
+ assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0);
+ }
+
+ #[test]
+ fn timeout() {
+ async fn timed_lock(timer: oneshot::Receiver<()>, mu: Arc<Mutex<()>>) {
+ select! {
+ res = timer.fuse() => {
+ match res {
+ Ok(()) => {},
+ Err(e) => panic!("Timer unexpectedly canceled: {}", e),
+ }
+ }
+ _ = mu.lock().fuse() => panic!("Successfuly acquired lock"),
+ }
+ }
+
+ let mu = Arc::new(Mutex::new(()));
+ let (tx, rx) = oneshot::channel();
+
+ let mut timeout = Box::pin(timed_lock(rx, mu.clone()));
+
+ let arc_waker = Arc::new(TestWaker);
+ let waker = waker_ref(&arc_waker);
+ let mut cx = Context::from_waker(&waker);
+
+ // Acquire the lock.
+ let g = block_on(mu.lock());
+
+ // Poll the future.
+ if let Poll::Ready(()) = timeout.as_mut().poll(&mut cx) {
+ panic!("timed_lock unexpectedly ready");
+ }
+
+ assert_eq!(
+ mu.raw.state.load(Ordering::Relaxed) & HAS_WAITERS,
+ HAS_WAITERS
+ );
+
+ // Signal the channel, which should cancel the lock.
+ tx.send(()).expect("Failed to send wakeup");
+
+ // Now the future should have completed without acquiring the lock.
+ if let Poll::Pending = timeout.as_mut().poll(&mut cx) {
+ panic!("timed_lock not ready after timeout");
+ }
+
+ // The mutex state should not show any waiters.
+ assert_eq!(mu.raw.state.load(Ordering::Relaxed) & HAS_WAITERS, 0);
+
+ mem::drop(g);
+
+ assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0);
+ }
+
+ #[test]
+ fn writer_waiting() {
+ async fn read_zero(mu: Arc<Mutex<usize>>) {
+ let val = mu.read_lock().await;
+ pending!();
+
+ assert_eq!(*val, 0);
+ }
+
+ async fn inc(mu: Arc<Mutex<usize>>) {
+ *mu.lock().await += 1;
+ }
+
+ async fn read_one(mu: Arc<Mutex<usize>>) {
+ let val = mu.read_lock().await;
+
+ assert_eq!(*val, 1);
+ }
+
+ let mu = Arc::new(Mutex::new(0));
+
+ let mut r1 = Box::pin(read_zero(mu.clone()));
+ let mut r2 = Box::pin(read_zero(mu.clone()));
+
+ let mut w = Box::pin(inc(mu.clone()));
+ let mut r3 = Box::pin(read_one(mu.clone()));
+
+ let arc_waker = Arc::new(TestWaker);
+ let waker = waker_ref(&arc_waker);
+ let mut cx = Context::from_waker(&waker);
+
+ if let Poll::Ready(()) = r1.as_mut().poll(&mut cx) {
+ panic!("read_zero unexpectedly ready");
+ }
+ if let Poll::Ready(()) = r2.as_mut().poll(&mut cx) {
+ panic!("read_zero unexpectedly ready");
+ }
+ assert_eq!(
+ mu.raw.state.load(Ordering::Relaxed) & READ_MASK,
+ 2 * READ_LOCK
+ );
+
+ if let Poll::Ready(()) = w.as_mut().poll(&mut cx) {
+ panic!("inc unexpectedly ready");
+ }
+ assert_eq!(
+ mu.raw.state.load(Ordering::Relaxed) & WRITER_WAITING,
+ WRITER_WAITING
+ );
+
+ // The WRITER_WAITING bit should prevent the next reader from acquiring the lock.
+ if let Poll::Ready(()) = r3.as_mut().poll(&mut cx) {
+ panic!("read_one unexpectedly ready");
+ }
+ assert_eq!(
+ mu.raw.state.load(Ordering::Relaxed) & READ_MASK,
+ 2 * READ_LOCK
+ );
+
+ if let Poll::Pending = r1.as_mut().poll(&mut cx) {
+ panic!("read_zero unable to complete");
+ }
+ if let Poll::Pending = r2.as_mut().poll(&mut cx) {
+ panic!("read_zero unable to complete");
+ }
+ if let Poll::Pending = w.as_mut().poll(&mut cx) {
+ panic!("inc unable to complete");
+ }
+ if let Poll::Pending = r3.as_mut().poll(&mut cx) {
+ panic!("read_one unable to complete");
+ }
+
+ assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0);
+ }
+
+ #[test]
+ fn transfer_notify_one() {
+ async fn read(mu: Arc<Mutex<usize>>, cv: Arc<Condvar>) {
+ let mut count = mu.read_lock().await;
+ while *count == 0 {
+ count = cv.wait_read(count).await;
+ }
+ }
+
+ async fn write(mu: Arc<Mutex<usize>>, cv: Arc<Condvar>) {
+ let mut count = mu.lock().await;
+ while *count == 0 {
+ count = cv.wait(count).await;
+ }
+
+ *count -= 1;
+ }
+
+ let mu = Arc::new(Mutex::new(0));
+ let cv = Arc::new(Condvar::new());
+
+ let arc_waker = Arc::new(TestWaker);
+ let waker = waker_ref(&arc_waker);
+ let mut cx = Context::from_waker(&waker);
+
+ let mut readers = [
+ Box::pin(read(mu.clone(), cv.clone())),
+ Box::pin(read(mu.clone(), cv.clone())),
+ Box::pin(read(mu.clone(), cv.clone())),
+ Box::pin(read(mu.clone(), cv.clone())),
+ ];
+ let mut writer = Box::pin(write(mu.clone(), cv.clone()));
+
+ for r in &mut readers {
+ if let Poll::Ready(()) = r.as_mut().poll(&mut cx) {
+ panic!("reader unexpectedly ready");
+ }
+ }
+ if let Poll::Ready(()) = writer.as_mut().poll(&mut cx) {
+ panic!("writer unexpectedly ready");
+ }
+
+ let mut count = block_on(mu.lock());
+ *count = 1;
+
+ // This should transfer all readers + one writer to the waiter queue.
+ cv.notify_one();
+
+ assert_eq!(
+ mu.raw.state.load(Ordering::Relaxed) & HAS_WAITERS,
+ HAS_WAITERS
+ );
+ assert_eq!(
+ mu.raw.state.load(Ordering::Relaxed) & WRITER_WAITING,
+ WRITER_WAITING
+ );
+
+ mem::drop(count);
+
+ assert_eq!(
+ mu.raw.state.load(Ordering::Relaxed) & (HAS_WAITERS | WRITER_WAITING),
+ HAS_WAITERS | WRITER_WAITING
+ );
+
+ for r in &mut readers {
+ if let Poll::Pending = r.as_mut().poll(&mut cx) {
+ panic!("reader unable to complete");
+ }
+ }
+
+ if let Poll::Pending = writer.as_mut().poll(&mut cx) {
+ panic!("writer unable to complete");
+ }
+
+ assert_eq!(*block_on(mu.read_lock()), 0);
+ }
+
+ #[test]
+ fn transfer_waiters_when_unlocked() {
+ async fn dec(mu: Arc<Mutex<usize>>, cv: Arc<Condvar>) {
+ let mut count = mu.lock().await;
+
+ while *count == 0 {
+ count = cv.wait(count).await;
+ }
+
+ *count -= 1;
+ }
+
+ let mu = Arc::new(Mutex::new(0));
+ let cv = Arc::new(Condvar::new());
+
+ let arc_waker = Arc::new(TestWaker);
+ let waker = waker_ref(&arc_waker);
+ let mut cx = Context::from_waker(&waker);
+
+ let mut futures = [
+ Box::pin(dec(mu.clone(), cv.clone())),
+ Box::pin(dec(mu.clone(), cv.clone())),
+ Box::pin(dec(mu.clone(), cv.clone())),
+ Box::pin(dec(mu.clone(), cv.clone())),
+ ];
+
+ for f in &mut futures {
+ if let Poll::Ready(()) = f.as_mut().poll(&mut cx) {
+ panic!("future unexpectedly ready");
+ }
+ }
+
+ *block_on(mu.lock()) = futures.len();
+ cv.notify_all();
+
+ // Since the lock is not held, instead of transferring the waiters to the waiter list we
+ // should just wake them all up.
+ assert_eq!(mu.raw.state.load(Ordering::Relaxed) & HAS_WAITERS, 0);
+
+ for f in &mut futures {
+ if let Poll::Pending = f.as_mut().poll(&mut cx) {
+ panic!("future unexpectedly ready");
+ }
+ }
+ assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0);
+ }
+
+ #[test]
+ fn transfer_reader_writer() {
+ async fn read(mu: Arc<Mutex<usize>>, cv: Arc<Condvar>) {
+ let mut count = mu.read_lock().await;
+ while *count == 0 {
+ count = cv.wait_read(count).await;
+ }
+
+ // Yield once while holding the read lock, which should prevent the writer from waking
+ // up.
+ pending!();
+ }
+
+ async fn write(mu: Arc<Mutex<usize>>, cv: Arc<Condvar>) {
+ let mut count = mu.lock().await;
+ while *count == 0 {
+ count = cv.wait(count).await;
+ }
+
+ *count -= 1;
+ }
+
+ async fn lock(mu: Arc<Mutex<usize>>) {
+ mem::drop(mu.lock().await);
+ }
+
+ let mu = Arc::new(Mutex::new(0));
+ let cv = Arc::new(Condvar::new());
+
+ let arc_waker = Arc::new(TestWaker);
+ let waker = waker_ref(&arc_waker);
+ let mut cx = Context::from_waker(&waker);
+
+ let mut futures: [Pin<Box<dyn Future<Output = ()>>>; 5] = [
+ Box::pin(read(mu.clone(), cv.clone())),
+ Box::pin(read(mu.clone(), cv.clone())),
+ Box::pin(read(mu.clone(), cv.clone())),
+ Box::pin(write(mu.clone(), cv.clone())),
+ Box::pin(read(mu.clone(), cv.clone())),
+ ];
+ const NUM_READERS: usize = 4;
+
+ let mut l = Box::pin(lock(mu.clone()));
+
+ for f in &mut futures {
+ if let Poll::Ready(()) = f.as_mut().poll(&mut cx) {
+ panic!("future unexpectedly ready");
+ }
+ }
+
+ assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0);
+
+ let mut count = block_on(mu.lock());
+ *count = 1;
+
+ // Now poll the lock function. Since the lock is held by us, it will get queued on the
+ // waiter list.
+ if let Poll::Ready(()) = l.as_mut().poll(&mut cx) {
+ panic!("lock() unexpectedly ready");
+ }
+
+ assert_eq!(
+ mu.raw.state.load(Ordering::Relaxed) & (HAS_WAITERS | WRITER_WAITING),
+ HAS_WAITERS | WRITER_WAITING
+ );
+
+ // Wake up waiters while holding the lock. This should end with them transferred to the
+ // mutex's waiter list.
+ cv.notify_all();
+
+ // Drop the lock. This should wake up the lock function.
+ mem::drop(count);
+
+ if let Poll::Pending = l.as_mut().poll(&mut cx) {
+ panic!("lock() unable to complete");
+ }
+
+ assert_eq!(
+ mu.raw.state.load(Ordering::Relaxed) & (HAS_WAITERS | WRITER_WAITING),
+ HAS_WAITERS | WRITER_WAITING
+ );
+
+ // Poll everything again. The readers should be able to make progress (but not complete) but
+ // the writer should be blocked.
+ for f in &mut futures {
+ if let Poll::Ready(()) = f.as_mut().poll(&mut cx) {
+ panic!("future unexpectedly ready");
+ }
+ }
+
+ assert_eq!(
+ mu.raw.state.load(Ordering::Relaxed) & READ_MASK,
+ READ_LOCK * NUM_READERS
+ );
+
+ // All the readers can now finish but the writer needs to be polled again.
+ let mut needs_poll = None;
+ for (i, r) in futures.iter_mut().enumerate() {
+ match r.as_mut().poll(&mut cx) {
+ Poll::Ready(()) => {}
+ Poll::Pending => {
+ if needs_poll.is_some() {
+ panic!("More than one future unable to complete");
+ }
+ needs_poll = Some(i);
+ }
+ }
+ }
+
+ if let Poll::Pending = futures[needs_poll.expect("Writer unexpectedly able to complete")]
+ .as_mut()
+ .poll(&mut cx)
+ {
+ panic!("Writer unable to complete");
+ }
+
+ assert_eq!(*block_on(mu.lock()), 0);
+ assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0);
+ }
+
+ #[test]
+ fn transfer_readers_with_read_lock() {
+ async fn read(mu: Arc<Mutex<usize>>, cv: Arc<Condvar>) {
+ let mut count = mu.read_lock().await;
+ while *count == 0 {
+ count = cv.wait_read(count).await;
+ }
+
+ // Yield once while holding the read lock.
+ pending!();
+ }
+
+ let mu = Arc::new(Mutex::new(0));
+ let cv = Arc::new(Condvar::new());
+
+ let arc_waker = Arc::new(TestWaker);
+ let waker = waker_ref(&arc_waker);
+ let mut cx = Context::from_waker(&waker);
+
+ let mut futures = [
+ Box::pin(read(mu.clone(), cv.clone())),
+ Box::pin(read(mu.clone(), cv.clone())),
+ Box::pin(read(mu.clone(), cv.clone())),
+ Box::pin(read(mu.clone(), cv.clone())),
+ ];
+
+ for f in &mut futures {
+ if let Poll::Ready(()) = f.as_mut().poll(&mut cx) {
+ panic!("future unexpectedly ready");
+ }
+ }
+
+ // Increment the count and then grab a read lock.
+ *block_on(mu.lock()) = 1;
+
+ let g = block_on(mu.read_lock());
+
+ // Notify the condvar while holding the read lock. This should wake up all the waiters
+ // rather than just transferring them.
+ cv.notify_all();
+ assert_eq!(mu.raw.state.load(Ordering::Relaxed) & HAS_WAITERS, 0);
+
+ mem::drop(g);
+
+ for f in &mut futures {
+ if let Poll::Ready(()) = f.as_mut().poll(&mut cx) {
+ panic!("future unexpectedly ready");
+ }
+ }
+ assert_eq!(
+ mu.raw.state.load(Ordering::Relaxed),
+ READ_LOCK * futures.len()
+ );
+
+ for f in &mut futures {
+ if let Poll::Pending = f.as_mut().poll(&mut cx) {
+ panic!("future unable to complete");
+ }
+ }
+
+ assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0);
+ }
+}
diff --git a/src/sync/spin.rs b/src/sync/spin.rs
new file mode 100644
index 0000000..0687bb7
--- /dev/null
+++ b/src/sync/spin.rs
@@ -0,0 +1,271 @@
+// Copyright 2020 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+use std::cell::UnsafeCell;
+use std::ops::{Deref, DerefMut};
+use std::sync::atomic::{spin_loop_hint, AtomicBool, Ordering};
+
+const UNLOCKED: bool = false;
+const LOCKED: bool = true;
+
+/// A primitive that provides safe, mutable access to a shared resource.
+///
+/// Unlike `Mutex`, a `SpinLock` will not voluntarily yield its CPU time until the resource is
+/// available and will instead keep spinning until the resource is acquired. For the vast majority
+/// of cases, `Mutex` is a better choice than `SpinLock`. If a `SpinLock` must be used then users
+/// should try to do as little work as possible while holding the `SpinLock` and avoid any sort of
+/// blocking at all costs as it can severely penalize performance.
+///
+/// # Poisoning
+///
+/// This `SpinLock` does not implement lock poisoning so it is possible for threads to access
+/// poisoned data if a thread panics while holding the lock. If lock poisoning is needed, it can be
+/// implemented by wrapping the `SpinLock` in a new type that implements poisoning. See the
+/// implementation of `std::sync::Mutex` for an example of how to do this.
+pub struct SpinLock<T: ?Sized> {
+ lock: AtomicBool,
+ value: UnsafeCell<T>,
+}
+
+impl<T> SpinLock<T> {
+ /// Creates a new, unlocked `SpinLock` that's ready for use.
+ pub fn new(value: T) -> SpinLock<T> {
+ SpinLock {
+ lock: AtomicBool::new(UNLOCKED),
+ value: UnsafeCell::new(value),
+ }
+ }
+
+ /// Consumes the `SpinLock` and returns the value guarded by it. This method doesn't perform any
+ /// locking as the compiler guarantees that there are no references to `self`.
+ pub fn into_inner(self) -> T {
+ // No need to take the lock because the compiler can statically guarantee
+ // that there are no references to the SpinLock.
+ self.value.into_inner()
+ }
+}
+
+impl<T: ?Sized> SpinLock<T> {
+ /// Acquires exclusive, mutable access to the resource protected by the `SpinLock`, blocking the
+ /// current thread until it is able to do so. Upon returning, the current thread will be the
+ /// only thread with access to the resource. The `SpinLock` will be released when the returned
+ /// `SpinLockGuard` is dropped. Attempting to call `lock` while already holding the `SpinLock`
+ /// will cause a deadlock.
+ pub fn lock(&self) -> SpinLockGuard<T> {
+ while self
+ .lock
+ .compare_exchange_weak(UNLOCKED, LOCKED, Ordering::Acquire, Ordering::Relaxed)
+ .is_err()
+ {
+ spin_loop_hint();
+ }
+
+ SpinLockGuard {
+ lock: self,
+ value: unsafe { &mut *self.value.get() },
+ }
+ }
+
+ fn unlock(&self) {
+ // Don't need to compare and swap because we exclusively hold the lock.
+ self.lock.store(UNLOCKED, Ordering::Release);
+ }
+
+ /// Returns a mutable reference to the contained value. This method doesn't perform any locking
+ /// as the compiler will statically guarantee that there are no other references to `self`.
+ pub fn get_mut(&mut self) -> &mut T {
+ // Safe because the compiler can statically guarantee that there are no other references to
+ // `self`. This is also why we don't need to acquire the lock.
+ unsafe { &mut *self.value.get() }
+ }
+}
+
+unsafe impl<T: ?Sized + Send> Send for SpinLock<T> {}
+unsafe impl<T: ?Sized + Send> Sync for SpinLock<T> {}
+
+impl<T: ?Sized + Default> Default for SpinLock<T> {
+ fn default() -> Self {
+ Self::new(Default::default())
+ }
+}
+
+impl<T> From<T> for SpinLock<T> {
+ fn from(source: T) -> Self {
+ Self::new(source)
+ }
+}
+
+/// An RAII implementation of a "scoped lock" for a `SpinLock`. When this structure is dropped, the
+/// lock will be released. The resource protected by the `SpinLock` can be accessed via the `Deref`
+/// and `DerefMut` implementations of this structure.
+pub struct SpinLockGuard<'a, T: 'a + ?Sized> {
+ lock: &'a SpinLock<T>,
+ value: &'a mut T,
+}
+
+impl<'a, T: ?Sized> Deref for SpinLockGuard<'a, T> {
+ type Target = T;
+ fn deref(&self) -> &T {
+ self.value
+ }
+}
+
+impl<'a, T: ?Sized> DerefMut for SpinLockGuard<'a, T> {
+ fn deref_mut(&mut self) -> &mut T {
+ self.value
+ }
+}
+
+impl<'a, T: ?Sized> Drop for SpinLockGuard<'a, T> {
+ fn drop(&mut self) {
+ self.lock.unlock();
+ }
+}
+
+#[cfg(test)]
+mod test {
+ use super::*;
+
+ use std::mem;
+ use std::sync::atomic::{AtomicUsize, Ordering};
+ use std::sync::Arc;
+ use std::thread;
+
+ #[derive(PartialEq, Eq, Debug)]
+ struct NonCopy(u32);
+
+ #[test]
+ fn it_works() {
+ let sl = SpinLock::new(NonCopy(13));
+
+ assert_eq!(*sl.lock(), NonCopy(13));
+ }
+
+ #[test]
+ fn smoke() {
+ let sl = SpinLock::new(NonCopy(7));
+
+ mem::drop(sl.lock());
+ mem::drop(sl.lock());
+ }
+
+ #[test]
+ fn send() {
+ let sl = SpinLock::new(NonCopy(19));
+
+ thread::spawn(move || {
+ let value = sl.lock();
+ assert_eq!(*value, NonCopy(19));
+ })
+ .join()
+ .unwrap();
+ }
+
+ #[test]
+ fn high_contention() {
+ const THREADS: usize = 23;
+ const ITERATIONS: usize = 101;
+
+ let mut threads = Vec::with_capacity(THREADS);
+
+ let sl = Arc::new(SpinLock::new(0usize));
+ for _ in 0..THREADS {
+ let sl2 = sl.clone();
+ threads.push(thread::spawn(move || {
+ for _ in 0..ITERATIONS {
+ *sl2.lock() += 1;
+ }
+ }));
+ }
+
+ for t in threads.into_iter() {
+ t.join().unwrap();
+ }
+
+ assert_eq!(*sl.lock(), THREADS * ITERATIONS);
+ }
+
+ #[test]
+ fn get_mut() {
+ let mut sl = SpinLock::new(NonCopy(13));
+ *sl.get_mut() = NonCopy(17);
+
+ assert_eq!(sl.into_inner(), NonCopy(17));
+ }
+
+ #[test]
+ fn into_inner() {
+ let sl = SpinLock::new(NonCopy(29));
+ assert_eq!(sl.into_inner(), NonCopy(29));
+ }
+
+ #[test]
+ fn into_inner_drop() {
+ struct NeedsDrop(Arc<AtomicUsize>);
+ impl Drop for NeedsDrop {
+ fn drop(&mut self) {
+ self.0.fetch_add(1, Ordering::AcqRel);
+ }
+ }
+
+ let value = Arc::new(AtomicUsize::new(0));
+ let needs_drop = SpinLock::new(NeedsDrop(value.clone()));
+ assert_eq!(value.load(Ordering::Acquire), 0);
+
+ {
+ let inner = needs_drop.into_inner();
+ assert_eq!(inner.0.load(Ordering::Acquire), 0);
+ }
+
+ assert_eq!(value.load(Ordering::Acquire), 1);
+ }
+
+ #[test]
+ fn arc_nested() {
+ // Tests nested sltexes and access to underlying data.
+ let sl = SpinLock::new(1);
+ let arc = Arc::new(SpinLock::new(sl));
+ thread::spawn(move || {
+ let nested = arc.lock();
+ let lock2 = nested.lock();
+ assert_eq!(*lock2, 1);
+ })
+ .join()
+ .unwrap();
+ }
+
+ #[test]
+ fn arc_access_in_unwind() {
+ let arc = Arc::new(SpinLock::new(1));
+ let arc2 = arc.clone();
+ thread::spawn(move || {
+ struct Unwinder {
+ i: Arc<SpinLock<i32>>,
+ }
+ impl Drop for Unwinder {
+ fn drop(&mut self) {
+ *self.i.lock() += 1;
+ }
+ }
+ let _u = Unwinder { i: arc2 };
+ panic!();
+ })
+ .join()
+ .expect_err("thread did not panic");
+ let lock = arc.lock();
+ assert_eq!(*lock, 2);
+ }
+
+ #[test]
+ fn unsized_value() {
+ let sltex: &SpinLock<[i32]> = &SpinLock::new([1, 2, 3]);
+ {
+ let b = &mut *sltex.lock();
+ b[0] = 4;
+ b[2] = 5;
+ }
+ let expected: &[i32] = &[4, 2, 5];
+ assert_eq!(&*sltex.lock(), expected);
+ }
+}
diff --git a/src/sync/waiter.rs b/src/sync/waiter.rs
new file mode 100644
index 0000000..48335c7
--- /dev/null
+++ b/src/sync/waiter.rs
@@ -0,0 +1,317 @@
+// Copyright 2020 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+use std::cell::UnsafeCell;
+use std::future::Future;
+use std::mem;
+use std::pin::Pin;
+use std::ptr::NonNull;
+use std::sync::atomic::{AtomicBool, AtomicU8, Ordering};
+use std::sync::Arc;
+use std::task::{Context, Poll, Waker};
+
+use intrusive_collections::linked_list::{LinkedList, LinkedListOps};
+use intrusive_collections::{intrusive_adapter, DefaultLinkOps, LinkOps};
+
+use crate::sync::SpinLock;
+
+// An atomic version of a LinkedListLink. See https://github.com/Amanieu/intrusive-rs/issues/47 for
+// more details.
+pub struct AtomicLink {
+ prev: UnsafeCell<Option<NonNull<AtomicLink>>>,
+ next: UnsafeCell<Option<NonNull<AtomicLink>>>,
+ linked: AtomicBool,
+}
+
+impl AtomicLink {
+ fn new() -> AtomicLink {
+ AtomicLink {
+ linked: AtomicBool::new(false),
+ prev: UnsafeCell::new(None),
+ next: UnsafeCell::new(None),
+ }
+ }
+
+ fn is_linked(&self) -> bool {
+ self.linked.load(Ordering::Relaxed)
+ }
+}
+
+impl DefaultLinkOps for AtomicLink {
+ type Ops = AtomicLinkOps;
+
+ const NEW: Self::Ops = AtomicLinkOps;
+}
+
+// Safe because the only way to mutate `AtomicLink` is via the `LinkedListOps` trait whose methods
+// are all unsafe and require that the caller has first called `acquire_link` (and had it return
+// true) to use them safely.
+unsafe impl Send for AtomicLink {}
+unsafe impl Sync for AtomicLink {}
+
+#[derive(Copy, Clone, Default)]
+pub struct AtomicLinkOps;
+
+unsafe impl LinkOps for AtomicLinkOps {
+ type LinkPtr = NonNull<AtomicLink>;
+
+ unsafe fn acquire_link(&mut self, ptr: Self::LinkPtr) -> bool {
+ !ptr.as_ref().linked.swap(true, Ordering::Acquire)
+ }
+
+ unsafe fn release_link(&mut self, ptr: Self::LinkPtr) {
+ ptr.as_ref().linked.store(false, Ordering::Release)
+ }
+}
+
+unsafe impl LinkedListOps for AtomicLinkOps {
+ unsafe fn next(&self, ptr: Self::LinkPtr) -> Option<Self::LinkPtr> {
+ *ptr.as_ref().next.get()
+ }
+
+ unsafe fn prev(&self, ptr: Self::LinkPtr) -> Option<Self::LinkPtr> {
+ *ptr.as_ref().prev.get()
+ }
+
+ unsafe fn set_next(&mut self, ptr: Self::LinkPtr, next: Option<Self::LinkPtr>) {
+ *ptr.as_ref().next.get() = next;
+ }
+
+ unsafe fn set_prev(&mut self, ptr: Self::LinkPtr, prev: Option<Self::LinkPtr>) {
+ *ptr.as_ref().prev.get() = prev;
+ }
+}
+
+#[derive(Clone, Copy)]
+pub enum Kind {
+ Shared,
+ Exclusive,
+}
+
+enum State {
+ Init,
+ Waiting(Waker),
+ Woken,
+ Finished,
+ Processing,
+}
+
+// Indicates the queue to which the waiter belongs. It is the responsibility of the Mutex and
+// Condvar implementations to update this value when adding/removing a Waiter from their respective
+// waiter lists.
+#[repr(u8)]
+#[derive(Debug, Eq, PartialEq)]
+pub enum WaitingFor {
+ // The waiter is either not linked into a waiter list or it is linked into a temporary list.
+ None = 0,
+ // The waiter is linked into the Mutex's waiter list.
+ Mutex = 1,
+ // The waiter is linked into the Condvar's waiter list.
+ Condvar = 2,
+}
+
+// Internal struct used to keep track of the cancellation function.
+struct Cancel {
+ c: fn(usize, &Waiter, bool) -> bool,
+ data: usize,
+}
+
+// Represents a thread currently blocked on a Condvar or on acquiring a Mutex.
+pub struct Waiter {
+ link: AtomicLink,
+ state: SpinLock<State>,
+ cancel: SpinLock<Cancel>,
+ kind: Kind,
+ waiting_for: AtomicU8,
+}
+
+impl Waiter {
+ // Create a new, initialized Waiter.
+ //
+ // `kind` should indicate whether this waiter represent a thread that is waiting for a shared
+ // lock or an exclusive lock.
+ //
+ // `cancel` is the function that is called when a `WaitFuture` (returned by the `wait()`
+ // function) is dropped before it can complete. `cancel_data` is used as the first parameter of
+ // the `cancel` function. The second parameter is the `Waiter` that was canceled and the third
+ // parameter indicates whether the `WaitFuture` was dropped after it was woken (but before it
+ // was polled to completion). The `cancel` function should return true if it was able to
+ // successfully process the cancellation. One reason why a `cancel` function may return false is
+ // if the `Waiter` was transferred to a different waiter list after the cancel function was
+ // called but before it was able to run. In this case, it is expected that the new waiter list
+ // updated the cancel function (by calling `set_cancel`) and the cancellation will be retried by
+ // fetching and calling the new cancellation function.
+ //
+ // `waiting_for` indicates the waiter list to which this `Waiter` will be added. See the
+ // documentation of the `WaitingFor` enum for the meaning of the different values.
+ pub fn new(
+ kind: Kind,
+ cancel: fn(usize, &Waiter, bool) -> bool,
+ cancel_data: usize,
+ waiting_for: WaitingFor,
+ ) -> Waiter {
+ Waiter {
+ link: AtomicLink::new(),
+ state: SpinLock::new(State::Init),
+ cancel: SpinLock::new(Cancel {
+ c: cancel,
+ data: cancel_data,
+ }),
+ kind,
+ waiting_for: AtomicU8::new(waiting_for as u8),
+ }
+ }
+
+ // The kind of lock that this `Waiter` is waiting to acquire.
+ pub fn kind(&self) -> Kind {
+ self.kind
+ }
+
+ // Returns true if this `Waiter` is currently linked into a waiter list.
+ pub fn is_linked(&self) -> bool {
+ self.link.is_linked()
+ }
+
+ // Indicates the waiter list to which this `Waiter` belongs.
+ pub fn is_waiting_for(&self) -> WaitingFor {
+ match self.waiting_for.load(Ordering::Acquire) {
+ 0 => WaitingFor::None,
+ 1 => WaitingFor::Mutex,
+ 2 => WaitingFor::Condvar,
+ v => panic!("Unknown value for `WaitingFor`: {}", v),
+ }
+ }
+
+ // Change the waiter list to which this `Waiter` belongs. This will panic if called when the
+ // `Waiter` is still linked into a waiter list.
+ pub fn set_waiting_for(&self, waiting_for: WaitingFor) {
+ self.waiting_for.store(waiting_for as u8, Ordering::Release);
+ }
+
+ // Change the cancellation function that this `Waiter` should use. This will panic if called
+ // when the `Waiter` is still linked into a waiter list.
+ pub fn set_cancel(&self, c: fn(usize, &Waiter, bool) -> bool, data: usize) {
+ debug_assert!(
+ !self.is_linked(),
+ "Cannot change cancellation function while linked"
+ );
+ let mut cancel = self.cancel.lock();
+ cancel.c = c;
+ cancel.data = data;
+ }
+
+ // Reset the Waiter back to its initial state. Panics if this `Waiter` is still linked into a
+ // waiter list.
+ pub fn reset(&self, waiting_for: WaitingFor) {
+ debug_assert!(!self.is_linked(), "Cannot reset `Waiter` while linked");
+ self.set_waiting_for(waiting_for);
+
+ let mut state = self.state.lock();
+ if let State::Waiting(waker) = mem::replace(&mut *state, State::Init) {
+ mem::drop(state);
+ mem::drop(waker);
+ }
+ }
+
+ // Wait until woken up by another thread.
+ pub fn wait(&self) -> WaitFuture<'_> {
+ WaitFuture { waiter: self }
+ }
+
+ // Wake up the thread associated with this `Waiter`. Panics if `waiting_for()` does not return
+ // `WaitingFor::None` or if `is_linked()` returns true.
+ pub fn wake(&self) {
+ debug_assert!(!self.is_linked(), "Cannot wake `Waiter` while linked");
+ debug_assert_eq!(self.is_waiting_for(), WaitingFor::None);
+
+ let mut state = self.state.lock();
+
+ if let State::Waiting(waker) = mem::replace(&mut *state, State::Woken) {
+ mem::drop(state);
+ waker.wake();
+ }
+ }
+}
+
+pub struct WaitFuture<'w> {
+ waiter: &'w Waiter,
+}
+
+impl<'w> Future for WaitFuture<'w> {
+ type Output = ();
+
+ fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
+ let mut state = self.waiter.state.lock();
+
+ match mem::replace(&mut *state, State::Processing) {
+ State::Init => {
+ *state = State::Waiting(cx.waker().clone());
+
+ Poll::Pending
+ }
+ State::Waiting(old_waker) => {
+ *state = State::Waiting(cx.waker().clone());
+ mem::drop(state);
+ mem::drop(old_waker);
+
+ Poll::Pending
+ }
+ State::Woken => {
+ *state = State::Finished;
+ Poll::Ready(())
+ }
+ State::Finished => {
+ panic!("Future polled after returning Poll::Ready");
+ }
+ State::Processing => {
+ panic!("Unexpected waker state");
+ }
+ }
+ }
+}
+
+impl<'w> Drop for WaitFuture<'w> {
+ fn drop(&mut self) {
+ let state = self.waiter.state.lock();
+
+ match *state {
+ State::Finished => {}
+ State::Processing => panic!("Unexpected waker state"),
+ State::Woken => {
+ mem::drop(state);
+
+ // We were woken but not polled. Wake up the next waiter.
+ let mut success = false;
+ while !success {
+ let cancel = self.waiter.cancel.lock();
+ let c = cancel.c;
+ let data = cancel.data;
+
+ mem::drop(cancel);
+
+ success = c(data, self.waiter, true);
+ }
+ }
+ _ => {
+ mem::drop(state);
+
+ // Not woken. No need to wake up any waiters.
+ let mut success = false;
+ while !success {
+ let cancel = self.waiter.cancel.lock();
+ let c = cancel.c;
+ let data = cancel.data;
+
+ mem::drop(cancel);
+
+ success = c(data, self.waiter, false);
+ }
+ }
+ }
+ }
+}
+
+intrusive_adapter!(pub WaiterAdapter = Arc<Waiter>: Waiter { link: AtomicLink });
+
+pub type WaiterList = LinkedList<WaiterAdapter>;
diff --git a/src/syslog.rs b/src/syslog.rs
new file mode 100644
index 0000000..0a12ae2
--- /dev/null
+++ b/src/syslog.rs
@@ -0,0 +1,73 @@
+// Copyright 2018 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+use std::ffi::CStr;
+
+use log::{self, Level, LevelFilter, Metadata, Record, SetLoggerError};
+
+static LOGGER: SyslogLogger = SyslogLogger;
+
+struct SyslogLogger;
+
+impl log::Log for SyslogLogger {
+ fn enabled(&self, metadata: &Metadata) -> bool {
+ if cfg!(debug_assertions) {
+ metadata.level() <= Level::Debug
+ } else {
+ metadata.level() <= Level::Info
+ }
+ }
+
+ fn log(&self, record: &Record) {
+ if !self.enabled(&record.metadata()) {
+ return;
+ }
+
+ let level = match record.level() {
+ Level::Error => libc::LOG_ERR,
+ Level::Warn => libc::LOG_WARNING,
+ Level::Info => libc::LOG_INFO,
+ Level::Debug => libc::LOG_DEBUG,
+ Level::Trace => libc::LOG_DEBUG,
+ };
+
+ let msg = format!("{}\0", record.args());
+ let cmsg = if let Ok(m) = CStr::from_bytes_with_nul(msg.as_bytes()) {
+ m
+ } else {
+ // For now we just drop messages with interior nuls.
+ return;
+ };
+
+ // Safe because this doesn't modify any memory. There's not much use
+ // in checking the return value because this _is_ the logging function
+ // so there's no way for us to tell anyone about the error.
+ unsafe {
+ libc::syslog(level, cmsg.as_ptr());
+ }
+ }
+ fn flush(&self) {}
+}
+
+/// Initializes the logger to send log messages to syslog.
+pub fn init(ident: &'static CStr) -> Result<(), SetLoggerError> {
+ // Safe because this only modifies libc's internal state and is safe to call
+ // multiple times.
+ unsafe {
+ libc::openlog(
+ ident.as_ptr(),
+ libc::LOG_NDELAY | libc::LOG_PID,
+ libc::LOG_USER,
+ )
+ };
+ log::set_logger(&LOGGER)?;
+ let level = if cfg!(debug_assertions) {
+ LevelFilter::Debug
+ } else {
+ LevelFilter::Info
+ };
+ log::set_max_level(level);
+
+ Ok(())
+}
diff --git a/src/vsock.rs b/src/vsock.rs
new file mode 100644
index 0000000..634591e
--- /dev/null
+++ b/src/vsock.rs
@@ -0,0 +1,491 @@
+// Copyright 2018 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+/// Support for virtual sockets.
+use std::fmt;
+use std::io;
+use std::mem::{self, size_of};
+use std::num::ParseIntError;
+use std::os::raw::{c_uchar, c_uint, c_ushort};
+use std::os::unix::io::{AsRawFd, IntoRawFd, RawFd};
+use std::result;
+use std::str::FromStr;
+
+use libc::{
+ self, c_void, sa_family_t, size_t, sockaddr, socklen_t, F_GETFL, F_SETFL, O_NONBLOCK,
+ VMADDR_CID_ANY, VMADDR_CID_HOST, VMADDR_CID_HYPERVISOR,
+};
+
+// The domain for vsock sockets.
+const AF_VSOCK: sa_family_t = 40;
+
+// Vsock loopback address.
+const VMADDR_CID_LOCAL: c_uint = 1;
+
+/// Vsock equivalent of binding on port 0. Binds to a random port.
+pub const VMADDR_PORT_ANY: c_uint = c_uint::max_value();
+
+// The number of bytes of padding to be added to the sockaddr_vm struct. Taken directly
+// from linux/vm_sockets.h.
+const PADDING: usize = size_of::<sockaddr>()
+ - size_of::<sa_family_t>()
+ - size_of::<c_ushort>()
+ - (2 * size_of::<c_uint>());
+
+#[repr(C)]
+#[derive(Default)]
+struct sockaddr_vm {
+ svm_family: sa_family_t,
+ svm_reserved1: c_ushort,
+ svm_port: c_uint,
+ svm_cid: c_uint,
+ svm_zero: [c_uchar; PADDING],
+}
+
+#[derive(Debug)]
+pub struct AddrParseError;
+
+impl fmt::Display for AddrParseError {
+ fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
+ write!(fmt, "failed to parse vsock address")
+ }
+}
+
+/// The vsock equivalent of an IP address.
+#[derive(Debug, Copy, Clone, Hash, Eq, PartialEq)]
+pub enum VsockCid {
+ /// Vsock equivalent of INADDR_ANY. Indicates the context id of the current endpoint.
+ Any,
+ /// An address that refers to the bare-metal machine that serves as the hypervisor.
+ Hypervisor,
+ /// The loopback address.
+ Local,
+ /// The parent machine. It may not be the hypervisor for nested VMs.
+ Host,
+ /// An assigned CID that serves as the address for VSOCK.
+ Cid(c_uint),
+}
+
+impl fmt::Display for VsockCid {
+ fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
+ match &self {
+ VsockCid::Any => write!(fmt, "Any"),
+ VsockCid::Hypervisor => write!(fmt, "Hypervisor"),
+ VsockCid::Local => write!(fmt, "Local"),
+ VsockCid::Host => write!(fmt, "Host"),
+ VsockCid::Cid(c) => write!(fmt, "'{}'", c),
+ }
+ }
+}
+
+impl From<c_uint> for VsockCid {
+ fn from(c: c_uint) -> Self {
+ match c {
+ VMADDR_CID_ANY => VsockCid::Any,
+ VMADDR_CID_HYPERVISOR => VsockCid::Hypervisor,
+ VMADDR_CID_LOCAL => VsockCid::Local,
+ VMADDR_CID_HOST => VsockCid::Host,
+ _ => VsockCid::Cid(c),
+ }
+ }
+}
+
+impl FromStr for VsockCid {
+ type Err = ParseIntError;
+
+ fn from_str(s: &str) -> Result<Self, Self::Err> {
+ let c: c_uint = s.parse()?;
+ Ok(c.into())
+ }
+}
+
+impl Into<c_uint> for VsockCid {
+ fn into(self) -> c_uint {
+ match self {
+ VsockCid::Any => VMADDR_CID_ANY,
+ VsockCid::Hypervisor => VMADDR_CID_HYPERVISOR,
+ VsockCid::Local => VMADDR_CID_LOCAL,
+ VsockCid::Host => VMADDR_CID_HOST,
+ VsockCid::Cid(c) => c,
+ }
+ }
+}
+
+/// An address associated with a virtual socket.
+#[derive(Debug, Copy, Clone, Hash, Eq, PartialEq)]
+pub struct SocketAddr {
+ pub cid: VsockCid,
+ pub port: c_uint,
+}
+
+pub trait ToSocketAddr {
+ fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError>;
+}
+
+impl ToSocketAddr for SocketAddr {
+ fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> {
+ Ok(*self)
+ }
+}
+
+impl ToSocketAddr for str {
+ fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> {
+ self.parse()
+ }
+}
+
+impl ToSocketAddr for (VsockCid, c_uint) {
+ fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> {
+ let (cid, port) = *self;
+ Ok(SocketAddr { cid, port })
+ }
+}
+
+impl<'a, T: ToSocketAddr + ?Sized> ToSocketAddr for &'a T {
+ fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> {
+ (**self).to_socket_addr()
+ }
+}
+
+impl FromStr for SocketAddr {
+ type Err = AddrParseError;
+
+ /// Parse a vsock SocketAddr from a string. vsock socket addresses are of the form
+ /// "vsock:cid:port".
+ fn from_str(s: &str) -> Result<SocketAddr, AddrParseError> {
+ let components: Vec<&str> = s.split(':').collect();
+ if components.len() != 3 || components[0] != "vsock" {
+ return Err(AddrParseError);
+ }
+
+ Ok(SocketAddr {
+ cid: components[1].parse().map_err(|_| AddrParseError)?,
+ port: components[2].parse().map_err(|_| AddrParseError)?,
+ })
+ }
+}
+
+impl fmt::Display for SocketAddr {
+ fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
+ write!(fmt, "{}:{}", self.cid, self.port)
+ }
+}
+
+/// Sets `fd` to be blocking or nonblocking. `fd` must be a valid fd of a type that accepts the
+/// `O_NONBLOCK` flag. This includes regular files, pipes, and sockets.
+unsafe fn set_nonblocking(fd: RawFd, nonblocking: bool) -> io::Result<()> {
+ let flags = libc::fcntl(fd, F_GETFL, 0);
+ if flags < 0 {
+ return Err(io::Error::last_os_error());
+ }
+
+ let flags = if nonblocking {
+ flags | O_NONBLOCK
+ } else {
+ flags & !O_NONBLOCK
+ };
+
+ let ret = libc::fcntl(fd, F_SETFL, flags);
+ if ret < 0 {
+ return Err(io::Error::last_os_error());
+ }
+
+ Ok(())
+}
+
+/// A virtual socket.
+///
+/// Do not use this class unless you need to change socket options or query the
+/// state of the socket prior to calling listen or connect. Instead use either VsockStream or
+/// VsockListener.
+#[derive(Debug)]
+pub struct VsockSocket {
+ fd: RawFd,
+}
+
+impl VsockSocket {
+ pub fn new() -> io::Result<Self> {
+ let fd = unsafe { libc::socket(libc::AF_VSOCK, libc::SOCK_STREAM | libc::SOCK_CLOEXEC, 0) };
+ if fd < 0 {
+ Err(io::Error::last_os_error())
+ } else {
+ Ok(VsockSocket { fd })
+ }
+ }
+
+ pub fn bind<A: ToSocketAddr>(&mut self, addr: A) -> io::Result<()> {
+ let sockaddr = addr
+ .to_socket_addr()
+ .map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?;
+
+ // The compiler should optimize this out since these are both compile-time constants.
+ assert_eq!(size_of::<sockaddr_vm>(), size_of::<sockaddr>());
+
+ let mut svm: sockaddr_vm = Default::default();
+ svm.svm_family = AF_VSOCK;
+ svm.svm_cid = sockaddr.cid.into();
+ svm.svm_port = sockaddr.port;
+
+ // Safe because this doesn't modify any memory and we check the return value.
+ let ret = unsafe {
+ libc::bind(
+ self.fd,
+ &svm as *const sockaddr_vm as *const sockaddr,
+ size_of::<sockaddr_vm>() as socklen_t,
+ )
+ };
+ if ret < 0 {
+ let bind_err = io::Error::last_os_error();
+ Err(bind_err)
+ } else {
+ Ok(())
+ }
+ }
+
+ pub fn connect<A: ToSocketAddr>(self, addr: A) -> io::Result<VsockStream> {
+ let sockaddr = addr
+ .to_socket_addr()
+ .map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?;
+
+ let mut svm: sockaddr_vm = Default::default();
+ svm.svm_family = AF_VSOCK;
+ svm.svm_cid = sockaddr.cid.into();
+ svm.svm_port = sockaddr.port;
+
+ // Safe because this just connects a vsock socket, and the return value is checked.
+ let ret = unsafe {
+ libc::connect(
+ self.fd,
+ &svm as *const sockaddr_vm as *const sockaddr,
+ size_of::<sockaddr_vm>() as socklen_t,
+ )
+ };
+ if ret < 0 {
+ let connect_err = io::Error::last_os_error();
+ Err(connect_err)
+ } else {
+ Ok(VsockStream { sock: self })
+ }
+ }
+
+ pub fn listen(self) -> io::Result<VsockListener> {
+ // Safe because this doesn't modify any memory and we check the return value.
+ let ret = unsafe { libc::listen(self.fd, 1) };
+ if ret < 0 {
+ let listen_err = io::Error::last_os_error();
+ return Err(listen_err);
+ }
+ Ok(VsockListener { sock: self })
+ }
+
+ /// Returns the port that this socket is bound to. This can only succeed after bind is called.
+ pub fn local_port(&self) -> io::Result<u32> {
+ let mut svm: sockaddr_vm = Default::default();
+
+ // Safe because we give a valid pointer for addrlen and check the length.
+ let mut addrlen = size_of::<sockaddr_vm>() as socklen_t;
+ let ret = unsafe {
+ // Get the socket address that was actually bound.
+ libc::getsockname(
+ self.fd,
+ &mut svm as *mut sockaddr_vm as *mut sockaddr,
+ &mut addrlen as *mut socklen_t,
+ )
+ };
+ if ret < 0 {
+ let getsockname_err = io::Error::last_os_error();
+ Err(getsockname_err)
+ } else {
+ // If this doesn't match, it's not safe to get the port out of the sockaddr.
+ assert_eq!(addrlen as usize, size_of::<sockaddr_vm>());
+
+ Ok(svm.svm_port)
+ }
+ }
+
+ pub fn try_clone(&self) -> io::Result<Self> {
+ // Safe because this doesn't modify any memory and we check the return value.
+ let dup_fd = unsafe { libc::fcntl(self.fd, libc::F_DUPFD_CLOEXEC, 0) };
+ if dup_fd < 0 {
+ Err(io::Error::last_os_error())
+ } else {
+ Ok(Self { fd: dup_fd })
+ }
+ }
+
+ pub fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> {
+ // Safe because the fd is valid and owned by this stream.
+ unsafe { set_nonblocking(self.fd, nonblocking) }
+ }
+}
+
+impl IntoRawFd for VsockSocket {
+ fn into_raw_fd(self) -> RawFd {
+ let fd = self.fd;
+ mem::forget(self);
+ fd
+ }
+}
+
+impl AsRawFd for VsockSocket {
+ fn as_raw_fd(&self) -> RawFd {
+ self.fd
+ }
+}
+
+impl Drop for VsockSocket {
+ fn drop(&mut self) {
+ // Safe because this doesn't modify any memory and we are the only
+ // owner of the file descriptor.
+ unsafe { libc::close(self.fd) };
+ }
+}
+
+/// A virtual stream socket.
+#[derive(Debug)]
+pub struct VsockStream {
+ sock: VsockSocket,
+}
+
+impl VsockStream {
+ pub fn connect<A: ToSocketAddr>(addr: A) -> io::Result<VsockStream> {
+ let sock = VsockSocket::new()?;
+ sock.connect(addr)
+ }
+
+ /// Returns the port that this stream is bound to.
+ pub fn local_port(&self) -> io::Result<u32> {
+ self.sock.local_port()
+ }
+
+ pub fn try_clone(&self) -> io::Result<VsockStream> {
+ self.sock.try_clone().map(|f| VsockStream { sock: f })
+ }
+
+ pub fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> {
+ self.sock.set_nonblocking(nonblocking)
+ }
+}
+
+impl io::Read for VsockStream {
+ fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
+ // Safe because this will only modify the contents of |buf| and we check the return value.
+ let ret = unsafe {
+ libc::read(
+ self.sock.as_raw_fd(),
+ buf as *mut [u8] as *mut c_void,
+ buf.len() as size_t,
+ )
+ };
+ if ret < 0 {
+ return Err(io::Error::last_os_error());
+ }
+
+ Ok(ret as usize)
+ }
+}
+
+impl io::Write for VsockStream {
+ fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
+ // Safe because this doesn't modify any memory and we check the return value.
+ let ret = unsafe {
+ libc::write(
+ self.sock.as_raw_fd(),
+ buf as *const [u8] as *const c_void,
+ buf.len() as size_t,
+ )
+ };
+ if ret < 0 {
+ return Err(io::Error::last_os_error());
+ }
+
+ Ok(ret as usize)
+ }
+
+ fn flush(&mut self) -> io::Result<()> {
+ // No buffered data so nothing to do.
+ Ok(())
+ }
+}
+
+impl AsRawFd for VsockStream {
+ fn as_raw_fd(&self) -> RawFd {
+ self.sock.as_raw_fd()
+ }
+}
+
+impl IntoRawFd for VsockStream {
+ fn into_raw_fd(self) -> RawFd {
+ self.sock.into_raw_fd()
+ }
+}
+
+/// Represents a virtual socket server.
+#[derive(Debug)]
+pub struct VsockListener {
+ sock: VsockSocket,
+}
+
+impl VsockListener {
+ /// Creates a new `VsockListener` bound to the specified port on the current virtual socket
+ /// endpoint.
+ pub fn bind<A: ToSocketAddr>(addr: A) -> io::Result<VsockListener> {
+ let mut sock = VsockSocket::new()?;
+ sock.bind(addr)?;
+ sock.listen()
+ }
+
+ /// Returns the port that this listener is bound to.
+ pub fn local_port(&self) -> io::Result<u32> {
+ self.sock.local_port()
+ }
+
+ /// Accepts a new incoming connection on this listener. Blocks the calling thread until a
+ /// new connection is established. When established, returns the corresponding `VsockStream`
+ /// and the remote peer's address.
+ pub fn accept(&self) -> io::Result<(VsockStream, SocketAddr)> {
+ let mut svm: sockaddr_vm = Default::default();
+
+ // Safe because this will only modify |svm| and we check the return value.
+ let mut socklen: socklen_t = size_of::<sockaddr_vm>() as socklen_t;
+ let fd = unsafe {
+ libc::accept4(
+ self.sock.as_raw_fd(),
+ &mut svm as *mut sockaddr_vm as *mut sockaddr,
+ &mut socklen as *mut socklen_t,
+ libc::SOCK_CLOEXEC,
+ )
+ };
+ if fd < 0 {
+ return Err(io::Error::last_os_error());
+ }
+
+ if svm.svm_family != AF_VSOCK {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ format!("unexpected address family: {}", svm.svm_family),
+ ));
+ }
+
+ Ok((
+ VsockStream {
+ sock: VsockSocket { fd },
+ },
+ SocketAddr {
+ cid: svm.svm_cid.into(),
+ port: svm.svm_port,
+ },
+ ))
+ }
+
+ pub fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> {
+ self.sock.set_nonblocking(nonblocking)
+ }
+}
+
+impl AsRawFd for VsockListener {
+ fn as_raw_fd(&self) -> RawFd {
+ self.sock.as_raw_fd()
+ }
+}