diff options
author | Victor Hsieh <victorhsieh@google.com> | 2021-01-08 19:33:49 +0000 |
---|---|---|
committer | Automerger Merge Worker <android-build-automerger-merge-worker@system.gserviceaccount.com> | 2021-01-08 19:33:49 +0000 |
commit | 22cc7969553c6f6218b27889b2b9abbc895896f5 (patch) | |
tree | adf0ed37ac7d7272ca7ba5c1ed11a742e0f3bed6 | |
parent | 4061c0fb78928db1e424afbaf808d36800fbbfe7 (diff) | |
parent | 9d31620918b44d2f3f75b9618c4fa812951a4b61 (diff) | |
download | libchromeos-rs-22cc7969553c6f6218b27889b2b9abbc895896f5.tar.gz |
Import libchromeos-rs am: 5cc8a92ff3 am: 674821bae2 am: 9d31620918
Original change: https://android-review.googlesource.com/c/platform/external/libchromeos-rs/+/1530501
MUST ONLY BE SUBMITTED BY AUTOMERGER
Change-Id: I11eff92a8115426a26a822165fdfd5071a06e7ec
-rw-r--r-- | Android.bp | 92 | ||||
-rw-r--r-- | Cargo.lock | 284 | ||||
-rw-r--r-- | Cargo.toml | 19 | ||||
-rw-r--r-- | LICENSE | 27 | ||||
-rw-r--r-- | METADATA | 14 | ||||
-rw-r--r-- | MODULE_LICENSE_BSD | 0 | ||||
-rw-r--r-- | OWNERS | 5 | ||||
-rw-r--r-- | README.md | 8 | ||||
-rw-r--r-- | TEST_MAPPING | 13 | ||||
-rw-r--r-- | patches/Android.bp.patch | 31 | ||||
-rw-r--r-- | src/lib.rs | 25 | ||||
-rw-r--r-- | src/linux.rs | 18 | ||||
-rw-r--r-- | src/net.rs | 269 | ||||
-rw-r--r-- | src/read_dir.rs | 150 | ||||
-rw-r--r-- | src/scoped_path.rs | 115 | ||||
-rw-r--r-- | src/sync.rs | 14 | ||||
-rw-r--r-- | src/sync/blocking.rs | 192 | ||||
-rw-r--r-- | src/sync/cv.rs | 1251 | ||||
-rw-r--r-- | src/sync/mu.rs | 2400 | ||||
-rw-r--r-- | src/sync/spin.rs | 271 | ||||
-rw-r--r-- | src/sync/waiter.rs | 317 | ||||
-rw-r--r-- | src/syslog.rs | 73 | ||||
-rw-r--r-- | src/vsock.rs | 491 |
23 files changed, 6079 insertions, 0 deletions
diff --git a/Android.bp b/Android.bp new file mode 100644 index 0000000..63b9fdd --- /dev/null +++ b/Android.bp @@ -0,0 +1,92 @@ +// This file is generated by cargo2android.py --run --device --tests --dependencies --patch=patches/Android.bp.patch. + +rust_defaults { + name: "libchromeos-rs_defaults", + crate_name: "libchromeos", + srcs: ["src/lib.rs"], + test_suites: ["general-tests"], + auto_gen_config: true, + edition: "2018", + rustlibs: [ + "libdata_model", + "libfutures", + "libfutures_executor", + "libfutures_util", + "libintrusive_collections", + "liblibc", + "liblog_rust", + "libprotobuf", + ], +} + +rust_test_host { + name: "libchromeos-rs_host_test_src_lib", + defaults: ["libchromeos-rs_defaults"], +} + +rust_test { + name: "libchromeos-rs_device_test_src_lib", + defaults: ["libchromeos-rs_defaults"], + // Manually limit to 64-bit to avoid depending on non-existing 32-bit build + // of libdata_model currently. + compile_multilib: "64", +} + +rust_library { + name: "liblibchromeos", + host_supported: true, + crate_name: "libchromeos", + srcs: ["src/lib.rs"], + edition: "2018", + rustlibs: [ + "libdata_model", + "libfutures", + "libintrusive_collections", + "liblibc", + "liblog_rust", + "libprotobuf", + ], + apex_available: [ + "//apex_available:platform", + "com.android.virt", + ], + // This library depends on libdata_model that is is part of crosvm project. + // Projects within crosvm on Android have only 64-bit target build enabled. + // As a result, we need to manually limit this build to 64-bit only, too. + // This is fine because this library is only used by crosvm now (thus 64-bit + // only). + compile_multilib: "64", +} + +// dependent_library ["feature_list"] +// ../crosvm/assertions/src/lib.rs +// ../crosvm/data_model/src/lib.rs +// autocfg-1.0.1 +// cfg-if-0.1.10 +// futures-0.3.8 "alloc,async-await,default,executor,futures-executor,std" +// futures-channel-0.3.8 "alloc,futures-sink,sink,std" +// futures-core-0.3.8 "alloc,std" +// futures-executor-0.3.8 "default,num_cpus,std,thread-pool" +// futures-io-0.3.8 "std" +// futures-macro-0.3.8 +// futures-sink-0.3.8 "alloc,std" +// futures-task-0.3.8 "alloc,once_cell,std" +// futures-util-0.3.8 "alloc,async-await,async-await-macro,channel,default,futures-channel,futures-io,futures-macro,futures-sink,io,memchr,proc-macro-hack,proc-macro-nested,sink,slab,std" +// intrusive-collections-0.9.0 "alloc,default" +// libc-0.2.82 "default,std" +// log-0.4.11 +// memchr-2.3.4 "default,std" +// memoffset-0.5.6 "default" +// num_cpus-1.13.0 +// once_cell-1.5.2 "alloc,std" +// pin-project-1.0.3 +// pin-project-internal-1.0.3 +// pin-utils-0.1.0 +// proc-macro-hack-0.5.19 +// proc-macro-nested-0.1.6 +// proc-macro2-1.0.24 "default,proc-macro" +// protobuf-2.20.0 +// quote-1.0.8 "default,proc-macro" +// slab-0.4.2 +// syn-1.0.58 "clone-impls,default,derive,full,parsing,printing,proc-macro,quote,visit-mut" +// unicode-xid-0.2.1 "default" diff --git a/Cargo.lock b/Cargo.lock new file mode 100644 index 0000000..4a3acbe --- /dev/null +++ b/Cargo.lock @@ -0,0 +1,284 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +[[package]] +name = "assertions" +version = "0.1.0" + +[[package]] +name = "autocfg" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cdb031dd78e28731d87d56cc8ffef4a8f36ca26c38fe2de700543e627f8a464a" + +[[package]] +name = "cfg-if" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4785bdd1c96b2a846b2bd7cc02e86b6b3dbf14e7e53446c4f54c92a361040822" + +[[package]] +name = "data_model" +version = "0.1.0" +dependencies = [ + "assertions", + "libc", +] + +[[package]] +name = "futures" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b3b0c040a1fe6529d30b3c5944b280c7f0dcb2930d2c3062bca967b602583d0" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-channel" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b7109687aa4e177ef6fe84553af6280ef2778bdb7783ba44c9dc3399110fe64" +dependencies = [ + "futures-core", + "futures-sink", +] + +[[package]] +name = "futures-core" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "847ce131b72ffb13b6109a221da9ad97a64cbe48feb1028356b836b47b8f1748" + +[[package]] +name = "futures-executor" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4caa2b2b68b880003057c1dd49f1ed937e38f22fcf6c212188a121f08cf40a65" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", + "num_cpus", +] + +[[package]] +name = "futures-io" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "611834ce18aaa1bd13c4b374f5d653e1027cf99b6b502584ff8c9a64413b30bb" + +[[package]] +name = "futures-macro" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77408a692f1f97bcc61dc001d752e00643408fbc922e4d634c655df50d595556" +dependencies = [ + "proc-macro-hack", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "futures-sink" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f878195a49cee50e006b02b93cf7e0a95a38ac7b776b4c4d9cc1207cd20fcb3d" + +[[package]] +name = "futures-task" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c554eb5bf48b2426c4771ab68c6b14468b6e76cc90996f528c3338d761a4d0d" +dependencies = [ + "once_cell", +] + +[[package]] +name = "futures-util" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d304cff4a7b99cfb7986f7d43fbe93d175e72e704a8860787cc95e9ffd85cbd2" +dependencies = [ + "futures-channel", + "futures-core", + "futures-io", + "futures-macro", + "futures-sink", + "futures-task", + "memchr", + "pin-project", + "pin-utils", + "proc-macro-hack", + "proc-macro-nested", + "slab", +] + +[[package]] +name = "hermit-abi" +version = "0.1.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5aca5565f760fb5b220e499d72710ed156fdb74e631659e99377d9ebfbd13ae8" +dependencies = [ + "libc", +] + +[[package]] +name = "intrusive-collections" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4bca8c0bb831cd60d4dda79a58e3705ca6eb47efb65d665651a8d672213ec3db" +dependencies = [ + "memoffset", +] + +[[package]] +name = "libc" +version = "0.2.80" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4d58d1b70b004888f764dfbf6a26a3b0342a1632d33968e4a179d8011c760614" + +[[package]] +name = "libchromeos" +version = "0.1.0" +dependencies = [ + "data_model", + "futures", + "futures-executor", + "futures-util", + "intrusive-collections", + "libc", + "log", + "protobuf", +] + +[[package]] +name = "log" +version = "0.4.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fabed175da42fed1fa0746b0ea71f412aa9d35e76e95e59b192c64b9dc2bf8b" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "memchr" +version = "2.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ee1c47aaa256ecabcaea351eae4a9b01ef39ed810004e298d2511ed284b1525" + +[[package]] +name = "memoffset" +version = "0.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "043175f069eda7b85febe4a74abbaeff828d9f8b448515d3151a14a3542811aa" +dependencies = [ + "autocfg", +] + +[[package]] +name = "num_cpus" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05499f3756671c15885fee9034446956fff3f243d6077b91e5767df161f766b3" +dependencies = [ + "hermit-abi", + "libc", +] + +[[package]] +name = "once_cell" +version = "1.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13bd41f508810a131401606d54ac32a467c97172d74ba7662562ebba5ad07fa0" + +[[package]] +name = "pin-project" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ccc2237c2c489783abd8c4c80e5450fc0e98644555b1364da68cc29aa151ca7" +dependencies = [ + "pin-project-internal", +] + +[[package]] +name = "pin-project-internal" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8e8d2bf0b23038a4424865103a4df472855692821aab4e4f5c3312d461d9e5f" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "pin-utils" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" + +[[package]] +name = "proc-macro-hack" +version = "0.5.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dbf0c48bc1d91375ae5c3cd81e3722dff1abcf81a30960240640d223f59fe0e5" + +[[package]] +name = "proc-macro-nested" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eba180dafb9038b050a4c280019bbedf9f2467b61e5d892dcad585bb57aadc5a" + +[[package]] +name = "proc-macro2" +version = "1.0.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e0704ee1a7e00d7bb417d0770ea303c1bccbabf0ef1667dae92b5967f5f8a71" +dependencies = [ + "unicode-xid", +] + +[[package]] +name = "protobuf" +version = "2.18.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da78e04bc0e40f36df43ecc6575e4f4b180e8156c4efd73f13d5619479b05696" + +[[package]] +name = "quote" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa563d17ecb180e500da1cfd2b028310ac758de548efdd203e18f283af693f37" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "slab" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c111b5bd5695e56cffe5129854aa230b39c93a305372fdbb2668ca2394eea9f8" + +[[package]] +name = "syn" +version = "1.0.50" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "443b4178719c5a851e1bde36ce12da21d74a0e60b4d982ec3385a933c812f0f6" +dependencies = [ + "proc-macro2", + "quote", + "unicode-xid", +] + +[[package]] +name = "unicode-xid" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7fe0bb3479651439c9112f72b6c505038574c9fbb575ed1bf3b797fa39dd564" diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..d0de4ce --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "libchromeos" +version = "0.1.0" +authors = ["The Chromium OS Authors"] +edition = "2018" + +[dependencies] +data_model = { path = "../crosvm/data_model" } # provided by ebuild +libc = "0.2" +log = "0.4" +protobuf = "2.1" +intrusive-collections = "0.9" +futures = { version = "0.3", default-features = false, features = ["alloc"] } + +[dev-dependencies] +futures = { version = "0.3", features = ["async-await"] } +futures-executor = { version = "0.3", features = ["thread-pool"] } +futures-util = "0.3" + @@ -0,0 +1,27 @@ +// Copyright 2014 The Chromium OS Authors. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/METADATA b/METADATA new file mode 100644 index 0000000..c4cea7d --- /dev/null +++ b/METADATA @@ -0,0 +1,14 @@ +name: "libchromeos-rs" +description: + "libchromeos-rs contains Rust code that can be reused across any Chrome OS " + "project. The crate is copied from a bigger platform2 repo." + +third_party { + url { + type: GIT + value: "https://chromium.googlesource.com/chromiumos/platform2" + } + version: "a39ed76e487e191a009cba4676bdd9a33006cbc0" + last_upgrade_date { year: 2020 month: 12 day: 8 } + license_type: NOTICE +} diff --git a/MODULE_LICENSE_BSD b/MODULE_LICENSE_BSD new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/MODULE_LICENSE_BSD @@ -0,0 +1,5 @@ +# ANDROID: Remove chromium owners to not confuse Gerrit with unrecognized +# owners. +#chirantan@chromium.org +#smbarber@chromium.org +#zachr@chromium.org diff --git a/README.md b/README.md new file mode 100644 index 0000000..02d55d2 --- /dev/null +++ b/README.md @@ -0,0 +1,8 @@ +# libchromeos-rs - The Rust crate for common Chrome OS code + +`libchromeos-rs` contains Rust code that can be reused across any Chrome OS +project. It's the Rust equivalent of [libbrillo](../libbrillo/). + +Current modules include: +* `syslog` - an adaptor for using the generic `log` crate with syslog +* `vsock` - wrappers for dealing with AF_VSOCK sockets diff --git a/TEST_MAPPING b/TEST_MAPPING new file mode 100644 index 0000000..9882312 --- /dev/null +++ b/TEST_MAPPING @@ -0,0 +1,13 @@ +// Generated by cargo2android.py for tests in Android.bp +{ + "presubmit": [ + { + "host": true, + "name": "libchromeos-rs_host_test_src_lib" +// Presubmit tries to run x86, but we only support 64-bit builds. +// }, +// { +// "name": "libchromeos-rs_device_test_src_lib" + } + ] +} diff --git a/patches/Android.bp.patch b/patches/Android.bp.patch new file mode 100644 index 0000000..32234ed --- /dev/null +++ b/patches/Android.bp.patch @@ -0,0 +1,31 @@ +diff --git a/Android.bp b/Android.bp +index fc2c8b8..c4e3c5b 100644 +--- a/Android.bp ++++ b/Android.bp +@@ -27,6 +27,9 @@ rust_test_host { + rust_test { + name: "libchromeos-rs_device_test_src_lib", + defaults: ["libchromeos-rs_defaults"], ++ // Manually limit to 64-bit to avoid depending on non-existing 32-bit build ++ // of libdata_model currently. ++ compile_multilib: "64", + } + + rust_library { +@@ -43,6 +46,16 @@ rust_library { + "liblog_rust", + "libprotobuf", + ], ++ apex_available: [ ++ "//apex_available:platform", ++ "com.android.virt", ++ ], ++ // This library depends on libdata_model that is is part of crosvm project. ++ // Projects within crosvm on Android have only 64-bit target build enabled. ++ // As a result, we need to manually limit this build to 64-bit only, too. ++ // This is fine because this library is only used by crosvm now (thus 64-bit ++ // only). ++ compile_multilib: "64", + } + + // dependent_library ["feature_list"] diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..9d34580 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,25 @@ +// Copyright 2019 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +pub mod linux; +pub mod net; +mod read_dir; +pub mod scoped_path; +pub mod sync; +pub mod syslog; +pub mod vsock; + +pub use read_dir::*; + +#[macro_export] +macro_rules! syscall { + ($e:expr) => {{ + let res = $e; + if res < 0 { + Err(::std::io::Error::last_os_error()) + } else { + Ok(res) + } + }}; +} diff --git a/src/linux.rs b/src/linux.rs new file mode 100644 index 0000000..4826730 --- /dev/null +++ b/src/linux.rs @@ -0,0 +1,18 @@ +// Copyright 2020 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +/// Provides safe implementations of common low level functions that assume a Linux environment. +use libc::{syscall, SYS_gettid}; + +pub type Pid = libc::pid_t; + +pub fn getpid() -> Pid { + // Calling getpid() is always safe. + unsafe { libc::getpid() } +} + +pub fn gettid() -> Pid { + // Calling the gettid() sycall is always safe. + unsafe { syscall(SYS_gettid) as Pid } +} diff --git a/src/net.rs b/src/net.rs new file mode 100644 index 0000000..be1cd25 --- /dev/null +++ b/src/net.rs @@ -0,0 +1,269 @@ +// Copyright 2020 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +/// Structs to supplement std::net. +use std::io; +use std::mem::{self, size_of}; +use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6, TcpListener, TcpStream, ToSocketAddrs}; +use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd}; + +use libc::{ + c_int, in6_addr, in_addr, sa_family_t, sockaddr, sockaddr_in, sockaddr_in6, socklen_t, AF_INET, + AF_INET6, SOCK_CLOEXEC, SOCK_STREAM, +}; + +/// Assist in handling both IP version 4 and IP version 6. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub enum InetVersion { + V4, + V6, +} + +impl InetVersion { + pub fn from_sockaddr(s: &SocketAddr) -> Self { + match s { + SocketAddr::V4(_) => InetVersion::V4, + SocketAddr::V6(_) => InetVersion::V6, + } + } +} + +impl Into<sa_family_t> for InetVersion { + fn into(self) -> sa_family_t { + match self { + InetVersion::V4 => AF_INET as sa_family_t, + InetVersion::V6 => AF_INET6 as sa_family_t, + } + } +} + +fn sockaddrv4_to_lib_c(s: &SocketAddrV4) -> sockaddr_in { + sockaddr_in { + sin_family: AF_INET as sa_family_t, + sin_port: s.port().to_be(), + sin_addr: in_addr { + s_addr: u32::from_ne_bytes(s.ip().octets()), + }, + sin_zero: [0; 8], + } +} + +fn sockaddrv6_to_lib_c(s: &SocketAddrV6) -> sockaddr_in6 { + sockaddr_in6 { + sin6_family: AF_INET6 as sa_family_t, + sin6_port: s.port().to_be(), + sin6_flowinfo: 0, + sin6_addr: in6_addr { + s6_addr: s.ip().octets(), + }, + sin6_scope_id: 0, + } +} + +/// A TCP socket. +/// +/// Do not use this class unless you need to change socket options or query the +/// state of the socket prior to calling listen or connect. Instead use either TcpStream or +/// TcpListener. +#[derive(Debug)] +pub struct TcpSocket { + inet_version: InetVersion, + fd: RawFd, +} + +impl TcpSocket { + pub fn new(inet_version: InetVersion) -> io::Result<Self> { + let fd = unsafe { + libc::socket( + Into::<sa_family_t>::into(inet_version) as c_int, + SOCK_STREAM | SOCK_CLOEXEC, + 0, + ) + }; + if fd < 0 { + Err(io::Error::last_os_error()) + } else { + Ok(TcpSocket { inet_version, fd }) + } + } + + pub fn bind<A: ToSocketAddrs>(&mut self, addr: A) -> io::Result<()> { + let sockaddr = addr + .to_socket_addrs() + .map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))? + .next() + .unwrap(); + + let ret = match sockaddr { + SocketAddr::V4(a) => { + let sin = sockaddrv4_to_lib_c(&a); + // Safe because this doesn't modify any memory and we check the return value. + unsafe { + libc::bind( + self.fd, + &sin as *const sockaddr_in as *const sockaddr, + size_of::<sockaddr_in>() as socklen_t, + ) + } + } + SocketAddr::V6(a) => { + let sin6 = sockaddrv6_to_lib_c(&a); + // Safe because this doesn't modify any memory and we check the return value. + unsafe { + libc::bind( + self.fd, + &sin6 as *const sockaddr_in6 as *const sockaddr, + size_of::<sockaddr_in6>() as socklen_t, + ) + } + } + }; + if ret < 0 { + let bind_err = io::Error::last_os_error(); + Err(bind_err) + } else { + Ok(()) + } + } + + pub fn connect<A: ToSocketAddrs>(self, addr: A) -> io::Result<TcpStream> { + let sockaddr = addr + .to_socket_addrs() + .map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))? + .next() + .unwrap(); + + let ret = match sockaddr { + SocketAddr::V4(a) => { + let sin = sockaddrv4_to_lib_c(&a); + // Safe because this doesn't modify any memory and we check the return value. + unsafe { + libc::connect( + self.fd, + &sin as *const sockaddr_in as *const sockaddr, + size_of::<sockaddr_in>() as socklen_t, + ) + } + } + SocketAddr::V6(a) => { + let sin6 = sockaddrv6_to_lib_c(&a); + // Safe because this doesn't modify any memory and we check the return value. + unsafe { + libc::connect( + self.fd, + &sin6 as *const sockaddr_in6 as *const sockaddr, + size_of::<sockaddr_in>() as socklen_t, + ) + } + } + }; + + if ret < 0 { + let connect_err = io::Error::last_os_error(); + Err(connect_err) + } else { + // Safe because the ownership of the raw fd is released from self and taken over by the + // new TcpStream. + Ok(unsafe { TcpStream::from_raw_fd(self.into_raw_fd()) }) + } + } + + pub fn listen(self) -> io::Result<TcpListener> { + // Safe because this doesn't modify any memory and we check the return value. + let ret = unsafe { libc::listen(self.fd, 1) }; + if ret < 0 { + let listen_err = io::Error::last_os_error(); + Err(listen_err) + } else { + // Safe because the ownership of the raw fd is released from self and taken over by the + // new TcpListener. + Ok(unsafe { TcpListener::from_raw_fd(self.into_raw_fd()) }) + } + } + + /// Returns the port that this socket is bound to. This can only succeed after bind is called. + pub fn local_port(&self) -> io::Result<u16> { + match self.inet_version { + InetVersion::V4 => { + let mut sin = sockaddr_in { + sin_family: 0, + sin_port: 0, + sin_addr: in_addr { s_addr: 0 }, + sin_zero: [0; 8], + }; + + // Safe because we give a valid pointer for addrlen and check the length. + let mut addrlen = size_of::<sockaddr_in>() as socklen_t; + let ret = unsafe { + // Get the socket address that was actually bound. + libc::getsockname( + self.fd, + &mut sin as *mut sockaddr_in as *mut sockaddr, + &mut addrlen as *mut socklen_t, + ) + }; + if ret < 0 { + let getsockname_err = io::Error::last_os_error(); + Err(getsockname_err) + } else { + // If this doesn't match, it's not safe to get the port out of the sockaddr. + assert_eq!(addrlen as usize, size_of::<sockaddr_in>()); + + Ok(sin.sin_port) + } + } + InetVersion::V6 => { + let mut sin6 = sockaddr_in6 { + sin6_family: 0, + sin6_port: 0, + sin6_flowinfo: 0, + sin6_addr: in6_addr { s6_addr: [0; 16] }, + sin6_scope_id: 0, + }; + + // Safe because we give a valid pointer for addrlen and check the length. + let mut addrlen = size_of::<sockaddr_in6>() as socklen_t; + let ret = unsafe { + // Get the socket address that was actually bound. + libc::getsockname( + self.fd, + &mut sin6 as *mut sockaddr_in6 as *mut sockaddr, + &mut addrlen as *mut socklen_t, + ) + }; + if ret < 0 { + let getsockname_err = io::Error::last_os_error(); + Err(getsockname_err) + } else { + // If this doesn't match, it's not safe to get the port out of the sockaddr. + assert_eq!(addrlen as usize, size_of::<sockaddr_in>()); + + Ok(sin6.sin6_port) + } + } + } + } +} + +impl IntoRawFd for TcpSocket { + fn into_raw_fd(self) -> RawFd { + let fd = self.fd; + mem::forget(self); + fd + } +} + +impl AsRawFd for TcpSocket { + fn as_raw_fd(&self) -> RawFd { + self.fd + } +} + +impl Drop for TcpSocket { + fn drop(&mut self) { + // Safe because this doesn't modify any memory and we are the only + // owner of the file descriptor. + unsafe { libc::close(self.fd) }; + } +} diff --git a/src/read_dir.rs b/src/read_dir.rs new file mode 100644 index 0000000..eb8a549 --- /dev/null +++ b/src/read_dir.rs @@ -0,0 +1,150 @@ +// Copyright 2020 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +use std::ffi::CStr; +use std::io; +use std::mem::size_of; +use std::os::unix::io::AsRawFd; + +use data_model::DataInit; + +use crate::syscall; + +#[repr(C, packed)] +#[derive(Clone, Copy)] +struct LinuxDirent64 { + d_ino: libc::ino64_t, + d_off: libc::off64_t, + d_reclen: libc::c_ushort, + d_ty: libc::c_uchar, +} +unsafe impl DataInit for LinuxDirent64 {} + +pub struct DirEntry<'r> { + pub ino: libc::ino64_t, + pub offset: u64, + pub type_: u8, + pub name: &'r CStr, +} + +pub struct ReadDir<'d, D> { + buf: [u8; 256], + dir: &'d mut D, + current: usize, + end: usize, +} + +impl<'d, D: AsRawFd> ReadDir<'d, D> { + /// Return the next directory entry. This is implemented as a separate method rather than via + /// the `Iterator` trait because rust doesn't currently support generic associated types. + #[allow(clippy::should_implement_trait)] + pub fn next(&mut self) -> Option<io::Result<DirEntry>> { + if self.current >= self.end { + let res = syscall!(unsafe { + libc::syscall( + libc::SYS_getdents64, + self.dir.as_raw_fd(), + self.buf.as_mut_ptr() as *mut LinuxDirent64, + self.buf.len() as libc::c_int, + ) + }); + match res { + Ok(end) => { + self.current = 0; + self.end = end as usize; + } + Err(e) => return Some(Err(e)), + } + } + + let rem = &self.buf[self.current..self.end]; + if rem.is_empty() { + return None; + } + + // We only use debug asserts here because these values are coming from the kernel and we + // trust them implicitly. + debug_assert!( + rem.len() >= size_of::<LinuxDirent64>(), + "not enough space left in `rem`" + ); + + let (front, back) = rem.split_at(size_of::<LinuxDirent64>()); + + let dirent64 = + LinuxDirent64::from_slice(front).expect("unable to get LinuxDirent64 from slice"); + + let namelen = dirent64.d_reclen as usize - size_of::<LinuxDirent64>(); + debug_assert!(namelen <= back.len(), "back is smaller than `namelen`"); + + // The kernel will pad the name with additional nul bytes until it is 8-byte aligned so + // we need to strip those off here. + let name = strip_padding(&back[..namelen]); + let entry = DirEntry { + ino: dirent64.d_ino, + offset: dirent64.d_off as u64, + type_: dirent64.d_ty, + name, + }; + + debug_assert!( + rem.len() >= dirent64.d_reclen as usize, + "rem is smaller than `d_reclen`" + ); + self.current += dirent64.d_reclen as usize; + Some(Ok(entry)) + } +} + +pub fn read_dir<D: AsRawFd>(dir: &mut D, offset: libc::off64_t) -> io::Result<ReadDir<D>> { + // Safe because this doesn't modify any memory and we check the return value. + syscall!(unsafe { libc::lseek64(dir.as_raw_fd(), offset, libc::SEEK_SET) })?; + + Ok(ReadDir { + buf: [0u8; 256], + dir, + current: 0, + end: 0, + }) +} + +// Like `CStr::from_bytes_with_nul` but strips any bytes after the first '\0'-byte. Panics if `b` +// doesn't contain any '\0' bytes. +fn strip_padding(b: &[u8]) -> &CStr { + // It would be nice if we could use memchr here but that's locked behind an unstable gate. + let pos = b + .iter() + .position(|&c| c == 0) + .expect("`b` doesn't contain any nul bytes"); + + // Safe because we are creating this string with the first nul-byte we found so we can + // guarantee that it is nul-terminated and doesn't contain any interior nuls. + unsafe { CStr::from_bytes_with_nul_unchecked(&b[..pos + 1]) } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn padded_cstrings() { + assert_eq!(strip_padding(b".\0\0\0\0\0\0\0").to_bytes(), b"."); + assert_eq!(strip_padding(b"..\0\0\0\0\0\0").to_bytes(), b".."); + assert_eq!( + strip_padding(b"normal cstring\0").to_bytes(), + b"normal cstring" + ); + assert_eq!(strip_padding(b"\0\0\0\0").to_bytes(), b""); + assert_eq!( + strip_padding(b"interior\0nul bytes\0\0\0").to_bytes(), + b"interior" + ); + } + + #[test] + #[should_panic(expected = "`b` doesn't contain any nul bytes")] + fn no_nul_byte() { + strip_padding(b"no nul bytes in string"); + } +} diff --git a/src/scoped_path.rs b/src/scoped_path.rs new file mode 100644 index 0000000..22d0ad9 --- /dev/null +++ b/src/scoped_path.rs @@ -0,0 +1,115 @@ +// Copyright 2020 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +use std::env::{current_exe, temp_dir}; +use std::fs::{create_dir_all, remove_dir_all}; +use std::ops::Deref; +use std::path::{Path, PathBuf}; +use std::thread::panicking; + +use super::linux::{getpid, gettid}; + +/// Returns a stable path based on the label, pid, and tid. If the label isn't provided the +/// current_exe is used instead. +pub fn get_temp_path(label: Option<&str>) -> PathBuf { + if let Some(label) = label { + temp_dir().join(format!("{}-{}-{}", label, getpid(), gettid())) + } else { + get_temp_path(Some(current_exe().unwrap().to_str().unwrap())) + } +} + +/// Automatically deletes the path it contains when it goes out of scope unless it is a test and +/// drop is called after a panic!. +/// +/// This is particularly useful for creating temporary directories for use with tests. +pub struct ScopedPath<P: AsRef<Path>>(P); + +impl<P: AsRef<Path>> ScopedPath<P> { + pub fn create(p: P) -> Result<Self, std::io::Error> { + create_dir_all(p.as_ref())?; + Ok(ScopedPath(p)) + } +} + +impl<P: AsRef<Path>> AsRef<Path> for ScopedPath<P> { + fn as_ref(&self) -> &Path { + self.0.as_ref() + } +} + +impl<P: AsRef<Path>> Deref for ScopedPath<P> { + type Target = Path; + + fn deref(&self) -> &Self::Target { + self.0.as_ref() + } +} + +impl<P: AsRef<Path>> Drop for ScopedPath<P> { + fn drop(&mut self) { + // Leave the files on a failed test run for debugging. + if panicking() && cfg!(test) { + eprintln!("NOTE: Not removing {}", self.display()); + return; + } + if let Err(e) = remove_dir_all(&**self) { + eprintln!("Failed to remove {}: {}", self.display(), e); + } + } +} + +#[cfg(test)] +pub(crate) mod tests { + use super::*; + + use std::panic::catch_unwind; + + #[test] + fn gettemppath() { + assert_ne!("", get_temp_path(None).to_string_lossy()); + assert_eq!( + get_temp_path(None), + get_temp_path(Some(current_exe().unwrap().to_str().unwrap())) + ); + assert_ne!( + get_temp_path(Some("label")), + get_temp_path(Some(current_exe().unwrap().to_str().unwrap())) + ); + } + + #[test] + fn scopedpath_exists() { + let tmp_path = get_temp_path(None); + { + let scoped_path = ScopedPath::create(&tmp_path).unwrap(); + assert!(scoped_path.exists()); + } + assert!(!tmp_path.exists()); + } + + #[test] + fn scopedpath_notexists() { + let tmp_path = get_temp_path(None); + { + let _scoped_path = ScopedPath(&tmp_path); + } + assert!(!tmp_path.exists()); + } + + #[test] + fn scopedpath_panic() { + let tmp_path = get_temp_path(None); + assert!(catch_unwind(|| { + { + let scoped_path = ScopedPath::create(&tmp_path).unwrap(); + assert!(scoped_path.exists()); + panic!() + } + }) + .is_err()); + assert!(tmp_path.exists()); + remove_dir_all(&tmp_path).unwrap(); + } +} diff --git a/src/sync.rs b/src/sync.rs new file mode 100644 index 0000000..80dec38 --- /dev/null +++ b/src/sync.rs @@ -0,0 +1,14 @@ +// Copyright 2020 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +mod blocking; +mod cv; +mod mu; +mod spin; +mod waiter; + +pub use blocking::block_on; +pub use cv::Condvar; +pub use mu::Mutex; +pub use spin::SpinLock; diff --git a/src/sync/blocking.rs b/src/sync/blocking.rs new file mode 100644 index 0000000..e95694d --- /dev/null +++ b/src/sync/blocking.rs @@ -0,0 +1,192 @@ +// Copyright 2020 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +use std::future::Future; +use std::ptr; +use std::sync::atomic::{AtomicI32, Ordering}; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use futures::pin_mut; +use futures::task::{waker_ref, ArcWake}; + +// Randomly generated values to indicate the state of the current thread. +const WAITING: i32 = 0x25de_74d1; +const WOKEN: i32 = 0x72d3_2c9f; + +const FUTEX_WAIT_PRIVATE: libc::c_int = libc::FUTEX_WAIT | libc::FUTEX_PRIVATE_FLAG; +const FUTEX_WAKE_PRIVATE: libc::c_int = libc::FUTEX_WAKE | libc::FUTEX_PRIVATE_FLAG; + +thread_local!(static PER_THREAD_WAKER: Arc<Waker> = Arc::new(Waker(AtomicI32::new(WAITING)))); + +#[repr(transparent)] +struct Waker(AtomicI32); + +extern { + #[cfg_attr(target_os = "android", link_name = "__errno")] + #[cfg_attr(target_os = "linux", link_name = "__errno_location")] + fn errno_location() -> *mut libc::c_int; +} + +impl ArcWake for Waker { + fn wake_by_ref(arc_self: &Arc<Self>) { + let state = arc_self.0.swap(WOKEN, Ordering::Release); + if state == WAITING { + // The thread hasn't already been woken up so wake it up now. Safe because this doesn't + // modify any memory and we check the return value. + let res = unsafe { + libc::syscall( + libc::SYS_futex, + &arc_self.0, + FUTEX_WAKE_PRIVATE, + libc::INT_MAX, // val + ptr::null() as *const libc::timespec, // timeout + ptr::null() as *const libc::c_int, // uaddr2 + 0 as libc::c_int, // val3 + ) + }; + if res < 0 { + panic!("unexpected error from FUTEX_WAKE_PRIVATE: {}", unsafe { + *errno_location() + }); + } + } + } +} + +/// Run a future to completion on the current thread. +/// +/// This method will block the current thread until `f` completes. Useful when you need to call an +/// async fn from a non-async context. +pub fn block_on<F: Future>(f: F) -> F::Output { + pin_mut!(f); + + PER_THREAD_WAKER.with(|thread_waker| { + let waker = waker_ref(thread_waker); + let mut cx = Context::from_waker(&waker); + + loop { + if let Poll::Ready(t) = f.as_mut().poll(&mut cx) { + return t; + } + + let state = thread_waker.0.swap(WAITING, Ordering::Acquire); + if state == WAITING { + // If we weren't already woken up then wait until we are. Safe because this doesn't + // modify any memory and we check the return value. + let res = unsafe { + libc::syscall( + libc::SYS_futex, + &thread_waker.0, + FUTEX_WAIT_PRIVATE, + state, + ptr::null() as *const libc::timespec, // timeout + ptr::null() as *const libc::c_int, // uaddr2 + 0 as libc::c_int, // val3 + ) + }; + + if res < 0 { + // Safe because libc guarantees that this is a valid pointer. + match unsafe { *errno_location() } { + libc::EAGAIN | libc::EINTR => {} + e => panic!("unexpected error from FUTEX_WAIT_PRIVATE: {}", e), + } + } + + // Clear the state to prevent unnecessary extra loop iterations and also to allow + // nested usage of `block_on`. + thread_waker.0.store(WAITING, Ordering::Release); + } + } + }) +} + +#[cfg(test)] +mod test { + use super::*; + + use std::future::Future; + use std::pin::Pin; + use std::sync::mpsc::{channel, Sender}; + use std::sync::Arc; + use std::task::{Context, Poll, Waker}; + use std::thread; + use std::time::Duration; + + use crate::sync::SpinLock; + + struct TimerState { + fired: bool, + waker: Option<Waker>, + } + struct Timer { + state: Arc<SpinLock<TimerState>>, + } + + impl Future for Timer { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> { + let mut state = self.state.lock(); + if state.fired { + return Poll::Ready(()); + } + + state.waker = Some(cx.waker().clone()); + Poll::Pending + } + } + + fn start_timer(dur: Duration, notify: Option<Sender<()>>) -> Timer { + let state = Arc::new(SpinLock::new(TimerState { + fired: false, + waker: None, + })); + + let thread_state = Arc::clone(&state); + thread::spawn(move || { + thread::sleep(dur); + let mut ts = thread_state.lock(); + ts.fired = true; + if let Some(waker) = ts.waker.take() { + waker.wake(); + } + drop(ts); + + if let Some(tx) = notify { + tx.send(()).expect("Failed to send completion notification"); + } + }); + + Timer { state } + } + + #[test] + fn it_works() { + block_on(start_timer(Duration::from_millis(100), None)); + } + + #[test] + fn nested() { + async fn inner() { + block_on(start_timer(Duration::from_millis(100), None)); + } + + block_on(inner()); + } + + #[test] + fn ready_before_poll() { + let (tx, rx) = channel(); + + let timer = start_timer(Duration::from_millis(50), Some(tx)); + + rx.recv() + .expect("Failed to receive completion notification"); + + // We know the timer has already fired so the poll should complete immediately. + block_on(timer); + } +} diff --git a/src/sync/cv.rs b/src/sync/cv.rs new file mode 100644 index 0000000..714c6d6 --- /dev/null +++ b/src/sync/cv.rs @@ -0,0 +1,1251 @@ +// Copyright 2020 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +use std::cell::UnsafeCell; +use std::mem; +use std::sync::atomic::{spin_loop_hint, AtomicUsize, Ordering}; +use std::sync::Arc; + +use crate::sync::mu::{MutexGuard, MutexReadGuard, RawMutex}; +use crate::sync::waiter::{Kind as WaiterKind, Waiter, WaiterAdapter, WaiterList, WaitingFor}; + +const SPINLOCK: usize = 1 << 0; +const HAS_WAITERS: usize = 1 << 1; + +/// A primitive to wait for an event to occur without consuming CPU time. +/// +/// Condition variables are used in combination with a `Mutex` when a thread wants to wait for some +/// condition to become true. The condition must always be verified while holding the `Mutex` lock. +/// It is an error to use a `Condvar` with more than one `Mutex` while there are threads waiting on +/// the `Condvar`. +/// +/// # Examples +/// +/// ```edition2018 +/// use std::sync::Arc; +/// use std::thread; +/// use std::sync::mpsc::channel; +/// +/// use libchromeos::sync::{block_on, Condvar, Mutex}; +/// +/// const N: usize = 13; +/// +/// // Spawn a few threads to increment a shared variable (non-atomically), and +/// // let all threads waiting on the Condvar know once the increments are done. +/// let data = Arc::new(Mutex::new(0)); +/// let cv = Arc::new(Condvar::new()); +/// +/// for _ in 0..N { +/// let (data, cv) = (data.clone(), cv.clone()); +/// thread::spawn(move || { +/// let mut data = block_on(data.lock()); +/// *data += 1; +/// if *data == N { +/// cv.notify_all(); +/// } +/// }); +/// } +/// +/// let mut val = block_on(data.lock()); +/// while *val != N { +/// val = block_on(cv.wait(val)); +/// } +/// ``` +pub struct Condvar { + state: AtomicUsize, + waiters: UnsafeCell<WaiterList>, + mu: UnsafeCell<usize>, +} + +impl Condvar { + /// Creates a new condition variable ready to be waited on and notified. + pub fn new() -> Condvar { + Condvar { + state: AtomicUsize::new(0), + waiters: UnsafeCell::new(WaiterList::new(WaiterAdapter::new())), + mu: UnsafeCell::new(0), + } + } + + /// Block the current thread until this `Condvar` is notified by another thread. + /// + /// This method will atomically unlock the `Mutex` held by `guard` and then block the current + /// thread. Any call to `notify_one` or `notify_all` after the `Mutex` is unlocked may wake up + /// the thread. + /// + /// To allow for more efficient scheduling, this call may return even when the programmer + /// doesn't expect the thread to be woken. Therefore, calls to `wait()` should be used inside a + /// loop that checks the predicate before continuing. + /// + /// Callers that are not in an async context may wish to use the `block_on` method to block the + /// thread until the `Condvar` is notified. + /// + /// # Panics + /// + /// This method will panic if used with more than one `Mutex` at the same time. + /// + /// # Examples + /// + /// ``` + /// # use std::sync::Arc; + /// # use std::thread; + /// + /// # use libchromeos::sync::{block_on, Condvar, Mutex}; + /// + /// # let mu = Arc::new(Mutex::new(false)); + /// # let cv = Arc::new(Condvar::new()); + /// # let (mu2, cv2) = (mu.clone(), cv.clone()); + /// + /// # let t = thread::spawn(move || { + /// # *block_on(mu2.lock()) = true; + /// # cv2.notify_all(); + /// # }); + /// + /// let mut ready = block_on(mu.lock()); + /// while !*ready { + /// ready = block_on(cv.wait(ready)); + /// } + /// + /// # t.join().expect("failed to join thread"); + /// ``` + // Clippy doesn't like the lifetime parameters here but doing what it suggests leads to code + // that doesn't compile. + #[allow(clippy::needless_lifetimes)] + pub async fn wait<'g, T>(&self, guard: MutexGuard<'g, T>) -> MutexGuard<'g, T> { + let waiter = Arc::new(Waiter::new( + WaiterKind::Exclusive, + cancel_waiter, + self as *const Condvar as usize, + WaitingFor::Condvar, + )); + + self.add_waiter(waiter.clone(), guard.as_raw_mutex()); + + // Get a reference to the mutex and then drop the lock. + let mu = guard.into_inner(); + + // Wait to be woken up. + waiter.wait().await; + + // Now re-acquire the lock. + mu.lock_from_cv().await + } + + /// Like `wait()` but takes and returns a `MutexReadGuard` instead. + // Clippy doesn't like the lifetime parameters here but doing what it suggests leads to code + // that doesn't compile. + #[allow(clippy::needless_lifetimes)] + pub async fn wait_read<'g, T>(&self, guard: MutexReadGuard<'g, T>) -> MutexReadGuard<'g, T> { + let waiter = Arc::new(Waiter::new( + WaiterKind::Shared, + cancel_waiter, + self as *const Condvar as usize, + WaitingFor::Condvar, + )); + + self.add_waiter(waiter.clone(), guard.as_raw_mutex()); + + // Get a reference to the mutex and then drop the lock. + let mu = guard.into_inner(); + + // Wait to be woken up. + waiter.wait().await; + + // Now re-acquire the lock. + mu.read_lock_from_cv().await + } + + fn add_waiter(&self, waiter: Arc<Waiter>, raw_mutex: &RawMutex) { + // Acquire the spin lock. + let mut oldstate = self.state.load(Ordering::Relaxed); + while (oldstate & SPINLOCK) != 0 + || self.state.compare_and_swap( + oldstate, + oldstate | SPINLOCK | HAS_WAITERS, + Ordering::Acquire, + ) != oldstate + { + spin_loop_hint(); + oldstate = self.state.load(Ordering::Relaxed); + } + + // Safe because the spin lock guarantees exclusive access and the reference does not escape + // this function. + let mu = unsafe { &mut *self.mu.get() }; + let muptr = raw_mutex as *const RawMutex as usize; + + match *mu { + 0 => *mu = muptr, + p if p == muptr => {} + _ => panic!("Attempting to use Condvar with more than one Mutex at the same time"), + } + + // Safe because the spin lock guarantees exclusive access. + unsafe { (*self.waiters.get()).push_back(waiter) }; + + // Release the spin lock. Use a direct store here because no other thread can modify + // `self.state` while we hold the spin lock. Keep the `HAS_WAITERS` bit that we set earlier + // because we just added a waiter. + self.state.store(HAS_WAITERS, Ordering::Release); + } + + /// Notify at most one thread currently waiting on the `Condvar`. + /// + /// If there is a thread currently waiting on the `Condvar` it will be woken up from its call to + /// `wait`. + /// + /// Unlike more traditional condition variable interfaces, this method requires a reference to + /// the `Mutex` associated with this `Condvar`. This is because it is inherently racy to call + /// `notify_one` or `notify_all` without first acquiring the `Mutex` lock. Additionally, taking + /// a reference to the `Mutex` here allows us to make some optimizations that can improve + /// performance by reducing unnecessary wakeups. + pub fn notify_one(&self) { + let mut oldstate = self.state.load(Ordering::Relaxed); + if (oldstate & HAS_WAITERS) == 0 { + // No waiters. + return; + } + + while (oldstate & SPINLOCK) != 0 + || self + .state + .compare_and_swap(oldstate, oldstate | SPINLOCK, Ordering::Acquire) + != oldstate + { + spin_loop_hint(); + oldstate = self.state.load(Ordering::Relaxed); + } + + // Safe because the spin lock guarantees exclusive access and the reference does not escape + // this function. + let waiters = unsafe { &mut *self.waiters.get() }; + let (mut wake_list, all_readers) = get_wake_list(waiters); + + // Safe because the spin lock guarantees exclusive access. + let muptr = unsafe { (*self.mu.get()) as *const RawMutex }; + + let newstate = if waiters.is_empty() { + // Also clear the mutex associated with this Condvar since there are no longer any + // waiters. Safe because the spin lock guarantees exclusive access. + unsafe { *self.mu.get() = 0 }; + + // We are releasing the spin lock and there are no more waiters so we can clear all bits + // in `self.state`. + 0 + } else { + // There are still waiters so we need to keep the HAS_WAITERS bit in the state. + HAS_WAITERS + }; + + // Try to transfer waiters before releasing the spin lock. + if !wake_list.is_empty() { + // Safe because there was a waiter in the queue and the thread that owns the waiter also + // owns a reference to the Mutex, guaranteeing that the pointer is valid. + unsafe { (*muptr).transfer_waiters(&mut wake_list, all_readers) }; + } + + // Release the spin lock. + self.state.store(newstate, Ordering::Release); + + // Now wake any waiters still left in the wake list. + for w in wake_list { + w.wake(); + } + } + + /// Notify all threads currently waiting on the `Condvar`. + /// + /// All threads currently waiting on the `Condvar` will be woken up from their call to `wait`. + /// + /// Unlike more traditional condition variable interfaces, this method requires a reference to + /// the `Mutex` associated with this `Condvar`. This is because it is inherently racy to call + /// `notify_one` or `notify_all` without first acquiring the `Mutex` lock. Additionally, taking + /// a reference to the `Mutex` here allows us to make some optimizations that can improve + /// performance by reducing unnecessary wakeups. + pub fn notify_all(&self) { + let mut oldstate = self.state.load(Ordering::Relaxed); + if (oldstate & HAS_WAITERS) == 0 { + // No waiters. + return; + } + + while (oldstate & SPINLOCK) != 0 + || self + .state + .compare_and_swap(oldstate, oldstate | SPINLOCK, Ordering::Acquire) + != oldstate + { + spin_loop_hint(); + oldstate = self.state.load(Ordering::Relaxed); + } + + // Safe because the spin lock guarantees exclusive access to `self.waiters`. + let mut wake_list = unsafe { (*self.waiters.get()).take() }; + + // Safe because the spin lock guarantees exclusive access. + let muptr = unsafe { (*self.mu.get()) as *const RawMutex }; + + // Clear the mutex associated with this Condvar since there are no longer any waiters. Safe + // because we the spin lock guarantees exclusive access. + unsafe { *self.mu.get() = 0 }; + + // Try to transfer waiters before releasing the spin lock. + if !wake_list.is_empty() { + // Safe because there was a waiter in the queue and the thread that owns the waiter also + // owns a reference to the Mutex, guaranteeing that the pointer is valid. + unsafe { (*muptr).transfer_waiters(&mut wake_list, false) }; + } + + // Mark any waiters left as no longer waiting for the Condvar. + for w in &wake_list { + w.set_waiting_for(WaitingFor::None); + } + + // Release the spin lock. We can clear all bits in the state since we took all the waiters. + self.state.store(0, Ordering::Release); + + // Now wake any waiters still left in the wake list. + for w in wake_list { + w.wake(); + } + } + + fn cancel_waiter(&self, waiter: &Waiter, wake_next: bool) -> bool { + let mut oldstate = self.state.load(Ordering::Relaxed); + while oldstate & SPINLOCK != 0 + || self + .state + .compare_exchange_weak( + oldstate, + oldstate | SPINLOCK, + Ordering::Acquire, + Ordering::Relaxed, + ) + .is_err() + { + spin_loop_hint(); + oldstate = self.state.load(Ordering::Relaxed); + } + + // Safe because the spin lock provides exclusive access and the reference does not escape + // this function. + let waiters = unsafe { &mut *self.waiters.get() }; + + let waiting_for = waiter.is_waiting_for(); + if waiting_for == WaitingFor::Mutex { + // The waiter was moved to the mutex's list. Retry the cancel. + let set_on_release = if waiters.is_empty() { + // Clear the mutex associated with this Condvar since there are no longer any waiters. Safe + // because we the spin lock guarantees exclusive access. + unsafe { *self.mu.get() = 0 }; + + 0 + } else { + HAS_WAITERS + }; + + self.state.store(set_on_release, Ordering::Release); + + false + } else { + // Don't drop the old waiter now as we're still holding the spin lock. + let old_waiter = if waiter.is_linked() && waiting_for == WaitingFor::Condvar { + // Safe because we know that the waiter is still linked and is waiting for the Condvar, + // which guarantees that it is still in `self.waiters`. + let mut cursor = unsafe { waiters.cursor_mut_from_ptr(waiter as *const Waiter) }; + cursor.remove() + } else { + None + }; + + let (mut wake_list, all_readers) = if wake_next || waiting_for == WaitingFor::None { + // Either the waiter was already woken or it's been removed from the condvar's waiter + // list and is going to be woken. Either way, we need to wake up another thread. + get_wake_list(waiters) + } else { + (WaiterList::new(WaiterAdapter::new()), false) + }; + + // Safe because the spin lock guarantees exclusive access. + let muptr = unsafe { (*self.mu.get()) as *const RawMutex }; + + // Try to transfer waiters before releasing the spin lock. + if !wake_list.is_empty() { + // Safe because there was a waiter in the queue and the thread that owns the waiter also + // owns a reference to the Mutex, guaranteeing that the pointer is valid. + unsafe { (*muptr).transfer_waiters(&mut wake_list, all_readers) }; + } + + let set_on_release = if waiters.is_empty() { + // Clear the mutex associated with this Condvar since there are no longer any waiters. Safe + // because we the spin lock guarantees exclusive access. + unsafe { *self.mu.get() = 0 }; + + 0 + } else { + HAS_WAITERS + }; + + self.state.store(set_on_release, Ordering::Release); + + // Now wake any waiters still left in the wake list. + for w in wake_list { + w.wake(); + } + + mem::drop(old_waiter); + true + } + } +} + +unsafe impl Send for Condvar {} +unsafe impl Sync for Condvar {} + +impl Default for Condvar { + fn default() -> Self { + Self::new() + } +} + +// Scan `waiters` and return all waiters that should be woken up. If all waiters in the returned +// wait list are readers then the returned bool will be true. +// +// If the first waiter is trying to acquire a shared lock, then all waiters in the list that are +// waiting for a shared lock are also woken up. In addition one writer is woken up, if possible. +// +// If the first waiter is trying to acquire an exclusive lock, then only that waiter is returned and +// the rest of the list is not scanned. +fn get_wake_list(waiters: &mut WaiterList) -> (WaiterList, bool) { + let mut to_wake = WaiterList::new(WaiterAdapter::new()); + let mut cursor = waiters.front_mut(); + + let mut waking_readers = false; + let mut all_readers = true; + while let Some(w) = cursor.get() { + match w.kind() { + WaiterKind::Exclusive if !waking_readers => { + // This is the first waiter and it's a writer. No need to check the other waiters. + // Also mark the waiter as having been removed from the Condvar's waiter list. + let waiter = cursor.remove().unwrap(); + waiter.set_waiting_for(WaitingFor::None); + to_wake.push_back(waiter); + all_readers = false; + break; + } + + WaiterKind::Shared => { + // This is a reader and the first waiter in the list was not a writer so wake up all + // the readers in the wait list. + let waiter = cursor.remove().unwrap(); + waiter.set_waiting_for(WaitingFor::None); + to_wake.push_back(waiter); + waking_readers = true; + } + + WaiterKind::Exclusive => { + debug_assert!(waking_readers); + if all_readers { + // We are waking readers but we need to ensure that at least one writer is woken + // up. Since we haven't yet woken up a writer, wake up this one. + let waiter = cursor.remove().unwrap(); + waiter.set_waiting_for(WaitingFor::None); + to_wake.push_back(waiter); + all_readers = false; + } else { + // We are waking readers and have already woken one writer. Skip this one. + cursor.move_next(); + } + } + } + } + + (to_wake, all_readers) +} + +fn cancel_waiter(cv: usize, waiter: &Waiter, wake_next: bool) -> bool { + let condvar = cv as *const Condvar; + + // Safe because the thread that owns the waiter being canceled must also own a reference to the + // Condvar, which guarantees that this pointer is valid. + unsafe { (*condvar).cancel_waiter(waiter, wake_next) } +} + +#[cfg(test)] +mod test { + use super::*; + + use std::future::Future; + use std::mem; + use std::ptr; + use std::rc::Rc; + use std::sync::mpsc::{channel, Sender}; + use std::sync::Arc; + use std::task::{Context, Poll}; + use std::thread::{self, JoinHandle}; + use std::time::Duration; + + use futures::channel::oneshot; + use futures::task::{waker_ref, ArcWake}; + use futures::{select, FutureExt}; + use futures_executor::{LocalPool, LocalSpawner, ThreadPool}; + use futures_util::task::LocalSpawnExt; + + use crate::sync::{block_on, Mutex}; + + // Dummy waker used when we want to manually drive futures. + struct TestWaker; + impl ArcWake for TestWaker { + fn wake_by_ref(_arc_self: &Arc<Self>) {} + } + + #[test] + fn smoke() { + let cv = Condvar::new(); + cv.notify_one(); + cv.notify_all(); + } + + #[test] + fn notify_one() { + let mu = Arc::new(Mutex::new(())); + let cv = Arc::new(Condvar::new()); + + let mu2 = mu.clone(); + let cv2 = cv.clone(); + + let guard = block_on(mu.lock()); + thread::spawn(move || { + let _g = block_on(mu2.lock()); + cv2.notify_one(); + }); + + let guard = block_on(cv.wait(guard)); + mem::drop(guard); + } + + #[test] + fn multi_mutex() { + const NUM_THREADS: usize = 5; + + let mu = Arc::new(Mutex::new(false)); + let cv = Arc::new(Condvar::new()); + + let mut threads = Vec::with_capacity(NUM_THREADS); + for _ in 0..NUM_THREADS { + let mu = mu.clone(); + let cv = cv.clone(); + + threads.push(thread::spawn(move || { + let mut ready = block_on(mu.lock()); + while !*ready { + ready = block_on(cv.wait(ready)); + } + })); + } + + let mut g = block_on(mu.lock()); + *g = true; + mem::drop(g); + cv.notify_all(); + + threads + .into_iter() + .map(JoinHandle::join) + .collect::<thread::Result<()>>() + .expect("Failed to join threads"); + + // Now use the Condvar with a different mutex. + let alt_mu = Arc::new(Mutex::new(None)); + let alt_mu2 = alt_mu.clone(); + let cv2 = cv.clone(); + let handle = thread::spawn(move || { + let mut g = block_on(alt_mu2.lock()); + while g.is_none() { + g = block_on(cv2.wait(g)); + } + }); + + let mut alt_g = block_on(alt_mu.lock()); + *alt_g = Some(()); + mem::drop(alt_g); + cv.notify_all(); + + handle + .join() + .expect("Failed to join thread alternate mutex"); + } + + #[test] + fn notify_one_single_thread_async() { + async fn notify(mu: Rc<Mutex<()>>, cv: Rc<Condvar>) { + let _g = mu.lock().await; + cv.notify_one(); + } + + async fn wait(mu: Rc<Mutex<()>>, cv: Rc<Condvar>, spawner: LocalSpawner) { + let mu2 = Rc::clone(&mu); + let cv2 = Rc::clone(&cv); + + let g = mu.lock().await; + // Has to be spawned _after_ acquiring the lock to prevent a race + // where the notify happens before the waiter has acquired the lock. + spawner + .spawn_local(notify(mu2, cv2)) + .expect("Failed to spawn `notify` task"); + let _g = cv.wait(g).await; + } + + let mut ex = LocalPool::new(); + let spawner = ex.spawner(); + + let mu = Rc::new(Mutex::new(())); + let cv = Rc::new(Condvar::new()); + + spawner + .spawn_local(wait(mu, cv, spawner.clone())) + .expect("Failed to spawn `wait` task"); + + ex.run(); + } + + #[test] + fn notify_one_multi_thread_async() { + async fn notify(mu: Arc<Mutex<()>>, cv: Arc<Condvar>) { + let _g = mu.lock().await; + cv.notify_one(); + } + + async fn wait(mu: Arc<Mutex<()>>, cv: Arc<Condvar>, tx: Sender<()>, pool: ThreadPool) { + let mu2 = Arc::clone(&mu); + let cv2 = Arc::clone(&cv); + + let g = mu.lock().await; + // Has to be spawned _after_ acquiring the lock to prevent a race + // where the notify happens before the waiter has acquired the lock. + pool.spawn_ok(notify(mu2, cv2)); + let _g = cv.wait(g).await; + + tx.send(()).expect("Failed to send completion notification"); + } + + let ex = ThreadPool::new().expect("Failed to create ThreadPool"); + + let mu = Arc::new(Mutex::new(())); + let cv = Arc::new(Condvar::new()); + + let (tx, rx) = channel(); + ex.spawn_ok(wait(mu, cv, tx, ex.clone())); + + rx.recv_timeout(Duration::from_secs(5)) + .expect("Failed to receive completion notification"); + } + + #[test] + fn notify_one_with_cancel() { + const TASKS: usize = 17; + const OBSERVERS: usize = 7; + const ITERATIONS: usize = 103; + + async fn observe(mu: &Arc<Mutex<usize>>, cv: &Arc<Condvar>) { + let mut count = mu.read_lock().await; + while *count == 0 { + count = cv.wait_read(count).await; + } + let _ = unsafe { ptr::read_volatile(&*count as *const usize) }; + } + + async fn decrement(mu: &Arc<Mutex<usize>>, cv: &Arc<Condvar>) { + let mut count = mu.lock().await; + while *count == 0 { + count = cv.wait(count).await; + } + *count -= 1; + } + + async fn increment(mu: Arc<Mutex<usize>>, cv: Arc<Condvar>, done: Sender<()>) { + for _ in 0..TASKS * OBSERVERS * ITERATIONS { + *mu.lock().await += 1; + cv.notify_one(); + } + + done.send(()).expect("Failed to send completion message"); + } + + async fn observe_either( + mu: Arc<Mutex<usize>>, + cv: Arc<Condvar>, + alt_mu: Arc<Mutex<usize>>, + alt_cv: Arc<Condvar>, + done: Sender<()>, + ) { + for _ in 0..ITERATIONS { + select! { + () = observe(&mu, &cv).fuse() => {}, + () = observe(&alt_mu, &alt_cv).fuse() => {}, + } + } + + done.send(()).expect("Failed to send completion message"); + } + + async fn decrement_either( + mu: Arc<Mutex<usize>>, + cv: Arc<Condvar>, + alt_mu: Arc<Mutex<usize>>, + alt_cv: Arc<Condvar>, + done: Sender<()>, + ) { + for _ in 0..ITERATIONS { + select! { + () = decrement(&mu, &cv).fuse() => {}, + () = decrement(&alt_mu, &alt_cv).fuse() => {}, + } + } + + done.send(()).expect("Failed to send completion message"); + } + + let ex = ThreadPool::new().expect("Failed to create ThreadPool"); + + let mu = Arc::new(Mutex::new(0usize)); + let alt_mu = Arc::new(Mutex::new(0usize)); + + let cv = Arc::new(Condvar::new()); + let alt_cv = Arc::new(Condvar::new()); + + let (tx, rx) = channel(); + for _ in 0..TASKS { + ex.spawn_ok(decrement_either( + Arc::clone(&mu), + Arc::clone(&cv), + Arc::clone(&alt_mu), + Arc::clone(&alt_cv), + tx.clone(), + )); + } + + for _ in 0..OBSERVERS { + ex.spawn_ok(observe_either( + Arc::clone(&mu), + Arc::clone(&cv), + Arc::clone(&alt_mu), + Arc::clone(&alt_cv), + tx.clone(), + )); + } + + ex.spawn_ok(increment(Arc::clone(&mu), Arc::clone(&cv), tx.clone())); + ex.spawn_ok(increment(Arc::clone(&alt_mu), Arc::clone(&alt_cv), tx)); + + for _ in 0..TASKS + OBSERVERS + 2 { + if let Err(e) = rx.recv_timeout(Duration::from_secs(10)) { + panic!("Error while waiting for threads to complete: {}", e); + } + } + + assert_eq!( + *block_on(mu.read_lock()) + *block_on(alt_mu.read_lock()), + (TASKS * OBSERVERS * ITERATIONS * 2) - (TASKS * ITERATIONS) + ); + assert_eq!(cv.state.load(Ordering::Relaxed), 0); + assert_eq!(alt_cv.state.load(Ordering::Relaxed), 0); + } + + #[test] + fn notify_all_with_cancel() { + const TASKS: usize = 17; + const ITERATIONS: usize = 103; + + async fn decrement(mu: &Arc<Mutex<usize>>, cv: &Arc<Condvar>) { + let mut count = mu.lock().await; + while *count == 0 { + count = cv.wait(count).await; + } + *count -= 1; + } + + async fn increment(mu: Arc<Mutex<usize>>, cv: Arc<Condvar>, done: Sender<()>) { + for _ in 0..TASKS * ITERATIONS { + *mu.lock().await += 1; + cv.notify_all(); + } + + done.send(()).expect("Failed to send completion message"); + } + + async fn decrement_either( + mu: Arc<Mutex<usize>>, + cv: Arc<Condvar>, + alt_mu: Arc<Mutex<usize>>, + alt_cv: Arc<Condvar>, + done: Sender<()>, + ) { + for _ in 0..ITERATIONS { + select! { + () = decrement(&mu, &cv).fuse() => {}, + () = decrement(&alt_mu, &alt_cv).fuse() => {}, + } + } + + done.send(()).expect("Failed to send completion message"); + } + + let ex = ThreadPool::new().expect("Failed to create ThreadPool"); + + let mu = Arc::new(Mutex::new(0usize)); + let alt_mu = Arc::new(Mutex::new(0usize)); + + let cv = Arc::new(Condvar::new()); + let alt_cv = Arc::new(Condvar::new()); + + let (tx, rx) = channel(); + for _ in 0..TASKS { + ex.spawn_ok(decrement_either( + Arc::clone(&mu), + Arc::clone(&cv), + Arc::clone(&alt_mu), + Arc::clone(&alt_cv), + tx.clone(), + )); + } + + ex.spawn_ok(increment(Arc::clone(&mu), Arc::clone(&cv), tx.clone())); + ex.spawn_ok(increment(Arc::clone(&alt_mu), Arc::clone(&alt_cv), tx)); + + for _ in 0..TASKS + 2 { + if let Err(e) = rx.recv_timeout(Duration::from_secs(10)) { + panic!("Error while waiting for threads to complete: {}", e); + } + } + + assert_eq!( + *block_on(mu.read_lock()) + *block_on(alt_mu.read_lock()), + TASKS * ITERATIONS + ); + assert_eq!(cv.state.load(Ordering::Relaxed), 0); + assert_eq!(alt_cv.state.load(Ordering::Relaxed), 0); + } + #[test] + fn notify_all() { + const THREADS: usize = 13; + + let mu = Arc::new(Mutex::new(0)); + let cv = Arc::new(Condvar::new()); + let (tx, rx) = channel(); + + let mut threads = Vec::with_capacity(THREADS); + for _ in 0..THREADS { + let mu2 = mu.clone(); + let cv2 = cv.clone(); + let tx2 = tx.clone(); + + threads.push(thread::spawn(move || { + let mut count = block_on(mu2.lock()); + *count += 1; + if *count == THREADS { + tx2.send(()).unwrap(); + } + + while *count != 0 { + count = block_on(cv2.wait(count)); + } + })); + } + + mem::drop(tx); + + // Wait till all threads have started. + rx.recv_timeout(Duration::from_secs(5)).unwrap(); + + let mut count = block_on(mu.lock()); + *count = 0; + mem::drop(count); + cv.notify_all(); + + for t in threads { + t.join().unwrap(); + } + } + + #[test] + fn notify_all_single_thread_async() { + const TASKS: usize = 13; + + async fn reset(mu: Rc<Mutex<usize>>, cv: Rc<Condvar>) { + let mut count = mu.lock().await; + *count = 0; + cv.notify_all(); + } + + async fn watcher(mu: Rc<Mutex<usize>>, cv: Rc<Condvar>, spawner: LocalSpawner) { + let mut count = mu.lock().await; + *count += 1; + if *count == TASKS { + spawner + .spawn_local(reset(mu.clone(), cv.clone())) + .expect("Failed to spawn reset task"); + } + + while *count != 0 { + count = cv.wait(count).await; + } + } + + let mut ex = LocalPool::new(); + let spawner = ex.spawner(); + + let mu = Rc::new(Mutex::new(0)); + let cv = Rc::new(Condvar::new()); + + for _ in 0..TASKS { + spawner + .spawn_local(watcher(mu.clone(), cv.clone(), spawner.clone())) + .expect("Failed to spawn watcher task"); + } + + ex.run(); + } + + #[test] + fn notify_all_multi_thread_async() { + const TASKS: usize = 13; + + async fn reset(mu: Arc<Mutex<usize>>, cv: Arc<Condvar>) { + let mut count = mu.lock().await; + *count = 0; + cv.notify_all(); + } + + async fn watcher( + mu: Arc<Mutex<usize>>, + cv: Arc<Condvar>, + pool: ThreadPool, + tx: Sender<()>, + ) { + let mut count = mu.lock().await; + *count += 1; + if *count == TASKS { + pool.spawn_ok(reset(mu.clone(), cv.clone())); + } + + while *count != 0 { + count = cv.wait(count).await; + } + + tx.send(()).expect("Failed to send completion notification"); + } + + let pool = ThreadPool::new().expect("Failed to create ThreadPool"); + + let mu = Arc::new(Mutex::new(0)); + let cv = Arc::new(Condvar::new()); + + let (tx, rx) = channel(); + for _ in 0..TASKS { + pool.spawn_ok(watcher(mu.clone(), cv.clone(), pool.clone(), tx.clone())); + } + + for _ in 0..TASKS { + rx.recv_timeout(Duration::from_secs(5)) + .expect("Failed to receive completion notification"); + } + } + + #[test] + fn wake_all_readers() { + async fn read(mu: Arc<Mutex<bool>>, cv: Arc<Condvar>) { + let mut ready = mu.read_lock().await; + while !*ready { + ready = cv.wait_read(ready).await; + } + } + + let mu = Arc::new(Mutex::new(false)); + let cv = Arc::new(Condvar::new()); + let mut readers = [ + Box::pin(read(mu.clone(), cv.clone())), + Box::pin(read(mu.clone(), cv.clone())), + Box::pin(read(mu.clone(), cv.clone())), + Box::pin(read(mu.clone(), cv.clone())), + ]; + + let arc_waker = Arc::new(TestWaker); + let waker = waker_ref(&arc_waker); + let mut cx = Context::from_waker(&waker); + + // First have all the readers wait on the Condvar. + for r in &mut readers { + if let Poll::Ready(()) = r.as_mut().poll(&mut cx) { + panic!("reader unexpectedly ready"); + } + } + + assert_eq!(cv.state.load(Ordering::Relaxed) & HAS_WAITERS, HAS_WAITERS); + + // Now make the condition true and notify the condvar. Even though we will call notify_one, + // all the readers should be woken up. + *block_on(mu.lock()) = true; + cv.notify_one(); + + assert_eq!(cv.state.load(Ordering::Relaxed), 0); + + // All readers should now be able to complete. + for r in &mut readers { + if let Poll::Pending = r.as_mut().poll(&mut cx) { + panic!("reader unable to complete"); + } + } + } + + #[test] + fn cancel_before_notify() { + async fn dec(mu: Arc<Mutex<usize>>, cv: Arc<Condvar>) { + let mut count = mu.lock().await; + + while *count == 0 { + count = cv.wait(count).await; + } + + *count -= 1; + } + + let mu = Arc::new(Mutex::new(0)); + let cv = Arc::new(Condvar::new()); + + let arc_waker = Arc::new(TestWaker); + let waker = waker_ref(&arc_waker); + let mut cx = Context::from_waker(&waker); + + let mut fut1 = Box::pin(dec(mu.clone(), cv.clone())); + let mut fut2 = Box::pin(dec(mu.clone(), cv.clone())); + + if let Poll::Ready(()) = fut1.as_mut().poll(&mut cx) { + panic!("future unexpectedly ready"); + } + if let Poll::Ready(()) = fut2.as_mut().poll(&mut cx) { + panic!("future unexpectedly ready"); + } + assert_eq!(cv.state.load(Ordering::Relaxed) & HAS_WAITERS, HAS_WAITERS); + + *block_on(mu.lock()) = 2; + // Drop fut1 before notifying the cv. + mem::drop(fut1); + cv.notify_one(); + + // fut2 should now be ready to complete. + assert_eq!(cv.state.load(Ordering::Relaxed), 0); + + if let Poll::Pending = fut2.as_mut().poll(&mut cx) { + panic!("future unable to complete"); + } + + assert_eq!(*block_on(mu.lock()), 1); + } + + #[test] + fn cancel_after_notify() { + async fn dec(mu: Arc<Mutex<usize>>, cv: Arc<Condvar>) { + let mut count = mu.lock().await; + + while *count == 0 { + count = cv.wait(count).await; + } + + *count -= 1; + } + + let mu = Arc::new(Mutex::new(0)); + let cv = Arc::new(Condvar::new()); + + let arc_waker = Arc::new(TestWaker); + let waker = waker_ref(&arc_waker); + let mut cx = Context::from_waker(&waker); + + let mut fut1 = Box::pin(dec(mu.clone(), cv.clone())); + let mut fut2 = Box::pin(dec(mu.clone(), cv.clone())); + + if let Poll::Ready(()) = fut1.as_mut().poll(&mut cx) { + panic!("future unexpectedly ready"); + } + if let Poll::Ready(()) = fut2.as_mut().poll(&mut cx) { + panic!("future unexpectedly ready"); + } + assert_eq!(cv.state.load(Ordering::Relaxed) & HAS_WAITERS, HAS_WAITERS); + + *block_on(mu.lock()) = 2; + cv.notify_one(); + + // fut1 should now be ready to complete. Drop it before polling. This should wake up fut2. + mem::drop(fut1); + assert_eq!(cv.state.load(Ordering::Relaxed), 0); + + if let Poll::Pending = fut2.as_mut().poll(&mut cx) { + panic!("future unable to complete"); + } + + assert_eq!(*block_on(mu.lock()), 1); + } + + #[test] + fn cancel_after_transfer() { + async fn dec(mu: Arc<Mutex<usize>>, cv: Arc<Condvar>) { + let mut count = mu.lock().await; + + while *count == 0 { + count = cv.wait(count).await; + } + + *count -= 1; + } + + let mu = Arc::new(Mutex::new(0)); + let cv = Arc::new(Condvar::new()); + + let arc_waker = Arc::new(TestWaker); + let waker = waker_ref(&arc_waker); + let mut cx = Context::from_waker(&waker); + + let mut fut1 = Box::pin(dec(mu.clone(), cv.clone())); + let mut fut2 = Box::pin(dec(mu.clone(), cv.clone())); + + if let Poll::Ready(()) = fut1.as_mut().poll(&mut cx) { + panic!("future unexpectedly ready"); + } + if let Poll::Ready(()) = fut2.as_mut().poll(&mut cx) { + panic!("future unexpectedly ready"); + } + assert_eq!(cv.state.load(Ordering::Relaxed) & HAS_WAITERS, HAS_WAITERS); + + let mut count = block_on(mu.lock()); + *count = 2; + + // Notify the cv while holding the lock. Only transfer one waiter. + cv.notify_one(); + assert_eq!(cv.state.load(Ordering::Relaxed) & HAS_WAITERS, HAS_WAITERS); + + // Drop the lock and then the future. This should not cause fut2 to become runnable as it + // should still be in the Condvar's wait queue. + mem::drop(count); + mem::drop(fut1); + + if let Poll::Ready(()) = fut2.as_mut().poll(&mut cx) { + panic!("future unexpectedly ready"); + } + + // Now wake up fut2. Since the lock isn't held, it should wake up immediately. + cv.notify_one(); + if let Poll::Pending = fut2.as_mut().poll(&mut cx) { + panic!("future unable to complete"); + } + + assert_eq!(*block_on(mu.lock()), 1); + } + + #[test] + fn cancel_after_transfer_and_wake() { + async fn dec(mu: Arc<Mutex<usize>>, cv: Arc<Condvar>) { + let mut count = mu.lock().await; + + while *count == 0 { + count = cv.wait(count).await; + } + + *count -= 1; + } + + let mu = Arc::new(Mutex::new(0)); + let cv = Arc::new(Condvar::new()); + + let arc_waker = Arc::new(TestWaker); + let waker = waker_ref(&arc_waker); + let mut cx = Context::from_waker(&waker); + + let mut fut1 = Box::pin(dec(mu.clone(), cv.clone())); + let mut fut2 = Box::pin(dec(mu.clone(), cv.clone())); + + if let Poll::Ready(()) = fut1.as_mut().poll(&mut cx) { + panic!("future unexpectedly ready"); + } + if let Poll::Ready(()) = fut2.as_mut().poll(&mut cx) { + panic!("future unexpectedly ready"); + } + assert_eq!(cv.state.load(Ordering::Relaxed) & HAS_WAITERS, HAS_WAITERS); + + let mut count = block_on(mu.lock()); + *count = 2; + + // Notify the cv while holding the lock. This should transfer both waiters to the mutex's + // wait queue. + cv.notify_all(); + assert_eq!(cv.state.load(Ordering::Relaxed), 0); + + mem::drop(count); + + mem::drop(fut1); + + if let Poll::Pending = fut2.as_mut().poll(&mut cx) { + panic!("future unable to complete"); + } + + assert_eq!(*block_on(mu.lock()), 1); + } + + #[test] + fn timed_wait() { + async fn wait_deadline( + mu: Arc<Mutex<usize>>, + cv: Arc<Condvar>, + timeout: oneshot::Receiver<()>, + ) { + let mut count = mu.lock().await; + + if *count == 0 { + let mut rx = timeout.fuse(); + + while *count == 0 { + select! { + res = rx => { + if let Err(e) = res { + panic!("Error while receiving timeout notification: {}", e); + } + + return; + }, + c = cv.wait(count).fuse() => count = c, + } + } + } + + *count += 1; + } + + let mu = Arc::new(Mutex::new(0)); + let cv = Arc::new(Condvar::new()); + + let arc_waker = Arc::new(TestWaker); + let waker = waker_ref(&arc_waker); + let mut cx = Context::from_waker(&waker); + + let (tx, rx) = oneshot::channel(); + let mut wait = Box::pin(wait_deadline(mu.clone(), cv.clone(), rx)); + + if let Poll::Ready(()) = wait.as_mut().poll(&mut cx) { + panic!("wait_deadline unexpectedly ready"); + } + + assert_eq!(cv.state.load(Ordering::Relaxed), HAS_WAITERS); + + // Signal the channel, which should cancel the wait. + tx.send(()).expect("Failed to send wakeup"); + + // Wait for the timer to run out. + if let Poll::Pending = wait.as_mut().poll(&mut cx) { + panic!("wait_deadline unable to complete in time"); + } + + assert_eq!(cv.state.load(Ordering::Relaxed), 0); + assert_eq!(*block_on(mu.lock()), 0); + } +} diff --git a/src/sync/mu.rs b/src/sync/mu.rs new file mode 100644 index 0000000..26735d1 --- /dev/null +++ b/src/sync/mu.rs @@ -0,0 +1,2400 @@ +// Copyright 2020 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +use std::cell::UnsafeCell; +use std::mem; +use std::ops::{Deref, DerefMut}; +use std::sync::atomic::{spin_loop_hint, AtomicUsize, Ordering}; +use std::sync::Arc; +use std::thread::yield_now; + +use crate::sync::waiter::{Kind as WaiterKind, Waiter, WaiterAdapter, WaiterList, WaitingFor}; + +// Set when the mutex is exclusively locked. +const LOCKED: usize = 1 << 0; +// Set when there are one or more threads waiting to acquire the lock. +const HAS_WAITERS: usize = 1 << 1; +// Set when a thread has been woken up from the wait queue. Cleared when that thread either acquires +// the lock or adds itself back into the wait queue. Used to prevent unnecessary wake ups when a +// thread has been removed from the wait queue but has not gotten CPU time yet. +const DESIGNATED_WAKER: usize = 1 << 2; +// Used to provide exclusive access to the `waiters` field in `Mutex`. Should only be held while +// modifying the waiter list. +const SPINLOCK: usize = 1 << 3; +// Set when a thread that wants an exclusive lock adds itself to the wait queue. New threads +// attempting to acquire a shared lock will be preventing from getting it when this bit is set. +// However, this bit is ignored once a thread has gone through the wait queue at least once. +const WRITER_WAITING: usize = 1 << 4; +// Set when a thread has gone through the wait queue many times but has failed to acquire the lock +// every time it is woken up. When this bit is set, all other threads are prevented from acquiring +// the lock until the thread that set the `LONG_WAIT` bit has acquired the lock. +const LONG_WAIT: usize = 1 << 5; +// The bit that is added to the mutex state in order to acquire a shared lock. Since more than one +// thread can acquire a shared lock, we cannot use a single bit. Instead we use all the remaining +// bits in the state to track the number of threads that have acquired a shared lock. +const READ_LOCK: usize = 1 << 8; +// Mask used for checking if any threads currently hold a shared lock. +const READ_MASK: usize = !0xff; +// Mask used to check if the lock is held in either shared or exclusive mode. +const ANY_LOCK: usize = LOCKED | READ_MASK; + +// The number of times the thread should just spin and attempt to re-acquire the lock. +const SPIN_THRESHOLD: usize = 7; + +// The number of times the thread needs to go through the wait queue before it sets the `LONG_WAIT` +// bit and forces all other threads to wait for it to acquire the lock. This value is set relatively +// high so that we don't lose the benefit of having running threads unless it is absolutely +// necessary. +const LONG_WAIT_THRESHOLD: usize = 19; + +// Common methods between shared and exclusive locks. +trait Kind { + // The bits that must be zero for the thread to acquire this kind of lock. If any of these bits + // are not zero then the thread will first spin and retry a few times before adding itself to + // the wait queue. + fn zero_to_acquire() -> usize; + + // The bit that must be added in order to acquire this kind of lock. This should either be + // `LOCKED` or `READ_LOCK`. + fn add_to_acquire() -> usize; + + // The bits that should be set when a thread adds itself to the wait queue while waiting to + // acquire this kind of lock. + fn set_when_waiting() -> usize; + + // The bits that should be cleared when a thread acquires this kind of lock. + fn clear_on_acquire() -> usize; + + // The waiter that a thread should use when waiting to acquire this kind of lock. + fn new_waiter(raw: &RawMutex) -> Arc<Waiter>; +} + +// A lock type for shared read-only access to the data. More than one thread may hold this kind of +// lock simultaneously. +struct Shared; + +impl Kind for Shared { + fn zero_to_acquire() -> usize { + LOCKED | WRITER_WAITING | LONG_WAIT + } + + fn add_to_acquire() -> usize { + READ_LOCK + } + + fn set_when_waiting() -> usize { + 0 + } + + fn clear_on_acquire() -> usize { + 0 + } + + fn new_waiter(raw: &RawMutex) -> Arc<Waiter> { + Arc::new(Waiter::new( + WaiterKind::Shared, + cancel_waiter, + raw as *const RawMutex as usize, + WaitingFor::Mutex, + )) + } +} + +// A lock type for mutually exclusive read-write access to the data. Only one thread can hold this +// kind of lock at a time. +struct Exclusive; + +impl Kind for Exclusive { + fn zero_to_acquire() -> usize { + LOCKED | READ_MASK | LONG_WAIT + } + + fn add_to_acquire() -> usize { + LOCKED + } + + fn set_when_waiting() -> usize { + WRITER_WAITING + } + + fn clear_on_acquire() -> usize { + WRITER_WAITING + } + + fn new_waiter(raw: &RawMutex) -> Arc<Waiter> { + Arc::new(Waiter::new( + WaiterKind::Exclusive, + cancel_waiter, + raw as *const RawMutex as usize, + WaitingFor::Mutex, + )) + } +} + +// Scan `waiters` and return the ones that should be woken up. Also returns any bits that should be +// set in the mutex state when the current thread releases the spin lock protecting the waiter list. +// +// If the first waiter is trying to acquire a shared lock, then all waiters in the list that are +// waiting for a shared lock are also woken up. If any waiters waiting for an exclusive lock are +// found when iterating through the list, then the returned `usize` contains the `WRITER_WAITING` +// bit, which should be set when the thread releases the spin lock. +// +// If the first waiter is trying to acquire an exclusive lock, then only that waiter is returned and +// no bits are set in the returned `usize`. +fn get_wake_list(waiters: &mut WaiterList) -> (WaiterList, usize) { + let mut to_wake = WaiterList::new(WaiterAdapter::new()); + let mut set_on_release = 0; + let mut cursor = waiters.front_mut(); + + let mut waking_readers = false; + while let Some(w) = cursor.get() { + match w.kind() { + WaiterKind::Exclusive if !waking_readers => { + // This is the first waiter and it's a writer. No need to check the other waiters. + let waiter = cursor.remove().unwrap(); + waiter.set_waiting_for(WaitingFor::None); + to_wake.push_back(waiter); + break; + } + + WaiterKind::Shared => { + // This is a reader and the first waiter in the list was not a writer so wake up all + // the readers in the wait list. + let waiter = cursor.remove().unwrap(); + waiter.set_waiting_for(WaitingFor::None); + to_wake.push_back(waiter); + waking_readers = true; + } + + WaiterKind::Exclusive => { + // We found a writer while looking for more readers to wake up. Set the + // `WRITER_WAITING` bit to prevent any new readers from acquiring the lock. All + // readers currently in the wait list will ignore this bit since they already waited + // once. + set_on_release |= WRITER_WAITING; + cursor.move_next(); + } + } + } + + (to_wake, set_on_release) +} + +#[inline] +fn cpu_relax(iterations: usize) { + for _ in 0..iterations { + spin_loop_hint(); + } +} + +pub(crate) struct RawMutex { + state: AtomicUsize, + waiters: UnsafeCell<WaiterList>, +} + +impl RawMutex { + pub fn new() -> RawMutex { + RawMutex { + state: AtomicUsize::new(0), + waiters: UnsafeCell::new(WaiterList::new(WaiterAdapter::new())), + } + } + + #[inline] + pub async fn lock(&self) { + match self + .state + .compare_exchange_weak(0, LOCKED, Ordering::Acquire, Ordering::Relaxed) + { + Ok(_) => {} + Err(oldstate) => { + // If any bits that should be zero are not zero or if we fail to acquire the lock + // with a single compare_exchange then go through the slow path. + if (oldstate & Exclusive::zero_to_acquire()) != 0 + || self + .state + .compare_exchange_weak( + oldstate, + (oldstate + Exclusive::add_to_acquire()) + & !Exclusive::clear_on_acquire(), + Ordering::Acquire, + Ordering::Relaxed, + ) + .is_err() + { + self.lock_slow::<Exclusive>(0, 0).await; + } + } + } + } + + #[inline] + pub async fn read_lock(&self) { + match self + .state + .compare_exchange_weak(0, READ_LOCK, Ordering::Acquire, Ordering::Relaxed) + { + Ok(_) => {} + Err(oldstate) => { + if (oldstate & Shared::zero_to_acquire()) != 0 + || self + .state + .compare_exchange_weak( + oldstate, + (oldstate + Shared::add_to_acquire()) & !Shared::clear_on_acquire(), + Ordering::Acquire, + Ordering::Relaxed, + ) + .is_err() + { + self.lock_slow::<Shared>(0, 0).await; + } + } + } + } + + // Slow path for acquiring the lock. `clear` should contain any bits that need to be cleared + // when the lock is acquired. Any bits set in `zero_mask` are cleared from the bits returned by + // `K::zero_to_acquire()`. + #[cold] + async fn lock_slow<K: Kind>(&self, mut clear: usize, zero_mask: usize) { + let mut zero_to_acquire = K::zero_to_acquire() & !zero_mask; + + let mut spin_count = 0; + let mut wait_count = 0; + let mut waiter = None; + loop { + let oldstate = self.state.load(Ordering::Relaxed); + // If all the bits in `zero_to_acquire` are actually zero then try to acquire the lock + // directly. + if (oldstate & zero_to_acquire) == 0 { + if self + .state + .compare_exchange_weak( + oldstate, + (oldstate + K::add_to_acquire()) & !(clear | K::clear_on_acquire()), + Ordering::Acquire, + Ordering::Relaxed, + ) + .is_ok() + { + return; + } + } else if (oldstate & SPINLOCK) == 0 { + // The mutex is locked and the spin lock is available. Try to add this thread + // to the waiter queue. + let w = waiter.get_or_insert_with(|| K::new_waiter(self)); + w.reset(WaitingFor::Mutex); + + if self + .state + .compare_exchange_weak( + oldstate, + (oldstate | SPINLOCK | HAS_WAITERS | K::set_when_waiting()) & !clear, + Ordering::Acquire, + Ordering::Relaxed, + ) + .is_ok() + { + let mut set_on_release = 0; + + // Safe because we have acquired the spin lock and it provides exclusive + // access to the waiter queue. + if wait_count < LONG_WAIT_THRESHOLD { + // Add the waiter to the back of the queue. + unsafe { (*self.waiters.get()).push_back(w.clone()) }; + } else { + // This waiter has gone through the queue too many times. Put it in the + // front of the queue and block all other threads from acquiring the lock + // until this one has acquired it at least once. + unsafe { (*self.waiters.get()).push_front(w.clone()) }; + + // Set the LONG_WAIT bit to prevent all other threads from acquiring the + // lock. + set_on_release |= LONG_WAIT; + + // Make sure we clear the LONG_WAIT bit when we do finally get the lock. + clear |= LONG_WAIT; + + // Since we set the LONG_WAIT bit we shouldn't allow that bit to prevent us + // from acquiring the lock. + zero_to_acquire &= !LONG_WAIT; + } + + // Release the spin lock. + let mut state = oldstate; + loop { + match self.state.compare_exchange_weak( + state, + (state | set_on_release) & !SPINLOCK, + Ordering::Release, + Ordering::Relaxed, + ) { + Ok(_) => break, + Err(w) => state = w, + } + } + + // Now wait until we are woken. + w.wait().await; + + // The `DESIGNATED_WAKER` bit gets set when this thread is woken up by the + // thread that originally held the lock. While this bit is set, no other waiters + // will be woken up so it's important to clear it the next time we try to + // acquire the main lock or the spin lock. + clear |= DESIGNATED_WAKER; + + // Now that the thread has waited once, we no longer care if there is a writer + // waiting. Only the limits of mutual exclusion can prevent us from acquiring + // the lock. + zero_to_acquire &= !WRITER_WAITING; + + // Reset the spin count since we just went through the wait queue. + spin_count = 0; + + // Increment the wait count since we went through the wait queue. + wait_count += 1; + + // Skip the `cpu_relax` below. + continue; + } + } + + // Both the lock and the spin lock are held by one or more other threads. First, we'll + // spin a few times in case we can acquire the lock or the spin lock. If that fails then + // we yield because we might be preventing the threads that do hold the 2 locks from + // getting cpu time. + if spin_count < SPIN_THRESHOLD { + cpu_relax(1 << spin_count); + spin_count += 1; + } else { + yield_now(); + } + } + } + + #[inline] + pub fn unlock(&self) { + // Fast path, if possible. We can directly clear the locked bit since we have exclusive + // access to the mutex. + let oldstate = self.state.fetch_sub(LOCKED, Ordering::Release); + + // Panic if we just tried to unlock a mutex that wasn't held by this thread. This shouldn't + // really be possible since `unlock` is not a public method. + debug_assert_eq!( + oldstate & READ_MASK, + 0, + "`unlock` called on mutex held in read-mode" + ); + debug_assert_ne!( + oldstate & LOCKED, + 0, + "`unlock` called on mutex not held in write-mode" + ); + + if (oldstate & HAS_WAITERS) != 0 && (oldstate & DESIGNATED_WAKER) == 0 { + // The oldstate has waiters but no designated waker has been chosen yet. + self.unlock_slow(); + } + } + + #[inline] + pub fn read_unlock(&self) { + // Fast path, if possible. We can directly subtract the READ_LOCK bit since we had + // previously added it. + let oldstate = self.state.fetch_sub(READ_LOCK, Ordering::Release); + + debug_assert_eq!( + oldstate & LOCKED, + 0, + "`read_unlock` called on mutex held in write-mode" + ); + debug_assert_ne!( + oldstate & READ_MASK, + 0, + "`read_unlock` called on mutex not held in read-mode" + ); + + if (oldstate & HAS_WAITERS) != 0 + && (oldstate & DESIGNATED_WAKER) == 0 + && (oldstate & READ_MASK) == READ_LOCK + { + // There are waiters, no designated waker has been chosen yet, and the last reader is + // unlocking so we have to take the slow path. + self.unlock_slow(); + } + } + + #[cold] + fn unlock_slow(&self) { + let mut spin_count = 0; + + loop { + let oldstate = self.state.load(Ordering::Relaxed); + if (oldstate & HAS_WAITERS) == 0 || (oldstate & DESIGNATED_WAKER) != 0 { + // No more waiters or a designated waker has been chosen. Nothing left for us to do. + return; + } else if (oldstate & SPINLOCK) == 0 { + // The spin lock is not held by another thread. Try to acquire it. Also set the + // `DESIGNATED_WAKER` bit since we are likely going to wake up one or more threads. + if self + .state + .compare_exchange_weak( + oldstate, + oldstate | SPINLOCK | DESIGNATED_WAKER, + Ordering::Acquire, + Ordering::Relaxed, + ) + .is_ok() + { + // Acquired the spinlock. Try to wake a waiter. We may also end up wanting to + // clear the HAS_WAITER and DESIGNATED_WAKER bits so start collecting the bits + // to be cleared. + let mut clear = SPINLOCK; + + // Safe because the spinlock guarantees exclusive access to the waiter list and + // the reference does not escape this function. + let waiters = unsafe { &mut *self.waiters.get() }; + let (wake_list, set_on_release) = get_wake_list(waiters); + + // If the waiter list is now empty, clear the HAS_WAITERS bit. + if waiters.is_empty() { + clear |= HAS_WAITERS; + } + + if wake_list.is_empty() { + // Since we are not going to wake any waiters clear the DESIGNATED_WAKER bit + // that we set when we acquired the spin lock. + clear |= DESIGNATED_WAKER; + } + + // Release the spin lock and clear any other bits as necessary. Also, set any + // bits returned by `get_wake_list`. For now, this is just the `WRITER_WAITING` + // bit, which needs to be set when we are waking up a bunch of readers and there + // are still writers in the wait queue. This will prevent any readers that + // aren't in `wake_list` from acquiring the read lock. + let mut state = oldstate; + loop { + match self.state.compare_exchange_weak( + state, + (state | set_on_release) & !clear, + Ordering::Release, + Ordering::Relaxed, + ) { + Ok(_) => break, + Err(w) => state = w, + } + } + + // Now wake the waiters, if any. + for w in wake_list { + w.wake(); + } + + // We're done. + return; + } + } + + // Spin and try again. It's ok to block here as we have already released the lock. + if spin_count < SPIN_THRESHOLD { + cpu_relax(1 << spin_count); + spin_count += 1; + } else { + yield_now(); + } + } + } + + // Transfer waiters from the `Condvar` wait list to the `Mutex` wait list. `all_readers` may + // be set to true if all waiters are waiting to acquire a shared lock but should not be true if + // there is even one waiter waiting on an exclusive lock. + // + // This is similar to what the `FUTEX_CMP_REQUEUE` flag does on linux. + pub fn transfer_waiters(&self, new_waiters: &mut WaiterList, all_readers: bool) { + if new_waiters.is_empty() { + return; + } + + let mut oldstate = self.state.load(Ordering::Relaxed); + let can_acquire_read_lock = (oldstate & Shared::zero_to_acquire()) == 0; + + // The lock needs to be held in some mode or else the waiters we transfer now may never get + // woken up. Additionally, if all the new waiters are readers and can acquire the lock now + // then we can just wake them up. + if (oldstate & ANY_LOCK) == 0 || (all_readers && can_acquire_read_lock) { + // Nothing to do here. The Condvar will wake up all the waiters left in `new_waiters`. + return; + } + + if (oldstate & SPINLOCK) == 0 + && self + .state + .compare_exchange_weak( + oldstate, + oldstate | SPINLOCK | HAS_WAITERS, + Ordering::Acquire, + Ordering::Relaxed, + ) + .is_ok() + { + let mut transferred_writer = false; + + // Safe because the spin lock guarantees exclusive access and the reference does not + // escape this function. + let waiters = unsafe { &mut *self.waiters.get() }; + + let mut current = new_waiters.front_mut(); + while let Some(w) = current.get() { + match w.kind() { + WaiterKind::Shared => { + if can_acquire_read_lock { + current.move_next(); + } else { + // We need to update the cancellation function since we're moving this + // waiter into our queue. Also update the waiting to indicate that it is + // now in the Mutex's waiter list. + let w = current.remove().unwrap(); + w.set_cancel(cancel_waiter, self as *const RawMutex as usize); + w.set_waiting_for(WaitingFor::Mutex); + waiters.push_back(w); + } + } + WaiterKind::Exclusive => { + transferred_writer = true; + // We need to update the cancellation function since we're moving this + // waiter into our queue. Also update the waiting to indicate that it is + // now in the Mutex's waiter list. + let w = current.remove().unwrap(); + w.set_cancel(cancel_waiter, self as *const RawMutex as usize); + w.set_waiting_for(WaitingFor::Mutex); + waiters.push_back(w); + } + } + } + + let set_on_release = if transferred_writer { + WRITER_WAITING + } else { + 0 + }; + + // If we didn't actually transfer any waiters, clear the HAS_WAITERS bit that we set + // earlier when we acquired the spin lock. + let clear = if waiters.is_empty() { + SPINLOCK | HAS_WAITERS + } else { + SPINLOCK + }; + + while self + .state + .compare_exchange_weak( + oldstate, + (oldstate | set_on_release) & !clear, + Ordering::Release, + Ordering::Relaxed, + ) + .is_err() + { + spin_loop_hint(); + oldstate = self.state.load(Ordering::Relaxed); + } + } + + // The Condvar will wake up any waiters still left in the queue. + } + + fn cancel_waiter(&self, waiter: &Waiter, wake_next: bool) -> bool { + let mut oldstate = self.state.load(Ordering::Relaxed); + while oldstate & SPINLOCK != 0 + || self + .state + .compare_exchange_weak( + oldstate, + oldstate | SPINLOCK, + Ordering::Acquire, + Ordering::Relaxed, + ) + .is_err() + { + spin_loop_hint(); + oldstate = self.state.load(Ordering::Relaxed); + } + + // Safe because the spin lock provides exclusive access and the reference does not escape + // this function. + let waiters = unsafe { &mut *self.waiters.get() }; + + let mut clear = SPINLOCK; + + // If we are about to remove the first waiter in the wait list, then clear the LONG_WAIT + // bit. Also clear the bit if we are going to be waking some other waiters. In this case the + // waiter that set the bit may have already been removed from the waiter list (and could be + // the one that is currently being dropped). If it is still in the waiter list then clearing + // this bit may starve it for one more iteration through the lock_slow() loop, whereas not + // clearing this bit could cause a deadlock if the waiter that set it is the one that is + // being dropped. + if wake_next + || waiters + .front() + .get() + .map(|front| front as *const Waiter == waiter as *const Waiter) + .unwrap_or(false) + { + clear |= LONG_WAIT; + } + + // Don't drop the old waiter while holding the spin lock. + let old_waiter = if waiter.is_linked() && waiter.is_waiting_for() == WaitingFor::Mutex { + // We know that the waitir is still linked and is waiting for the mutex, which + // guarantees that it is still linked into `self.waiters`. + let mut cursor = unsafe { waiters.cursor_mut_from_ptr(waiter as *const Waiter) }; + cursor.remove() + } else { + None + }; + + let (wake_list, set_on_release) = + if wake_next || waiter.is_waiting_for() == WaitingFor::None { + // Either the waiter was already woken or it's been removed from the mutex's waiter + // list and is going to be woken. Either way, we need to wake up another thread. + get_wake_list(waiters) + } else { + (WaiterList::new(WaiterAdapter::new()), 0) + }; + + if waiters.is_empty() { + clear |= HAS_WAITERS; + } + + if wake_list.is_empty() { + // We're not waking any other threads so clear the DESIGNATED_WAKER bit. In the worst + // case this leads to an additional thread being woken up but we risk a deadlock if we + // don't clear it. + clear |= DESIGNATED_WAKER; + } + + if let WaiterKind::Exclusive = waiter.kind() { + // The waiter being dropped is a writer so clear the writer waiting bit for now. If we + // found more writers in the list while fetching waiters to wake up then this bit will + // be set again via `set_on_release`. + clear |= WRITER_WAITING; + } + + while self + .state + .compare_exchange_weak( + oldstate, + (oldstate & !clear) | set_on_release, + Ordering::Release, + Ordering::Relaxed, + ) + .is_err() + { + spin_loop_hint(); + oldstate = self.state.load(Ordering::Relaxed); + } + + for w in wake_list { + w.wake(); + } + + mem::drop(old_waiter); + + // Canceling a waker is always successful. + true + } +} + +unsafe impl Send for RawMutex {} +unsafe impl Sync for RawMutex {} + +fn cancel_waiter(raw: usize, waiter: &Waiter, wake_next: bool) -> bool { + let raw_mutex = raw as *const RawMutex; + + // Safe because the thread that owns the waiter that is being canceled must + // also own a reference to the mutex, which ensures that this pointer is + // valid. + unsafe { (*raw_mutex).cancel_waiter(waiter, wake_next) } +} + +/// A high-level primitive that provides safe, mutable access to a shared resource. +/// +/// Unlike more traditional mutexes, `Mutex` can safely provide both shared, immutable access (via +/// `read_lock()`) as well as exclusive, mutable access (via `lock()`) to an underlying resource +/// with no loss of performance. +/// +/// # Poisoning +/// +/// `Mutex` does not support lock poisoning so if a thread panics while holding the lock, the +/// poisoned data will be accessible by other threads in your program. If you need to guarantee that +/// other threads cannot access poisoned data then you may wish to wrap this `Mutex` inside another +/// type that provides the poisoning feature. See the implementation of `std::sync::Mutex` for an +/// example of this. +/// +/// +/// # Fairness +/// +/// This `Mutex` implementation does not guarantee that threads will acquire the lock in the same +/// order that they call `lock()` or `read_lock()`. However it will attempt to prevent long-term +/// starvation: if a thread repeatedly fails to acquire the lock beyond a threshold then all other +/// threads will fail to acquire the lock until the starved thread has acquired it. +/// +/// Similarly, this `Mutex` will attempt to balance reader and writer threads: once there is a +/// writer thread waiting to acquire the lock no new reader threads will be allowed to acquire it. +/// However, any reader threads that were already waiting will still be allowed to acquire it. +/// +/// # Examples +/// +/// ```edition2018 +/// use std::sync::Arc; +/// use std::thread; +/// use std::sync::mpsc::channel; +/// +/// use libchromeos::sync::{block_on, Mutex}; +/// +/// const N: usize = 10; +/// +/// // Spawn a few threads to increment a shared variable (non-atomically), and +/// // let the main thread know once all increments are done. +/// // +/// // Here we're using an Arc to share memory among threads, and the data inside +/// // the Arc is protected with a mutex. +/// let data = Arc::new(Mutex::new(0)); +/// +/// let (tx, rx) = channel(); +/// for _ in 0..N { +/// let (data, tx) = (Arc::clone(&data), tx.clone()); +/// thread::spawn(move || { +/// // The shared state can only be accessed once the lock is held. +/// // Our non-atomic increment is safe because we're the only thread +/// // which can access the shared state when the lock is held. +/// let mut data = block_on(data.lock()); +/// *data += 1; +/// if *data == N { +/// tx.send(()).unwrap(); +/// } +/// // the lock is unlocked here when `data` goes out of scope. +/// }); +/// } +/// +/// rx.recv().unwrap(); +/// ``` +pub struct Mutex<T: ?Sized> { + raw: RawMutex, + value: UnsafeCell<T>, +} + +impl<T> Mutex<T> { + /// Create a new, unlocked `Mutex` ready for use. + pub fn new(v: T) -> Mutex<T> { + Mutex { + raw: RawMutex::new(), + value: UnsafeCell::new(v), + } + } + + /// Consume the `Mutex` and return the contained value. This method does not perform any locking + /// as the compiler will guarantee that there are no other references to `self` and the caller + /// owns the `Mutex`. + pub fn into_inner(self) -> T { + // Don't need to acquire the lock because the compiler guarantees that there are + // no references to `self`. + self.value.into_inner() + } +} + +impl<T: ?Sized> Mutex<T> { + /// Acquires exclusive, mutable access to the resource protected by the `Mutex`, blocking the + /// current thread until it is able to do so. Upon returning, the current thread will be the + /// only thread with access to the resource. The `Mutex` will be released when the returned + /// `MutexGuard` is dropped. + /// + /// Calling `lock()` while holding a `MutexGuard` or a `MutexReadGuard` will cause a deadlock. + /// + /// Callers that are not in an async context may wish to use the `block_on` method to block the + /// thread until the `Mutex` is acquired. + #[inline] + pub async fn lock(&self) -> MutexGuard<'_, T> { + self.raw.lock().await; + + // Safe because we have exclusive access to `self.value`. + MutexGuard { + mu: self, + value: unsafe { &mut *self.value.get() }, + } + } + + /// Acquires shared, immutable access to the resource protected by the `Mutex`, blocking the + /// current thread until it is able to do so. Upon returning there may be other threads that + /// also have immutable access to the resource but there will not be any threads that have + /// mutable access to the resource. When the returned `MutexReadGuard` is dropped the thread + /// releases its access to the resource. + /// + /// Calling `read_lock()` while holding a `MutexReadGuard` may deadlock. Calling `read_lock()` + /// while holding a `MutexGuard` will deadlock. + /// + /// Callers that are not in an async context may wish to use the `block_on` method to block the + /// thread until the `Mutex` is acquired. + #[inline] + pub async fn read_lock(&self) -> MutexReadGuard<'_, T> { + self.raw.read_lock().await; + + // Safe because we have shared read-only access to `self.value`. + MutexReadGuard { + mu: self, + value: unsafe { &*self.value.get() }, + } + } + + // Called from `Condvar::wait` when the thread wants to reacquire the lock. Since we may + // directly transfer waiters from the `Condvar` wait list to the `Mutex` wait list (see + // `transfer_all` below), we cannot call `Mutex::lock` as we also need to clear the + // `DESIGNATED_WAKER` bit when acquiring the lock. Not doing so will prevent us from waking up + // any other threads in the wait list. + #[inline] + pub(crate) async fn lock_from_cv(&self) -> MutexGuard<'_, T> { + self.raw.lock_slow::<Exclusive>(DESIGNATED_WAKER, 0).await; + + // Safe because we have exclusive access to `self.value`. + MutexGuard { + mu: self, + value: unsafe { &mut *self.value.get() }, + } + } + + // Like `lock_from_cv` but for acquiring a shared lock. + #[inline] + pub(crate) async fn read_lock_from_cv(&self) -> MutexReadGuard<'_, T> { + // Threads that have waited in the Condvar's waiter list don't have to care if there is a + // writer waiting. This also prevents a deadlock in the following case: + // + // * Thread A holds a write lock. + // * Thread B is in the mutex's waiter list, also waiting on a write lock. + // * Threads C, D, and E are in the condvar's waiter list. C and D want a read lock; E + // wants a write lock. + // * A calls `cv.notify_all()` while still holding the lock, which transfers C, D, and E + // onto the mutex's wait list. + // * A releases the lock, which wakes up B. + // * B acquires the lock, does some work, and releases the lock. This wakes up C and D. + // However, when iterating through the waiter list we find E, which is waiting for a + // write lock so we set the WRITER_WAITING bit. + // * C and D go through this function to acquire the lock. If we didn't clear the + // WRITER_WAITING bit from the zero_to_acquire set then it would prevent C and D from + // acquiring the lock and they would add themselves back into the waiter list. + // * Now C, D, and E will sit in the waiter list indefinitely unless some other thread + // comes along and acquires the lock. On release, it would wake up E and everything would + // go back to normal. + self.raw + .lock_slow::<Shared>(DESIGNATED_WAKER, WRITER_WAITING) + .await; + + // Safe because we have exclusive access to `self.value`. + MutexReadGuard { + mu: self, + value: unsafe { &*self.value.get() }, + } + } + + #[inline] + fn unlock(&self) { + self.raw.unlock(); + } + + #[inline] + fn read_unlock(&self) { + self.raw.read_unlock(); + } + + pub fn get_mut(&mut self) -> &mut T { + // Safe because the compiler statically guarantees that are no other references to `self`. + // This is also why we don't need to acquire the lock first. + unsafe { &mut *self.value.get() } + } +} + +unsafe impl<T: ?Sized + Send> Send for Mutex<T> {} +unsafe impl<T: ?Sized + Send> Sync for Mutex<T> {} + +impl<T: ?Sized + Default> Default for Mutex<T> { + fn default() -> Self { + Self::new(Default::default()) + } +} + +impl<T> From<T> for Mutex<T> { + fn from(source: T) -> Self { + Self::new(source) + } +} + +/// An RAII implementation of a "scoped exclusive lock" for a `Mutex`. When this structure is +/// dropped, the lock will be released. The resource protected by the `Mutex` can be accessed via +/// the `Deref` and `DerefMut` implementations of this structure. +pub struct MutexGuard<'a, T: ?Sized + 'a> { + mu: &'a Mutex<T>, + value: &'a mut T, +} + +impl<'a, T: ?Sized> MutexGuard<'a, T> { + pub(crate) fn into_inner(self) -> &'a Mutex<T> { + self.mu + } + + pub(crate) fn as_raw_mutex(&self) -> &RawMutex { + &self.mu.raw + } +} + +impl<'a, T: ?Sized> Deref for MutexGuard<'a, T> { + type Target = T; + + fn deref(&self) -> &Self::Target { + self.value + } +} + +impl<'a, T: ?Sized> DerefMut for MutexGuard<'a, T> { + fn deref_mut(&mut self) -> &mut Self::Target { + self.value + } +} + +impl<'a, T: ?Sized> Drop for MutexGuard<'a, T> { + fn drop(&mut self) { + self.mu.unlock() + } +} + +/// An RAII implementation of a "scoped shared lock" for a `Mutex`. When this structure is dropped, +/// the lock will be released. The resource protected by the `Mutex` can be accessed via the `Deref` +/// implementation of this structure. +pub struct MutexReadGuard<'a, T: ?Sized + 'a> { + mu: &'a Mutex<T>, + value: &'a T, +} + +impl<'a, T: ?Sized> MutexReadGuard<'a, T> { + pub(crate) fn into_inner(self) -> &'a Mutex<T> { + self.mu + } + + pub(crate) fn as_raw_mutex(&self) -> &RawMutex { + &self.mu.raw + } +} + +impl<'a, T: ?Sized> Deref for MutexReadGuard<'a, T> { + type Target = T; + + fn deref(&self) -> &Self::Target { + self.value + } +} + +impl<'a, T: ?Sized> Drop for MutexReadGuard<'a, T> { + fn drop(&mut self) { + self.mu.read_unlock() + } +} + +#[cfg(test)] +mod test { + use super::*; + + use std::future::Future; + use std::mem; + use std::pin::Pin; + use std::rc::Rc; + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::sync::mpsc::{channel, Sender}; + use std::sync::Arc; + use std::task::{Context, Poll, Waker}; + use std::thread; + use std::time::Duration; + + use futures::channel::oneshot; + use futures::task::{waker_ref, ArcWake}; + use futures::{pending, select, FutureExt}; + use futures_executor::{LocalPool, ThreadPool}; + use futures_util::task::LocalSpawnExt; + + use crate::sync::{block_on, Condvar, SpinLock}; + + #[derive(Debug, Eq, PartialEq)] + struct NonCopy(u32); + + // Dummy waker used when we want to manually drive futures. + struct TestWaker; + impl ArcWake for TestWaker { + fn wake_by_ref(_arc_self: &Arc<Self>) {} + } + + #[test] + fn it_works() { + let mu = Mutex::new(NonCopy(13)); + + assert_eq!(*block_on(mu.lock()), NonCopy(13)); + } + + #[test] + fn smoke() { + let mu = Mutex::new(NonCopy(7)); + + mem::drop(block_on(mu.lock())); + mem::drop(block_on(mu.lock())); + } + + #[test] + fn rw_smoke() { + let mu = Mutex::new(NonCopy(7)); + + mem::drop(block_on(mu.lock())); + mem::drop(block_on(mu.read_lock())); + mem::drop((block_on(mu.read_lock()), block_on(mu.read_lock()))); + mem::drop(block_on(mu.lock())); + } + + #[test] + fn async_smoke() { + async fn lock(mu: Rc<Mutex<NonCopy>>) { + mu.lock().await; + } + + async fn read_lock(mu: Rc<Mutex<NonCopy>>) { + mu.read_lock().await; + } + + async fn double_read_lock(mu: Rc<Mutex<NonCopy>>) { + let first = mu.read_lock().await; + mu.read_lock().await; + + // Make sure first lives past the second read lock. + first.as_raw_mutex(); + } + + let mu = Rc::new(Mutex::new(NonCopy(7))); + + let mut ex = LocalPool::new(); + let spawner = ex.spawner(); + + spawner + .spawn_local(lock(Rc::clone(&mu))) + .expect("Failed to spawn future"); + spawner + .spawn_local(read_lock(Rc::clone(&mu))) + .expect("Failed to spawn future"); + spawner + .spawn_local(double_read_lock(Rc::clone(&mu))) + .expect("Failed to spawn future"); + spawner + .spawn_local(lock(Rc::clone(&mu))) + .expect("Failed to spawn future"); + + ex.run(); + } + + #[test] + fn send() { + let mu = Mutex::new(NonCopy(19)); + + thread::spawn(move || { + let value = block_on(mu.lock()); + assert_eq!(*value, NonCopy(19)); + }) + .join() + .unwrap(); + } + + #[test] + fn arc_nested() { + // Tests nested mutexes and access to underlying data. + let mu = Mutex::new(1); + let arc = Arc::new(Mutex::new(mu)); + thread::spawn(move || { + let nested = block_on(arc.lock()); + let lock2 = block_on(nested.lock()); + assert_eq!(*lock2, 1); + }) + .join() + .unwrap(); + } + + #[test] + fn arc_access_in_unwind() { + let arc = Arc::new(Mutex::new(1)); + let arc2 = arc.clone(); + thread::spawn(move || { + struct Unwinder { + i: Arc<Mutex<i32>>, + } + impl Drop for Unwinder { + fn drop(&mut self) { + *block_on(self.i.lock()) += 1; + } + } + let _u = Unwinder { i: arc2 }; + panic!(); + }) + .join() + .expect_err("thread did not panic"); + let lock = block_on(arc.lock()); + assert_eq!(*lock, 2); + } + + #[test] + fn unsized_value() { + let mutex: &Mutex<[i32]> = &Mutex::new([1, 2, 3]); + { + let b = &mut *block_on(mutex.lock()); + b[0] = 4; + b[2] = 5; + } + let expected: &[i32] = &[4, 2, 5]; + assert_eq!(&*block_on(mutex.lock()), expected); + } + #[test] + fn high_contention() { + const THREADS: usize = 17; + const ITERATIONS: usize = 103; + + let mut threads = Vec::with_capacity(THREADS); + + let mu = Arc::new(Mutex::new(0usize)); + for _ in 0..THREADS { + let mu2 = mu.clone(); + threads.push(thread::spawn(move || { + for _ in 0..ITERATIONS { + *block_on(mu2.lock()) += 1; + } + })); + } + + for t in threads.into_iter() { + t.join().unwrap(); + } + + assert_eq!(*block_on(mu.read_lock()), THREADS * ITERATIONS); + assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0); + } + + #[test] + fn high_contention_with_cancel() { + const TASKS: usize = 17; + const ITERATIONS: usize = 103; + + async fn increment(mu: Arc<Mutex<usize>>, alt_mu: Arc<Mutex<usize>>, tx: Sender<()>) { + for _ in 0..ITERATIONS { + select! { + mut count = mu.lock().fuse() => *count += 1, + mut count = alt_mu.lock().fuse() => *count += 1, + } + } + tx.send(()).expect("Failed to send completion signal"); + } + + let ex = ThreadPool::new().expect("Failed to create ThreadPool"); + + let mu = Arc::new(Mutex::new(0usize)); + let alt_mu = Arc::new(Mutex::new(0usize)); + + let (tx, rx) = channel(); + for _ in 0..TASKS { + ex.spawn_ok(increment(Arc::clone(&mu), Arc::clone(&alt_mu), tx.clone())); + } + + for _ in 0..TASKS { + if let Err(e) = rx.recv_timeout(Duration::from_secs(10)) { + panic!("Error while waiting for threads to complete: {}", e); + } + } + + assert_eq!( + *block_on(mu.read_lock()) + *block_on(alt_mu.read_lock()), + TASKS * ITERATIONS + ); + assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0); + assert_eq!(alt_mu.raw.state.load(Ordering::Relaxed), 0); + } + + #[test] + fn single_thread_async() { + const TASKS: usize = 17; + const ITERATIONS: usize = 103; + + // Async closures are unstable. + async fn increment(mu: Rc<Mutex<usize>>) { + for _ in 0..ITERATIONS { + *mu.lock().await += 1; + } + } + + let mut ex = LocalPool::new(); + let spawner = ex.spawner(); + + let mu = Rc::new(Mutex::new(0usize)); + for _ in 0..TASKS { + spawner + .spawn_local(increment(Rc::clone(&mu))) + .expect("Failed to spawn task"); + } + + ex.run(); + + assert_eq!(*block_on(mu.read_lock()), TASKS * ITERATIONS); + assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0); + } + + #[test] + fn multi_thread_async() { + const TASKS: usize = 17; + const ITERATIONS: usize = 103; + + // Async closures are unstable. + async fn increment(mu: Arc<Mutex<usize>>, tx: Sender<()>) { + for _ in 0..ITERATIONS { + *mu.lock().await += 1; + } + tx.send(()).expect("Failed to send completion signal"); + } + + let ex = ThreadPool::new().expect("Failed to create ThreadPool"); + + let mu = Arc::new(Mutex::new(0usize)); + let (tx, rx) = channel(); + for _ in 0..TASKS { + ex.spawn_ok(increment(Arc::clone(&mu), tx.clone())); + } + + for _ in 0..TASKS { + rx.recv_timeout(Duration::from_secs(5)) + .expect("Failed to receive completion signal"); + } + assert_eq!(*block_on(mu.read_lock()), TASKS * ITERATIONS); + assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0); + } + + #[test] + fn get_mut() { + let mut mu = Mutex::new(NonCopy(13)); + *mu.get_mut() = NonCopy(17); + + assert_eq!(mu.into_inner(), NonCopy(17)); + } + + #[test] + fn into_inner() { + let mu = Mutex::new(NonCopy(29)); + assert_eq!(mu.into_inner(), NonCopy(29)); + } + + #[test] + fn into_inner_drop() { + struct NeedsDrop(Arc<AtomicUsize>); + impl Drop for NeedsDrop { + fn drop(&mut self) { + self.0.fetch_add(1, Ordering::AcqRel); + } + } + + let value = Arc::new(AtomicUsize::new(0)); + let needs_drop = Mutex::new(NeedsDrop(value.clone())); + assert_eq!(value.load(Ordering::Acquire), 0); + + { + let inner = needs_drop.into_inner(); + assert_eq!(inner.0.load(Ordering::Acquire), 0); + } + + assert_eq!(value.load(Ordering::Acquire), 1); + } + + #[test] + fn rw_arc() { + const THREADS: isize = 7; + const ITERATIONS: isize = 13; + + let mu = Arc::new(Mutex::new(0isize)); + let mu2 = mu.clone(); + + let (tx, rx) = channel(); + thread::spawn(move || { + let mut guard = block_on(mu2.lock()); + for _ in 0..ITERATIONS { + let tmp = *guard; + *guard = -1; + thread::yield_now(); + *guard = tmp + 1; + } + tx.send(()).unwrap(); + }); + + let mut readers = Vec::with_capacity(10); + for _ in 0..THREADS { + let mu3 = mu.clone(); + let handle = thread::spawn(move || { + let guard = block_on(mu3.read_lock()); + assert!(*guard >= 0); + }); + + readers.push(handle); + } + + // Wait for the readers to finish their checks. + for r in readers { + r.join().expect("One or more readers saw a negative value"); + } + + // Wait for the writer to finish. + rx.recv_timeout(Duration::from_secs(5)).unwrap(); + assert_eq!(*block_on(mu.read_lock()), ITERATIONS); + assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0); + } + + #[test] + fn rw_single_thread_async() { + // A Future that returns `Poll::pending` the first time it is polled and `Poll::Ready` every + // time after that. + struct TestFuture { + polled: bool, + waker: Arc<SpinLock<Option<Waker>>>, + } + + impl Future for TestFuture { + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> { + if self.polled { + Poll::Ready(()) + } else { + self.polled = true; + *self.waker.lock() = Some(cx.waker().clone()); + Poll::Pending + } + } + } + + fn wake_future(waker: Arc<SpinLock<Option<Waker>>>) { + loop { + if let Some(waker) = waker.lock().take() { + waker.wake(); + return; + } else { + thread::sleep(Duration::from_millis(10)); + } + } + } + + async fn writer(mu: Rc<Mutex<isize>>) { + let mut guard = mu.lock().await; + for _ in 0..ITERATIONS { + let tmp = *guard; + *guard = -1; + let waker = Arc::new(SpinLock::new(None)); + let waker2 = Arc::clone(&waker); + thread::spawn(move || wake_future(waker2)); + let fut = TestFuture { + polled: false, + waker, + }; + fut.await; + *guard = tmp + 1; + } + } + + async fn reader(mu: Rc<Mutex<isize>>) { + let guard = mu.read_lock().await; + assert!(*guard >= 0); + } + + const TASKS: isize = 7; + const ITERATIONS: isize = 13; + + let mu = Rc::new(Mutex::new(0isize)); + let mut ex = LocalPool::new(); + let spawner = ex.spawner(); + + spawner + .spawn_local(writer(Rc::clone(&mu))) + .expect("Failed to spawn writer"); + + for _ in 0..TASKS { + spawner + .spawn_local(reader(Rc::clone(&mu))) + .expect("Failed to spawn reader"); + } + + ex.run(); + + assert_eq!(*block_on(mu.read_lock()), ITERATIONS); + assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0); + } + + #[test] + fn rw_multi_thread_async() { + async fn writer(mu: Arc<Mutex<isize>>, tx: Sender<()>) { + let mut guard = mu.lock().await; + for _ in 0..ITERATIONS { + let tmp = *guard; + *guard = -1; + thread::yield_now(); + *guard = tmp + 1; + } + + mem::drop(guard); + tx.send(()).unwrap(); + } + + async fn reader(mu: Arc<Mutex<isize>>, tx: Sender<()>) { + let guard = mu.read_lock().await; + assert!(*guard >= 0); + + mem::drop(guard); + tx.send(()).expect("Failed to send completion message"); + } + + const TASKS: isize = 7; + const ITERATIONS: isize = 13; + + let mu = Arc::new(Mutex::new(0isize)); + let ex = ThreadPool::new().expect("Failed to create ThreadPool"); + + let (txw, rxw) = channel(); + ex.spawn_ok(writer(Arc::clone(&mu), txw)); + + let (txr, rxr) = channel(); + for _ in 0..TASKS { + ex.spawn_ok(reader(Arc::clone(&mu), txr.clone())); + } + + // Wait for the readers to finish their checks. + for _ in 0..TASKS { + rxr.recv_timeout(Duration::from_secs(5)) + .expect("Failed to receive completion message from reader"); + } + + // Wait for the writer to finish. + rxw.recv_timeout(Duration::from_secs(5)) + .expect("Failed to receive completion message from writer"); + + assert_eq!(*block_on(mu.read_lock()), ITERATIONS); + assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0); + } + + #[test] + fn wake_all_readers() { + async fn read(mu: Arc<Mutex<()>>) { + let g = mu.read_lock().await; + pending!(); + mem::drop(g); + } + + async fn write(mu: Arc<Mutex<()>>) { + mu.lock().await; + } + + let mu = Arc::new(Mutex::new(())); + let mut futures: [Pin<Box<dyn Future<Output = ()>>>; 5] = [ + Box::pin(read(mu.clone())), + Box::pin(read(mu.clone())), + Box::pin(read(mu.clone())), + Box::pin(write(mu.clone())), + Box::pin(read(mu.clone())), + ]; + const NUM_READERS: usize = 4; + + let arc_waker = Arc::new(TestWaker); + let waker = waker_ref(&arc_waker); + let mut cx = Context::from_waker(&waker); + + // Acquire the lock so that the futures cannot get it. + let g = block_on(mu.lock()); + + for r in &mut futures { + if let Poll::Ready(()) = r.as_mut().poll(&mut cx) { + panic!("future unexpectedly ready"); + } + } + + assert_eq!( + mu.raw.state.load(Ordering::Relaxed) & HAS_WAITERS, + HAS_WAITERS + ); + + assert_eq!( + mu.raw.state.load(Ordering::Relaxed) & WRITER_WAITING, + WRITER_WAITING + ); + + // Drop the lock. This should allow all readers to make progress. Since they already waited + // once they should ignore the WRITER_WAITING bit that is currently set. + mem::drop(g); + for r in &mut futures { + if let Poll::Ready(()) = r.as_mut().poll(&mut cx) { + panic!("future unexpectedly ready"); + } + } + + // Check that all readers were able to acquire the lock. + assert_eq!( + mu.raw.state.load(Ordering::Relaxed) & READ_MASK, + READ_LOCK * NUM_READERS + ); + assert_eq!( + mu.raw.state.load(Ordering::Relaxed) & WRITER_WAITING, + WRITER_WAITING + ); + + let mut needs_poll = None; + + // All the readers can now finish but the writer needs to be polled again. + for (i, r) in futures.iter_mut().enumerate() { + match r.as_mut().poll(&mut cx) { + Poll::Ready(()) => {} + Poll::Pending => { + if needs_poll.is_some() { + panic!("More than one future unable to complete"); + } + needs_poll = Some(i); + } + } + } + + if let Poll::Pending = futures[needs_poll.expect("Writer unexpectedly able to complete")] + .as_mut() + .poll(&mut cx) + { + panic!("Writer unable to complete"); + } + + assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0); + } + + #[test] + fn long_wait() { + async fn tight_loop(mu: Arc<Mutex<bool>>) { + loop { + let ready = mu.lock().await; + if *ready { + break; + } + pending!(); + } + } + + async fn mark_ready(mu: Arc<Mutex<bool>>) { + *mu.lock().await = true; + } + + let mu = Arc::new(Mutex::new(false)); + let mut tl = Box::pin(tight_loop(mu.clone())); + let mut mark = Box::pin(mark_ready(mu.clone())); + + let arc_waker = Arc::new(TestWaker); + let waker = waker_ref(&arc_waker); + let mut cx = Context::from_waker(&waker); + + for _ in 0..=LONG_WAIT_THRESHOLD { + if let Poll::Ready(()) = tl.as_mut().poll(&mut cx) { + panic!("tight_loop unexpectedly ready"); + } + + if let Poll::Ready(()) = mark.as_mut().poll(&mut cx) { + panic!("mark_ready unexpectedly ready"); + } + } + + assert_eq!( + mu.raw.state.load(Ordering::Relaxed), + LOCKED | HAS_WAITERS | WRITER_WAITING | LONG_WAIT + ); + + // This time the tight loop will fail to acquire the lock. + if let Poll::Ready(()) = tl.as_mut().poll(&mut cx) { + panic!("tight_loop unexpectedly ready"); + } + + // Which will finally allow the mark_ready function to make progress. + if let Poll::Pending = mark.as_mut().poll(&mut cx) { + panic!("mark_ready not able to make progress"); + } + + // Now the tight loop will finish. + if let Poll::Pending = tl.as_mut().poll(&mut cx) { + panic!("tight_loop not able to finish"); + } + + assert!(*block_on(mu.lock())); + assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0); + } + + #[test] + fn cancel_long_wait_before_wake() { + async fn tight_loop(mu: Arc<Mutex<bool>>) { + loop { + let ready = mu.lock().await; + if *ready { + break; + } + pending!(); + } + } + + async fn mark_ready(mu: Arc<Mutex<bool>>) { + *mu.lock().await = true; + } + + let mu = Arc::new(Mutex::new(false)); + let mut tl = Box::pin(tight_loop(mu.clone())); + let mut mark = Box::pin(mark_ready(mu.clone())); + + let arc_waker = Arc::new(TestWaker); + let waker = waker_ref(&arc_waker); + let mut cx = Context::from_waker(&waker); + + for _ in 0..=LONG_WAIT_THRESHOLD { + if let Poll::Ready(()) = tl.as_mut().poll(&mut cx) { + panic!("tight_loop unexpectedly ready"); + } + + if let Poll::Ready(()) = mark.as_mut().poll(&mut cx) { + panic!("mark_ready unexpectedly ready"); + } + } + + assert_eq!( + mu.raw.state.load(Ordering::Relaxed), + LOCKED | HAS_WAITERS | WRITER_WAITING | LONG_WAIT + ); + + // Now drop the mark_ready future, which should clear the LONG_WAIT bit. + mem::drop(mark); + assert_eq!(mu.raw.state.load(Ordering::Relaxed), LOCKED); + + mem::drop(tl); + assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0); + } + + #[test] + fn cancel_long_wait_after_wake() { + async fn tight_loop(mu: Arc<Mutex<bool>>) { + loop { + let ready = mu.lock().await; + if *ready { + break; + } + pending!(); + } + } + + async fn mark_ready(mu: Arc<Mutex<bool>>) { + *mu.lock().await = true; + } + + let mu = Arc::new(Mutex::new(false)); + let mut tl = Box::pin(tight_loop(mu.clone())); + let mut mark = Box::pin(mark_ready(mu.clone())); + + let arc_waker = Arc::new(TestWaker); + let waker = waker_ref(&arc_waker); + let mut cx = Context::from_waker(&waker); + + for _ in 0..=LONG_WAIT_THRESHOLD { + if let Poll::Ready(()) = tl.as_mut().poll(&mut cx) { + panic!("tight_loop unexpectedly ready"); + } + + if let Poll::Ready(()) = mark.as_mut().poll(&mut cx) { + panic!("mark_ready unexpectedly ready"); + } + } + + assert_eq!( + mu.raw.state.load(Ordering::Relaxed), + LOCKED | HAS_WAITERS | WRITER_WAITING | LONG_WAIT + ); + + // This time the tight loop will fail to acquire the lock. + if let Poll::Ready(()) = tl.as_mut().poll(&mut cx) { + panic!("tight_loop unexpectedly ready"); + } + + // Now drop the mark_ready future, which should clear the LONG_WAIT bit. + mem::drop(mark); + assert_eq!(mu.raw.state.load(Ordering::Relaxed) & LONG_WAIT, 0); + + // Since the lock is not held, we should be able to spawn a future to set the ready flag. + block_on(mark_ready(mu.clone())); + + // Now the tight loop will finish. + if let Poll::Pending = tl.as_mut().poll(&mut cx) { + panic!("tight_loop not able to finish"); + } + + assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0); + } + + #[test] + fn designated_waker() { + async fn inc(mu: Arc<Mutex<usize>>) { + *mu.lock().await += 1; + } + + let mu = Arc::new(Mutex::new(0)); + + let mut futures = [ + Box::pin(inc(mu.clone())), + Box::pin(inc(mu.clone())), + Box::pin(inc(mu.clone())), + ]; + + let arc_waker = Arc::new(TestWaker); + let waker = waker_ref(&arc_waker); + let mut cx = Context::from_waker(&waker); + + let count = block_on(mu.lock()); + + // Poll 2 futures. Since neither will be able to acquire the lock, they should get added to + // the waiter list. + if let Poll::Ready(()) = futures[0].as_mut().poll(&mut cx) { + panic!("future unexpectedly ready"); + } + if let Poll::Ready(()) = futures[1].as_mut().poll(&mut cx) { + panic!("future unexpectedly ready"); + } + + assert_eq!( + mu.raw.state.load(Ordering::Relaxed), + LOCKED | HAS_WAITERS | WRITER_WAITING, + ); + + // Now drop the lock. This should set the DESIGNATED_WAKER bit and wake up the first future + // in the wait list. + mem::drop(count); + + assert_eq!( + mu.raw.state.load(Ordering::Relaxed), + DESIGNATED_WAKER | HAS_WAITERS | WRITER_WAITING, + ); + + // Now poll the third future. It should be able to acquire the lock immediately. + if let Poll::Pending = futures[2].as_mut().poll(&mut cx) { + panic!("future unable to complete"); + } + assert_eq!(*block_on(mu.lock()), 1); + + // There should still be a waiter in the wait list and the DESIGNATED_WAKER bit should still + // be set. + assert_eq!( + mu.raw.state.load(Ordering::Relaxed) & DESIGNATED_WAKER, + DESIGNATED_WAKER + ); + assert_eq!( + mu.raw.state.load(Ordering::Relaxed) & HAS_WAITERS, + HAS_WAITERS + ); + + // Now let the future that was woken up run. + if let Poll::Pending = futures[0].as_mut().poll(&mut cx) { + panic!("future unable to complete"); + } + assert_eq!(*block_on(mu.lock()), 2); + + if let Poll::Pending = futures[1].as_mut().poll(&mut cx) { + panic!("future unable to complete"); + } + assert_eq!(*block_on(mu.lock()), 3); + + assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0); + } + + #[test] + fn cancel_designated_waker() { + async fn inc(mu: Arc<Mutex<usize>>) { + *mu.lock().await += 1; + } + + let mu = Arc::new(Mutex::new(0)); + + let mut fut = Box::pin(inc(mu.clone())); + + let arc_waker = Arc::new(TestWaker); + let waker = waker_ref(&arc_waker); + let mut cx = Context::from_waker(&waker); + + let count = block_on(mu.lock()); + + if let Poll::Ready(()) = fut.as_mut().poll(&mut cx) { + panic!("Future unexpectedly ready when lock is held"); + } + + // Drop the lock. This will wake up the future. + mem::drop(count); + + // Now drop the future without polling. This should clear all the state in the mutex. + mem::drop(fut); + + assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0); + } + + #[test] + fn cancel_before_wake() { + async fn inc(mu: Arc<Mutex<usize>>) { + *mu.lock().await += 1; + } + + let mu = Arc::new(Mutex::new(0)); + + let mut fut1 = Box::pin(inc(mu.clone())); + + let mut fut2 = Box::pin(inc(mu.clone())); + + let arc_waker = Arc::new(TestWaker); + let waker = waker_ref(&arc_waker); + let mut cx = Context::from_waker(&waker); + + // First acquire the lock. + let count = block_on(mu.lock()); + + // Now poll the futures. Since the lock is acquired they will both get queued in the waiter + // list. + match fut1.as_mut().poll(&mut cx) { + Poll::Pending => {} + Poll::Ready(()) => panic!("Future is unexpectedly ready"), + } + + match fut2.as_mut().poll(&mut cx) { + Poll::Pending => {} + Poll::Ready(()) => panic!("Future is unexpectedly ready"), + } + + assert_eq!( + mu.raw.state.load(Ordering::Relaxed) & WRITER_WAITING, + WRITER_WAITING + ); + + // Drop fut1. This should remove it from the waiter list but shouldn't wake fut2. + mem::drop(fut1); + + // There should be no designated waker. + assert_eq!(mu.raw.state.load(Ordering::Relaxed) & DESIGNATED_WAKER, 0); + + // Since the waiter was a writer, we should clear the WRITER_WAITING bit. + assert_eq!(mu.raw.state.load(Ordering::Relaxed) & WRITER_WAITING, 0); + + match fut2.as_mut().poll(&mut cx) { + Poll::Pending => {} + Poll::Ready(()) => panic!("Future is unexpectedly ready"), + } + + // Now drop the lock. This should mark fut2 as ready to make progress. + mem::drop(count); + + match fut2.as_mut().poll(&mut cx) { + Poll::Pending => panic!("Future is not ready to make progress"), + Poll::Ready(()) => {} + } + + // Verify that we only incremented the count once. + assert_eq!(*block_on(mu.lock()), 1); + assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0); + } + + #[test] + fn cancel_after_wake() { + async fn inc(mu: Arc<Mutex<usize>>) { + *mu.lock().await += 1; + } + + let mu = Arc::new(Mutex::new(0)); + + let mut fut1 = Box::pin(inc(mu.clone())); + + let mut fut2 = Box::pin(inc(mu.clone())); + + let arc_waker = Arc::new(TestWaker); + let waker = waker_ref(&arc_waker); + let mut cx = Context::from_waker(&waker); + + // First acquire the lock. + let count = block_on(mu.lock()); + + // Now poll the futures. Since the lock is acquired they will both get queued in the waiter + // list. + match fut1.as_mut().poll(&mut cx) { + Poll::Pending => {} + Poll::Ready(()) => panic!("Future is unexpectedly ready"), + } + + match fut2.as_mut().poll(&mut cx) { + Poll::Pending => {} + Poll::Ready(()) => panic!("Future is unexpectedly ready"), + } + + assert_eq!( + mu.raw.state.load(Ordering::Relaxed) & WRITER_WAITING, + WRITER_WAITING + ); + + // Drop the lock. This should mark fut1 as ready to make progress. + mem::drop(count); + + // Now drop fut1. This should make fut2 ready to make progress. + mem::drop(fut1); + + // Since there was still another waiter in the list we shouldn't have cleared the + // DESIGNATED_WAKER bit. + assert_eq!( + mu.raw.state.load(Ordering::Relaxed) & DESIGNATED_WAKER, + DESIGNATED_WAKER + ); + + // Since the waiter was a writer, we should clear the WRITER_WAITING bit. + assert_eq!(mu.raw.state.load(Ordering::Relaxed) & WRITER_WAITING, 0); + + match fut2.as_mut().poll(&mut cx) { + Poll::Pending => panic!("Future is not ready to make progress"), + Poll::Ready(()) => {} + } + + // Verify that we only incremented the count once. + assert_eq!(*block_on(mu.lock()), 1); + assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0); + } + + #[test] + fn timeout() { + async fn timed_lock(timer: oneshot::Receiver<()>, mu: Arc<Mutex<()>>) { + select! { + res = timer.fuse() => { + match res { + Ok(()) => {}, + Err(e) => panic!("Timer unexpectedly canceled: {}", e), + } + } + _ = mu.lock().fuse() => panic!("Successfuly acquired lock"), + } + } + + let mu = Arc::new(Mutex::new(())); + let (tx, rx) = oneshot::channel(); + + let mut timeout = Box::pin(timed_lock(rx, mu.clone())); + + let arc_waker = Arc::new(TestWaker); + let waker = waker_ref(&arc_waker); + let mut cx = Context::from_waker(&waker); + + // Acquire the lock. + let g = block_on(mu.lock()); + + // Poll the future. + if let Poll::Ready(()) = timeout.as_mut().poll(&mut cx) { + panic!("timed_lock unexpectedly ready"); + } + + assert_eq!( + mu.raw.state.load(Ordering::Relaxed) & HAS_WAITERS, + HAS_WAITERS + ); + + // Signal the channel, which should cancel the lock. + tx.send(()).expect("Failed to send wakeup"); + + // Now the future should have completed without acquiring the lock. + if let Poll::Pending = timeout.as_mut().poll(&mut cx) { + panic!("timed_lock not ready after timeout"); + } + + // The mutex state should not show any waiters. + assert_eq!(mu.raw.state.load(Ordering::Relaxed) & HAS_WAITERS, 0); + + mem::drop(g); + + assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0); + } + + #[test] + fn writer_waiting() { + async fn read_zero(mu: Arc<Mutex<usize>>) { + let val = mu.read_lock().await; + pending!(); + + assert_eq!(*val, 0); + } + + async fn inc(mu: Arc<Mutex<usize>>) { + *mu.lock().await += 1; + } + + async fn read_one(mu: Arc<Mutex<usize>>) { + let val = mu.read_lock().await; + + assert_eq!(*val, 1); + } + + let mu = Arc::new(Mutex::new(0)); + + let mut r1 = Box::pin(read_zero(mu.clone())); + let mut r2 = Box::pin(read_zero(mu.clone())); + + let mut w = Box::pin(inc(mu.clone())); + let mut r3 = Box::pin(read_one(mu.clone())); + + let arc_waker = Arc::new(TestWaker); + let waker = waker_ref(&arc_waker); + let mut cx = Context::from_waker(&waker); + + if let Poll::Ready(()) = r1.as_mut().poll(&mut cx) { + panic!("read_zero unexpectedly ready"); + } + if let Poll::Ready(()) = r2.as_mut().poll(&mut cx) { + panic!("read_zero unexpectedly ready"); + } + assert_eq!( + mu.raw.state.load(Ordering::Relaxed) & READ_MASK, + 2 * READ_LOCK + ); + + if let Poll::Ready(()) = w.as_mut().poll(&mut cx) { + panic!("inc unexpectedly ready"); + } + assert_eq!( + mu.raw.state.load(Ordering::Relaxed) & WRITER_WAITING, + WRITER_WAITING + ); + + // The WRITER_WAITING bit should prevent the next reader from acquiring the lock. + if let Poll::Ready(()) = r3.as_mut().poll(&mut cx) { + panic!("read_one unexpectedly ready"); + } + assert_eq!( + mu.raw.state.load(Ordering::Relaxed) & READ_MASK, + 2 * READ_LOCK + ); + + if let Poll::Pending = r1.as_mut().poll(&mut cx) { + panic!("read_zero unable to complete"); + } + if let Poll::Pending = r2.as_mut().poll(&mut cx) { + panic!("read_zero unable to complete"); + } + if let Poll::Pending = w.as_mut().poll(&mut cx) { + panic!("inc unable to complete"); + } + if let Poll::Pending = r3.as_mut().poll(&mut cx) { + panic!("read_one unable to complete"); + } + + assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0); + } + + #[test] + fn transfer_notify_one() { + async fn read(mu: Arc<Mutex<usize>>, cv: Arc<Condvar>) { + let mut count = mu.read_lock().await; + while *count == 0 { + count = cv.wait_read(count).await; + } + } + + async fn write(mu: Arc<Mutex<usize>>, cv: Arc<Condvar>) { + let mut count = mu.lock().await; + while *count == 0 { + count = cv.wait(count).await; + } + + *count -= 1; + } + + let mu = Arc::new(Mutex::new(0)); + let cv = Arc::new(Condvar::new()); + + let arc_waker = Arc::new(TestWaker); + let waker = waker_ref(&arc_waker); + let mut cx = Context::from_waker(&waker); + + let mut readers = [ + Box::pin(read(mu.clone(), cv.clone())), + Box::pin(read(mu.clone(), cv.clone())), + Box::pin(read(mu.clone(), cv.clone())), + Box::pin(read(mu.clone(), cv.clone())), + ]; + let mut writer = Box::pin(write(mu.clone(), cv.clone())); + + for r in &mut readers { + if let Poll::Ready(()) = r.as_mut().poll(&mut cx) { + panic!("reader unexpectedly ready"); + } + } + if let Poll::Ready(()) = writer.as_mut().poll(&mut cx) { + panic!("writer unexpectedly ready"); + } + + let mut count = block_on(mu.lock()); + *count = 1; + + // This should transfer all readers + one writer to the waiter queue. + cv.notify_one(); + + assert_eq!( + mu.raw.state.load(Ordering::Relaxed) & HAS_WAITERS, + HAS_WAITERS + ); + assert_eq!( + mu.raw.state.load(Ordering::Relaxed) & WRITER_WAITING, + WRITER_WAITING + ); + + mem::drop(count); + + assert_eq!( + mu.raw.state.load(Ordering::Relaxed) & (HAS_WAITERS | WRITER_WAITING), + HAS_WAITERS | WRITER_WAITING + ); + + for r in &mut readers { + if let Poll::Pending = r.as_mut().poll(&mut cx) { + panic!("reader unable to complete"); + } + } + + if let Poll::Pending = writer.as_mut().poll(&mut cx) { + panic!("writer unable to complete"); + } + + assert_eq!(*block_on(mu.read_lock()), 0); + } + + #[test] + fn transfer_waiters_when_unlocked() { + async fn dec(mu: Arc<Mutex<usize>>, cv: Arc<Condvar>) { + let mut count = mu.lock().await; + + while *count == 0 { + count = cv.wait(count).await; + } + + *count -= 1; + } + + let mu = Arc::new(Mutex::new(0)); + let cv = Arc::new(Condvar::new()); + + let arc_waker = Arc::new(TestWaker); + let waker = waker_ref(&arc_waker); + let mut cx = Context::from_waker(&waker); + + let mut futures = [ + Box::pin(dec(mu.clone(), cv.clone())), + Box::pin(dec(mu.clone(), cv.clone())), + Box::pin(dec(mu.clone(), cv.clone())), + Box::pin(dec(mu.clone(), cv.clone())), + ]; + + for f in &mut futures { + if let Poll::Ready(()) = f.as_mut().poll(&mut cx) { + panic!("future unexpectedly ready"); + } + } + + *block_on(mu.lock()) = futures.len(); + cv.notify_all(); + + // Since the lock is not held, instead of transferring the waiters to the waiter list we + // should just wake them all up. + assert_eq!(mu.raw.state.load(Ordering::Relaxed) & HAS_WAITERS, 0); + + for f in &mut futures { + if let Poll::Pending = f.as_mut().poll(&mut cx) { + panic!("future unexpectedly ready"); + } + } + assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0); + } + + #[test] + fn transfer_reader_writer() { + async fn read(mu: Arc<Mutex<usize>>, cv: Arc<Condvar>) { + let mut count = mu.read_lock().await; + while *count == 0 { + count = cv.wait_read(count).await; + } + + // Yield once while holding the read lock, which should prevent the writer from waking + // up. + pending!(); + } + + async fn write(mu: Arc<Mutex<usize>>, cv: Arc<Condvar>) { + let mut count = mu.lock().await; + while *count == 0 { + count = cv.wait(count).await; + } + + *count -= 1; + } + + async fn lock(mu: Arc<Mutex<usize>>) { + mem::drop(mu.lock().await); + } + + let mu = Arc::new(Mutex::new(0)); + let cv = Arc::new(Condvar::new()); + + let arc_waker = Arc::new(TestWaker); + let waker = waker_ref(&arc_waker); + let mut cx = Context::from_waker(&waker); + + let mut futures: [Pin<Box<dyn Future<Output = ()>>>; 5] = [ + Box::pin(read(mu.clone(), cv.clone())), + Box::pin(read(mu.clone(), cv.clone())), + Box::pin(read(mu.clone(), cv.clone())), + Box::pin(write(mu.clone(), cv.clone())), + Box::pin(read(mu.clone(), cv.clone())), + ]; + const NUM_READERS: usize = 4; + + let mut l = Box::pin(lock(mu.clone())); + + for f in &mut futures { + if let Poll::Ready(()) = f.as_mut().poll(&mut cx) { + panic!("future unexpectedly ready"); + } + } + + assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0); + + let mut count = block_on(mu.lock()); + *count = 1; + + // Now poll the lock function. Since the lock is held by us, it will get queued on the + // waiter list. + if let Poll::Ready(()) = l.as_mut().poll(&mut cx) { + panic!("lock() unexpectedly ready"); + } + + assert_eq!( + mu.raw.state.load(Ordering::Relaxed) & (HAS_WAITERS | WRITER_WAITING), + HAS_WAITERS | WRITER_WAITING + ); + + // Wake up waiters while holding the lock. This should end with them transferred to the + // mutex's waiter list. + cv.notify_all(); + + // Drop the lock. This should wake up the lock function. + mem::drop(count); + + if let Poll::Pending = l.as_mut().poll(&mut cx) { + panic!("lock() unable to complete"); + } + + assert_eq!( + mu.raw.state.load(Ordering::Relaxed) & (HAS_WAITERS | WRITER_WAITING), + HAS_WAITERS | WRITER_WAITING + ); + + // Poll everything again. The readers should be able to make progress (but not complete) but + // the writer should be blocked. + for f in &mut futures { + if let Poll::Ready(()) = f.as_mut().poll(&mut cx) { + panic!("future unexpectedly ready"); + } + } + + assert_eq!( + mu.raw.state.load(Ordering::Relaxed) & READ_MASK, + READ_LOCK * NUM_READERS + ); + + // All the readers can now finish but the writer needs to be polled again. + let mut needs_poll = None; + for (i, r) in futures.iter_mut().enumerate() { + match r.as_mut().poll(&mut cx) { + Poll::Ready(()) => {} + Poll::Pending => { + if needs_poll.is_some() { + panic!("More than one future unable to complete"); + } + needs_poll = Some(i); + } + } + } + + if let Poll::Pending = futures[needs_poll.expect("Writer unexpectedly able to complete")] + .as_mut() + .poll(&mut cx) + { + panic!("Writer unable to complete"); + } + + assert_eq!(*block_on(mu.lock()), 0); + assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0); + } + + #[test] + fn transfer_readers_with_read_lock() { + async fn read(mu: Arc<Mutex<usize>>, cv: Arc<Condvar>) { + let mut count = mu.read_lock().await; + while *count == 0 { + count = cv.wait_read(count).await; + } + + // Yield once while holding the read lock. + pending!(); + } + + let mu = Arc::new(Mutex::new(0)); + let cv = Arc::new(Condvar::new()); + + let arc_waker = Arc::new(TestWaker); + let waker = waker_ref(&arc_waker); + let mut cx = Context::from_waker(&waker); + + let mut futures = [ + Box::pin(read(mu.clone(), cv.clone())), + Box::pin(read(mu.clone(), cv.clone())), + Box::pin(read(mu.clone(), cv.clone())), + Box::pin(read(mu.clone(), cv.clone())), + ]; + + for f in &mut futures { + if let Poll::Ready(()) = f.as_mut().poll(&mut cx) { + panic!("future unexpectedly ready"); + } + } + + // Increment the count and then grab a read lock. + *block_on(mu.lock()) = 1; + + let g = block_on(mu.read_lock()); + + // Notify the condvar while holding the read lock. This should wake up all the waiters + // rather than just transferring them. + cv.notify_all(); + assert_eq!(mu.raw.state.load(Ordering::Relaxed) & HAS_WAITERS, 0); + + mem::drop(g); + + for f in &mut futures { + if let Poll::Ready(()) = f.as_mut().poll(&mut cx) { + panic!("future unexpectedly ready"); + } + } + assert_eq!( + mu.raw.state.load(Ordering::Relaxed), + READ_LOCK * futures.len() + ); + + for f in &mut futures { + if let Poll::Pending = f.as_mut().poll(&mut cx) { + panic!("future unable to complete"); + } + } + + assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0); + } +} diff --git a/src/sync/spin.rs b/src/sync/spin.rs new file mode 100644 index 0000000..0687bb7 --- /dev/null +++ b/src/sync/spin.rs @@ -0,0 +1,271 @@ +// Copyright 2020 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +use std::cell::UnsafeCell; +use std::ops::{Deref, DerefMut}; +use std::sync::atomic::{spin_loop_hint, AtomicBool, Ordering}; + +const UNLOCKED: bool = false; +const LOCKED: bool = true; + +/// A primitive that provides safe, mutable access to a shared resource. +/// +/// Unlike `Mutex`, a `SpinLock` will not voluntarily yield its CPU time until the resource is +/// available and will instead keep spinning until the resource is acquired. For the vast majority +/// of cases, `Mutex` is a better choice than `SpinLock`. If a `SpinLock` must be used then users +/// should try to do as little work as possible while holding the `SpinLock` and avoid any sort of +/// blocking at all costs as it can severely penalize performance. +/// +/// # Poisoning +/// +/// This `SpinLock` does not implement lock poisoning so it is possible for threads to access +/// poisoned data if a thread panics while holding the lock. If lock poisoning is needed, it can be +/// implemented by wrapping the `SpinLock` in a new type that implements poisoning. See the +/// implementation of `std::sync::Mutex` for an example of how to do this. +pub struct SpinLock<T: ?Sized> { + lock: AtomicBool, + value: UnsafeCell<T>, +} + +impl<T> SpinLock<T> { + /// Creates a new, unlocked `SpinLock` that's ready for use. + pub fn new(value: T) -> SpinLock<T> { + SpinLock { + lock: AtomicBool::new(UNLOCKED), + value: UnsafeCell::new(value), + } + } + + /// Consumes the `SpinLock` and returns the value guarded by it. This method doesn't perform any + /// locking as the compiler guarantees that there are no references to `self`. + pub fn into_inner(self) -> T { + // No need to take the lock because the compiler can statically guarantee + // that there are no references to the SpinLock. + self.value.into_inner() + } +} + +impl<T: ?Sized> SpinLock<T> { + /// Acquires exclusive, mutable access to the resource protected by the `SpinLock`, blocking the + /// current thread until it is able to do so. Upon returning, the current thread will be the + /// only thread with access to the resource. The `SpinLock` will be released when the returned + /// `SpinLockGuard` is dropped. Attempting to call `lock` while already holding the `SpinLock` + /// will cause a deadlock. + pub fn lock(&self) -> SpinLockGuard<T> { + while self + .lock + .compare_exchange_weak(UNLOCKED, LOCKED, Ordering::Acquire, Ordering::Relaxed) + .is_err() + { + spin_loop_hint(); + } + + SpinLockGuard { + lock: self, + value: unsafe { &mut *self.value.get() }, + } + } + + fn unlock(&self) { + // Don't need to compare and swap because we exclusively hold the lock. + self.lock.store(UNLOCKED, Ordering::Release); + } + + /// Returns a mutable reference to the contained value. This method doesn't perform any locking + /// as the compiler will statically guarantee that there are no other references to `self`. + pub fn get_mut(&mut self) -> &mut T { + // Safe because the compiler can statically guarantee that there are no other references to + // `self`. This is also why we don't need to acquire the lock. + unsafe { &mut *self.value.get() } + } +} + +unsafe impl<T: ?Sized + Send> Send for SpinLock<T> {} +unsafe impl<T: ?Sized + Send> Sync for SpinLock<T> {} + +impl<T: ?Sized + Default> Default for SpinLock<T> { + fn default() -> Self { + Self::new(Default::default()) + } +} + +impl<T> From<T> for SpinLock<T> { + fn from(source: T) -> Self { + Self::new(source) + } +} + +/// An RAII implementation of a "scoped lock" for a `SpinLock`. When this structure is dropped, the +/// lock will be released. The resource protected by the `SpinLock` can be accessed via the `Deref` +/// and `DerefMut` implementations of this structure. +pub struct SpinLockGuard<'a, T: 'a + ?Sized> { + lock: &'a SpinLock<T>, + value: &'a mut T, +} + +impl<'a, T: ?Sized> Deref for SpinLockGuard<'a, T> { + type Target = T; + fn deref(&self) -> &T { + self.value + } +} + +impl<'a, T: ?Sized> DerefMut for SpinLockGuard<'a, T> { + fn deref_mut(&mut self) -> &mut T { + self.value + } +} + +impl<'a, T: ?Sized> Drop for SpinLockGuard<'a, T> { + fn drop(&mut self) { + self.lock.unlock(); + } +} + +#[cfg(test)] +mod test { + use super::*; + + use std::mem; + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::sync::Arc; + use std::thread; + + #[derive(PartialEq, Eq, Debug)] + struct NonCopy(u32); + + #[test] + fn it_works() { + let sl = SpinLock::new(NonCopy(13)); + + assert_eq!(*sl.lock(), NonCopy(13)); + } + + #[test] + fn smoke() { + let sl = SpinLock::new(NonCopy(7)); + + mem::drop(sl.lock()); + mem::drop(sl.lock()); + } + + #[test] + fn send() { + let sl = SpinLock::new(NonCopy(19)); + + thread::spawn(move || { + let value = sl.lock(); + assert_eq!(*value, NonCopy(19)); + }) + .join() + .unwrap(); + } + + #[test] + fn high_contention() { + const THREADS: usize = 23; + const ITERATIONS: usize = 101; + + let mut threads = Vec::with_capacity(THREADS); + + let sl = Arc::new(SpinLock::new(0usize)); + for _ in 0..THREADS { + let sl2 = sl.clone(); + threads.push(thread::spawn(move || { + for _ in 0..ITERATIONS { + *sl2.lock() += 1; + } + })); + } + + for t in threads.into_iter() { + t.join().unwrap(); + } + + assert_eq!(*sl.lock(), THREADS * ITERATIONS); + } + + #[test] + fn get_mut() { + let mut sl = SpinLock::new(NonCopy(13)); + *sl.get_mut() = NonCopy(17); + + assert_eq!(sl.into_inner(), NonCopy(17)); + } + + #[test] + fn into_inner() { + let sl = SpinLock::new(NonCopy(29)); + assert_eq!(sl.into_inner(), NonCopy(29)); + } + + #[test] + fn into_inner_drop() { + struct NeedsDrop(Arc<AtomicUsize>); + impl Drop for NeedsDrop { + fn drop(&mut self) { + self.0.fetch_add(1, Ordering::AcqRel); + } + } + + let value = Arc::new(AtomicUsize::new(0)); + let needs_drop = SpinLock::new(NeedsDrop(value.clone())); + assert_eq!(value.load(Ordering::Acquire), 0); + + { + let inner = needs_drop.into_inner(); + assert_eq!(inner.0.load(Ordering::Acquire), 0); + } + + assert_eq!(value.load(Ordering::Acquire), 1); + } + + #[test] + fn arc_nested() { + // Tests nested sltexes and access to underlying data. + let sl = SpinLock::new(1); + let arc = Arc::new(SpinLock::new(sl)); + thread::spawn(move || { + let nested = arc.lock(); + let lock2 = nested.lock(); + assert_eq!(*lock2, 1); + }) + .join() + .unwrap(); + } + + #[test] + fn arc_access_in_unwind() { + let arc = Arc::new(SpinLock::new(1)); + let arc2 = arc.clone(); + thread::spawn(move || { + struct Unwinder { + i: Arc<SpinLock<i32>>, + } + impl Drop for Unwinder { + fn drop(&mut self) { + *self.i.lock() += 1; + } + } + let _u = Unwinder { i: arc2 }; + panic!(); + }) + .join() + .expect_err("thread did not panic"); + let lock = arc.lock(); + assert_eq!(*lock, 2); + } + + #[test] + fn unsized_value() { + let sltex: &SpinLock<[i32]> = &SpinLock::new([1, 2, 3]); + { + let b = &mut *sltex.lock(); + b[0] = 4; + b[2] = 5; + } + let expected: &[i32] = &[4, 2, 5]; + assert_eq!(&*sltex.lock(), expected); + } +} diff --git a/src/sync/waiter.rs b/src/sync/waiter.rs new file mode 100644 index 0000000..48335c7 --- /dev/null +++ b/src/sync/waiter.rs @@ -0,0 +1,317 @@ +// Copyright 2020 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +use std::cell::UnsafeCell; +use std::future::Future; +use std::mem; +use std::pin::Pin; +use std::ptr::NonNull; +use std::sync::atomic::{AtomicBool, AtomicU8, Ordering}; +use std::sync::Arc; +use std::task::{Context, Poll, Waker}; + +use intrusive_collections::linked_list::{LinkedList, LinkedListOps}; +use intrusive_collections::{intrusive_adapter, DefaultLinkOps, LinkOps}; + +use crate::sync::SpinLock; + +// An atomic version of a LinkedListLink. See https://github.com/Amanieu/intrusive-rs/issues/47 for +// more details. +pub struct AtomicLink { + prev: UnsafeCell<Option<NonNull<AtomicLink>>>, + next: UnsafeCell<Option<NonNull<AtomicLink>>>, + linked: AtomicBool, +} + +impl AtomicLink { + fn new() -> AtomicLink { + AtomicLink { + linked: AtomicBool::new(false), + prev: UnsafeCell::new(None), + next: UnsafeCell::new(None), + } + } + + fn is_linked(&self) -> bool { + self.linked.load(Ordering::Relaxed) + } +} + +impl DefaultLinkOps for AtomicLink { + type Ops = AtomicLinkOps; + + const NEW: Self::Ops = AtomicLinkOps; +} + +// Safe because the only way to mutate `AtomicLink` is via the `LinkedListOps` trait whose methods +// are all unsafe and require that the caller has first called `acquire_link` (and had it return +// true) to use them safely. +unsafe impl Send for AtomicLink {} +unsafe impl Sync for AtomicLink {} + +#[derive(Copy, Clone, Default)] +pub struct AtomicLinkOps; + +unsafe impl LinkOps for AtomicLinkOps { + type LinkPtr = NonNull<AtomicLink>; + + unsafe fn acquire_link(&mut self, ptr: Self::LinkPtr) -> bool { + !ptr.as_ref().linked.swap(true, Ordering::Acquire) + } + + unsafe fn release_link(&mut self, ptr: Self::LinkPtr) { + ptr.as_ref().linked.store(false, Ordering::Release) + } +} + +unsafe impl LinkedListOps for AtomicLinkOps { + unsafe fn next(&self, ptr: Self::LinkPtr) -> Option<Self::LinkPtr> { + *ptr.as_ref().next.get() + } + + unsafe fn prev(&self, ptr: Self::LinkPtr) -> Option<Self::LinkPtr> { + *ptr.as_ref().prev.get() + } + + unsafe fn set_next(&mut self, ptr: Self::LinkPtr, next: Option<Self::LinkPtr>) { + *ptr.as_ref().next.get() = next; + } + + unsafe fn set_prev(&mut self, ptr: Self::LinkPtr, prev: Option<Self::LinkPtr>) { + *ptr.as_ref().prev.get() = prev; + } +} + +#[derive(Clone, Copy)] +pub enum Kind { + Shared, + Exclusive, +} + +enum State { + Init, + Waiting(Waker), + Woken, + Finished, + Processing, +} + +// Indicates the queue to which the waiter belongs. It is the responsibility of the Mutex and +// Condvar implementations to update this value when adding/removing a Waiter from their respective +// waiter lists. +#[repr(u8)] +#[derive(Debug, Eq, PartialEq)] +pub enum WaitingFor { + // The waiter is either not linked into a waiter list or it is linked into a temporary list. + None = 0, + // The waiter is linked into the Mutex's waiter list. + Mutex = 1, + // The waiter is linked into the Condvar's waiter list. + Condvar = 2, +} + +// Internal struct used to keep track of the cancellation function. +struct Cancel { + c: fn(usize, &Waiter, bool) -> bool, + data: usize, +} + +// Represents a thread currently blocked on a Condvar or on acquiring a Mutex. +pub struct Waiter { + link: AtomicLink, + state: SpinLock<State>, + cancel: SpinLock<Cancel>, + kind: Kind, + waiting_for: AtomicU8, +} + +impl Waiter { + // Create a new, initialized Waiter. + // + // `kind` should indicate whether this waiter represent a thread that is waiting for a shared + // lock or an exclusive lock. + // + // `cancel` is the function that is called when a `WaitFuture` (returned by the `wait()` + // function) is dropped before it can complete. `cancel_data` is used as the first parameter of + // the `cancel` function. The second parameter is the `Waiter` that was canceled and the third + // parameter indicates whether the `WaitFuture` was dropped after it was woken (but before it + // was polled to completion). The `cancel` function should return true if it was able to + // successfully process the cancellation. One reason why a `cancel` function may return false is + // if the `Waiter` was transferred to a different waiter list after the cancel function was + // called but before it was able to run. In this case, it is expected that the new waiter list + // updated the cancel function (by calling `set_cancel`) and the cancellation will be retried by + // fetching and calling the new cancellation function. + // + // `waiting_for` indicates the waiter list to which this `Waiter` will be added. See the + // documentation of the `WaitingFor` enum for the meaning of the different values. + pub fn new( + kind: Kind, + cancel: fn(usize, &Waiter, bool) -> bool, + cancel_data: usize, + waiting_for: WaitingFor, + ) -> Waiter { + Waiter { + link: AtomicLink::new(), + state: SpinLock::new(State::Init), + cancel: SpinLock::new(Cancel { + c: cancel, + data: cancel_data, + }), + kind, + waiting_for: AtomicU8::new(waiting_for as u8), + } + } + + // The kind of lock that this `Waiter` is waiting to acquire. + pub fn kind(&self) -> Kind { + self.kind + } + + // Returns true if this `Waiter` is currently linked into a waiter list. + pub fn is_linked(&self) -> bool { + self.link.is_linked() + } + + // Indicates the waiter list to which this `Waiter` belongs. + pub fn is_waiting_for(&self) -> WaitingFor { + match self.waiting_for.load(Ordering::Acquire) { + 0 => WaitingFor::None, + 1 => WaitingFor::Mutex, + 2 => WaitingFor::Condvar, + v => panic!("Unknown value for `WaitingFor`: {}", v), + } + } + + // Change the waiter list to which this `Waiter` belongs. This will panic if called when the + // `Waiter` is still linked into a waiter list. + pub fn set_waiting_for(&self, waiting_for: WaitingFor) { + self.waiting_for.store(waiting_for as u8, Ordering::Release); + } + + // Change the cancellation function that this `Waiter` should use. This will panic if called + // when the `Waiter` is still linked into a waiter list. + pub fn set_cancel(&self, c: fn(usize, &Waiter, bool) -> bool, data: usize) { + debug_assert!( + !self.is_linked(), + "Cannot change cancellation function while linked" + ); + let mut cancel = self.cancel.lock(); + cancel.c = c; + cancel.data = data; + } + + // Reset the Waiter back to its initial state. Panics if this `Waiter` is still linked into a + // waiter list. + pub fn reset(&self, waiting_for: WaitingFor) { + debug_assert!(!self.is_linked(), "Cannot reset `Waiter` while linked"); + self.set_waiting_for(waiting_for); + + let mut state = self.state.lock(); + if let State::Waiting(waker) = mem::replace(&mut *state, State::Init) { + mem::drop(state); + mem::drop(waker); + } + } + + // Wait until woken up by another thread. + pub fn wait(&self) -> WaitFuture<'_> { + WaitFuture { waiter: self } + } + + // Wake up the thread associated with this `Waiter`. Panics if `waiting_for()` does not return + // `WaitingFor::None` or if `is_linked()` returns true. + pub fn wake(&self) { + debug_assert!(!self.is_linked(), "Cannot wake `Waiter` while linked"); + debug_assert_eq!(self.is_waiting_for(), WaitingFor::None); + + let mut state = self.state.lock(); + + if let State::Waiting(waker) = mem::replace(&mut *state, State::Woken) { + mem::drop(state); + waker.wake(); + } + } +} + +pub struct WaitFuture<'w> { + waiter: &'w Waiter, +} + +impl<'w> Future for WaitFuture<'w> { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> { + let mut state = self.waiter.state.lock(); + + match mem::replace(&mut *state, State::Processing) { + State::Init => { + *state = State::Waiting(cx.waker().clone()); + + Poll::Pending + } + State::Waiting(old_waker) => { + *state = State::Waiting(cx.waker().clone()); + mem::drop(state); + mem::drop(old_waker); + + Poll::Pending + } + State::Woken => { + *state = State::Finished; + Poll::Ready(()) + } + State::Finished => { + panic!("Future polled after returning Poll::Ready"); + } + State::Processing => { + panic!("Unexpected waker state"); + } + } + } +} + +impl<'w> Drop for WaitFuture<'w> { + fn drop(&mut self) { + let state = self.waiter.state.lock(); + + match *state { + State::Finished => {} + State::Processing => panic!("Unexpected waker state"), + State::Woken => { + mem::drop(state); + + // We were woken but not polled. Wake up the next waiter. + let mut success = false; + while !success { + let cancel = self.waiter.cancel.lock(); + let c = cancel.c; + let data = cancel.data; + + mem::drop(cancel); + + success = c(data, self.waiter, true); + } + } + _ => { + mem::drop(state); + + // Not woken. No need to wake up any waiters. + let mut success = false; + while !success { + let cancel = self.waiter.cancel.lock(); + let c = cancel.c; + let data = cancel.data; + + mem::drop(cancel); + + success = c(data, self.waiter, false); + } + } + } + } +} + +intrusive_adapter!(pub WaiterAdapter = Arc<Waiter>: Waiter { link: AtomicLink }); + +pub type WaiterList = LinkedList<WaiterAdapter>; diff --git a/src/syslog.rs b/src/syslog.rs new file mode 100644 index 0000000..0a12ae2 --- /dev/null +++ b/src/syslog.rs @@ -0,0 +1,73 @@ +// Copyright 2018 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +use std::ffi::CStr; + +use log::{self, Level, LevelFilter, Metadata, Record, SetLoggerError}; + +static LOGGER: SyslogLogger = SyslogLogger; + +struct SyslogLogger; + +impl log::Log for SyslogLogger { + fn enabled(&self, metadata: &Metadata) -> bool { + if cfg!(debug_assertions) { + metadata.level() <= Level::Debug + } else { + metadata.level() <= Level::Info + } + } + + fn log(&self, record: &Record) { + if !self.enabled(&record.metadata()) { + return; + } + + let level = match record.level() { + Level::Error => libc::LOG_ERR, + Level::Warn => libc::LOG_WARNING, + Level::Info => libc::LOG_INFO, + Level::Debug => libc::LOG_DEBUG, + Level::Trace => libc::LOG_DEBUG, + }; + + let msg = format!("{}\0", record.args()); + let cmsg = if let Ok(m) = CStr::from_bytes_with_nul(msg.as_bytes()) { + m + } else { + // For now we just drop messages with interior nuls. + return; + }; + + // Safe because this doesn't modify any memory. There's not much use + // in checking the return value because this _is_ the logging function + // so there's no way for us to tell anyone about the error. + unsafe { + libc::syslog(level, cmsg.as_ptr()); + } + } + fn flush(&self) {} +} + +/// Initializes the logger to send log messages to syslog. +pub fn init(ident: &'static CStr) -> Result<(), SetLoggerError> { + // Safe because this only modifies libc's internal state and is safe to call + // multiple times. + unsafe { + libc::openlog( + ident.as_ptr(), + libc::LOG_NDELAY | libc::LOG_PID, + libc::LOG_USER, + ) + }; + log::set_logger(&LOGGER)?; + let level = if cfg!(debug_assertions) { + LevelFilter::Debug + } else { + LevelFilter::Info + }; + log::set_max_level(level); + + Ok(()) +} diff --git a/src/vsock.rs b/src/vsock.rs new file mode 100644 index 0000000..634591e --- /dev/null +++ b/src/vsock.rs @@ -0,0 +1,491 @@ +// Copyright 2018 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +/// Support for virtual sockets. +use std::fmt; +use std::io; +use std::mem::{self, size_of}; +use std::num::ParseIntError; +use std::os::raw::{c_uchar, c_uint, c_ushort}; +use std::os::unix::io::{AsRawFd, IntoRawFd, RawFd}; +use std::result; +use std::str::FromStr; + +use libc::{ + self, c_void, sa_family_t, size_t, sockaddr, socklen_t, F_GETFL, F_SETFL, O_NONBLOCK, + VMADDR_CID_ANY, VMADDR_CID_HOST, VMADDR_CID_HYPERVISOR, +}; + +// The domain for vsock sockets. +const AF_VSOCK: sa_family_t = 40; + +// Vsock loopback address. +const VMADDR_CID_LOCAL: c_uint = 1; + +/// Vsock equivalent of binding on port 0. Binds to a random port. +pub const VMADDR_PORT_ANY: c_uint = c_uint::max_value(); + +// The number of bytes of padding to be added to the sockaddr_vm struct. Taken directly +// from linux/vm_sockets.h. +const PADDING: usize = size_of::<sockaddr>() + - size_of::<sa_family_t>() + - size_of::<c_ushort>() + - (2 * size_of::<c_uint>()); + +#[repr(C)] +#[derive(Default)] +struct sockaddr_vm { + svm_family: sa_family_t, + svm_reserved1: c_ushort, + svm_port: c_uint, + svm_cid: c_uint, + svm_zero: [c_uchar; PADDING], +} + +#[derive(Debug)] +pub struct AddrParseError; + +impl fmt::Display for AddrParseError { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + write!(fmt, "failed to parse vsock address") + } +} + +/// The vsock equivalent of an IP address. +#[derive(Debug, Copy, Clone, Hash, Eq, PartialEq)] +pub enum VsockCid { + /// Vsock equivalent of INADDR_ANY. Indicates the context id of the current endpoint. + Any, + /// An address that refers to the bare-metal machine that serves as the hypervisor. + Hypervisor, + /// The loopback address. + Local, + /// The parent machine. It may not be the hypervisor for nested VMs. + Host, + /// An assigned CID that serves as the address for VSOCK. + Cid(c_uint), +} + +impl fmt::Display for VsockCid { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + match &self { + VsockCid::Any => write!(fmt, "Any"), + VsockCid::Hypervisor => write!(fmt, "Hypervisor"), + VsockCid::Local => write!(fmt, "Local"), + VsockCid::Host => write!(fmt, "Host"), + VsockCid::Cid(c) => write!(fmt, "'{}'", c), + } + } +} + +impl From<c_uint> for VsockCid { + fn from(c: c_uint) -> Self { + match c { + VMADDR_CID_ANY => VsockCid::Any, + VMADDR_CID_HYPERVISOR => VsockCid::Hypervisor, + VMADDR_CID_LOCAL => VsockCid::Local, + VMADDR_CID_HOST => VsockCid::Host, + _ => VsockCid::Cid(c), + } + } +} + +impl FromStr for VsockCid { + type Err = ParseIntError; + + fn from_str(s: &str) -> Result<Self, Self::Err> { + let c: c_uint = s.parse()?; + Ok(c.into()) + } +} + +impl Into<c_uint> for VsockCid { + fn into(self) -> c_uint { + match self { + VsockCid::Any => VMADDR_CID_ANY, + VsockCid::Hypervisor => VMADDR_CID_HYPERVISOR, + VsockCid::Local => VMADDR_CID_LOCAL, + VsockCid::Host => VMADDR_CID_HOST, + VsockCid::Cid(c) => c, + } + } +} + +/// An address associated with a virtual socket. +#[derive(Debug, Copy, Clone, Hash, Eq, PartialEq)] +pub struct SocketAddr { + pub cid: VsockCid, + pub port: c_uint, +} + +pub trait ToSocketAddr { + fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError>; +} + +impl ToSocketAddr for SocketAddr { + fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> { + Ok(*self) + } +} + +impl ToSocketAddr for str { + fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> { + self.parse() + } +} + +impl ToSocketAddr for (VsockCid, c_uint) { + fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> { + let (cid, port) = *self; + Ok(SocketAddr { cid, port }) + } +} + +impl<'a, T: ToSocketAddr + ?Sized> ToSocketAddr for &'a T { + fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> { + (**self).to_socket_addr() + } +} + +impl FromStr for SocketAddr { + type Err = AddrParseError; + + /// Parse a vsock SocketAddr from a string. vsock socket addresses are of the form + /// "vsock:cid:port". + fn from_str(s: &str) -> Result<SocketAddr, AddrParseError> { + let components: Vec<&str> = s.split(':').collect(); + if components.len() != 3 || components[0] != "vsock" { + return Err(AddrParseError); + } + + Ok(SocketAddr { + cid: components[1].parse().map_err(|_| AddrParseError)?, + port: components[2].parse().map_err(|_| AddrParseError)?, + }) + } +} + +impl fmt::Display for SocketAddr { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + write!(fmt, "{}:{}", self.cid, self.port) + } +} + +/// Sets `fd` to be blocking or nonblocking. `fd` must be a valid fd of a type that accepts the +/// `O_NONBLOCK` flag. This includes regular files, pipes, and sockets. +unsafe fn set_nonblocking(fd: RawFd, nonblocking: bool) -> io::Result<()> { + let flags = libc::fcntl(fd, F_GETFL, 0); + if flags < 0 { + return Err(io::Error::last_os_error()); + } + + let flags = if nonblocking { + flags | O_NONBLOCK + } else { + flags & !O_NONBLOCK + }; + + let ret = libc::fcntl(fd, F_SETFL, flags); + if ret < 0 { + return Err(io::Error::last_os_error()); + } + + Ok(()) +} + +/// A virtual socket. +/// +/// Do not use this class unless you need to change socket options or query the +/// state of the socket prior to calling listen or connect. Instead use either VsockStream or +/// VsockListener. +#[derive(Debug)] +pub struct VsockSocket { + fd: RawFd, +} + +impl VsockSocket { + pub fn new() -> io::Result<Self> { + let fd = unsafe { libc::socket(libc::AF_VSOCK, libc::SOCK_STREAM | libc::SOCK_CLOEXEC, 0) }; + if fd < 0 { + Err(io::Error::last_os_error()) + } else { + Ok(VsockSocket { fd }) + } + } + + pub fn bind<A: ToSocketAddr>(&mut self, addr: A) -> io::Result<()> { + let sockaddr = addr + .to_socket_addr() + .map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?; + + // The compiler should optimize this out since these are both compile-time constants. + assert_eq!(size_of::<sockaddr_vm>(), size_of::<sockaddr>()); + + let mut svm: sockaddr_vm = Default::default(); + svm.svm_family = AF_VSOCK; + svm.svm_cid = sockaddr.cid.into(); + svm.svm_port = sockaddr.port; + + // Safe because this doesn't modify any memory and we check the return value. + let ret = unsafe { + libc::bind( + self.fd, + &svm as *const sockaddr_vm as *const sockaddr, + size_of::<sockaddr_vm>() as socklen_t, + ) + }; + if ret < 0 { + let bind_err = io::Error::last_os_error(); + Err(bind_err) + } else { + Ok(()) + } + } + + pub fn connect<A: ToSocketAddr>(self, addr: A) -> io::Result<VsockStream> { + let sockaddr = addr + .to_socket_addr() + .map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?; + + let mut svm: sockaddr_vm = Default::default(); + svm.svm_family = AF_VSOCK; + svm.svm_cid = sockaddr.cid.into(); + svm.svm_port = sockaddr.port; + + // Safe because this just connects a vsock socket, and the return value is checked. + let ret = unsafe { + libc::connect( + self.fd, + &svm as *const sockaddr_vm as *const sockaddr, + size_of::<sockaddr_vm>() as socklen_t, + ) + }; + if ret < 0 { + let connect_err = io::Error::last_os_error(); + Err(connect_err) + } else { + Ok(VsockStream { sock: self }) + } + } + + pub fn listen(self) -> io::Result<VsockListener> { + // Safe because this doesn't modify any memory and we check the return value. + let ret = unsafe { libc::listen(self.fd, 1) }; + if ret < 0 { + let listen_err = io::Error::last_os_error(); + return Err(listen_err); + } + Ok(VsockListener { sock: self }) + } + + /// Returns the port that this socket is bound to. This can only succeed after bind is called. + pub fn local_port(&self) -> io::Result<u32> { + let mut svm: sockaddr_vm = Default::default(); + + // Safe because we give a valid pointer for addrlen and check the length. + let mut addrlen = size_of::<sockaddr_vm>() as socklen_t; + let ret = unsafe { + // Get the socket address that was actually bound. + libc::getsockname( + self.fd, + &mut svm as *mut sockaddr_vm as *mut sockaddr, + &mut addrlen as *mut socklen_t, + ) + }; + if ret < 0 { + let getsockname_err = io::Error::last_os_error(); + Err(getsockname_err) + } else { + // If this doesn't match, it's not safe to get the port out of the sockaddr. + assert_eq!(addrlen as usize, size_of::<sockaddr_vm>()); + + Ok(svm.svm_port) + } + } + + pub fn try_clone(&self) -> io::Result<Self> { + // Safe because this doesn't modify any memory and we check the return value. + let dup_fd = unsafe { libc::fcntl(self.fd, libc::F_DUPFD_CLOEXEC, 0) }; + if dup_fd < 0 { + Err(io::Error::last_os_error()) + } else { + Ok(Self { fd: dup_fd }) + } + } + + pub fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> { + // Safe because the fd is valid and owned by this stream. + unsafe { set_nonblocking(self.fd, nonblocking) } + } +} + +impl IntoRawFd for VsockSocket { + fn into_raw_fd(self) -> RawFd { + let fd = self.fd; + mem::forget(self); + fd + } +} + +impl AsRawFd for VsockSocket { + fn as_raw_fd(&self) -> RawFd { + self.fd + } +} + +impl Drop for VsockSocket { + fn drop(&mut self) { + // Safe because this doesn't modify any memory and we are the only + // owner of the file descriptor. + unsafe { libc::close(self.fd) }; + } +} + +/// A virtual stream socket. +#[derive(Debug)] +pub struct VsockStream { + sock: VsockSocket, +} + +impl VsockStream { + pub fn connect<A: ToSocketAddr>(addr: A) -> io::Result<VsockStream> { + let sock = VsockSocket::new()?; + sock.connect(addr) + } + + /// Returns the port that this stream is bound to. + pub fn local_port(&self) -> io::Result<u32> { + self.sock.local_port() + } + + pub fn try_clone(&self) -> io::Result<VsockStream> { + self.sock.try_clone().map(|f| VsockStream { sock: f }) + } + + pub fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> { + self.sock.set_nonblocking(nonblocking) + } +} + +impl io::Read for VsockStream { + fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { + // Safe because this will only modify the contents of |buf| and we check the return value. + let ret = unsafe { + libc::read( + self.sock.as_raw_fd(), + buf as *mut [u8] as *mut c_void, + buf.len() as size_t, + ) + }; + if ret < 0 { + return Err(io::Error::last_os_error()); + } + + Ok(ret as usize) + } +} + +impl io::Write for VsockStream { + fn write(&mut self, buf: &[u8]) -> io::Result<usize> { + // Safe because this doesn't modify any memory and we check the return value. + let ret = unsafe { + libc::write( + self.sock.as_raw_fd(), + buf as *const [u8] as *const c_void, + buf.len() as size_t, + ) + }; + if ret < 0 { + return Err(io::Error::last_os_error()); + } + + Ok(ret as usize) + } + + fn flush(&mut self) -> io::Result<()> { + // No buffered data so nothing to do. + Ok(()) + } +} + +impl AsRawFd for VsockStream { + fn as_raw_fd(&self) -> RawFd { + self.sock.as_raw_fd() + } +} + +impl IntoRawFd for VsockStream { + fn into_raw_fd(self) -> RawFd { + self.sock.into_raw_fd() + } +} + +/// Represents a virtual socket server. +#[derive(Debug)] +pub struct VsockListener { + sock: VsockSocket, +} + +impl VsockListener { + /// Creates a new `VsockListener` bound to the specified port on the current virtual socket + /// endpoint. + pub fn bind<A: ToSocketAddr>(addr: A) -> io::Result<VsockListener> { + let mut sock = VsockSocket::new()?; + sock.bind(addr)?; + sock.listen() + } + + /// Returns the port that this listener is bound to. + pub fn local_port(&self) -> io::Result<u32> { + self.sock.local_port() + } + + /// Accepts a new incoming connection on this listener. Blocks the calling thread until a + /// new connection is established. When established, returns the corresponding `VsockStream` + /// and the remote peer's address. + pub fn accept(&self) -> io::Result<(VsockStream, SocketAddr)> { + let mut svm: sockaddr_vm = Default::default(); + + // Safe because this will only modify |svm| and we check the return value. + let mut socklen: socklen_t = size_of::<sockaddr_vm>() as socklen_t; + let fd = unsafe { + libc::accept4( + self.sock.as_raw_fd(), + &mut svm as *mut sockaddr_vm as *mut sockaddr, + &mut socklen as *mut socklen_t, + libc::SOCK_CLOEXEC, + ) + }; + if fd < 0 { + return Err(io::Error::last_os_error()); + } + + if svm.svm_family != AF_VSOCK { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("unexpected address family: {}", svm.svm_family), + )); + } + + Ok(( + VsockStream { + sock: VsockSocket { fd }, + }, + SocketAddr { + cid: svm.svm_cid.into(), + port: svm.svm_port, + }, + )) + } + + pub fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> { + self.sock.set_nonblocking(nonblocking) + } +} + +impl AsRawFd for VsockListener { + fn as_raw_fd(&self) -> RawFd { + self.sock.as_raw_fd() + } +} |