diff options
author | Jorge E. Moreira <jemoreira@google.com> | 2021-04-21 20:32:53 -0700 |
---|---|---|
committer | Jorge E. Moreira <jemoreira@google.com> | 2021-04-21 20:32:53 -0700 |
commit | 67c7636ec8364d9d92a96bceae6441ba01461eb2 (patch) | |
tree | 93c4786948168cfe49e3f47fa3ed94924d3b2e99 | |
parent | 38a78326ce3cfe429a8344297abfb162d47e6204 (diff) | |
parent | 5be4f273e87bc55dfc1ed3f4a6126f4d9f02e797 (diff) | |
download | crosvm-67c7636ec8364d9d92a96bceae6441ba01461eb2.tar.gz |
Merge remote-tracking branch 'aosp/upstream-main'
Bug: 185155959
Test: locally with following change
Change-Id: I9580972149384e197e57abb09d480d8997f527e5
294 files changed, 14679 insertions, 5797 deletions
diff --git a/.gitignore b/.gitignore index 1acdbc7a0..bfe747156 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ target/ !/Cargo.lock lcov.info .idea +.vscode diff --git a/Android.bp b/Android.bp index 3c8a620a8..5ea3df5d5 100644 --- a/Android.bp +++ b/Android.bp @@ -85,7 +85,6 @@ rust_binary { name: "crosvm", defaults: ["crosvm_defaults"], host_supported: true, - prefer_rlib: true, crate_name: "crosvm", srcs: ["src/main.rs"], @@ -94,15 +93,7 @@ rust_binary { relative_install_path: "aarch64-linux-bionic", }, linux_glibc_x86_64: { - features: [ - "gdb", - "gdbstub", - ], relative_install_path: "x86_64-linux-gnu", - rustlibs: [ - "libgdbstub", - "libthiserror", - ], }, darwin: { enabled: false, @@ -183,18 +174,6 @@ rust_defaults { rustlibs: ["libaarch64"], }, }, - target: { - linux_glibc_x86_64: { - features: [ - "gdb", - "gdbstub", - ], - rustlibs: [ - "libgdbstub", - "libthiserror", - ], - }, - }, features: [ "default", ], @@ -342,16 +321,6 @@ rust_library { "gfxstream", ], }, - linux_glibc_x86_64: { - features: [ - "gdb", - "gdbstub", - ], - rustlibs: [ - "libgdbstub", - "libthiserror", - ], - }, }, arch: { x86_64: { diff --git a/Cargo.lock b/Cargo.lock index 9d9e76995..87cdedda2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -53,7 +53,6 @@ dependencies = [ "kernel_cmdline", "libc", "minijail", - "msg_socket", "power_monitor", "resources", "sync", @@ -93,6 +92,12 @@ dependencies = [ [[package]] name = "autocfg" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d49d90015b3c36167a20fe2810c5cd875ad504b39cff3d4eae7977e6b7c1cb2" + +[[package]] +name = "autocfg" version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cdb031dd78e28731d87d56cc8ffef4a8f36ca26c38fe2de700543e627f8a464a" @@ -104,8 +109,12 @@ dependencies = [ "cros_async", "data_model", "libc", + "serde", + "serde_json", + "smallvec", "sync", "sys_util", + "thiserror", ] [[package]] @@ -143,6 +152,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4785bdd1c96b2a846b2bd7cc02e86b6b3dbf14e7e53446c4f54c92a361040822" [[package]] +name = "cloudabi" +version = "0.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ddfc5b9aa5d4507acaf872de71051dfd0e309860e88966e1051e462a077aac4f" +dependencies = [ + "bitflags", +] + +[[package]] name = "cras-sys" version = "0.1.0" dependencies = [ @@ -158,6 +176,7 @@ dependencies = [ "async-trait", "data_model", "futures", + "intrusive-collections", "io_uring", "libc", "paste", @@ -165,11 +184,17 @@ dependencies = [ "slab", "sync", "sys_util", - "syscall_defines", "thiserror", ] [[package]] +name = "cros_fuzz" +version = "0.1.0" +dependencies = [ + "rand_core 0.4.2", +] + +[[package]] name = "crosvm" version = "0.1.0" dependencies = [ @@ -194,7 +219,6 @@ dependencies = [ "libc", "libcras", "minijail", - "msg_socket", "net_util", "p9", "protobuf", @@ -213,6 +237,24 @@ dependencies = [ ] [[package]] +name = "crosvm-fuzz" +version = "0.0.1" +dependencies = [ + "base", + "cros_fuzz", + "data_model", + "devices", + "disk", + "fuse", + "kernel_loader", + "libc", + "rand", + "tempfile", + "usb_util", + "vm_memory", +] + +[[package]] name = "crosvm_plugin" version = "0.17.0" dependencies = [ @@ -261,13 +303,10 @@ dependencies = [ "hypervisor", "kvm_sys", "libc", - "libchromeos", "libcras", "libvda", "linux_input_sys", "minijail", - "msg_on_socket_derive", - "msg_socket", "net_sys", "net_util", "p9", @@ -277,9 +316,10 @@ dependencies = [ "remain", "resources", "rutabaga_gfx", + "serde", + "smallvec", "sync", "sys_util", - "syscall_defines", "tempfile", "thiserror", "tpm2", @@ -289,6 +329,7 @@ dependencies = [ "virtio_sys", "vm_control", "vm_memory", + "vmm_vhost", ] [[package]] @@ -324,6 +365,12 @@ dependencies = [ ] [[package]] +name = "fuchsia-cprng" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a06f77d526c1a601b7c4cdd98f54b5eaabffc14d5f2f0296febdc7f357c6d3ba" + +[[package]] name = "fuse" version = "0.1.0" dependencies = [ @@ -343,7 +390,6 @@ checksum = "b6f16056ecbb57525ff698bb955162d0cd03bee84e6241c27ff75c08d8ca5987" dependencies = [ "futures-channel", "futures-core", - "futures-executor", "futures-io", "futures-sink", "futures-task", @@ -367,35 +413,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "79564c427afefab1dfb3298535b21eda083ef7935b4f0ecbfcb121f0aec10866" [[package]] -name = "futures-executor" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e274736563f686a837a0568b478bdabfeaec2dca794b5649b04e2fe1627c231" -dependencies = [ - "futures-core", - "futures-task", - "futures-util", -] - -[[package]] name = "futures-io" version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e676577d229e70952ab25f3945795ba5b16d63ca794ca9d2c860e5595d20b5ff" [[package]] -name = "futures-macro" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "52e7c56c15537adb4f76d0b7a76ad131cb4d2f4f32d3b0bcabcbe1c7c5e87764" -dependencies = [ - "proc-macro-hack", - "proc-macro2", - "quote", - "syn", -] - -[[package]] name = "futures-sink" version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -416,13 +439,10 @@ dependencies = [ "futures-channel", "futures-core", "futures-io", - "futures-macro", "futures-sink", "futures-task", "memchr", "pin-utils", - "proc-macro-hack", - "proc-macro-nested", "slab", ] @@ -471,7 +491,7 @@ dependencies = [ "kvm", "kvm_sys", "libc", - "msg_socket", + "serde", "sync", "vm_memory", ] @@ -505,10 +525,15 @@ dependencies = [ "libc", "sync", "sys_util", - "syscall_defines", ] [[package]] +name = "itoa" +version = "0.4.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd25036021b0de88a0aff6b850051563c6516d0bf53f8638938edbb9de732736" + +[[package]] name = "kernel_cmdline" version = "0.1.0" dependencies = [ @@ -533,7 +558,6 @@ dependencies = [ "data_model", "kvm_sys", "libc", - "msg_socket", "sync", "vm_memory", ] @@ -549,9 +573,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.81" +version = "0.2.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1482821306169ec4d07f6aca392a4681f66c75c9918aa49641a2595db64053cb" +checksum = "9385f66bf6105b241aa65a61cb923ef20efc665cb9f9bb50ac2f0c4b7f378d41" [[package]] name = "libchromeos" @@ -559,11 +583,12 @@ version = "0.1.0" dependencies = [ "data_model", "futures", - "intrusive-collections", "libc", "log", "protobuf", + "sys_util", "thiserror", + "zeroize", ] [[package]] @@ -578,6 +603,15 @@ dependencies = [ ] [[package]] +name = "libcrosvm_control" +version = "0.1.0" +dependencies = [ + "base", + "libc", + "vm_control", +] + +[[package]] name = "libdbus-sys" version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -631,7 +665,7 @@ version = "0.5.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "043175f069eda7b85febe4a74abbaeff828d9f8b448515d3151a14a3542811aa" dependencies = [ - "autocfg", + "autocfg 1.0.1", ] [[package]] @@ -651,29 +685,6 @@ dependencies = [ ] [[package]] -name = "msg_on_socket_derive" -version = "0.1.0" -dependencies = [ - "base", - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "msg_socket" -version = "0.1.0" -dependencies = [ - "base", - "cros_async", - "data_model", - "futures", - "libc", - "msg_on_socket_derive", - "sync", -] - -[[package]] name = "net_sys" version = "0.1.0" dependencies = [ @@ -696,7 +707,7 @@ version = "0.2.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ac267bcc07f48ee5f8935ab0d24f316fb722d7a1292e2913f0cc196b29ffd611" dependencies = [ - "autocfg", + "autocfg 1.0.1", ] [[package]] @@ -755,23 +766,6 @@ dependencies = [ ] [[package]] -name = "proc-macro-hack" -version = "0.5.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ecd45702f76d6d3c75a80564378ae228a85f0b59d2f3ed43c91b4a69eb2ebfc5" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "proc-macro-nested" -version = "0.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "369a6ed065f249a159e06c45752c780bda2fb53c995718f9e484d08daa9eb42e" - -[[package]] name = "proc-macro2" version = "1.0.24" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -845,10 +839,125 @@ dependencies = [ ] [[package]] +name = "rand" +version = "0.6.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d71dacdc3c88c1fde3885a3be3fbab9f35724e6ce99467f7d9c5026132184ca" +dependencies = [ + "autocfg 0.1.7", + "libc", + "rand_chacha", + "rand_core 0.4.2", + "rand_hc", + "rand_isaac", + "rand_jitter", + "rand_os", + "rand_pcg", + "rand_xorshift", + "winapi", +] + +[[package]] +name = "rand_chacha" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "556d3a1ca6600bfcbab7c7c91ccb085ac7fbbcd70e008a98742e7847f4f7bcef" +dependencies = [ + "autocfg 0.1.7", + "rand_core 0.3.1", +] + +[[package]] +name = "rand_core" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a6fdeb83b075e8266dcc8762c22776f6877a63111121f5f8c7411e5be7eed4b" +dependencies = [ + "rand_core 0.4.2", +] + +[[package]] +name = "rand_core" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c33a3c44ca05fa6f1807d8e6743f3824e8509beca625669633be0acbdf509dc" + +[[package]] +name = "rand_hc" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b40677c7be09ae76218dc623efbf7b18e34bced3f38883af07bb75630a21bc4" +dependencies = [ + "rand_core 0.3.1", +] + +[[package]] +name = "rand_isaac" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ded997c9d5f13925be2a6fd7e66bf1872597f759fd9dd93513dd7e92e5a5ee08" +dependencies = [ + "rand_core 0.3.1", +] + +[[package]] name = "rand_ish" version = "0.1.0" [[package]] +name = "rand_jitter" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1166d5c91dc97b88d1decc3285bb0a99ed84b05cfd0bc2341bdf2d43fc41e39b" +dependencies = [ + "libc", + "rand_core 0.4.2", + "winapi", +] + +[[package]] +name = "rand_os" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b75f676a1e053fc562eafbb47838d67c84801e38fc1ba459e8f180deabd5071" +dependencies = [ + "cloudabi", + "fuchsia-cprng", + "libc", + "rand_core 0.4.2", + "rdrand", + "winapi", +] + +[[package]] +name = "rand_pcg" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "abf9b09b01790cfe0364f52bf32995ea3c39f4d2dd011eac241d2914146d0b44" +dependencies = [ + "autocfg 0.1.7", + "rand_core 0.4.2", +] + +[[package]] +name = "rand_xorshift" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cbf7e9e623549b0e21f6e97cf8ecf247c1a8fd2e8a992ae265314300b2455d5c" +dependencies = [ + "rand_core 0.3.1", +] + +[[package]] +name = "rdrand" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "678054eb77286b51581ba43620cc911abf02758c91f93f479767aed0f90458b2" +dependencies = [ + "rand_core 0.3.1", +] + +[[package]] name = "remain" version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -865,7 +974,7 @@ version = "0.1.0" dependencies = [ "base", "libc", - "msg_socket", + "serde", ] [[package]] @@ -879,6 +988,12 @@ dependencies = [ ] [[package]] +name = "ryu" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "71d301d4193d031abdd79ff7e3dd721168a9572ef3fe51a1517aba235bd8f86e" + +[[package]] name = "serde" version = "1.0.121" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -899,12 +1014,29 @@ dependencies = [ ] [[package]] +name = "serde_json" +version = "1.0.64" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "799e97dc9fdae36a5c8b8f2cae9ce2ee9fdce2058c57a93e6099d919fd982f79" +dependencies = [ + "itoa", + "ryu", + "serde", +] + +[[package]] name = "slab" version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c111b5bd5695e56cffe5129854aa230b39c93a305372fdbb2668ca2394eea9f8" [[package]] +name = "smallvec" +version = "1.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fe0f37c9e8f3c5a4a66ad655a93c74daac4ad00c441533bf5c6e7990bb42604e" + +[[package]] name = "syn" version = "1.0.58" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -920,6 +1052,18 @@ name = "sync" version = "0.1.0" [[package]] +name = "synstructure" +version = "0.12.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b834f2d66f734cb897113e34aaff2f1ab4719ca946f9a7358dba8f8064148701" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "unicode-xid", +] + +[[package]] name = "sys_util" version = "0.1.0" dependencies = [ @@ -927,16 +1071,13 @@ dependencies = [ "data_model", "libc", "poll_token_derive", + "serde", + "serde_json", "sync", - "syscall_defines", "tempfile", ] [[package]] -name = "syscall_defines" -version = "0.1.0" - -[[package]] name = "tempfile" version = "3.0.7" dependencies = [ @@ -1044,9 +1185,9 @@ dependencies = [ "gdbstub", "hypervisor", "libc", - "msg_socket", "resources", "rutabaga_gfx", + "serde", "sync", "vm_memory", ] @@ -1056,13 +1197,45 @@ name = "vm_memory" version = "0.1.0" dependencies = [ "base", + "bitflags", "cros_async", "data_model", "libc", - "syscall_defines", ] [[package]] +name = "vmm_vhost" +version = "0.1.0" +dependencies = [ + "bitflags", + "libc", + "sys_util", + "tempfile", +] + +[[package]] +name = "winapi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +dependencies = [ + "winapi-i686-pc-windows-gnu", + "winapi-x86_64-pc-windows-gnu", +] + +[[package]] +name = "winapi-i686-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" + +[[package]] +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" + +[[package]] name = "wire_format_derive" version = "0.1.0" dependencies = [ @@ -1087,10 +1260,30 @@ dependencies = [ "kernel_loader", "libc", "minijail", - "msg_socket", "remain", "resources", "sync", "vm_control", "vm_memory", ] + +[[package]] +name = "zeroize" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81a974bcdd357f0dca4d41677db03436324d45a4c9ed2d0b873a5a360ce41c36" +dependencies = [ + "zeroize_derive", +] + +[[package]] +name = "zeroize_derive" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3f369ddb18862aba61aa49bf31e74d29f0f162dec753063200e1dc084345d16" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "synstructure", +] diff --git a/Cargo.toml b/Cargo.toml index 5ff22bde9..8c969aca8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,6 +3,7 @@ name = "crosvm" version = "0.1.0" authors = ["The Chromium OS Authors"] edition = "2018" +default-run = "crosvm" [lib] path = "src/crosvm.rs" @@ -11,14 +12,21 @@ path = "src/crosvm.rs" name = "crosvm" path = "src/main.rs" +[[bin]] +name = "crosvm-direct" +path = "src/main.rs" +required-features = [ "direct" ] + [profile.release] panic = 'abort' overflow-checks = true [workspace] members = [ + "fuzz", "qcow_utils", "integration_tests", + "libcrosvm_control", ] exclude = [ "assertions", @@ -28,7 +36,6 @@ exclude = [ "rand_ish", "sync", "sys_util", - "syscall_defines", "tempfile", "vm_memory", ] @@ -37,6 +44,7 @@ exclude = [ default = ["audio", "gpu"] chromeos = ["base/chromeos"] default-no-sandbox = [] +direct = ["devices/direct"] audio = ["devices/audio"] gpu = ["devices/gpu"] plugin = ["protos/plugin", "crosvm_plugin", "kvm", "kvm_sys", "protobuf"] @@ -70,10 +78,9 @@ kernel_cmdline = { path = "kernel_cmdline" } kernel_loader = { path = "kernel_loader" } kvm = { path = "kvm", optional = true } kvm_sys = { path = "kvm_sys", optional = true } -libc = "0.2.65" +libc = "0.2.93" libcras = "*" minijail = "*" # provided by ebuild -msg_socket = { path = "msg_socket" } net_util = { path = "net_util" } p9 = { path = "../vm_tools/p9" } protobuf = { version = "2.3", optional = true } @@ -102,13 +109,14 @@ base = "*" assertions = { path = "assertions" } audio_streams = { path = "../adhd/audio_streams" } # ignored by ebuild base = { path = "base" } +cros_fuzz = { path = "../../platform2/cros-fuzz" } # ignored by ebuild data_model = { path = "data_model" } libchromeos = { path = "../libchromeos-rs" } # ignored by ebuild libcras = { path = "../adhd/cras/client/libcras" } # ignored by ebuild minijail = { path = "../minijail/rust/minijail" } # ignored by ebuild p9 = { path = "../vm_tools/p9" } # ignored by ebuild sync = { path = "sync" } -syscall_defines = { path = "syscall_defines" } sys_util = { path = "sys_util" } tempfile = { path = "tempfile" } wire_format_derive = { path = "../vm_tools/p9/wire_format_derive" } # ignored by ebuild +vmm_vhost = { path = "../rust/crates/vhost", features = ["vhost-user-master"] } # ignored by ebuild @@ -1,8 +1,5 @@ adelva@google.com chirantan@google.com dgreid@google.com -smbarber@chromium.org -zachr@chromium.org - -# So any team members can +2 -* +dverkamp@google.com +zachr@google.com diff --git a/aarch64/src/lib.rs b/aarch64/src/lib.rs index 4acef3dc2..4a3a35ecf 100644 --- a/aarch64/src/lib.rs +++ b/aarch64/src/lib.rs @@ -4,10 +4,8 @@ use std::collections::BTreeMap; use std::error::Error as StdError; -use std::ffi::{CStr, CString}; use std::fmt::{self, Display}; -use std::fs::File; -use std::io::{self, Seek}; +use std::io::{self}; use std::sync::Arc; use arch::{ @@ -15,13 +13,8 @@ use arch::{ VmComponents, VmImage, }; use base::Event; -use devices::{ - Bus, BusError, IrqChip, IrqChipAArch64, PciAddress, PciConfigMmio, PciDevice, PciInterruptPin, - ProtectionType, -}; -use hypervisor::{ - DeviceKind, Hypervisor, HypervisorCap, PsciVersion, VcpuAArch64, VcpuFeature, VmAArch64, -}; +use devices::{Bus, BusError, IrqChip, IrqChipAArch64, PciConfigMmio, PciDevice, ProtectionType}; +use hypervisor::{DeviceKind, Hypervisor, HypervisorCap, VcpuAArch64, VcpuFeature, VmAArch64}; use minijail::Minijail; use remain::sorted; use resources::SystemAllocator; @@ -222,13 +215,19 @@ pub struct AArch64; impl arch::LinuxArch for AArch64 { type Error = Error; - fn build_vm<V, Vcpu, I, FD, FV, FI, E1, E2, E3>( + fn guest_memory_layout( + components: &VmComponents, + ) -> std::result::Result<Vec<(GuestAddress, u64)>, Self::Error> { + Ok(arch_memory_regions(components.memory_size)) + } + + fn build_vm<V, Vcpu, I, FD, FI, E1, E2>( mut components: VmComponents, serial_parameters: &BTreeMap<(SerialHardware, u8), SerialParameters>, serial_jail: Option<Minijail>, _battery: (&Option<BatteryType>, Option<Minijail>), + mut vm: V, create_devices: FD, - create_vm: FV, create_irq_chip: FI, ) -> std::result::Result<RunnableLinuxVm<V, Vcpu, I>, Self::Error> where @@ -241,20 +240,17 @@ impl arch::LinuxArch for AArch64 { &mut SystemAllocator, &Event, ) -> std::result::Result<Vec<(Box<dyn PciDevice>, Option<Minijail>)>, E1>, - FV: FnOnce(GuestMemory) -> std::result::Result<V, E2>, - FI: FnOnce(&V, /* vcpu_count: */ usize) -> std::result::Result<I, E3>, + FI: FnOnce(&V, /* vcpu_count: */ usize) -> std::result::Result<I, E2>, E1: StdError + 'static, E2: StdError + 'static, - E3: StdError + 'static, { let has_bios = match components.vm_image { VmImage::Bios(_) => true, _ => false, }; + let mem = vm.get_memory().clone(); let mut resources = Self::get_resource_allocator(components.memory_size); - let mem = Self::setup_memory(components.memory_size)?; - let mut vm = create_vm(mem.clone()).map_err(|e| Error::CreateVm(Box::new(e)))?; if components.protected_vm == ProtectionType::Protected { vm.enable_protected_vm( @@ -270,7 +266,7 @@ impl arch::LinuxArch for AArch64 { let vcpu_count = components.vcpu_count; let mut vcpus = Vec::with_capacity(vcpu_count); for vcpu_id in 0..vcpu_count { - let vcpu = *vm + let vcpu: Vcpu = *vm .create_vcpu(vcpu_id) .map_err(Error::CreateVcpu)? .downcast::<Vcpu>() @@ -429,12 +425,6 @@ impl arch::LinuxArch for AArch64 { } impl AArch64 { - fn setup_memory(mem_size: u64) -> Result<GuestMemory> { - let arch_mem_regions = arch_memory_regions(mem_size); - let mem = GuestMemory::new(&arch_mem_regions).map_err(Error::SetupGuestMemory)?; - Ok(mem) - } - fn get_high_mmio_base_size(mem_size: u64) -> (u64, u64) { let base = AARCH64_PHYS_MEM_START + mem_size; let size = u64::max_value() - base; @@ -507,31 +497,27 @@ impl AArch64 { } vcpu.init(&features).map_err(Error::VcpuInit)?; - // set up registers - let mut data: u64; - let mut reg_id: u64; - // All interrupts masked - data = PSR_D_BIT | PSR_A_BIT | PSR_I_BIT | PSR_F_BIT | PSR_MODE_EL1H; - reg_id = arm64_core_reg!(pstate); - vcpu.set_one_reg(reg_id, data).map_err(Error::SetReg)?; + let pstate = PSR_D_BIT | PSR_A_BIT | PSR_I_BIT | PSR_F_BIT | PSR_MODE_EL1H; + vcpu.set_one_reg(arm64_core_reg!(pstate), pstate) + .map_err(Error::SetReg)?; // Other cpus are powered off initially if vcpu_id == 0 { - if has_bios { - data = AARCH64_PHYS_MEM_START + AARCH64_BIOS_OFFSET; + let entry_addr = if has_bios { + AARCH64_PHYS_MEM_START + AARCH64_BIOS_OFFSET } else { - data = AARCH64_PHYS_MEM_START + AARCH64_KERNEL_OFFSET; - } - reg_id = arm64_core_reg!(pc); - vcpu.set_one_reg(reg_id, data).map_err(Error::SetReg)?; + AARCH64_PHYS_MEM_START + AARCH64_KERNEL_OFFSET + }; + vcpu.set_one_reg(arm64_core_reg!(pc), entry_addr) + .map_err(Error::SetReg)?; /* X0 -- fdt address */ let mem_size = guest_mem.memory_size(); - data = (AARCH64_PHYS_MEM_START + fdt_offset(mem_size, has_bios)) as u64; + let fdt_addr = (AARCH64_PHYS_MEM_START + fdt_offset(mem_size, has_bios)) as u64; // hack -- can't get this to do offsetof(regs[0]) but luckily it's at offset 0 - reg_id = arm64_core_reg!(regs); - vcpu.set_one_reg(reg_id, data).map_err(Error::SetReg)?; + vcpu.set_one_reg(arm64_core_reg!(regs), fdt_addr) + .map_err(Error::SetReg)?; } Ok(()) diff --git a/arch/Android.bp b/arch/Android.bp index 5e5c40ddd..943e0750e 100644 --- a/arch/Android.bp +++ b/arch/Android.bp @@ -1,5 +1,4 @@ -// This file is generated by cargo2android.py --run --device --tests --dependencies --global_defaults=crosvm_defaults --add_workspace --features=gdb. -// NOTE: The --features=gdb should be applied only to the host (not the device) and there are inline changes to achieve this +// This file is generated by cargo2android.py --run --device --tests --dependencies --global_defaults=crosvm_defaults --add_workspace. package { // See: http://go/android-license-faq @@ -18,17 +17,6 @@ rust_defaults { test_suites: ["general-tests"], auto_gen_config: true, edition: "2018", - target: { - linux_glibc_x86_64: { - features: [ - "gdb", - "gdbstub", - ], - rustlibs: [ - "libgdbstub", - ], - }, - }, rustlibs: [ "libacpi_tables", "libbase_rust", @@ -70,17 +58,6 @@ rust_library { crate_name: "arch", srcs: ["src/lib.rs"], edition: "2018", - target: { - linux_glibc_x86_64: { - features: [ - "gdb", - "gdbstub", - ], - rustlibs: [ - "libgdbstub", - ], - }, - }, rustlibs: [ "libacpi_tables", "libbase_rust", @@ -149,11 +126,10 @@ rust_library { // ../vm_control/src/lib.rs // ../vm_memory/src/lib.rs // async-task-4.0.3 "default,std" -// async-trait-0.1.48 +// async-trait-0.1.45 // autocfg-1.0.1 // base-0.1.0 // bitflags-1.2.1 "default" -// cfg-if-0.1.10 // cfg-if-1.0.0 // downcast-rs-1.2.0 "default,std" // futures-0.3.13 "alloc,async-await,default,executor,futures-executor,std" @@ -165,15 +141,12 @@ rust_library { // futures-sink-0.3.13 "alloc,std" // futures-task-0.3.13 "alloc,std" // futures-util-0.3.13 "alloc,async-await,async-await-macro,channel,futures-channel,futures-io,futures-macro,futures-sink,io,memchr,proc-macro-hack,proc-macro-nested,sink,slab,std" -// gdbstub-0.4.4 "alloc,default,std" // getrandom-0.2.2 "std" // intrusive-collections-0.9.0 "alloc,default" -// libc-0.2.88 "default,std" +// libc-0.2.87 "default,std" // log-0.4.14 -// managed-0.8.0 "alloc" // memchr-2.3.4 "default,std" // memoffset-0.5.6 "default" -// num-traits-0.2.14 // paste-1.0.4 // pin-project-lite-0.2.6 // pin-utils-0.1.0 @@ -189,10 +162,10 @@ rust_library { // rand_core-0.6.2 "alloc,getrandom,std" // remain-0.2.2 // remove_dir_all-0.5.3 -// serde-1.0.124 "default,derive,serde_derive,std" -// serde_derive-1.0.124 "default" +// serde-1.0.123 "default,derive,serde_derive,std" +// serde_derive-1.0.123 "default" // slab-0.4.2 -// syn-1.0.63 "clone-impls,default,derive,full,parsing,printing,proc-macro,quote,visit-mut" +// syn-1.0.61 "clone-impls,default,derive,full,parsing,printing,proc-macro,quote,visit-mut" // tempfile-3.2.0 // thiserror-1.0.24 // thiserror-impl-1.0.24 diff --git a/arch/Cargo.toml b/arch/Cargo.toml index b72d11741..095cc480c 100644 --- a/arch/Cargo.toml +++ b/arch/Cargo.toml @@ -16,7 +16,6 @@ hypervisor = { path = "../hypervisor" } kernel_cmdline = { path = "../kernel_cmdline" } libc = "*" minijail = { path = "../../minijail/rust/minijail" } # ignored by ebuild -msg_socket = { path = "../msg_socket" } resources = { path = "../resources" } sync = { path = "../sync" } base = { path = "../base" } diff --git a/arch/src/fdt.rs b/arch/src/fdt.rs index f5c82ddd6..944e5025a 100644 --- a/arch/src/fdt.rs +++ b/arch/src/fdt.rs @@ -3,7 +3,7 @@ // found in the LICENSE file. //! This module writes Flattened Devicetree blobs as defined here: -//! https://devicetree-specification.readthedocs.io/en/stable/flattened-format.html +//! <https://devicetree-specification.readthedocs.io/en/stable/flattened-format.html> use std::collections::BTreeMap; use std::convert::TryInto; @@ -48,7 +48,7 @@ const FDT_END_NODE: u32 = 0x00000002; const FDT_PROP: u32 = 0x00000003; const FDT_END: u32 = 0x00000009; -/// Interface for writing a Flattened Devicetree (FDT) and emitting a Devicetree Blob (FDT). +/// Interface for writing a Flattened Devicetree (FDT) and emitting a Devicetree Blob (DTB). /// /// # Example /// diff --git a/arch/src/lib.rs b/arch/src/lib.rs index c59a2cd13..5070e728c 100644 --- a/arch/src/lib.rs +++ b/arch/src/lib.rs @@ -17,7 +17,7 @@ use std::sync::Arc; use acpi_tables::aml::Aml; use acpi_tables::sdt::SDT; -use base::{syslog, AsRawDescriptor, Event}; +use base::{syslog, AsRawDescriptor, Event, Tube}; use devices::virtio::VirtioDevice; use devices::{ Bus, BusDevice, BusError, IrqChip, PciAddress, PciDevice, PciDeviceError, PciInterruptPin, @@ -27,11 +27,7 @@ use hypervisor::{IoEventAddress, Vm}; use minijail::Minijail; use resources::{MmioType, SystemAllocator}; use sync::Mutex; -#[cfg(all(target_arch = "x86_64", feature = "gdb"))] -use vm_control::VmControlRequestSocket; -use vm_control::{ - BatControl, BatControlCommand, BatControlRequestSocket, BatControlResult, BatteryType, -}; +use vm_control::{BatControl, BatteryType}; use vm_memory::{GuestAddress, GuestMemory, GuestMemoryError}; #[cfg(all(target_arch = "x86_64", feature = "gdb"))] @@ -83,6 +79,7 @@ pub struct VmComponents { pub vcpu_count: usize, pub vcpu_affinity: Option<VcpuAffinity>, pub no_smt: bool, + pub hugepages: bool, pub vm_image: VmImage, pub android_fstab: Option<File>, pub pstore: Option<Pstore>, @@ -93,7 +90,8 @@ pub struct VmComponents { pub rt_cpus: Vec<usize>, pub protected_vm: ProtectionType, #[cfg(all(target_arch = "x86_64", feature = "gdb"))] - pub gdb: Option<(u32, VmControlRequestSocket)>, // port and control socket. + pub gdb: Option<(u32, Tube)>, // port and control tube. + pub dmi_path: Option<PathBuf>, } /// Holds the elements needed to run a Linux VM. Created by `build_vm`. @@ -116,7 +114,7 @@ pub struct RunnableLinuxVm<V: VmArch, Vcpu: VcpuArch, I: IrqChipArch> { pub rt_cpus: Vec<usize>, pub bat_control: Option<BatControl>, #[cfg(all(target_arch = "x86_64", feature = "gdb"))] - pub gdb: Option<(u32, VmControlRequestSocket)>, + pub gdb: Option<(u32, Tube)>, } /// The device and optional jail. @@ -130,6 +128,16 @@ pub struct VirtioDeviceStub { pub trait LinuxArch { type Error: StdError; + /// Returns a Vec of the valid memory addresses as pairs of address and length. These should be + /// used to configure the `GuestMemory` structure for the platform. + /// + /// # Arguments + /// + /// * `components` - Parts used to determine the memory layout. + fn guest_memory_layout( + components: &VmComponents, + ) -> std::result::Result<Vec<(GuestAddress, u64)>, Self::Error>; + /// Takes `VmComponents` and generates a `RunnableLinuxVm`. /// /// # Arguments @@ -138,15 +146,14 @@ pub trait LinuxArch { /// * `serial_parameters` - definitions for how the serial devices should be configured. /// * `battery` - defines what battery device will be created. /// * `create_devices` - Function to generate a list of devices. - /// * `create_vm` - Function to generate a VM. /// * `create_irq_chip` - Function to generate an IRQ chip. - fn build_vm<V, Vcpu, I, FD, FV, FI, E1, E2, E3>( + fn build_vm<V, Vcpu, I, FD, FI, E1, E2>( components: VmComponents, serial_parameters: &BTreeMap<(SerialHardware, u8), SerialParameters>, serial_jail: Option<Minijail>, battery: (&Option<BatteryType>, Option<Minijail>), + vm: V, create_devices: FD, - create_vm: FV, create_irq_chip: FI, ) -> std::result::Result<RunnableLinuxVm<V, Vcpu, I>, Self::Error> where @@ -159,11 +166,9 @@ pub trait LinuxArch { &mut SystemAllocator, &Event, ) -> std::result::Result<Vec<(Box<dyn PciDevice>, Option<Minijail>)>, E1>, - FV: FnOnce(GuestMemory) -> std::result::Result<V, E2>, - FI: FnOnce(&V, /* vcpu_count: */ usize) -> std::result::Result<I, E3>, + FI: FnOnce(&V, /* vcpu_count: */ usize) -> std::result::Result<I, E2>, E1: StdError + 'static, - E2: StdError + 'static, - E3: StdError + 'static; + E2: StdError + 'static; /// Configures the vcpu and should be called once per vcpu from the vcpu's thread. /// @@ -240,8 +245,8 @@ pub enum DeviceRegistrationError { CreatePipe(base::Error), // Unable to create serial device from serial parameters CreateSerialDevice(serial::Error), - // Unable to create socket - CreateSocket(io::Error), + // Unable to create tube + CreateTube(base::TubeError), /// Could not clone an event. EventClone(base::Error), /// Could not create an event. @@ -279,7 +284,7 @@ impl Display for DeviceRegistrationError { AllocateIrq => write!(f, "Allocating IRQ number"), CreatePipe(e) => write!(f, "failed to create pipe: {}", e), CreateSerialDevice(e) => write!(f, "failed to create serial device: {}", e), - CreateSocket(e) => write!(f, "failed to create socket: {}", e), + CreateTube(e) => write!(f, "failed to create tube: {}", e), Cmdline(e) => write!(f, "unable to add device to kernel command line: {}", e), EventClone(e) => write!(f, "failed to clone event: {}", e), EventCreate(e) => write!(f, "failed to create event: {}", e), @@ -434,7 +439,7 @@ pub fn add_goldfish_battery( irq_chip: &mut impl IrqChip, irq_num: u32, resources: &mut SystemAllocator, -) -> Result<BatControlRequestSocket, DeviceRegistrationError> { +) -> Result<Tube, DeviceRegistrationError> { let alloc = resources.get_anon_alloc(); let mmio_base = resources .mmio_allocator(MmioType::Low) @@ -453,9 +458,8 @@ pub fn add_goldfish_battery( .register_irq_event(irq_num, &irq_evt, Some(&irq_resample_evt)) .map_err(DeviceRegistrationError::RegisterIrqfd)?; - let (control_socket, response_socket) = - msg_socket::pair::<BatControlCommand, BatControlResult>() - .map_err(DeviceRegistrationError::CreateSocket)?; + let (control_tube, response_tube) = + Tube::pair().map_err(DeviceRegistrationError::CreateTube)?; #[cfg(feature = "power-monitor-powerd")] let create_monitor = Some(Box::new(power_monitor::powerd::DBusMonitor::connect) @@ -469,7 +473,7 @@ pub fn add_goldfish_battery( irq_num, irq_evt, irq_resample_evt, - response_socket, + response_tube, create_monitor, ) .map_err(DeviceRegistrationError::RegisterBattery)?; @@ -501,7 +505,7 @@ pub fn add_goldfish_battery( } } - Ok(control_socket) + Ok(control_tube) } /// Errors for image loading. diff --git a/arch/src/pstore.rs b/arch/src/pstore.rs index 3a12aab8c..9335a1d46 100644 --- a/arch/src/pstore.rs +++ b/arch/src/pstore.rs @@ -63,7 +63,7 @@ pub fn create_memory_region( .map_err(Error::ResourcesError)?; let memory_mapping = MemoryMappingBuilder::new(pstore.size as usize) - .from_descriptor(&file) + .from_file(&file) .build() .map_err(Error::MmapError)?; diff --git a/base/Cargo.toml b/base/Cargo.toml index 2520d4256..aecbb3eaa 100644 --- a/base/Cargo.toml +++ b/base/Cargo.toml @@ -11,5 +11,9 @@ chromeos = ["sys_util/chromeos"] cros_async = { path = "../cros_async" } data_model = { path = "../data_model" } libc = "*" +serde = { version = "1", features = [ "derive" ] } +serde_json = "*" +smallvec = "1.6.1" sync = { path = "../sync" } sys_util = { path = "../sys_util" } +thiserror = "1.0.20" diff --git a/base/src/event.rs b/base/src/event.rs index 4a6c0cdf0..4f34914a8 100644 --- a/base/src/event.rs +++ b/base/src/event.rs @@ -8,13 +8,16 @@ use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd}; use std::ptr; use std::time::Duration; +use serde::{Deserialize, Serialize}; + use crate::{AsRawDescriptor, FromRawDescriptor, IntoRawDescriptor, RawDescriptor, Result}; use sys_util::EventFd; pub use sys_util::EventReadResult; /// See [EventFd](sys_util::EventFd) for struct- and method-level /// documentation. -#[derive(Debug, PartialEq, Eq)] +#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)] +#[serde(transparent)] pub struct Event(pub EventFd); impl Event { pub fn new() -> Result<Event> { diff --git a/base/src/lib.rs b/base/src/lib.rs index 2e9b2ebfa..b562d48d3 100644 --- a/base/src/lib.rs +++ b/base/src/lib.rs @@ -10,6 +10,7 @@ mod ioctl; mod mmap; mod shm; mod timer; +mod tube; mod wait_context; pub use async_types::*; @@ -18,7 +19,7 @@ pub use ioctl::{ ioctl, ioctl_with_mut_ptr, ioctl_with_mut_ref, ioctl_with_ptr, ioctl_with_ref, ioctl_with_val, }; pub use mmap::Unix as MemoryMappingUnix; -pub use mmap::{MemoryMapping, MemoryMappingBuilder}; +pub use mmap::{MemoryMapping, MemoryMappingBuilder, MemoryMappingBuilderUnix}; pub use shm::{SharedMemory, Unix as SharedMemoryUnix}; pub use sys_util::ioctl::*; pub use sys_util::sched::*; @@ -28,6 +29,7 @@ pub use sys_util::{ }; pub use sys_util::{SeekHole, WriteZeroesAt}; pub use timer::{FakeTimer, Timer}; +pub use tube::{AsyncTube, Error as TubeError, Result as TubeResult, Tube}; pub use wait_context::{EventToken, EventType, TriggeredEvent, WaitContext}; /// Wraps an AsRawDescriptor in the simple Descriptor struct, which diff --git a/base/src/mmap.rs b/base/src/mmap.rs index 85b6acd1b..d2bd4cb33 100644 --- a/base/src/mmap.rs +++ b/base/src/mmap.rs @@ -2,9 +2,10 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -use crate::{wrap_descriptor, AsRawDescriptor, MappedRegion, MmapError, Protection}; +use crate::{wrap_descriptor, AsRawDescriptor, MappedRegion, MmapError, Protection, SharedMemory}; use data_model::volatile_memory::*; use data_model::DataInit; +use std::fs::File; use sys_util::MemoryMapping as SysUtilMmap; pub type Result<T> = std::result::Result<T, MmapError>; @@ -12,26 +13,33 @@ pub type Result<T> = std::result::Result<T, MmapError>; /// See [MemoryMapping](sys_util::MemoryMapping) for struct- and method-level /// documentation. #[derive(Debug)] -pub struct MemoryMapping(SysUtilMmap); +pub struct MemoryMapping { + mapping: SysUtilMmap, +} + impl MemoryMapping { pub fn write_slice(&self, buf: &[u8], offset: usize) -> Result<usize> { - self.0.write_slice(buf, offset) + self.mapping.write_slice(buf, offset) } pub fn read_slice(&self, buf: &mut [u8], offset: usize) -> Result<usize> { - self.0.read_slice(buf, offset) + self.mapping.read_slice(buf, offset) } pub fn write_obj<T: DataInit>(&self, val: T, offset: usize) -> Result<()> { - self.0.write_obj(val, offset) + self.mapping.write_obj(val, offset) } pub fn read_obj<T: DataInit>(&self, offset: usize) -> Result<T> { - self.0.read_obj(offset) + self.mapping.read_obj(offset) } pub fn msync(&self) -> Result<()> { - self.0.msync() + self.mapping.msync() + } + + pub fn use_hugepages(&self) -> Result<()> { + self.mapping.use_hugepages() } pub fn read_to_memory( @@ -40,7 +48,7 @@ impl MemoryMapping { src: &dyn AsRawDescriptor, count: usize, ) -> Result<()> { - self.0 + self.mapping .read_to_memory(mem_offset, &wrap_descriptor(src), count) } @@ -50,7 +58,7 @@ impl MemoryMapping { dst: &dyn AsRawDescriptor, count: usize, ) -> Result<()> { - self.0 + self.mapping .write_from_memory(mem_offset, &wrap_descriptor(dst), count) } } @@ -61,10 +69,14 @@ pub trait Unix { impl Unix for MemoryMapping { fn remove_range(&self, mem_offset: usize, count: usize) -> Result<()> { - self.0.remove_range(mem_offset, count) + self.mapping.remove_range(mem_offset, count) } } +pub trait MemoryMappingBuilderUnix<'a> { + fn from_descriptor(self, descriptor: &'a dyn AsRawDescriptor) -> MemoryMappingBuilder; +} + pub struct MemoryMappingBuilder<'a> { descriptor: Option<&'a dyn AsRawDescriptor>, size: usize, @@ -73,6 +85,16 @@ pub struct MemoryMappingBuilder<'a> { populate: bool, } +impl<'a> MemoryMappingBuilderUnix<'a> for MemoryMappingBuilder<'a> { + /// Build the memory mapping given the specified descriptor to mapped memory + /// + /// Default: Create a new memory mapping. + fn from_descriptor(mut self, descriptor: &'a dyn AsRawDescriptor) -> MemoryMappingBuilder { + self.descriptor = Some(descriptor); + self + } +} + /// Builds a MemoryMapping object from the specified arguments. impl<'a> MemoryMappingBuilder<'a> { /// Creates a new builder specifying size of the memory region in bytes. @@ -86,11 +108,23 @@ impl<'a> MemoryMappingBuilder<'a> { } } - /// Build the memory mapping given the specified descriptor to mapped memory + /// Build the memory mapping given the specified File to mapped memory /// /// Default: Create a new memory mapping. - pub fn from_descriptor(mut self, descriptor: &'a dyn AsRawDescriptor) -> MemoryMappingBuilder { - self.descriptor = Some(descriptor); + /// + /// Note: this is a forward looking interface to accomodate platforms that + /// require special handling for file backed mappings. + #[allow(unused_mut)] + pub fn from_file(mut self, file: &'a File) -> MemoryMappingBuilder { + self.descriptor = Some(file as &dyn AsRawDescriptor); + self + } + + /// Build the memory mapping given the specified SharedMemory to mapped memory + /// + /// Default: Create a new memory mapping. + pub fn from_shared_memory(mut self, shm: &'a SharedMemory) -> MemoryMappingBuilder { + self.descriptor = Some(shm as &dyn AsRawDescriptor); self } @@ -176,23 +210,23 @@ impl<'a> MemoryMappingBuilder<'a> { } fn wrap(result: Result<SysUtilMmap>) -> Result<MemoryMapping> { - result.map(MemoryMapping) + result.map(|mapping| MemoryMapping { mapping }) } } impl VolatileMemory for MemoryMapping { fn get_slice(&self, offset: usize, count: usize) -> VolatileMemoryResult<VolatileSlice> { - self.0.get_slice(offset, count) + self.mapping.get_slice(offset, count) } } // Safe because it exclusively forwards calls to a safe implementation. unsafe impl MappedRegion for MemoryMapping { fn as_ptr(&self) -> *mut u8 { - self.0.as_ptr() + self.mapping.as_ptr() } fn size(&self) -> usize { - self.0.size() + self.mapping.size() } } diff --git a/base/src/shm.rs b/base/src/shm.rs index 36c6f7e1a..a445f3d62 100644 --- a/base/src/shm.rs +++ b/base/src/shm.rs @@ -8,11 +8,15 @@ use crate::{ }; use std::ffi::CStr; use std::fs::File; -use std::os::unix::io::AsRawFd; +use std::os::unix::io::{AsRawFd, IntoRawFd}; + +use serde::{Deserialize, Serialize}; use sys_util::SharedMemory as SysUtilSharedMemory; /// See [SharedMemory](sys_util::SharedMemory) for struct- and method-level /// documentation. +#[derive(Serialize, Deserialize)] +#[serde(transparent)] pub struct SharedMemory(SysUtilSharedMemory); impl SharedMemory { pub fn named<T: Into<Vec<u8>>>(name: T, size: u64) -> Result<SharedMemory> { @@ -77,8 +81,15 @@ impl AsRawDescriptor for SharedMemory { } } -impl Into<File> for SharedMemory { - fn into(self) -> File { - self.0.into() +impl IntoRawDescriptor for SharedMemory { + fn into_raw_descriptor(self) -> RawDescriptor { + self.0.into_raw_fd() + } +} + +impl Into<SafeDescriptor> for SharedMemory { + fn into(self) -> SafeDescriptor { + // Safe because we own the SharedMemory at this point. + unsafe { SafeDescriptor::from_raw_descriptor(self.into_raw_descriptor()) } } } diff --git a/base/src/tube.rs b/base/src/tube.rs new file mode 100644 index 000000000..21917c931 --- /dev/null +++ b/base/src/tube.rs @@ -0,0 +1,239 @@ +// Copyright 2021 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +use std::io::{self, IoSlice}; +use std::marker::PhantomData; +use std::ops::Deref; +use std::os::unix::prelude::{AsRawFd, RawFd}; +use std::time::Duration; + +use crate::{net::UnixSeqpacket, FromRawDescriptor, SafeDescriptor, ScmSocket, UnsyncMarker}; + +use cros_async::{Executor, IntoAsync, IoSourceExt}; +use serde::{de::DeserializeOwned, Serialize}; +use sys_util::{ + deserialize_with_descriptors, AsRawDescriptor, RawDescriptor, SerializeDescriptors, +}; +use thiserror::Error as ThisError; + +#[derive(ThisError, Debug)] +pub enum Error { + #[error("failed to serialize/deserialize json from packet: {0}")] + Json(serde_json::Error), + #[error("failed to send packet: {0}")] + Send(sys_util::Error), + #[error("failed to receive packet: {0}")] + Recv(io::Error), + #[error("tube was disconnected")] + Disconnected, + #[error("failed to crate tube pair: {0}")] + Pair(io::Error), + #[error("failed to set send timeout: {0}")] + SetSendTimeout(io::Error), + #[error("failed to set recv timeout: {0}")] + SetRecvTimeout(io::Error), + #[error("failed to create async tube: {0}")] + CreateAsync(cros_async::AsyncError), +} + +pub type Result<T> = std::result::Result<T, Error>; + +/// Bidirectional tube that support both send and recv. +pub struct Tube { + socket: UnixSeqpacket, + _unsync_marker: UnsyncMarker, +} + +impl Tube { + /// Create a pair of connected tubes. Request is send in one direction while response is in the + /// other direction. + pub fn pair() -> Result<(Tube, Tube)> { + let (socket1, socket2) = UnixSeqpacket::pair().map_err(Error::Pair)?; + let tube1 = Tube::new(socket1); + let tube2 = Tube::new(socket2); + Ok((tube1, tube2)) + } + + // Create a new `Tube`. + pub fn new(socket: UnixSeqpacket) -> Tube { + Tube { + socket, + _unsync_marker: PhantomData, + } + } + + pub fn into_async_tube(self, ex: &Executor) -> Result<AsyncTube> { + let inner = ex.async_from(self).map_err(Error::CreateAsync)?; + Ok(AsyncTube { inner }) + } + + pub fn send<T: Serialize>(&self, msg: &T) -> Result<()> { + let msg_serialize = SerializeDescriptors::new(&msg); + let msg_json = serde_json::to_vec(&msg_serialize).map_err(Error::Json)?; + let msg_descriptors = msg_serialize.into_descriptors(); + + self.socket + .send_with_fds(&[IoSlice::new(&msg_json)], &msg_descriptors) + .map_err(Error::Send)?; + Ok(()) + } + + pub fn recv<T: DeserializeOwned>(&self) -> Result<T> { + let (msg_json, msg_descriptors) = + self.socket.recv_as_vec_with_fds().map_err(Error::Recv)?; + + if msg_json.is_empty() { + return Err(Error::Disconnected); + } + + let mut msg_descriptors_safe = msg_descriptors + .into_iter() + .map(|v| { + Some(unsafe { + // Safe because the socket returns new fds that are owned locally by this scope. + SafeDescriptor::from_raw_descriptor(v) + }) + }) + .collect(); + + deserialize_with_descriptors( + || serde_json::from_slice(&msg_json), + &mut msg_descriptors_safe, + ) + .map_err(Error::Json) + } + + /// Returns true if there is a packet ready to `recv` without blocking. + /// + /// If there is an error trying to determine if there is a packet ready, this returns false. + pub fn is_packet_ready(&self) -> bool { + self.socket.get_readable_bytes().unwrap_or(0) > 0 + } + + pub fn set_send_timeout(&self, timeout: Option<Duration>) -> Result<()> { + self.socket + .set_write_timeout(timeout) + .map_err(Error::SetSendTimeout) + } + + pub fn set_recv_timeout(&self, timeout: Option<Duration>) -> Result<()> { + self.socket + .set_read_timeout(timeout) + .map_err(Error::SetRecvTimeout) + } +} + +impl AsRawDescriptor for Tube { + fn as_raw_descriptor(&self) -> RawDescriptor { + self.socket.as_raw_descriptor() + } +} + +impl AsRawFd for Tube { + fn as_raw_fd(&self) -> RawFd { + self.socket.as_raw_fd() + } +} + +impl IntoAsync for Tube {} + +pub struct AsyncTube { + inner: Box<dyn IoSourceExt<Tube>>, +} + +impl AsyncTube { + pub async fn next<T: DeserializeOwned>(&self) -> Result<T> { + self.inner.wait_readable().await.unwrap(); + self.inner.as_source().recv() + } +} + +impl Deref for AsyncTube { + type Target = Tube; + + fn deref(&self) -> &Self::Target { + self.inner.as_source() + } +} + +impl Into<Tube> for AsyncTube { + fn into(self) -> Tube { + self.inner.into_source() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::Event; + + use std::collections::HashMap; + use std::time::Duration; + + use serde::{Deserialize, Serialize}; + + #[track_caller] + fn test_event_pair(send: Event, mut recv: Event) { + send.write(1).unwrap(); + recv.read_timeout(Duration::from_secs(1)).unwrap(); + } + + #[test] + fn send_recv_no_fd() { + let (s1, s2) = Tube::pair().unwrap(); + + let test_msg = "hello world"; + s1.send(&test_msg).unwrap(); + let recv_msg: String = s2.recv().unwrap(); + + assert_eq!(test_msg, recv_msg); + } + + #[test] + fn send_recv_one_fd() { + #[derive(Serialize, Deserialize)] + struct EventStruct { + x: u32, + b: Event, + } + + let (s1, s2) = Tube::pair().unwrap(); + + let test_msg = EventStruct { + x: 100, + b: Event::new().unwrap(), + }; + s1.send(&test_msg).unwrap(); + let recv_msg: EventStruct = s2.recv().unwrap(); + + assert_eq!(test_msg.x, recv_msg.x); + + test_event_pair(test_msg.b, recv_msg.b); + } + + #[test] + fn send_recv_hash_map() { + let (s1, s2) = Tube::pair().unwrap(); + + let mut test_msg = HashMap::new(); + test_msg.insert("Red".to_owned(), Event::new().unwrap()); + test_msg.insert("White".to_owned(), Event::new().unwrap()); + test_msg.insert("Blue".to_owned(), Event::new().unwrap()); + test_msg.insert("Orange".to_owned(), Event::new().unwrap()); + test_msg.insert("Green".to_owned(), Event::new().unwrap()); + s1.send(&test_msg).unwrap(); + let mut recv_msg: HashMap<String, Event> = s2.recv().unwrap(); + + let mut test_msg_keys: Vec<_> = test_msg.keys().collect(); + test_msg_keys.sort(); + let mut recv_msg_keys: Vec<_> = recv_msg.keys().collect(); + recv_msg_keys.sort(); + assert_eq!(test_msg_keys, recv_msg_keys); + + for (key, test_event) in test_msg { + let recv_event = recv_msg.remove(&key).unwrap(); + test_event_pair(test_event, recv_event); + } + } +} diff --git a/base/src/wait_context.rs b/base/src/wait_context.rs index 4e492b7bf..f3afb6ae4 100644 --- a/base/src/wait_context.rs +++ b/base/src/wait_context.rs @@ -6,6 +6,7 @@ use std::os::unix::io::AsRawFd; use std::time::Duration; use crate::{wrap_descriptor, AsRawDescriptor, RawDescriptor, Result}; +use smallvec::SmallVec; use sys_util::{PollContext, PollToken, WatchingEvents}; // Typedef PollToken as EventToken for better adherance to base naming. @@ -124,23 +125,22 @@ impl<T: EventToken> WaitContext<T> { } /// Waits for one or more of the registered triggers to become signaled. - pub fn wait(&self) -> Result<Vec<TriggeredEvent<T>>> { + pub fn wait(&self) -> Result<SmallVec<[TriggeredEvent<T>; 16]>> { self.wait_timeout(Duration::new(i64::MAX as u64, 0)) } /// Waits for one or more of the registered triggers to become signaled, failing if no triggers /// are signaled before the designated timeout has elapsed. - pub fn wait_timeout(&self, timeout: Duration) -> Result<Vec<TriggeredEvent<T>>> { + pub fn wait_timeout(&self, timeout: Duration) -> Result<SmallVec<[TriggeredEvent<T>; 16]>> { let events = self.0.wait_timeout(timeout)?; - let mut return_vec: Vec<TriggeredEvent<T>> = vec![]; - for event in events.iter() { - return_vec.push(TriggeredEvent { + Ok(events + .iter() + .map(|event| TriggeredEvent { token: event.token(), is_readable: event.readable(), is_writable: event.writable(), is_hungup: event.hungup(), - }); - } - Ok(return_vec) + }) + .collect()) } } diff --git a/bin/clippy b/bin/clippy index a9b70f675..0e930793c 100755 --- a/bin/clippy +++ b/bin/clippy @@ -54,6 +54,10 @@ SUPPRESS=( useless_format wrong_self_convention + # False positives affecting WlVfd @ `devices/src/virtio/wl.rs`. + # Bug: https://github.com/rust-lang/rust-clippy/issues/6312 + field_reassign_with_default + # We don't care about these lints. Okay to remain suppressed globally. blacklisted_name cast_lossless diff --git a/bit_field/bit_field_derive/bit_field_derive.rs b/bit_field/bit_field_derive/bit_field_derive.rs index 92fea946b..df3e2825e 100644 --- a/bit_field/bit_field_derive/bit_field_derive.rs +++ b/bit_field/bit_field_derive/bit_field_derive.rs @@ -473,7 +473,7 @@ fn get_fields_impl(fields: &[FieldSpec]) -> Vec<TokenStream> { let span = expected_bits.span(); quote_spanned! {span=> #[allow(dead_code)] - const EXPECTED_BITS: [(); #expected_bits as usize] = + const EXPECTED_BITS: [(); #expected_bits] = [(); <#ty as ::bit_field::BitFieldSpecifier>::FIELD_WIDTH as usize]; } }); diff --git a/ci/crosvm_aarch64_builder/entrypoint b/ci/crosvm_aarch64_builder/entrypoint index 21d3087e0..74ff7e510 100755 --- a/ci/crosvm_aarch64_builder/entrypoint +++ b/ci/crosvm_aarch64_builder/entrypoint @@ -27,8 +27,13 @@ if [ "$1" = "--vm" ]; then echo "Starting testing vm..." (cd /workspace/vm && screen -Sdm vm ./start_vm) export CARGO_TARGET_AARCH64_UNKNOWN_LINUX_GNU_RUNNER="\ - /workspace/vm/exec_file" - test_target="Virtual Machine (See 'screen -r vm')" + /workspace/src/platform/crosvm/ci/vm_tools/exec_binary_in_vm" + + if [[ $# -eq 0 ]]; then + test_target="Virtual Machine (See 'screen -r vm' or 'ssh vm')" + else + test_target="Virtual Machine" + fi export CROSVM_USE_VM=1 else test_target="User-space emulation" diff --git a/ci/crosvm_base/rust-toolchain b/ci/crosvm_base/rust-toolchain index 9db5ea12f..5a5c7211d 100644 --- a/ci/crosvm_base/rust-toolchain +++ b/ci/crosvm_base/rust-toolchain @@ -1 +1 @@ -1.48.0 +1.50.0 diff --git a/ci/crosvm_builder/entrypoint b/ci/crosvm_builder/entrypoint index 6b51b2a2e..ec4c0765b 100755 --- a/ci/crosvm_builder/entrypoint +++ b/ci/crosvm_builder/entrypoint @@ -27,8 +27,13 @@ if [ "$1" = "--vm" ]; then echo "Starting testing vm..." (cd /workspace/vm && screen -Sdm vm ./start_vm) export CARGO_TARGET_X86_64_UNKNOWN_LINUX_GNU_RUNNER="\ - /workspace/vm/exec_file" - test_target="Virtual Machine (See 'screen -r vm')" + /workspace/src/platform/crosvm/ci/vm_tools/exec_binary_in_vm" + + if [[ $# -eq 0 ]]; then + test_target="Virtual Machine (See 'screen -r vm' or 'ssh vm')" + else + test_target="Virtual Machine" + fi export CROSVM_USE_VM=1 else test_target="Native execution" diff --git a/ci/crosvm_test_vm/Dockerfile b/ci/crosvm_test_vm/Dockerfile index 997405214..aa19f1748 100644 --- a/ci/crosvm_test_vm/Dockerfile +++ b/ci/crosvm_test_vm/Dockerfile @@ -24,7 +24,7 @@ RUN apt-get update && apt-get install --yes \ WORKDIR /workspace/vm RUN curl -sSfL -o rootfs.qcow2 \ - "https://cdimage.debian.org/cdimage/cloud/buster/20201214-484/debian-10-generic-${VM_ARCH}-20201214-484.qcow2" + "http://cloud.debian.org/images/cloud/bullseye/daily/20210208-542/debian-11-generic-${VM_ARCH}-daily-20210208-542.qcow2" # Package `cloud_init_data.yaml` to be loaded during `first_boot.expect` COPY build/cloud_init_data.yaml ./ @@ -73,11 +73,6 @@ RUN chmod 0600 /root/.ssh/id_rsa # Copy utility scripts COPY runtime/start_vm.${VM_ARCH} ./start_vm -COPY runtime/exec \ - runtime/exec_file \ - runtime/sync_so \ - runtime/wait_for_vm \ - ./ # Automatically start the VM. ENTRYPOINT [ "/workspace/vm/start_vm" ] diff --git a/ci/crosvm_test_vm/build/cloud_init_data.yaml b/ci/crosvm_test_vm/build/cloud_init_data.yaml index 5544d5f2e..a9d4dc953 100644 --- a/ci/crosvm_test_vm/build/cloud_init_data.yaml +++ b/ci/crosvm_test_vm/build/cloud_init_data.yaml @@ -3,6 +3,7 @@ users: - name: crosvm sudo: ALL=(ALL) NOPASSWD:ALL lock_passwd: False + shell: /bin/bash # Hashed password is 'crosvm' passwd: $6$rounds=4096$os6Q9Ok4Y9a8hKvG$EwQ1bbS0qd4IJyRP.bnRbyjPbSS8BwxEJh18PfhsyD0w7a4GhTwakrmYZ6KuBoyP.cSjYYSW9wYwko4oCPoJr. # Pubkey for `../vm_key` @@ -12,11 +13,7 @@ users: crosvm@localhost groups: kvm, disk, tty -apt: - sources: - testing: - source: "deb $MIRROR bullseye main" - conf: APT::Default-Release "stable"; +hostname: testvm # Store working data on tmpfs to reduce unnecessary disk IO mounts: @@ -37,17 +34,14 @@ packages: - rsync runcmd: - # Install testing (debian bullseye) versions of some libraries. - - [ - apt-get, - install, - --yes, - -t, - testing, - --no-install-recommends, - libdrm2, - libepoxy0, - ] + # Prevent those annoying "host not found errors". + - echo 127.0.0.1 testvm >> /etc/hosts + + # Make it easier to identify which VM we are in. + - echo "export PS1=\"testvm-$(arch):\\\\w# \"" >> /etc/bash.bashrc + + # Enable core dumps for debugging crashes + - echo "* soft core unlimited" > /etc/security/limits.conf # Trim some fat - [apt-get, remove, --yes, vim-runtime, iso-codes, perl, grub-common] diff --git a/ci/crosvm_test_vm/runtime/sync_so b/ci/crosvm_test_vm/runtime/sync_so deleted file mode 100755 index 7214945bc..000000000 --- a/ci/crosvm_test_vm/runtime/sync_so +++ /dev/null @@ -1,20 +0,0 @@ -#!/bin/bash -# Copyright 2021 The Chromium OS Authors. All rights reserved. -# Use of this source code is governed by a BSD-style license that can be -# found in the LICENSE file. -# -# Synchronizes shared objects into the virtual machine to allow crosvm binaries -# to run. - -${0%/*}/exec exit || exit 1 # Wait for VM to be available - -rust_toolchain=$(cat /workspace/src/platform/crosvm/rust-toolchain) - -# List of shared objects used by crosvm that need to be synced -shared_objects=( - /workspace/scratch/lib/*.so* - /root/.rustup/toolchains/${rust_toolchain}-*/lib/libstd-*.so - /root/.rustup/toolchains/${rust_toolchain}-*/lib/libtest-*.so -) - -rsync -azPL --rsync-path="sudo rsync" ${shared_objects[@]} vm:/usr/lib diff --git a/ci/image_tag b/ci/image_tag index 75d30fb53..b7be5a127 100644 --- a/ci/image_tag +++ b/ci/image_tag @@ -1 +1 @@ -r0001 +r0004 diff --git a/ci/kokoro/common.sh b/ci/kokoro/common.sh index 54aba174f..c1e1e8b3b 100755 --- a/ci/kokoro/common.sh +++ b/ci/kokoro/common.sh @@ -5,6 +5,21 @@ crosvm_root="${KOKORO_ARTIFACTS_DIR}"/git/crosvm +# Enable SSH access to the kokoro builder. +# Use the fusion2/ UI to trigger a build and set the DEBUG_SSH_KEY environment +# variable to your public key, that will allow you to connect to the builder +# via SSH. +# Note: Access is restricted to the google corporate network. +# Details: https://yaqs.corp.google.com/eng/q/6628551334035456 +if [[ ! -z "${DEBUG_SSH_KEY}" ]]; then + echo "${DEBUG_SSH_KEY}" >>~/.ssh/authorized_keys + external_ip=$( + curl -s -H "Metadata-Flavor: Google" + http://metadata/computeMetadata/v1/instance/network-interfaces/0/access-configs/0/external-ip + ) + echo "SSH Debug enabled. Connect to: kbuilder@${external_ip}" +fi + setup_source() { if [ -z "${KOKORO_ARTIFACTS_DIR}" ]; then echo "This script must be run in kokoro" @@ -30,7 +45,7 @@ setup_source() { -u https://chromium.googlesource.com/chromiumos/manifest.git \ --repo-url https://chromium.googlesource.com/external/repo.git \ -g crosvm || return 1 - ./repo sync -j8 -c -m "${crosvm_root}/ci/kokoro/manifest.xml" || return 1 + ./repo sync -j8 -c || return 1 # Bind mount source into cros checkout. echo "" @@ -45,6 +60,12 @@ setup_source() { } cleanup() { + # Sleep after the build to allow for SSH debugging to continue. + if [[ ! -z "${DEBUG_SSH_KEY}" ]]; then + echo "Build done. Blocking for SSH debugging." + sleep 1h + fi + if command -v bindfs >/dev/null; then fusermount -uz "${KOKORO_ARTIFACTS_DIR}/cros/src/platform/crosvm" else diff --git a/ci/kokoro/manifest.xml b/ci/kokoro/manifest.xml deleted file mode 100644 index 1a337ad1f..000000000 --- a/ci/kokoro/manifest.xml +++ /dev/null @@ -1,24 +0,0 @@ -<?xml version="1.0" encoding="UTF-8"?> -<manifest> - <remote fetch="https://android.googlesource.com" name="aosp" review="https://android-review.googlesource.com"/> - <remote alias="cros" fetch="https://chromium.googlesource.com/" name="chromium"/> - <remote fetch="https://chromium.googlesource.com" name="cros" review="https://chromium-review.googlesource.com"/> - - <default remote="cros" revision="refs/heads/main" sync-j="8"/> - - <project dest-branch="refs/heads/master" groups="minilayout,firmware,buildtools,chromeos-admin,labtools,sysmon,devserver,crosvm" name="chromiumos/chromite" path="chromite" revision="2c0017fef941137472b9b58f8acbb41780c0f14f" upstream="refs/heads/master"> - <copyfile dest="AUTHORS" src="AUTHORS"/> - <copyfile dest="LICENSE" src="LICENSE"/> - </project> - <project dest-branch="refs/heads/main" groups="crosvm" name="chromiumos/docs" path="docs" revision="73161502e320ebea00bf620a54d3a606ca1b9836" upstream="refs/heads/main"/> - <project dest-branch="refs/heads/main" groups="crosvm" name="chromiumos/platform/crosvm" path="src/platform/crosvm" revision="f4d1cdaaeb7b1f8e747e1af4c91e81a3255defe9" upstream="refs/heads/main"/> - <project dest-branch="refs/heads/main" groups="crosvm" name="chromiumos/platform/minigbm" path="src/platform/minigbm" revision="6e27708ff2093a19c01d51cef61507bf8a804bf9" upstream="refs/heads/main"/> - <project dest-branch="refs/heads/main" groups="crosvm" name="chromiumos/platform2" path="src/platform2" revision="c2c0bbfe867d53bc8f7aca0edb2ba65ec5849de1" upstream="refs/heads/main"/> - <project dest-branch="refs/heads/main" groups="minilayout,firmware,buildtools,labtools,crosvm" name="chromiumos/repohooks" path="src/repohooks" revision="d9ed85f771d9aaf012f850d0b15e69af029023b1" upstream="refs/heads/main"/> - <project dest-branch="refs/heads/main" groups="crosvm" name="chromiumos/third_party/adhd" path="src/third_party/adhd" revision="4cde9efd2bfde57def27da9ee864cce6dc430046" upstream="refs/heads/main"/> - <project dest-branch="refs/heads/main" groups="firmware,crosvm" name="chromiumos/third_party/tpm2" path="src/third_party/tpm2" revision="86e93379322f012d354b9b8a369373ed9b62718c" upstream="refs/heads/main"/> - <project dest-branch="refs/heads/master" groups="crosvm" name="chromiumos/third_party/virglrenderer" path="src/third_party/virglrenderer" revision="51f45f343b77c01897ddddc6bce84117a6278793" upstream="refs/heads/master"/> - <project dest-branch="refs/heads/master" groups="crosvm" name="platform/external/minijail" path="src/aosp/external/minijail" remote="aosp" revision="e119bbb81cb42aaddef61882b3747cf7995465f7" upstream="refs/heads/master"/> - - <repo-hooks enabled-list="pre-upload" in-project="chromiumos/repohooks"/> -</manifest> diff --git a/ci/kokoro/uprev b/ci/kokoro/uprev deleted file mode 100755 index 515f84f5a..000000000 --- a/ci/kokoro/uprev +++ /dev/null @@ -1,30 +0,0 @@ -#!/bin/bash -# Copyright 2021 The Chromium OS Authors. All rights reserved. -# Use of this source code is governed by a BSD-style license that can be -# found in the LICENSE file. -# -# Uprevs manifest.xml to the latest versions. -# -# This is just a wrapper around `repo manifest`. Usually we have unsubmitted -# CLs in our local repo, so it's safer to pull the latest manifest revisions -# from a fresh repo checkout to make sure all commit sha's are available. - -tmp=$(mktemp -d) -manifest_path=$(realpath $(dirname $0)/manifest.xml) - -cleanup() { - rm -rf "${tmp}" -} - -main() { - cd "${tmp}" - repo init --depth 1 \ - -u https://chromium.googlesource.com/chromiumos/manifest.git \ - --repo-url https://chromium.googlesource.com/external/repo.git \ - -g crosvm - repo sync -j8 -c - repo manifest --revision-as-HEAD -o "${manifest_path}" -} - -trap cleanup EXIT -main "$@" diff --git a/ci/test_runner.py b/ci/test_runner.py index eb7ec8712..1207207c1 100644 --- a/ci/test_runner.py +++ b/ci/test_runner.py @@ -28,7 +28,9 @@ VERY_VERBOSE = False # Runs tests using the exec_file wrapper, which will run the test inside the # builders built-in VM. -VM_TEST_RUNNER = "/workspace/vm/exec_file --no-sync" +VM_TEST_RUNNER = ( + os.path.abspath("./ci/vm_tools/exec_binary_in_vm") + " --no-sync" +) # Runs tests using QEMU user-space emulation. QEMU_TEST_RUNNER = ( @@ -52,12 +54,21 @@ class Requirements(enum.Enum): # Test is disabled explicitly. DISABLED = "disabled" - # Test needs to be executed with expanded privileges for device access. + # Test needs to be executed with expanded privileges for device access and + # will be run inside a VM. PRIVILEGED = "privileged" # Test needs to run single-threaded SINGLE_THREADED = "single_threaded" + # Separate workspaces that have dev-dependencies cannot be built from the + # crosvm workspace and need to be built separately. + # Note: Separate workspaces are built with no features enabled. + SEPARATE_WORKSPACE = "separate_workspace" + + # Build, but do not run. + DO_NOT_RUN = "do_not_run" + BUILD_TIME_REQUIREMENTS = [ Requirements.AARCH64, @@ -84,8 +95,13 @@ class CrateInfo(object): build_reqs = requirements.intersection(BUILD_TIME_REQUIREMENTS) self.can_build = all(req in capabilities for req in build_reqs) - self.can_run = self.can_build and ( - not self.needs_privilege or Requirements.PRIVILEGED in capabilities + self.can_run = ( + self.can_build + and ( + not self.needs_privilege + or Requirements.PRIVILEGED in capabilities + ) + and not Requirements.DO_NOT_RUN in self.requirements ) def __repr__(self): @@ -209,7 +225,7 @@ def results_summary(results: Union[RunResults, CrateResults]): num_pass = results.count(TestResult.PASS) num_skip = results.count(TestResult.SKIP) num_fail = results.count(TestResult.FAIL) - msg = [] + msg: List[str] = [] if num_pass: msg.append(f"{num_pass} passed") if num_skip: @@ -219,9 +235,42 @@ def results_summary(results: Union[RunResults, CrateResults]): return ", ".join(msg) +def cargo_build_process( + cwd: str = ".", crates: List[CrateInfo] = [], features: Set[str] = set() +): + """Builds the main crosvm crate.""" + cmd = [ + "cargo", + "build", + "--color=never", + "--no-default-features", + "--features", + ",".join(features), + ] + + for crate in sorted(crate.name for crate in crates): + cmd += ["-p", crate] + + if VERY_VERBOSE: + print("CMD", " ".join(cmd)) + + process = subprocess.run( + cmd, + cwd=cwd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + ) + if process.returncode != 0 or VERBOSE: + print() + print(process.stdout) + return process + + def cargo_test_process( - crates: List[CrateInfo], - features: Set[str], + cwd: str, + crates: List[CrateInfo] = [], + features: Set[str] = set(), run: bool = True, single_threaded: bool = False, use_vm: bool = False, @@ -233,6 +282,11 @@ def cargo_test_process( cmd += ["--no-run"] if features: cmd += ["--no-default-features", "--features", ",".join(features)] + + # Skip doc tests as these cannot be run in the VM. + if use_vm: + cmd += ["--bins", "--tests"] + for crate in sorted(crate.name for crate in crates): cmd += ["-p", crate] @@ -247,6 +301,7 @@ def cargo_test_process( process = subprocess.run( cmd, + cwd=cwd, env=env, timeout=timeout, stdout=subprocess.PIPE, @@ -261,9 +316,41 @@ def cargo_test_process( def cargo_build_tests(crates: List[CrateInfo], features: Set[str]): """Runs cargo test --no-run to build all listed `crates`.""" - print("Building: ", ", ".join(crate.name for crate in crates)) - process = cargo_test_process(crates, features, run=False) - return process.returncode == 0 + separate_workspace_crates = [ + crate + for crate in crates + if Requirements.SEPARATE_WORKSPACE in crate.requirements + ] + workspace_crates = [ + crate + for crate in crates + if Requirements.SEPARATE_WORKSPACE not in crate.requirements + ] + + print( + "Building workspace: ", + ", ".join(crate.name for crate in workspace_crates), + ) + build_process = cargo_build_process( + cwd=".", crates=workspace_crates, features=features + ) + if build_process.returncode != 0: + return False + test_process = cargo_test_process( + cwd=".", crates=workspace_crates, features=features, run=False + ) + if test_process.returncode != 0: + return False + + for crate in separate_workspace_crates: + print("Building crate:", crate.name) + build_process = cargo_build_process(cwd=crate.name) + if build_process.returncode != 0: + return False + test_process = cargo_test_process(cwd=crate.name, run=False) + if test_process.returncode != 0: + return False + return True def cargo_test( @@ -279,16 +366,29 @@ def cargo_test( msg.append("in vm") if single_threaded: msg.append("(single-threaded)") + if Requirements.SEPARATE_WORKSPACE in crate.requirements: + msg.append("(separate workspace)") sys.stdout.write(f"{' '.join(msg)}... ") sys.stdout.flush() - process = cargo_test_process( - [crate], - features, - run=True, - single_threaded=single_threaded, - use_vm=use_vm, - timeout=TEST_TIMEOUT_SECS, - ) + + if Requirements.SEPARATE_WORKSPACE in crate.requirements: + process = cargo_test_process( + cwd=crate.name, + run=True, + single_threaded=single_threaded, + use_vm=use_vm, + timeout=TEST_TIMEOUT_SECS, + ) + else: + process = cargo_test_process( + cwd=".", + crates=[crate], + features=features, + run=True, + single_threaded=single_threaded, + use_vm=use_vm, + timeout=TEST_TIMEOUT_SECS, + ) results = CrateResults( crate.name, process.returncode == 0, process.stdout ) @@ -318,7 +418,6 @@ def execute_batched_by_privilege( Non-privileged tests are run first. Privileged tests are executed in a VM if use_vm is set. """ - build_crates = [crate for crate in crates if crate.can_build] if not cargo_build_tests(build_crates, features): return [] @@ -335,7 +434,7 @@ def execute_batched_by_privilege( ] if privileged_crates: if use_vm: - subprocess.run("/workspace/vm/sync_so", check=True) + subprocess.run("./ci/vm_tools/sync_deps", check=True) yield from execute_batched_by_parallelism( privileged_crates, features, use_vm=True ) @@ -491,8 +590,11 @@ def main( help="Path to file where to store junit xml results", ) args = parser.parse_args() + + global VERBOSE, VERY_VERBOSE VERBOSE = args.verbose or args.very_verbose # type: ignore VERY_VERBOSE = args.very_verbose # type: ignore + use_vm = os.environ.get("CROSVM_USE_VM") != None or args.use_vm cros_build = os.environ.get("CROSVM_CROS_BUILD") != None or args.cros_build diff --git a/ci/vm_tools/README.md b/ci/vm_tools/README.md new file mode 100644 index 000000000..0dcc87f06 --- /dev/null +++ b/ci/vm_tools/README.md @@ -0,0 +1,8 @@ +# VM Tools + +The scripts in this directory are used to make it easier to work with the test +VM. + +The VM is expected to be accessible via `ssh vm` without a password prompt, this +is set up by the builders, so it'll work out-of-the-box when running inside the +builder shell. diff --git a/ci/crosvm_test_vm/runtime/exec_file b/ci/vm_tools/exec_binary_in_vm index c14f64885..7b2c230d1 100755 --- a/ci/crosvm_test_vm/runtime/exec_file +++ b/ci/vm_tools/exec_binary_in_vm @@ -3,15 +3,16 @@ # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. # -# Uploads and executes a file in the VM. +# Uploads and executes a file in the VM. This script can be set as a runner +# for cargo to execute tests inside the VM. -${0%/*}/exec exit || exit 1 # Wait for VM to be available +${0%/*}/wait_for_vm_with_timeout || exit 1 if [ "$1" = "--no-sync" ]; then shift else - echo "Syncing shared objects..." - ${0%/*}/sync_so || exit 1 + echo "Syncing dependencies..." + ${0%/*}/sync_deps || exit 1 fi filepath=$1 @@ -19,8 +20,9 @@ filename=$(basename $filepath) echo "Executing $filename ${@:2}" scp -q $filepath vm:/tmp/$filename +ssh vm -q -t "cd /tmp && sudo ./$filename ${@:2}" + # Make sure to preserve the exit code of $filename after cleaning up the file. -ssh vm -q -t "cd /tmp && ./$filename ${@:2}" ret=$? ssh vm -q -t "rm /tmp/$filename" exit $ret diff --git a/ci/vm_tools/sync_deps b/ci/vm_tools/sync_deps new file mode 100755 index 000000000..a97b019e0 --- /dev/null +++ b/ci/vm_tools/sync_deps @@ -0,0 +1,37 @@ +#!/bin/bash +# Copyright 2021 The Chromium OS Authors. All rights reserved. +# Use of this source code is governed by a BSD-style license that can be +# found in the LICENSE file. +# +# Synchronizes dependencies of crosvm into the virtual machine to allow test +# binaries to execute. + +${0%/*}/wait_for_vm_with_timeout || exit 1 + +crosvm_root="/workspace/src/platform/crosvm" +rust_toolchain=$(cat ${crosvm_root}/rust-toolchain) +target_dir=$( + cargo metadata --no-deps --format-version 1 | + jq -r ".target_directory" +) + +# List of shared objects used by crosvm that need to be synced. +shared_objects=( + /workspace/scratch/lib/*.so* + /root/.rustup/toolchains/${rust_toolchain}-*/lib/libstd-*.so + /root/.rustup/toolchains/${rust_toolchain}-*/lib/libtest-*.so +) +rsync -azPLq --rsync-path="sudo rsync" ${shared_objects[@]} vm:/usr/lib + +# Files needed by binaries at runtime in the working directory. +if [ -z "${CARGO_BUILD_TARGET}" ]; then + runtime_files=( + "${target_dir}/debug/crosvm" + ) +else + runtime_files=( + "${target_dir}/${CARGO_BUILD_TARGET}/debug/crosvm" + ) +fi + +rsync -azPLq --rsync-path="sudo rsync" ${runtime_files} vm:/tmp diff --git a/ci/crosvm_test_vm/runtime/wait_for_vm b/ci/vm_tools/wait_for_vm index dca090f9a..dca090f9a 100755 --- a/ci/crosvm_test_vm/runtime/wait_for_vm +++ b/ci/vm_tools/wait_for_vm diff --git a/ci/crosvm_test_vm/runtime/exec b/ci/vm_tools/wait_for_vm_with_timeout index 1d7043caf..7502177d1 100755 --- a/ci/crosvm_test_vm/runtime/exec +++ b/ci/vm_tools/wait_for_vm_with_timeout @@ -2,8 +2,6 @@ # Copyright 2021 The Chromium OS Authors. All rights reserved. # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. -# -# Executes a command in the VM once it is available. if ! timeout --foreground 180s ${0%/*}/wait_for_vm; then echo "" diff --git a/cros_async/Cargo.toml b/cros_async/Cargo.toml index 5b497c13b..fcaf95bc7 100644 --- a/cros_async/Cargo.toml +++ b/cros_async/Cargo.toml @@ -7,15 +7,15 @@ edition = "2018" [dependencies] async-trait = "0.1.36" async-task = "4" -io_uring = { path = "../io_uring" } +intrusive-collections = "0.9" +io_uring = { path = "../io_uring" } # provided by ebuild libc = "*" paste = "1.0" pin-utils = "0.1.0-alpha.4" slab = "0.4" -sync = { path = "../sync" } -sys_util = { path = "../sys_util" } -data_model = { path = "../data_model" } -syscall_defines = { path = "../syscall_defines" } +sync = { path = "../sync" } # provided by ebuild +sys_util = { path = "../sys_util" } # provided by ebuild +data_model = { path = "../data_model" } # provided by ebuild thiserror = "1.0.20" [dependencies.futures] @@ -25,5 +25,8 @@ features = ["alloc"] [dev-dependencies] futures = { version = "*", features = ["executor"] } -tempfile = { path = "../tempfile" } -vm_memory = { path = "../vm_memory" } +futures-executor = { version = "0.3", features = ["thread-pool"] } +futures-util = "0.3" +tempfile = { path = "../tempfile" } # provided by ebuild + +[workspace]
\ No newline at end of file diff --git a/cros_async/src/fd_executor.rs b/cros_async/src/fd_executor.rs index 4ab5dc578..ec6a72061 100644 --- a/cros_async/src/fd_executor.rs +++ b/cros_async/src/fd_executor.rs @@ -49,6 +49,9 @@ pub enum Error { /// PollContext failure. #[error("PollContext failure: {0}")] PollContextError(sys_util::Error), + /// An error occurred when setting the FD non-blocking. + #[error("An error occurred setting the FD non-blocking: {0}.")] + SettingNonBlocking(sys_util::Error), /// Failed to submit the waker to the polling context. #[error("An error adding to the Aio context: {0}")] SubmittingWaker(sys_util::Error), @@ -70,6 +73,62 @@ enum OpStatus { Completed, } +// An IO source previously registered with an FdExecutor. Used to initiate asynchronous IO with the +// associated executor. +pub struct RegisteredSource<F> { + source: F, + ex: Weak<RawExecutor>, +} + +impl<F: AsRawFd> RegisteredSource<F> { + // Start an asynchronous operation to wait for this source to become readable. The returned + // future will not be ready until the source is readable. + pub fn wait_readable(&self) -> Result<PendingOperation> { + let ex = self.ex.upgrade().ok_or(Error::ExecutorGone)?; + + let token = + ex.add_operation(self.source.as_raw_fd(), WatchingEvents::empty().set_read())?; + + Ok(PendingOperation { + token: Some(token), + ex: self.ex.clone(), + }) + } + + // Start an asynchronous operation to wait for this source to become writable. The returned + // future will not be ready until the source is writable. + pub fn wait_writable(&self) -> Result<PendingOperation> { + let ex = self.ex.upgrade().ok_or(Error::ExecutorGone)?; + + let token = + ex.add_operation(self.source.as_raw_fd(), WatchingEvents::empty().set_write())?; + + Ok(PendingOperation { + token: Some(token), + ex: self.ex.clone(), + }) + } +} + +impl<F> RegisteredSource<F> { + // Consume this RegisteredSource and return the inner IO source. + pub fn into_source(self) -> F { + self.source + } +} + +impl<F> AsRef<F> for RegisteredSource<F> { + fn as_ref(&self) -> &F { + &self.source + } +} + +impl<F> AsMut<F> for RegisteredSource<F> { + fn as_mut(&mut self) -> &mut F { + &mut self.source + } +} + /// A token returned from `add_operation` that can be used to cancel the waker before it completes. /// Used to manage getting the result from the underlying executor for a completed operation. /// Dropping a `PendingOperation` will get the result from the executor. @@ -227,11 +286,12 @@ impl RawExecutor { let raw = Arc::downgrade(self); let schedule = move |runnable| { if let Some(r) = raw.upgrade() { - r.queue.schedule(runnable); + r.queue.push_back(runnable); + r.wake(); } }; let (runnable, task) = async_task::spawn(f, schedule); - self.queue.schedule(runnable); + runnable.schedule(); task } @@ -243,11 +303,12 @@ impl RawExecutor { let raw = Arc::downgrade(self); let schedule = move |runnable| { if let Some(r) = raw.upgrade() { - r.queue.schedule(runnable); + r.queue.push_back(runnable); + r.wake(); } }; let (runnable, task) = async_task::spawn_local(f, schedule); - self.queue.schedule(runnable); + runnable.schedule(); task } @@ -257,7 +318,6 @@ impl RawExecutor { loop { self.state.store(PROCESSING, Ordering::Release); - self.queue.set_waker(cx.waker().clone()); for runnable in self.queue.iter() { runnable.run(); } @@ -266,10 +326,13 @@ impl RawExecutor { return Ok(val); } - let oldstate = self - .state - .compare_and_swap(PROCESSING, WAITING, Ordering::Acquire); - if oldstate != PROCESSING { + let oldstate = self.state.compare_exchange( + PROCESSING, + WAITING, + Ordering::Acquire, + Ordering::Acquire, + ); + if let Err(oldstate) = oldstate { debug_assert_eq!(oldstate, WOKEN); // One or more futures have become runnable. continue; @@ -430,24 +493,10 @@ impl FdExecutor { self.raw.run(&mut ctx, f) } - pub fn wait_readable<F: AsRawFd>(&self, f: &F) -> Result<PendingOperation> { - let token = self - .raw - .add_operation(f.as_raw_fd(), WatchingEvents::empty().set_read())?; - - Ok(PendingOperation { - token: Some(token), - ex: Arc::downgrade(&self.raw), - }) - } - - pub fn wait_writable<F: AsRawFd>(&self, f: &F) -> Result<PendingOperation> { - let token = self - .raw - .add_operation(f.as_raw_fd(), WatchingEvents::empty().set_read())?; - - Ok(PendingOperation { - token: Some(token), + pub(crate) fn register_source<F: AsRawFd>(&self, f: F) -> Result<RegisteredSource<F>> { + add_fd_flags(f.as_raw_fd(), libc::O_NONBLOCK).map_err(Error::SettingNonBlocking)?; + Ok(RegisteredSource { + source: f, ex: Arc::downgrade(&self.raw), }) } @@ -479,7 +528,8 @@ mod test { async fn do_test(ex: &FdExecutor) { let (r, _w) = sys_util::pipe(true).unwrap(); let done = Box::pin(async { 5usize }); - let pending = ex.wait_readable(&r).unwrap(); + let source = ex.register_source(r).unwrap(); + let pending = source.wait_readable().unwrap(); match futures::future::select(pending, done).await { Either::Right((5, pending)) => std::mem::drop(pending), _ => panic!("unexpected select result"), @@ -520,7 +570,8 @@ mod test { let ex = FdExecutor::new().unwrap(); - let op = ex.wait_writable(&tx).unwrap(); + let source = ex.register_source(tx.try_clone().unwrap()).unwrap(); + let op = source.wait_writable().unwrap(); ex.spawn_local(write_value(tx)).detach(); ex.spawn_local(check_op(op)).detach(); diff --git a/cros_async/src/lib.rs b/cros_async/src/lib.rs index 4188779ae..271115bd4 100644 --- a/cros_async/src/lib.rs +++ b/cros_async/src/lib.rs @@ -67,6 +67,7 @@ pub mod mem; mod poll_source; mod queue; mod select; +pub mod sync; mod timer; mod uring_executor; mod uring_source; diff --git a/cros_async/src/poll_source.rs b/cros_async/src/poll_source.rs index 748da78da..7fc674c62 100644 --- a/cros_async/src/poll_source.rs +++ b/cros_async/src/poll_source.rs @@ -10,11 +10,9 @@ use std::ops::{Deref, DerefMut}; use std::os::unix::io::AsRawFd; use std::sync::Arc; -use libc::O_NONBLOCK; -use sys_util::{self, add_fd_flags}; use thiserror::Error as ThisError; -use crate::fd_executor::{self, FdExecutor}; +use crate::fd_executor::{self, FdExecutor, RegisteredSource}; use crate::mem::{BackingMemory, MemRegion}; use crate::{AsyncError, AsyncResult}; use crate::{IoSourceExt, ReadAsync, WriteAsync}; @@ -35,15 +33,11 @@ pub enum Error { #[error("An error occurred when executing fsync synchronously: {0}")] Fsync(sys_util::Error), /// An error occurred when reading the FD. - /// #[error("An error occurred when reading the FD: {0}.")] Read(sys_util::Error), /// Can't seek file. #[error("An error occurred when seeking the FD: {0}.")] Seeking(sys_util::Error), - /// An error occurred when setting the FD non-blocking. - #[error("An error occurred setting the FD non-blocking: {0}.")] - SettingNonBlocking(sys_util::Error), /// An error occurred when writing the FD. #[error("An error occurred when writing the FD: {0}.")] Write(sys_util::Error), @@ -52,25 +46,19 @@ pub type Result<T> = std::result::Result<T, Error>; /// Async wrapper for an IO source that uses the FD executor to drive async operations. /// Used by `IoSourceExt::new` when uring isn't available. -pub struct PollSource<F: AsRawFd> { - source: F, - ex: FdExecutor, -} +pub struct PollSource<F>(RegisteredSource<F>); impl<F: AsRawFd> PollSource<F> { /// Create a new `PollSource` from the given IO source. pub fn new(f: F, ex: &FdExecutor) -> Result<Self> { - let fd = f.as_raw_fd(); - add_fd_flags(fd, O_NONBLOCK).map_err(Error::SettingNonBlocking)?; - Ok(Self { - source: f, - ex: ex.clone(), - }) + ex.register_source(f) + .map(PollSource) + .map_err(Error::Executor) } /// Return the inner source. pub fn into_source(self) -> F { - self.source + self.0.into_source() } } @@ -78,7 +66,13 @@ impl<F: AsRawFd> Deref for PollSource<F> { type Target = F; fn deref(&self) -> &Self::Target { - &self.source + self.0.as_ref() + } +} + +impl<F: AsRawFd> DerefMut for PollSource<F> { + fn deref_mut(&mut self) -> &mut Self::Target { + self.0.as_mut() } } @@ -94,7 +88,7 @@ impl<F: AsRawFd> ReadAsync for PollSource<F> { // Safe because this will only modify `vec` and we check the return value. let res = unsafe { libc::pread64( - self.source.as_raw_fd(), + self.as_raw_fd(), vec.as_mut_ptr() as *mut libc::c_void, vec.len(), file_offset as libc::off64_t, @@ -107,10 +101,7 @@ impl<F: AsRawFd> ReadAsync for PollSource<F> { match sys_util::Error::last() { e if e.errno() == libc::EWOULDBLOCK => { - let op = self - .ex - .wait_readable(&self.source) - .map_err(Error::AddingWaker)?; + let op = self.0.wait_readable().map_err(Error::AddingWaker)?; op.await.map_err(Error::Executor)?; } e => return Err(Error::Read(e).into()), @@ -135,7 +126,7 @@ impl<F: AsRawFd> ReadAsync for PollSource<F> { // guaranteed to be valid from the pointer by io_slice_mut. let res = unsafe { libc::preadv64( - self.source.as_raw_fd(), + self.as_raw_fd(), iovecs.as_mut_ptr() as *mut _, iovecs.len() as i32, file_offset as libc::off64_t, @@ -148,10 +139,7 @@ impl<F: AsRawFd> ReadAsync for PollSource<F> { match sys_util::Error::last() { e if e.errno() == libc::EWOULDBLOCK => { - let op = self - .ex - .wait_readable(&self.source) - .map_err(Error::AddingWaker)?; + let op = self.0.wait_readable().map_err(Error::AddingWaker)?; op.await.map_err(Error::Executor)?; } e => return Err(Error::Read(e).into()), @@ -161,10 +149,7 @@ impl<F: AsRawFd> ReadAsync for PollSource<F> { /// Wait for the FD of `self` to be readable. async fn wait_readable(&self) -> AsyncResult<()> { - let op = self - .ex - .wait_readable(&self.source) - .map_err(Error::AddingWaker)?; + let op = self.0.wait_readable().map_err(Error::AddingWaker)?; op.await.map_err(Error::Executor)?; Ok(()) } @@ -175,7 +160,7 @@ impl<F: AsRawFd> ReadAsync for PollSource<F> { // Safe because this will only modify `buf` and we check the return value. let res = unsafe { libc::read( - self.source.as_raw_fd(), + self.as_raw_fd(), buf.as_mut_ptr() as *mut libc::c_void, buf.len(), ) @@ -187,10 +172,7 @@ impl<F: AsRawFd> ReadAsync for PollSource<F> { match sys_util::Error::last() { e if e.errno() == libc::EWOULDBLOCK => { - let op = self - .ex - .wait_readable(&self.source) - .map_err(Error::AddingWaker)?; + let op = self.0.wait_readable().map_err(Error::AddingWaker)?; op.await.map_err(Error::Executor)?; } e => return Err(Error::Read(e).into()), @@ -211,7 +193,7 @@ impl<F: AsRawFd> WriteAsync for PollSource<F> { // Safe because this will not modify any memory and we check the return value. let res = unsafe { libc::pwrite64( - self.source.as_raw_fd(), + self.as_raw_fd(), vec.as_ptr() as *const libc::c_void, vec.len(), file_offset as libc::off64_t, @@ -224,10 +206,7 @@ impl<F: AsRawFd> WriteAsync for PollSource<F> { match sys_util::Error::last() { e if e.errno() == libc::EWOULDBLOCK => { - let op = self - .ex - .wait_writable(&self.source) - .map_err(Error::AddingWaker)?; + let op = self.0.wait_writable().map_err(Error::AddingWaker)?; op.await.map_err(Error::Executor)?; } e => return Err(Error::Write(e).into()), @@ -253,7 +232,7 @@ impl<F: AsRawFd> WriteAsync for PollSource<F> { // guaranteed to be valid from the pointer by io_slice_mut. let res = unsafe { libc::pwritev64( - self.source.as_raw_fd(), + self.as_raw_fd(), iovecs.as_ptr() as *mut _, iovecs.len() as i32, file_offset as libc::off64_t, @@ -266,10 +245,7 @@ impl<F: AsRawFd> WriteAsync for PollSource<F> { match sys_util::Error::last() { e if e.errno() == libc::EWOULDBLOCK => { - let op = self - .ex - .wait_writable(&self.source) - .map_err(Error::AddingWaker)?; + let op = self.0.wait_writable().map_err(Error::AddingWaker)?; op.await.map_err(Error::Executor)?; } e => return Err(Error::Write(e).into()), @@ -281,7 +257,7 @@ impl<F: AsRawFd> WriteAsync for PollSource<F> { async fn fallocate(&self, file_offset: u64, len: u64, mode: u32) -> AsyncResult<()> { let ret = unsafe { libc::fallocate64( - self.source.as_raw_fd(), + self.as_raw_fd(), mode as libc::c_int, file_offset as libc::off64_t, len as libc::off64_t, @@ -296,7 +272,7 @@ impl<F: AsRawFd> WriteAsync for PollSource<F> { /// Sync all completed write operations to the backing storage. async fn fsync(&self) -> AsyncResult<()> { - let ret = unsafe { libc::fsync(self.source.as_raw_fd()) }; + let ret = unsafe { libc::fsync(self.as_raw_fd()) }; if ret == 0 { Ok(()) } else { @@ -309,23 +285,17 @@ impl<F: AsRawFd> WriteAsync for PollSource<F> { impl<F: AsRawFd> IoSourceExt<F> for PollSource<F> { /// Yields the underlying IO source. fn into_source(self: Box<Self>) -> F { - self.source + self.0.into_source() } /// Provides a mutable ref to the underlying IO source. fn as_source_mut(&mut self) -> &mut F { - &mut self.source + self } /// Provides a ref to the underlying IO source. fn as_source(&self) -> &F { - &self.source - } -} - -impl<F: AsRawFd> DerefMut for PollSource<F> { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.source + self } } @@ -393,4 +363,23 @@ mod tests { let ex = FdExecutor::new().unwrap(); ex.run_until(go(&ex)).unwrap(); } + + #[test] + fn memory_leak() { + // This test needs to run under ASAN to detect memory leaks. + + async fn owns_poll_source(source: PollSource<File>) { + let _ = source.wait_readable().await; + } + + let (rx, _tx) = sys_util::pipe(true).unwrap(); + let ex = FdExecutor::new().unwrap(); + let source = PollSource::new(rx, &ex).unwrap(); + ex.spawn_local(owns_poll_source(source)).detach(); + + // Drop `ex` without running. This would cause a memory leak if PollSource owned a strong + // reference to the executor because it owns a reference to the future that owns PollSource + // (via its Runnable). The strong reference prevents the drop impl from running, which would + // otherwise poll the future and have it return with an error. + } } diff --git a/cros_async/src/queue.rs b/cros_async/src/queue.rs index dc7bb2d87..3192703a5 100644 --- a/cros_async/src/queue.rs +++ b/cros_async/src/queue.rs @@ -3,50 +3,32 @@ // found in the LICENSE file. use std::collections::VecDeque; -use std::mem; -use std::task::Waker; use async_task::Runnable; use sync::Mutex; -struct Inner { - runnables: VecDeque<Runnable>, - waker: Option<Waker>, -} - /// A queue of `Runnables`. Intended to be used by executors to keep track of futures that have been /// scheduled to run. pub struct RunnableQueue { - inner: Mutex<Inner>, + runnables: Mutex<VecDeque<Runnable>>, } impl RunnableQueue { /// Create a new, empty `RunnableQueue`. pub fn new() -> RunnableQueue { RunnableQueue { - inner: Mutex::new(Inner { - runnables: VecDeque::new(), - waker: None, - }), + runnables: Mutex::new(VecDeque::new()), } } - /// Schedule `runnable` to run in the future by adding it to this `RunnableQueue`. Also wakes up - /// the waker associated with this `RunnableQueue`, if any. - pub fn schedule(&self, runnable: Runnable) { - let mut inner = self.inner.lock(); - inner.runnables.push_back(runnable); - let waker = inner.waker.take(); - mem::drop(inner); - - if let Some(w) = waker { - w.wake(); - } + /// Schedule `runnable` to run in the future by adding it to this `RunnableQueue`. + pub fn push_back(&self, runnable: Runnable) { + self.runnables.lock().push_back(runnable); } /// Remove and return the first `Runnable` in this `RunnableQueue` or `None` if it is empty. pub fn pop_front(&self) -> Option<Runnable> { - self.inner.lock().runnables.pop_front() + self.runnables.lock().pop_front() } /// Create an iterator over this `RunnableQueue` that repeatedly calls `pop_front()` until it is @@ -54,12 +36,6 @@ impl RunnableQueue { pub fn iter(&self) -> RunnableQueueIter { self.into_iter() } - - /// Associate `waker` with this `RunnableQueue`. `waker` will be woken up the next time - /// `schedule()` is called. - pub fn set_waker(&self, waker: Waker) { - self.inner.lock().waker = Some(waker); - } } impl<'q> IntoIterator for &'q RunnableQueue { diff --git a/cros_async/src/sync.rs b/cros_async/src/sync.rs new file mode 100644 index 000000000..80dec3858 --- /dev/null +++ b/cros_async/src/sync.rs @@ -0,0 +1,14 @@ +// Copyright 2020 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +mod blocking; +mod cv; +mod mu; +mod spin; +mod waiter; + +pub use blocking::block_on; +pub use cv::Condvar; +pub use mu::Mutex; +pub use spin::SpinLock; diff --git a/cros_async/src/sync/blocking.rs b/cros_async/src/sync/blocking.rs new file mode 100644 index 000000000..ce02e4dd7 --- /dev/null +++ b/cros_async/src/sync/blocking.rs @@ -0,0 +1,192 @@ +// Copyright 2020 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +use std::future::Future; +use std::ptr; +use std::sync::atomic::{AtomicI32, Ordering}; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use futures::pin_mut; +use futures::task::{waker_ref, ArcWake}; + +// Randomly generated values to indicate the state of the current thread. +const WAITING: i32 = 0x25de_74d1; +const WOKEN: i32 = 0x72d3_2c9f; + +const FUTEX_WAIT_PRIVATE: libc::c_int = libc::FUTEX_WAIT | libc::FUTEX_PRIVATE_FLAG; +const FUTEX_WAKE_PRIVATE: libc::c_int = libc::FUTEX_WAKE | libc::FUTEX_PRIVATE_FLAG; + +thread_local!(static PER_THREAD_WAKER: Arc<Waker> = Arc::new(Waker(AtomicI32::new(WAITING)))); + +#[repr(transparent)] +struct Waker(AtomicI32); + +extern "C" { + #[cfg_attr(target_os = "android", link_name = "__errno")] + #[cfg_attr(target_os = "linux", link_name = "__errno_location")] + fn errno_location() -> *mut libc::c_int; +} + +impl ArcWake for Waker { + fn wake_by_ref(arc_self: &Arc<Self>) { + let state = arc_self.0.swap(WOKEN, Ordering::Release); + if state == WAITING { + // The thread hasn't already been woken up so wake it up now. Safe because this doesn't + // modify any memory and we check the return value. + let res = unsafe { + libc::syscall( + libc::SYS_futex, + &arc_self.0, + FUTEX_WAKE_PRIVATE, + libc::INT_MAX, // val + ptr::null() as *const libc::timespec, // timeout + ptr::null() as *const libc::c_int, // uaddr2 + 0_i32, // val3 + ) + }; + if res < 0 { + panic!("unexpected error from FUTEX_WAKE_PRIVATE: {}", unsafe { + *errno_location() + }); + } + } + } +} + +/// Run a future to completion on the current thread. +/// +/// This method will block the current thread until `f` completes. Useful when you need to call an +/// async fn from a non-async context. +pub fn block_on<F: Future>(f: F) -> F::Output { + pin_mut!(f); + + PER_THREAD_WAKER.with(|thread_waker| { + let waker = waker_ref(thread_waker); + let mut cx = Context::from_waker(&waker); + + loop { + if let Poll::Ready(t) = f.as_mut().poll(&mut cx) { + return t; + } + + let state = thread_waker.0.swap(WAITING, Ordering::Acquire); + if state == WAITING { + // If we weren't already woken up then wait until we are. Safe because this doesn't + // modify any memory and we check the return value. + let res = unsafe { + libc::syscall( + libc::SYS_futex, + &thread_waker.0, + FUTEX_WAIT_PRIVATE, + state, + ptr::null() as *const libc::timespec, // timeout + ptr::null() as *const libc::c_int, // uaddr2 + 0_i32, // val3 + ) + }; + + if res < 0 { + // Safe because libc guarantees that this is a valid pointer. + match unsafe { *errno_location() } { + libc::EAGAIN | libc::EINTR => {} + e => panic!("unexpected error from FUTEX_WAIT_PRIVATE: {}", e), + } + } + + // Clear the state to prevent unnecessary extra loop iterations and also to allow + // nested usage of `block_on`. + thread_waker.0.store(WAITING, Ordering::Release); + } + } + }) +} + +#[cfg(test)] +mod test { + use super::*; + + use std::future::Future; + use std::pin::Pin; + use std::sync::mpsc::{channel, Sender}; + use std::sync::Arc; + use std::task::{Context, Poll, Waker}; + use std::thread; + use std::time::Duration; + + use crate::sync::SpinLock; + + struct TimerState { + fired: bool, + waker: Option<Waker>, + } + struct Timer { + state: Arc<SpinLock<TimerState>>, + } + + impl Future for Timer { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> { + let mut state = self.state.lock(); + if state.fired { + return Poll::Ready(()); + } + + state.waker = Some(cx.waker().clone()); + Poll::Pending + } + } + + fn start_timer(dur: Duration, notify: Option<Sender<()>>) -> Timer { + let state = Arc::new(SpinLock::new(TimerState { + fired: false, + waker: None, + })); + + let thread_state = Arc::clone(&state); + thread::spawn(move || { + thread::sleep(dur); + let mut ts = thread_state.lock(); + ts.fired = true; + if let Some(waker) = ts.waker.take() { + waker.wake(); + } + drop(ts); + + if let Some(tx) = notify { + tx.send(()).expect("Failed to send completion notification"); + } + }); + + Timer { state } + } + + #[test] + fn it_works() { + block_on(start_timer(Duration::from_millis(100), None)); + } + + #[test] + fn nested() { + async fn inner() { + block_on(start_timer(Duration::from_millis(100), None)); + } + + block_on(inner()); + } + + #[test] + fn ready_before_poll() { + let (tx, rx) = channel(); + + let timer = start_timer(Duration::from_millis(50), Some(tx)); + + rx.recv() + .expect("Failed to receive completion notification"); + + // We know the timer has already fired so the poll should complete immediately. + block_on(timer); + } +} diff --git a/cros_async/src/sync/cv.rs b/cros_async/src/sync/cv.rs new file mode 100644 index 000000000..4da0a12ef --- /dev/null +++ b/cros_async/src/sync/cv.rs @@ -0,0 +1,1159 @@ +// Copyright 2020 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +use std::cell::UnsafeCell; +use std::mem; +use std::sync::atomic::{spin_loop_hint, AtomicUsize, Ordering}; +use std::sync::Arc; + +use crate::sync::mu::{MutexGuard, MutexReadGuard, RawMutex}; +use crate::sync::waiter::{Kind as WaiterKind, Waiter, WaiterAdapter, WaiterList, WaitingFor}; + +const SPINLOCK: usize = 1 << 0; +const HAS_WAITERS: usize = 1 << 1; + +/// A primitive to wait for an event to occur without consuming CPU time. +/// +/// Condition variables are used in combination with a `Mutex` when a thread wants to wait for some +/// condition to become true. The condition must always be verified while holding the `Mutex` lock. +/// It is an error to use a `Condvar` with more than one `Mutex` while there are threads waiting on +/// the `Condvar`. +/// +/// # Examples +/// +/// ```edition2018 +/// use std::sync::Arc; +/// use std::thread; +/// use std::sync::mpsc::channel; +/// +/// use cros_async::sync::{block_on, Condvar, Mutex}; +/// +/// const N: usize = 13; +/// +/// // Spawn a few threads to increment a shared variable (non-atomically), and +/// // let all threads waiting on the Condvar know once the increments are done. +/// let data = Arc::new(Mutex::new(0)); +/// let cv = Arc::new(Condvar::new()); +/// +/// for _ in 0..N { +/// let (data, cv) = (data.clone(), cv.clone()); +/// thread::spawn(move || { +/// let mut data = block_on(data.lock()); +/// *data += 1; +/// if *data == N { +/// cv.notify_all(); +/// } +/// }); +/// } +/// +/// let mut val = block_on(data.lock()); +/// while *val != N { +/// val = block_on(cv.wait(val)); +/// } +/// ``` +#[repr(align(128))] +pub struct Condvar { + state: AtomicUsize, + waiters: UnsafeCell<WaiterList>, + mu: UnsafeCell<usize>, +} + +impl Condvar { + /// Creates a new condition variable ready to be waited on and notified. + pub fn new() -> Condvar { + Condvar { + state: AtomicUsize::new(0), + waiters: UnsafeCell::new(WaiterList::new(WaiterAdapter::new())), + mu: UnsafeCell::new(0), + } + } + + /// Block the current thread until this `Condvar` is notified by another thread. + /// + /// This method will atomically unlock the `Mutex` held by `guard` and then block the current + /// thread. Any call to `notify_one` or `notify_all` after the `Mutex` is unlocked may wake up + /// the thread. + /// + /// To allow for more efficient scheduling, this call may return even when the programmer + /// doesn't expect the thread to be woken. Therefore, calls to `wait()` should be used inside a + /// loop that checks the predicate before continuing. + /// + /// Callers that are not in an async context may wish to use the `block_on` method to block the + /// thread until the `Condvar` is notified. + /// + /// # Panics + /// + /// This method will panic if used with more than one `Mutex` at the same time. + /// + /// # Examples + /// + /// ``` + /// # use std::sync::Arc; + /// # use std::thread; + /// + /// # use cros_async::sync::{block_on, Condvar, Mutex}; + /// + /// # let mu = Arc::new(Mutex::new(false)); + /// # let cv = Arc::new(Condvar::new()); + /// # let (mu2, cv2) = (mu.clone(), cv.clone()); + /// + /// # let t = thread::spawn(move || { + /// # *block_on(mu2.lock()) = true; + /// # cv2.notify_all(); + /// # }); + /// + /// let mut ready = block_on(mu.lock()); + /// while !*ready { + /// ready = block_on(cv.wait(ready)); + /// } + /// + /// # t.join().expect("failed to join thread"); + /// ``` + // Clippy doesn't like the lifetime parameters here but doing what it suggests leads to code + // that doesn't compile. + #[allow(clippy::needless_lifetimes)] + pub async fn wait<'g, T>(&self, guard: MutexGuard<'g, T>) -> MutexGuard<'g, T> { + let waiter = Arc::new(Waiter::new( + WaiterKind::Exclusive, + cancel_waiter, + self as *const Condvar as usize, + WaitingFor::Condvar, + )); + + self.add_waiter(waiter.clone(), guard.as_raw_mutex()); + + // Get a reference to the mutex and then drop the lock. + let mu = guard.into_inner(); + + // Wait to be woken up. + waiter.wait().await; + + // Now re-acquire the lock. + mu.lock_from_cv().await + } + + /// Like `wait()` but takes and returns a `MutexReadGuard` instead. + // Clippy doesn't like the lifetime parameters here but doing what it suggests leads to code + // that doesn't compile. + #[allow(clippy::needless_lifetimes)] + pub async fn wait_read<'g, T>(&self, guard: MutexReadGuard<'g, T>) -> MutexReadGuard<'g, T> { + let waiter = Arc::new(Waiter::new( + WaiterKind::Shared, + cancel_waiter, + self as *const Condvar as usize, + WaitingFor::Condvar, + )); + + self.add_waiter(waiter.clone(), guard.as_raw_mutex()); + + // Get a reference to the mutex and then drop the lock. + let mu = guard.into_inner(); + + // Wait to be woken up. + waiter.wait().await; + + // Now re-acquire the lock. + mu.read_lock_from_cv().await + } + + fn add_waiter(&self, waiter: Arc<Waiter>, raw_mutex: &RawMutex) { + // Acquire the spin lock. + let mut oldstate = self.state.load(Ordering::Relaxed); + while (oldstate & SPINLOCK) != 0 + || self + .state + .compare_exchange_weak( + oldstate, + oldstate | SPINLOCK | HAS_WAITERS, + Ordering::Acquire, + Ordering::Relaxed, + ) + .is_err() + { + spin_loop_hint(); + oldstate = self.state.load(Ordering::Relaxed); + } + + // Safe because the spin lock guarantees exclusive access and the reference does not escape + // this function. + let mu = unsafe { &mut *self.mu.get() }; + let muptr = raw_mutex as *const RawMutex as usize; + + match *mu { + 0 => *mu = muptr, + p if p == muptr => {} + _ => panic!("Attempting to use Condvar with more than one Mutex at the same time"), + } + + // Safe because the spin lock guarantees exclusive access. + unsafe { (*self.waiters.get()).push_back(waiter) }; + + // Release the spin lock. Use a direct store here because no other thread can modify + // `self.state` while we hold the spin lock. Keep the `HAS_WAITERS` bit that we set earlier + // because we just added a waiter. + self.state.store(HAS_WAITERS, Ordering::Release); + } + + /// Notify at most one thread currently waiting on the `Condvar`. + /// + /// If there is a thread currently waiting on the `Condvar` it will be woken up from its call to + /// `wait`. + /// + /// Unlike more traditional condition variable interfaces, this method requires a reference to + /// the `Mutex` associated with this `Condvar`. This is because it is inherently racy to call + /// `notify_one` or `notify_all` without first acquiring the `Mutex` lock. Additionally, taking + /// a reference to the `Mutex` here allows us to make some optimizations that can improve + /// performance by reducing unnecessary wakeups. + pub fn notify_one(&self) { + let mut oldstate = self.state.load(Ordering::Relaxed); + if (oldstate & HAS_WAITERS) == 0 { + // No waiters. + return; + } + + while (oldstate & SPINLOCK) != 0 + || self + .state + .compare_exchange_weak( + oldstate, + oldstate | SPINLOCK, + Ordering::Acquire, + Ordering::Relaxed, + ) + .is_err() + { + spin_loop_hint(); + oldstate = self.state.load(Ordering::Relaxed); + } + + // Safe because the spin lock guarantees exclusive access and the reference does not escape + // this function. + let waiters = unsafe { &mut *self.waiters.get() }; + let wake_list = get_wake_list(waiters); + + let newstate = if waiters.is_empty() { + // Also clear the mutex associated with this Condvar since there are no longer any + // waiters. Safe because the spin lock guarantees exclusive access. + unsafe { *self.mu.get() = 0 }; + + // We are releasing the spin lock and there are no more waiters so we can clear all bits + // in `self.state`. + 0 + } else { + // There are still waiters so we need to keep the HAS_WAITERS bit in the state. + HAS_WAITERS + }; + + // Release the spin lock. + self.state.store(newstate, Ordering::Release); + + // Now wake any waiters in the wake list. + for w in wake_list { + w.wake(); + } + } + + /// Notify all threads currently waiting on the `Condvar`. + /// + /// All threads currently waiting on the `Condvar` will be woken up from their call to `wait`. + /// + /// Unlike more traditional condition variable interfaces, this method requires a reference to + /// the `Mutex` associated with this `Condvar`. This is because it is inherently racy to call + /// `notify_one` or `notify_all` without first acquiring the `Mutex` lock. Additionally, taking + /// a reference to the `Mutex` here allows us to make some optimizations that can improve + /// performance by reducing unnecessary wakeups. + pub fn notify_all(&self) { + let mut oldstate = self.state.load(Ordering::Relaxed); + if (oldstate & HAS_WAITERS) == 0 { + // No waiters. + return; + } + + while (oldstate & SPINLOCK) != 0 + || self + .state + .compare_exchange_weak( + oldstate, + oldstate | SPINLOCK, + Ordering::Acquire, + Ordering::Relaxed, + ) + .is_err() + { + spin_loop_hint(); + oldstate = self.state.load(Ordering::Relaxed); + } + + // Safe because the spin lock guarantees exclusive access to `self.waiters`. + let wake_list = unsafe { (*self.waiters.get()).take() }; + + // Clear the mutex associated with this Condvar since there are no longer any waiters. Safe + // because we the spin lock guarantees exclusive access. + unsafe { *self.mu.get() = 0 }; + + // Mark any waiters left as no longer waiting for the Condvar. + for w in &wake_list { + w.set_waiting_for(WaitingFor::None); + } + + // Release the spin lock. We can clear all bits in the state since we took all the waiters. + self.state.store(0, Ordering::Release); + + // Now wake any waiters in the wake list. + for w in wake_list { + w.wake(); + } + } + + fn cancel_waiter(&self, waiter: &Waiter, wake_next: bool) { + let mut oldstate = self.state.load(Ordering::Relaxed); + while oldstate & SPINLOCK != 0 + || self + .state + .compare_exchange_weak( + oldstate, + oldstate | SPINLOCK, + Ordering::Acquire, + Ordering::Relaxed, + ) + .is_err() + { + spin_loop_hint(); + oldstate = self.state.load(Ordering::Relaxed); + } + + // Safe because the spin lock provides exclusive access and the reference does not escape + // this function. + let waiters = unsafe { &mut *self.waiters.get() }; + + let waiting_for = waiter.is_waiting_for(); + // Don't drop the old waiter now as we're still holding the spin lock. + let old_waiter = if waiter.is_linked() && waiting_for == WaitingFor::Condvar { + // Safe because we know that the waiter is still linked and is waiting for the Condvar, + // which guarantees that it is still in `self.waiters`. + let mut cursor = unsafe { waiters.cursor_mut_from_ptr(waiter as *const Waiter) }; + cursor.remove() + } else { + None + }; + + let wake_list = if wake_next || waiting_for == WaitingFor::None { + // Either the waiter was already woken or it's been removed from the condvar's waiter + // list and is going to be woken. Either way, we need to wake up another thread. + get_wake_list(waiters) + } else { + WaiterList::new(WaiterAdapter::new()) + }; + + let set_on_release = if waiters.is_empty() { + // Clear the mutex associated with this Condvar since there are no longer any waiters. Safe + // because we the spin lock guarantees exclusive access. + unsafe { *self.mu.get() = 0 }; + + 0 + } else { + HAS_WAITERS + }; + + self.state.store(set_on_release, Ordering::Release); + + // Now wake any waiters still left in the wake list. + for w in wake_list { + w.wake(); + } + + mem::drop(old_waiter); + } +} + +unsafe impl Send for Condvar {} +unsafe impl Sync for Condvar {} + +impl Default for Condvar { + fn default() -> Self { + Self::new() + } +} + +// Scan `waiters` and return all waiters that should be woken up. +// +// If the first waiter is trying to acquire a shared lock, then all waiters in the list that are +// waiting for a shared lock are also woken up. In addition one writer is woken up, if possible. +// +// If the first waiter is trying to acquire an exclusive lock, then only that waiter is returned and +// the rest of the list is not scanned. +fn get_wake_list(waiters: &mut WaiterList) -> WaiterList { + let mut to_wake = WaiterList::new(WaiterAdapter::new()); + let mut cursor = waiters.front_mut(); + + let mut waking_readers = false; + let mut all_readers = true; + while let Some(w) = cursor.get() { + match w.kind() { + WaiterKind::Exclusive if !waking_readers => { + // This is the first waiter and it's a writer. No need to check the other waiters. + // Also mark the waiter as having been removed from the Condvar's waiter list. + let waiter = cursor.remove().unwrap(); + waiter.set_waiting_for(WaitingFor::None); + to_wake.push_back(waiter); + break; + } + + WaiterKind::Shared => { + // This is a reader and the first waiter in the list was not a writer so wake up all + // the readers in the wait list. + let waiter = cursor.remove().unwrap(); + waiter.set_waiting_for(WaitingFor::None); + to_wake.push_back(waiter); + waking_readers = true; + } + + WaiterKind::Exclusive => { + debug_assert!(waking_readers); + if all_readers { + // We are waking readers but we need to ensure that at least one writer is woken + // up. Since we haven't yet woken up a writer, wake up this one. + let waiter = cursor.remove().unwrap(); + waiter.set_waiting_for(WaitingFor::None); + to_wake.push_back(waiter); + all_readers = false; + } else { + // We are waking readers and have already woken one writer. Skip this one. + cursor.move_next(); + } + } + } + } + + to_wake +} + +fn cancel_waiter(cv: usize, waiter: &Waiter, wake_next: bool) { + let condvar = cv as *const Condvar; + + // Safe because the thread that owns the waiter being canceled must also own a reference to the + // Condvar, which guarantees that this pointer is valid. + unsafe { (*condvar).cancel_waiter(waiter, wake_next) } +} + +#[cfg(test)] +mod test { + use super::*; + + use std::future::Future; + use std::mem; + use std::ptr; + use std::rc::Rc; + use std::sync::mpsc::{channel, Sender}; + use std::sync::Arc; + use std::task::{Context, Poll}; + use std::thread::{self, JoinHandle}; + use std::time::Duration; + + use futures::channel::oneshot; + use futures::task::{waker_ref, ArcWake}; + use futures::{select, FutureExt}; + use futures_executor::{LocalPool, LocalSpawner, ThreadPool}; + use futures_util::task::LocalSpawnExt; + + use crate::sync::{block_on, Mutex}; + + // Dummy waker used when we want to manually drive futures. + struct TestWaker; + impl ArcWake for TestWaker { + fn wake_by_ref(_arc_self: &Arc<Self>) {} + } + + #[test] + fn smoke() { + let cv = Condvar::new(); + cv.notify_one(); + cv.notify_all(); + } + + #[test] + fn notify_one() { + let mu = Arc::new(Mutex::new(())); + let cv = Arc::new(Condvar::new()); + + let mu2 = mu.clone(); + let cv2 = cv.clone(); + + let guard = block_on(mu.lock()); + thread::spawn(move || { + let _g = block_on(mu2.lock()); + cv2.notify_one(); + }); + + let guard = block_on(cv.wait(guard)); + mem::drop(guard); + } + + #[test] + fn multi_mutex() { + const NUM_THREADS: usize = 5; + + let mu = Arc::new(Mutex::new(false)); + let cv = Arc::new(Condvar::new()); + + let mut threads = Vec::with_capacity(NUM_THREADS); + for _ in 0..NUM_THREADS { + let mu = mu.clone(); + let cv = cv.clone(); + + threads.push(thread::spawn(move || { + let mut ready = block_on(mu.lock()); + while !*ready { + ready = block_on(cv.wait(ready)); + } + })); + } + + let mut g = block_on(mu.lock()); + *g = true; + mem::drop(g); + cv.notify_all(); + + threads + .into_iter() + .try_for_each(JoinHandle::join) + .expect("Failed to join threads"); + + // Now use the Condvar with a different mutex. + let alt_mu = Arc::new(Mutex::new(None)); + let alt_mu2 = alt_mu.clone(); + let cv2 = cv.clone(); + let handle = thread::spawn(move || { + let mut g = block_on(alt_mu2.lock()); + while g.is_none() { + g = block_on(cv2.wait(g)); + } + }); + + let mut alt_g = block_on(alt_mu.lock()); + *alt_g = Some(()); + mem::drop(alt_g); + cv.notify_all(); + + handle + .join() + .expect("Failed to join thread alternate mutex"); + } + + #[test] + fn notify_one_single_thread_async() { + async fn notify(mu: Rc<Mutex<()>>, cv: Rc<Condvar>) { + let _g = mu.lock().await; + cv.notify_one(); + } + + async fn wait(mu: Rc<Mutex<()>>, cv: Rc<Condvar>, spawner: LocalSpawner) { + let mu2 = Rc::clone(&mu); + let cv2 = Rc::clone(&cv); + + let g = mu.lock().await; + // Has to be spawned _after_ acquiring the lock to prevent a race + // where the notify happens before the waiter has acquired the lock. + spawner + .spawn_local(notify(mu2, cv2)) + .expect("Failed to spawn `notify` task"); + let _g = cv.wait(g).await; + } + + let mut ex = LocalPool::new(); + let spawner = ex.spawner(); + + let mu = Rc::new(Mutex::new(())); + let cv = Rc::new(Condvar::new()); + + spawner + .spawn_local(wait(mu, cv, spawner.clone())) + .expect("Failed to spawn `wait` task"); + + ex.run(); + } + + #[test] + fn notify_one_multi_thread_async() { + async fn notify(mu: Arc<Mutex<()>>, cv: Arc<Condvar>) { + let _g = mu.lock().await; + cv.notify_one(); + } + + async fn wait(mu: Arc<Mutex<()>>, cv: Arc<Condvar>, tx: Sender<()>, pool: ThreadPool) { + let mu2 = Arc::clone(&mu); + let cv2 = Arc::clone(&cv); + + let g = mu.lock().await; + // Has to be spawned _after_ acquiring the lock to prevent a race + // where the notify happens before the waiter has acquired the lock. + pool.spawn_ok(notify(mu2, cv2)); + let _g = cv.wait(g).await; + + tx.send(()).expect("Failed to send completion notification"); + } + + let ex = ThreadPool::new().expect("Failed to create ThreadPool"); + + let mu = Arc::new(Mutex::new(())); + let cv = Arc::new(Condvar::new()); + + let (tx, rx) = channel(); + ex.spawn_ok(wait(mu, cv, tx, ex.clone())); + + rx.recv_timeout(Duration::from_secs(5)) + .expect("Failed to receive completion notification"); + } + + #[test] + fn notify_one_with_cancel() { + const TASKS: usize = 17; + const OBSERVERS: usize = 7; + const ITERATIONS: usize = 103; + + async fn observe(mu: &Arc<Mutex<usize>>, cv: &Arc<Condvar>) { + let mut count = mu.read_lock().await; + while *count == 0 { + count = cv.wait_read(count).await; + } + let _ = unsafe { ptr::read_volatile(&*count as *const usize) }; + } + + async fn decrement(mu: &Arc<Mutex<usize>>, cv: &Arc<Condvar>) { + let mut count = mu.lock().await; + while *count == 0 { + count = cv.wait(count).await; + } + *count -= 1; + } + + async fn increment(mu: Arc<Mutex<usize>>, cv: Arc<Condvar>, done: Sender<()>) { + for _ in 0..TASKS * OBSERVERS * ITERATIONS { + *mu.lock().await += 1; + cv.notify_one(); + } + + done.send(()).expect("Failed to send completion message"); + } + + async fn observe_either( + mu: Arc<Mutex<usize>>, + cv: Arc<Condvar>, + alt_mu: Arc<Mutex<usize>>, + alt_cv: Arc<Condvar>, + done: Sender<()>, + ) { + for _ in 0..ITERATIONS { + select! { + () = observe(&mu, &cv).fuse() => {}, + () = observe(&alt_mu, &alt_cv).fuse() => {}, + } + } + + done.send(()).expect("Failed to send completion message"); + } + + async fn decrement_either( + mu: Arc<Mutex<usize>>, + cv: Arc<Condvar>, + alt_mu: Arc<Mutex<usize>>, + alt_cv: Arc<Condvar>, + done: Sender<()>, + ) { + for _ in 0..ITERATIONS { + select! { + () = decrement(&mu, &cv).fuse() => {}, + () = decrement(&alt_mu, &alt_cv).fuse() => {}, + } + } + + done.send(()).expect("Failed to send completion message"); + } + + let ex = ThreadPool::new().expect("Failed to create ThreadPool"); + + let mu = Arc::new(Mutex::new(0usize)); + let alt_mu = Arc::new(Mutex::new(0usize)); + + let cv = Arc::new(Condvar::new()); + let alt_cv = Arc::new(Condvar::new()); + + let (tx, rx) = channel(); + for _ in 0..TASKS { + ex.spawn_ok(decrement_either( + Arc::clone(&mu), + Arc::clone(&cv), + Arc::clone(&alt_mu), + Arc::clone(&alt_cv), + tx.clone(), + )); + } + + for _ in 0..OBSERVERS { + ex.spawn_ok(observe_either( + Arc::clone(&mu), + Arc::clone(&cv), + Arc::clone(&alt_mu), + Arc::clone(&alt_cv), + tx.clone(), + )); + } + + ex.spawn_ok(increment(Arc::clone(&mu), Arc::clone(&cv), tx.clone())); + ex.spawn_ok(increment(Arc::clone(&alt_mu), Arc::clone(&alt_cv), tx)); + + for _ in 0..TASKS + OBSERVERS + 2 { + if let Err(e) = rx.recv_timeout(Duration::from_secs(20)) { + panic!("Error while waiting for threads to complete: {}", e); + } + } + + assert_eq!( + *block_on(mu.read_lock()) + *block_on(alt_mu.read_lock()), + (TASKS * OBSERVERS * ITERATIONS * 2) - (TASKS * ITERATIONS) + ); + assert_eq!(cv.state.load(Ordering::Relaxed), 0); + assert_eq!(alt_cv.state.load(Ordering::Relaxed), 0); + } + + #[test] + fn notify_all_with_cancel() { + const TASKS: usize = 17; + const ITERATIONS: usize = 103; + + async fn decrement(mu: &Arc<Mutex<usize>>, cv: &Arc<Condvar>) { + let mut count = mu.lock().await; + while *count == 0 { + count = cv.wait(count).await; + } + *count -= 1; + } + + async fn increment(mu: Arc<Mutex<usize>>, cv: Arc<Condvar>, done: Sender<()>) { + for _ in 0..TASKS * ITERATIONS { + *mu.lock().await += 1; + cv.notify_all(); + } + + done.send(()).expect("Failed to send completion message"); + } + + async fn decrement_either( + mu: Arc<Mutex<usize>>, + cv: Arc<Condvar>, + alt_mu: Arc<Mutex<usize>>, + alt_cv: Arc<Condvar>, + done: Sender<()>, + ) { + for _ in 0..ITERATIONS { + select! { + () = decrement(&mu, &cv).fuse() => {}, + () = decrement(&alt_mu, &alt_cv).fuse() => {}, + } + } + + done.send(()).expect("Failed to send completion message"); + } + + let ex = ThreadPool::new().expect("Failed to create ThreadPool"); + + let mu = Arc::new(Mutex::new(0usize)); + let alt_mu = Arc::new(Mutex::new(0usize)); + + let cv = Arc::new(Condvar::new()); + let alt_cv = Arc::new(Condvar::new()); + + let (tx, rx) = channel(); + for _ in 0..TASKS { + ex.spawn_ok(decrement_either( + Arc::clone(&mu), + Arc::clone(&cv), + Arc::clone(&alt_mu), + Arc::clone(&alt_cv), + tx.clone(), + )); + } + + ex.spawn_ok(increment(Arc::clone(&mu), Arc::clone(&cv), tx.clone())); + ex.spawn_ok(increment(Arc::clone(&alt_mu), Arc::clone(&alt_cv), tx)); + + for _ in 0..TASKS + 2 { + if let Err(e) = rx.recv_timeout(Duration::from_secs(10)) { + panic!("Error while waiting for threads to complete: {}", e); + } + } + + assert_eq!( + *block_on(mu.read_lock()) + *block_on(alt_mu.read_lock()), + TASKS * ITERATIONS + ); + assert_eq!(cv.state.load(Ordering::Relaxed), 0); + assert_eq!(alt_cv.state.load(Ordering::Relaxed), 0); + } + #[test] + fn notify_all() { + const THREADS: usize = 13; + + let mu = Arc::new(Mutex::new(0)); + let cv = Arc::new(Condvar::new()); + let (tx, rx) = channel(); + + let mut threads = Vec::with_capacity(THREADS); + for _ in 0..THREADS { + let mu2 = mu.clone(); + let cv2 = cv.clone(); + let tx2 = tx.clone(); + + threads.push(thread::spawn(move || { + let mut count = block_on(mu2.lock()); + *count += 1; + if *count == THREADS { + tx2.send(()).unwrap(); + } + + while *count != 0 { + count = block_on(cv2.wait(count)); + } + })); + } + + mem::drop(tx); + + // Wait till all threads have started. + rx.recv_timeout(Duration::from_secs(5)).unwrap(); + + let mut count = block_on(mu.lock()); + *count = 0; + mem::drop(count); + cv.notify_all(); + + for t in threads { + t.join().unwrap(); + } + } + + #[test] + fn notify_all_single_thread_async() { + const TASKS: usize = 13; + + async fn reset(mu: Rc<Mutex<usize>>, cv: Rc<Condvar>) { + let mut count = mu.lock().await; + *count = 0; + cv.notify_all(); + } + + async fn watcher(mu: Rc<Mutex<usize>>, cv: Rc<Condvar>, spawner: LocalSpawner) { + let mut count = mu.lock().await; + *count += 1; + if *count == TASKS { + spawner + .spawn_local(reset(mu.clone(), cv.clone())) + .expect("Failed to spawn reset task"); + } + + while *count != 0 { + count = cv.wait(count).await; + } + } + + let mut ex = LocalPool::new(); + let spawner = ex.spawner(); + + let mu = Rc::new(Mutex::new(0)); + let cv = Rc::new(Condvar::new()); + + for _ in 0..TASKS { + spawner + .spawn_local(watcher(mu.clone(), cv.clone(), spawner.clone())) + .expect("Failed to spawn watcher task"); + } + + ex.run(); + } + + #[test] + fn notify_all_multi_thread_async() { + const TASKS: usize = 13; + + async fn reset(mu: Arc<Mutex<usize>>, cv: Arc<Condvar>) { + let mut count = mu.lock().await; + *count = 0; + cv.notify_all(); + } + + async fn watcher( + mu: Arc<Mutex<usize>>, + cv: Arc<Condvar>, + pool: ThreadPool, + tx: Sender<()>, + ) { + let mut count = mu.lock().await; + *count += 1; + if *count == TASKS { + pool.spawn_ok(reset(mu.clone(), cv.clone())); + } + + while *count != 0 { + count = cv.wait(count).await; + } + + tx.send(()).expect("Failed to send completion notification"); + } + + let pool = ThreadPool::new().expect("Failed to create ThreadPool"); + + let mu = Arc::new(Mutex::new(0)); + let cv = Arc::new(Condvar::new()); + + let (tx, rx) = channel(); + for _ in 0..TASKS { + pool.spawn_ok(watcher(mu.clone(), cv.clone(), pool.clone(), tx.clone())); + } + + for _ in 0..TASKS { + rx.recv_timeout(Duration::from_secs(5)) + .expect("Failed to receive completion notification"); + } + } + + #[test] + fn wake_all_readers() { + async fn read(mu: Arc<Mutex<bool>>, cv: Arc<Condvar>) { + let mut ready = mu.read_lock().await; + while !*ready { + ready = cv.wait_read(ready).await; + } + } + + let mu = Arc::new(Mutex::new(false)); + let cv = Arc::new(Condvar::new()); + let mut readers = [ + Box::pin(read(mu.clone(), cv.clone())), + Box::pin(read(mu.clone(), cv.clone())), + Box::pin(read(mu.clone(), cv.clone())), + Box::pin(read(mu.clone(), cv.clone())), + ]; + + let arc_waker = Arc::new(TestWaker); + let waker = waker_ref(&arc_waker); + let mut cx = Context::from_waker(&waker); + + // First have all the readers wait on the Condvar. + for r in &mut readers { + if let Poll::Ready(()) = r.as_mut().poll(&mut cx) { + panic!("reader unexpectedly ready"); + } + } + + assert_eq!(cv.state.load(Ordering::Relaxed) & HAS_WAITERS, HAS_WAITERS); + + // Now make the condition true and notify the condvar. Even though we will call notify_one, + // all the readers should be woken up. + *block_on(mu.lock()) = true; + cv.notify_one(); + + assert_eq!(cv.state.load(Ordering::Relaxed), 0); + + // All readers should now be able to complete. + for r in &mut readers { + if r.as_mut().poll(&mut cx).is_pending() { + panic!("reader unable to complete"); + } + } + } + + #[test] + fn cancel_before_notify() { + async fn dec(mu: Arc<Mutex<usize>>, cv: Arc<Condvar>) { + let mut count = mu.lock().await; + + while *count == 0 { + count = cv.wait(count).await; + } + + *count -= 1; + } + + let mu = Arc::new(Mutex::new(0)); + let cv = Arc::new(Condvar::new()); + + let arc_waker = Arc::new(TestWaker); + let waker = waker_ref(&arc_waker); + let mut cx = Context::from_waker(&waker); + + let mut fut1 = Box::pin(dec(mu.clone(), cv.clone())); + let mut fut2 = Box::pin(dec(mu.clone(), cv.clone())); + + if let Poll::Ready(()) = fut1.as_mut().poll(&mut cx) { + panic!("future unexpectedly ready"); + } + if let Poll::Ready(()) = fut2.as_mut().poll(&mut cx) { + panic!("future unexpectedly ready"); + } + assert_eq!(cv.state.load(Ordering::Relaxed) & HAS_WAITERS, HAS_WAITERS); + + *block_on(mu.lock()) = 2; + // Drop fut1 before notifying the cv. + mem::drop(fut1); + cv.notify_one(); + + // fut2 should now be ready to complete. + assert_eq!(cv.state.load(Ordering::Relaxed), 0); + + if fut2.as_mut().poll(&mut cx).is_pending() { + panic!("future unable to complete"); + } + + assert_eq!(*block_on(mu.lock()), 1); + } + + #[test] + fn cancel_after_notify_one() { + async fn dec(mu: Arc<Mutex<usize>>, cv: Arc<Condvar>) { + let mut count = mu.lock().await; + + while *count == 0 { + count = cv.wait(count).await; + } + + *count -= 1; + } + + let mu = Arc::new(Mutex::new(0)); + let cv = Arc::new(Condvar::new()); + + let arc_waker = Arc::new(TestWaker); + let waker = waker_ref(&arc_waker); + let mut cx = Context::from_waker(&waker); + + let mut fut1 = Box::pin(dec(mu.clone(), cv.clone())); + let mut fut2 = Box::pin(dec(mu.clone(), cv.clone())); + + if let Poll::Ready(()) = fut1.as_mut().poll(&mut cx) { + panic!("future unexpectedly ready"); + } + if let Poll::Ready(()) = fut2.as_mut().poll(&mut cx) { + panic!("future unexpectedly ready"); + } + assert_eq!(cv.state.load(Ordering::Relaxed) & HAS_WAITERS, HAS_WAITERS); + + *block_on(mu.lock()) = 2; + cv.notify_one(); + + // fut1 should now be ready to complete. Drop it before polling. This should wake up fut2. + mem::drop(fut1); + assert_eq!(cv.state.load(Ordering::Relaxed), 0); + + if fut2.as_mut().poll(&mut cx).is_pending() { + panic!("future unable to complete"); + } + + assert_eq!(*block_on(mu.lock()), 1); + } + + #[test] + fn cancel_after_notify_all() { + async fn dec(mu: Arc<Mutex<usize>>, cv: Arc<Condvar>) { + let mut count = mu.lock().await; + + while *count == 0 { + count = cv.wait(count).await; + } + + *count -= 1; + } + + let mu = Arc::new(Mutex::new(0)); + let cv = Arc::new(Condvar::new()); + + let arc_waker = Arc::new(TestWaker); + let waker = waker_ref(&arc_waker); + let mut cx = Context::from_waker(&waker); + + let mut fut1 = Box::pin(dec(mu.clone(), cv.clone())); + let mut fut2 = Box::pin(dec(mu.clone(), cv.clone())); + + if let Poll::Ready(()) = fut1.as_mut().poll(&mut cx) { + panic!("future unexpectedly ready"); + } + if let Poll::Ready(()) = fut2.as_mut().poll(&mut cx) { + panic!("future unexpectedly ready"); + } + assert_eq!(cv.state.load(Ordering::Relaxed) & HAS_WAITERS, HAS_WAITERS); + + let mut count = block_on(mu.lock()); + *count = 2; + + // Notify the cv while holding the lock. This should wake up both waiters. + cv.notify_all(); + assert_eq!(cv.state.load(Ordering::Relaxed), 0); + + mem::drop(count); + + mem::drop(fut1); + + if fut2.as_mut().poll(&mut cx).is_pending() { + panic!("future unable to complete"); + } + + assert_eq!(*block_on(mu.lock()), 1); + } + + #[test] + fn timed_wait() { + async fn wait_deadline( + mu: Arc<Mutex<usize>>, + cv: Arc<Condvar>, + timeout: oneshot::Receiver<()>, + ) { + let mut count = mu.lock().await; + + if *count == 0 { + let mut rx = timeout.fuse(); + + while *count == 0 { + select! { + res = rx => { + if let Err(e) = res { + panic!("Error while receiving timeout notification: {}", e); + } + + return; + }, + c = cv.wait(count).fuse() => count = c, + } + } + } + + *count += 1; + } + + let mu = Arc::new(Mutex::new(0)); + let cv = Arc::new(Condvar::new()); + + let arc_waker = Arc::new(TestWaker); + let waker = waker_ref(&arc_waker); + let mut cx = Context::from_waker(&waker); + + let (tx, rx) = oneshot::channel(); + let mut wait = Box::pin(wait_deadline(mu.clone(), cv.clone(), rx)); + + if let Poll::Ready(()) = wait.as_mut().poll(&mut cx) { + panic!("wait_deadline unexpectedly ready"); + } + + assert_eq!(cv.state.load(Ordering::Relaxed), HAS_WAITERS); + + // Signal the channel, which should cancel the wait. + tx.send(()).expect("Failed to send wakeup"); + + // Wait for the timer to run out. + if wait.as_mut().poll(&mut cx).is_pending() { + panic!("wait_deadline unable to complete in time"); + } + + assert_eq!(cv.state.load(Ordering::Relaxed), 0); + assert_eq!(*block_on(mu.lock()), 0); + } +} diff --git a/cros_async/src/sync/mu.rs b/cros_async/src/sync/mu.rs new file mode 100644 index 000000000..4f3443fb5 --- /dev/null +++ b/cros_async/src/sync/mu.rs @@ -0,0 +1,2289 @@ +// Copyright 2020 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +use std::cell::UnsafeCell; +use std::mem; +use std::ops::{Deref, DerefMut}; +use std::sync::atomic::{spin_loop_hint, AtomicUsize, Ordering}; +use std::sync::Arc; +use std::thread::yield_now; + +use crate::sync::waiter::{Kind as WaiterKind, Waiter, WaiterAdapter, WaiterList, WaitingFor}; + +// Set when the mutex is exclusively locked. +const LOCKED: usize = 1 << 0; +// Set when there are one or more threads waiting to acquire the lock. +const HAS_WAITERS: usize = 1 << 1; +// Set when a thread has been woken up from the wait queue. Cleared when that thread either acquires +// the lock or adds itself back into the wait queue. Used to prevent unnecessary wake ups when a +// thread has been removed from the wait queue but has not gotten CPU time yet. +const DESIGNATED_WAKER: usize = 1 << 2; +// Used to provide exclusive access to the `waiters` field in `Mutex`. Should only be held while +// modifying the waiter list. +const SPINLOCK: usize = 1 << 3; +// Set when a thread that wants an exclusive lock adds itself to the wait queue. New threads +// attempting to acquire a shared lock will be preventing from getting it when this bit is set. +// However, this bit is ignored once a thread has gone through the wait queue at least once. +const WRITER_WAITING: usize = 1 << 4; +// Set when a thread has gone through the wait queue many times but has failed to acquire the lock +// every time it is woken up. When this bit is set, all other threads are prevented from acquiring +// the lock until the thread that set the `LONG_WAIT` bit has acquired the lock. +const LONG_WAIT: usize = 1 << 5; +// The bit that is added to the mutex state in order to acquire a shared lock. Since more than one +// thread can acquire a shared lock, we cannot use a single bit. Instead we use all the remaining +// bits in the state to track the number of threads that have acquired a shared lock. +const READ_LOCK: usize = 1 << 8; +// Mask used for checking if any threads currently hold a shared lock. +const READ_MASK: usize = !0xff; + +// The number of times the thread should just spin and attempt to re-acquire the lock. +const SPIN_THRESHOLD: usize = 7; + +// The number of times the thread needs to go through the wait queue before it sets the `LONG_WAIT` +// bit and forces all other threads to wait for it to acquire the lock. This value is set relatively +// high so that we don't lose the benefit of having running threads unless it is absolutely +// necessary. +const LONG_WAIT_THRESHOLD: usize = 19; + +// Common methods between shared and exclusive locks. +trait Kind { + // The bits that must be zero for the thread to acquire this kind of lock. If any of these bits + // are not zero then the thread will first spin and retry a few times before adding itself to + // the wait queue. + fn zero_to_acquire() -> usize; + + // The bit that must be added in order to acquire this kind of lock. This should either be + // `LOCKED` or `READ_LOCK`. + fn add_to_acquire() -> usize; + + // The bits that should be set when a thread adds itself to the wait queue while waiting to + // acquire this kind of lock. + fn set_when_waiting() -> usize; + + // The bits that should be cleared when a thread acquires this kind of lock. + fn clear_on_acquire() -> usize; + + // The waiter that a thread should use when waiting to acquire this kind of lock. + fn new_waiter(raw: &RawMutex) -> Arc<Waiter>; +} + +// A lock type for shared read-only access to the data. More than one thread may hold this kind of +// lock simultaneously. +struct Shared; + +impl Kind for Shared { + fn zero_to_acquire() -> usize { + LOCKED | WRITER_WAITING | LONG_WAIT + } + + fn add_to_acquire() -> usize { + READ_LOCK + } + + fn set_when_waiting() -> usize { + 0 + } + + fn clear_on_acquire() -> usize { + 0 + } + + fn new_waiter(raw: &RawMutex) -> Arc<Waiter> { + Arc::new(Waiter::new( + WaiterKind::Shared, + cancel_waiter, + raw as *const RawMutex as usize, + WaitingFor::Mutex, + )) + } +} + +// A lock type for mutually exclusive read-write access to the data. Only one thread can hold this +// kind of lock at a time. +struct Exclusive; + +impl Kind for Exclusive { + fn zero_to_acquire() -> usize { + LOCKED | READ_MASK | LONG_WAIT + } + + fn add_to_acquire() -> usize { + LOCKED + } + + fn set_when_waiting() -> usize { + WRITER_WAITING + } + + fn clear_on_acquire() -> usize { + WRITER_WAITING + } + + fn new_waiter(raw: &RawMutex) -> Arc<Waiter> { + Arc::new(Waiter::new( + WaiterKind::Exclusive, + cancel_waiter, + raw as *const RawMutex as usize, + WaitingFor::Mutex, + )) + } +} + +// Scan `waiters` and return the ones that should be woken up. Also returns any bits that should be +// set in the mutex state when the current thread releases the spin lock protecting the waiter list. +// +// If the first waiter is trying to acquire a shared lock, then all waiters in the list that are +// waiting for a shared lock are also woken up. If any waiters waiting for an exclusive lock are +// found when iterating through the list, then the returned `usize` contains the `WRITER_WAITING` +// bit, which should be set when the thread releases the spin lock. +// +// If the first waiter is trying to acquire an exclusive lock, then only that waiter is returned and +// no bits are set in the returned `usize`. +fn get_wake_list(waiters: &mut WaiterList) -> (WaiterList, usize) { + let mut to_wake = WaiterList::new(WaiterAdapter::new()); + let mut set_on_release = 0; + let mut cursor = waiters.front_mut(); + + let mut waking_readers = false; + while let Some(w) = cursor.get() { + match w.kind() { + WaiterKind::Exclusive if !waking_readers => { + // This is the first waiter and it's a writer. No need to check the other waiters. + let waiter = cursor.remove().unwrap(); + waiter.set_waiting_for(WaitingFor::None); + to_wake.push_back(waiter); + break; + } + + WaiterKind::Shared => { + // This is a reader and the first waiter in the list was not a writer so wake up all + // the readers in the wait list. + let waiter = cursor.remove().unwrap(); + waiter.set_waiting_for(WaitingFor::None); + to_wake.push_back(waiter); + waking_readers = true; + } + + WaiterKind::Exclusive => { + // We found a writer while looking for more readers to wake up. Set the + // `WRITER_WAITING` bit to prevent any new readers from acquiring the lock. All + // readers currently in the wait list will ignore this bit since they already waited + // once. + set_on_release |= WRITER_WAITING; + cursor.move_next(); + } + } + } + + (to_wake, set_on_release) +} + +#[inline] +fn cpu_relax(iterations: usize) { + for _ in 0..iterations { + spin_loop_hint(); + } +} + +pub(crate) struct RawMutex { + state: AtomicUsize, + waiters: UnsafeCell<WaiterList>, +} + +impl RawMutex { + pub fn new() -> RawMutex { + RawMutex { + state: AtomicUsize::new(0), + waiters: UnsafeCell::new(WaiterList::new(WaiterAdapter::new())), + } + } + + #[inline] + pub async fn lock(&self) { + match self + .state + .compare_exchange_weak(0, LOCKED, Ordering::Acquire, Ordering::Relaxed) + { + Ok(_) => {} + Err(oldstate) => { + // If any bits that should be zero are not zero or if we fail to acquire the lock + // with a single compare_exchange then go through the slow path. + if (oldstate & Exclusive::zero_to_acquire()) != 0 + || self + .state + .compare_exchange_weak( + oldstate, + (oldstate + Exclusive::add_to_acquire()) + & !Exclusive::clear_on_acquire(), + Ordering::Acquire, + Ordering::Relaxed, + ) + .is_err() + { + self.lock_slow::<Exclusive>(0, 0).await; + } + } + } + } + + #[inline] + pub async fn read_lock(&self) { + match self + .state + .compare_exchange_weak(0, READ_LOCK, Ordering::Acquire, Ordering::Relaxed) + { + Ok(_) => {} + Err(oldstate) => { + if (oldstate & Shared::zero_to_acquire()) != 0 + || self + .state + .compare_exchange_weak( + oldstate, + (oldstate + Shared::add_to_acquire()) & !Shared::clear_on_acquire(), + Ordering::Acquire, + Ordering::Relaxed, + ) + .is_err() + { + self.lock_slow::<Shared>(0, 0).await; + } + } + } + } + + // Slow path for acquiring the lock. `clear` should contain any bits that need to be cleared + // when the lock is acquired. Any bits set in `zero_mask` are cleared from the bits returned by + // `K::zero_to_acquire()`. + #[cold] + async fn lock_slow<K: Kind>(&self, mut clear: usize, zero_mask: usize) { + let mut zero_to_acquire = K::zero_to_acquire() & !zero_mask; + + let mut spin_count = 0; + let mut wait_count = 0; + let mut waiter = None; + loop { + let oldstate = self.state.load(Ordering::Relaxed); + // If all the bits in `zero_to_acquire` are actually zero then try to acquire the lock + // directly. + if (oldstate & zero_to_acquire) == 0 { + if self + .state + .compare_exchange_weak( + oldstate, + (oldstate + K::add_to_acquire()) & !(clear | K::clear_on_acquire()), + Ordering::Acquire, + Ordering::Relaxed, + ) + .is_ok() + { + return; + } + } else if (oldstate & SPINLOCK) == 0 { + // The mutex is locked and the spin lock is available. Try to add this thread + // to the waiter queue. + let w = waiter.get_or_insert_with(|| K::new_waiter(self)); + w.reset(WaitingFor::Mutex); + + if self + .state + .compare_exchange_weak( + oldstate, + (oldstate | SPINLOCK | HAS_WAITERS | K::set_when_waiting()) & !clear, + Ordering::Acquire, + Ordering::Relaxed, + ) + .is_ok() + { + let mut set_on_release = 0; + + // Safe because we have acquired the spin lock and it provides exclusive + // access to the waiter queue. + if wait_count < LONG_WAIT_THRESHOLD { + // Add the waiter to the back of the queue. + unsafe { (*self.waiters.get()).push_back(w.clone()) }; + } else { + // This waiter has gone through the queue too many times. Put it in the + // front of the queue and block all other threads from acquiring the lock + // until this one has acquired it at least once. + unsafe { (*self.waiters.get()).push_front(w.clone()) }; + + // Set the LONG_WAIT bit to prevent all other threads from acquiring the + // lock. + set_on_release |= LONG_WAIT; + + // Make sure we clear the LONG_WAIT bit when we do finally get the lock. + clear |= LONG_WAIT; + + // Since we set the LONG_WAIT bit we shouldn't allow that bit to prevent us + // from acquiring the lock. + zero_to_acquire &= !LONG_WAIT; + } + + // Release the spin lock. + let mut state = oldstate; + loop { + match self.state.compare_exchange_weak( + state, + (state | set_on_release) & !SPINLOCK, + Ordering::Release, + Ordering::Relaxed, + ) { + Ok(_) => break, + Err(w) => state = w, + } + } + + // Now wait until we are woken. + w.wait().await; + + // The `DESIGNATED_WAKER` bit gets set when this thread is woken up by the + // thread that originally held the lock. While this bit is set, no other waiters + // will be woken up so it's important to clear it the next time we try to + // acquire the main lock or the spin lock. + clear |= DESIGNATED_WAKER; + + // Now that the thread has waited once, we no longer care if there is a writer + // waiting. Only the limits of mutual exclusion can prevent us from acquiring + // the lock. + zero_to_acquire &= !WRITER_WAITING; + + // Reset the spin count since we just went through the wait queue. + spin_count = 0; + + // Increment the wait count since we went through the wait queue. + wait_count += 1; + + // Skip the `cpu_relax` below. + continue; + } + } + + // Both the lock and the spin lock are held by one or more other threads. First, we'll + // spin a few times in case we can acquire the lock or the spin lock. If that fails then + // we yield because we might be preventing the threads that do hold the 2 locks from + // getting cpu time. + if spin_count < SPIN_THRESHOLD { + cpu_relax(1 << spin_count); + spin_count += 1; + } else { + yield_now(); + } + } + } + + #[inline] + pub fn unlock(&self) { + // Fast path, if possible. We can directly clear the locked bit since we have exclusive + // access to the mutex. + let oldstate = self.state.fetch_sub(LOCKED, Ordering::Release); + + // Panic if we just tried to unlock a mutex that wasn't held by this thread. This shouldn't + // really be possible since `unlock` is not a public method. + debug_assert_eq!( + oldstate & READ_MASK, + 0, + "`unlock` called on mutex held in read-mode" + ); + debug_assert_ne!( + oldstate & LOCKED, + 0, + "`unlock` called on mutex not held in write-mode" + ); + + if (oldstate & HAS_WAITERS) != 0 && (oldstate & DESIGNATED_WAKER) == 0 { + // The oldstate has waiters but no designated waker has been chosen yet. + self.unlock_slow(); + } + } + + #[inline] + pub fn read_unlock(&self) { + // Fast path, if possible. We can directly subtract the READ_LOCK bit since we had + // previously added it. + let oldstate = self.state.fetch_sub(READ_LOCK, Ordering::Release); + + debug_assert_eq!( + oldstate & LOCKED, + 0, + "`read_unlock` called on mutex held in write-mode" + ); + debug_assert_ne!( + oldstate & READ_MASK, + 0, + "`read_unlock` called on mutex not held in read-mode" + ); + + if (oldstate & HAS_WAITERS) != 0 + && (oldstate & DESIGNATED_WAKER) == 0 + && (oldstate & READ_MASK) == READ_LOCK + { + // There are waiters, no designated waker has been chosen yet, and the last reader is + // unlocking so we have to take the slow path. + self.unlock_slow(); + } + } + + #[cold] + fn unlock_slow(&self) { + let mut spin_count = 0; + + loop { + let oldstate = self.state.load(Ordering::Relaxed); + if (oldstate & HAS_WAITERS) == 0 || (oldstate & DESIGNATED_WAKER) != 0 { + // No more waiters or a designated waker has been chosen. Nothing left for us to do. + return; + } else if (oldstate & SPINLOCK) == 0 { + // The spin lock is not held by another thread. Try to acquire it. Also set the + // `DESIGNATED_WAKER` bit since we are likely going to wake up one or more threads. + if self + .state + .compare_exchange_weak( + oldstate, + oldstate | SPINLOCK | DESIGNATED_WAKER, + Ordering::Acquire, + Ordering::Relaxed, + ) + .is_ok() + { + // Acquired the spinlock. Try to wake a waiter. We may also end up wanting to + // clear the HAS_WAITER and DESIGNATED_WAKER bits so start collecting the bits + // to be cleared. + let mut clear = SPINLOCK; + + // Safe because the spinlock guarantees exclusive access to the waiter list and + // the reference does not escape this function. + let waiters = unsafe { &mut *self.waiters.get() }; + let (wake_list, set_on_release) = get_wake_list(waiters); + + // If the waiter list is now empty, clear the HAS_WAITERS bit. + if waiters.is_empty() { + clear |= HAS_WAITERS; + } + + if wake_list.is_empty() { + // Since we are not going to wake any waiters clear the DESIGNATED_WAKER bit + // that we set when we acquired the spin lock. + clear |= DESIGNATED_WAKER; + } + + // Release the spin lock and clear any other bits as necessary. Also, set any + // bits returned by `get_wake_list`. For now, this is just the `WRITER_WAITING` + // bit, which needs to be set when we are waking up a bunch of readers and there + // are still writers in the wait queue. This will prevent any readers that + // aren't in `wake_list` from acquiring the read lock. + let mut state = oldstate; + loop { + match self.state.compare_exchange_weak( + state, + (state | set_on_release) & !clear, + Ordering::Release, + Ordering::Relaxed, + ) { + Ok(_) => break, + Err(w) => state = w, + } + } + + // Now wake the waiters, if any. + for w in wake_list { + w.wake(); + } + + // We're done. + return; + } + } + + // Spin and try again. It's ok to block here as we have already released the lock. + if spin_count < SPIN_THRESHOLD { + cpu_relax(1 << spin_count); + spin_count += 1; + } else { + yield_now(); + } + } + } + + fn cancel_waiter(&self, waiter: &Waiter, wake_next: bool) { + let mut oldstate = self.state.load(Ordering::Relaxed); + while oldstate & SPINLOCK != 0 + || self + .state + .compare_exchange_weak( + oldstate, + oldstate | SPINLOCK, + Ordering::Acquire, + Ordering::Relaxed, + ) + .is_err() + { + spin_loop_hint(); + oldstate = self.state.load(Ordering::Relaxed); + } + + // Safe because the spin lock provides exclusive access and the reference does not escape + // this function. + let waiters = unsafe { &mut *self.waiters.get() }; + + let mut clear = SPINLOCK; + + // If we are about to remove the first waiter in the wait list, then clear the LONG_WAIT + // bit. Also clear the bit if we are going to be waking some other waiters. In this case the + // waiter that set the bit may have already been removed from the waiter list (and could be + // the one that is currently being dropped). If it is still in the waiter list then clearing + // this bit may starve it for one more iteration through the lock_slow() loop, whereas not + // clearing this bit could cause a deadlock if the waiter that set it is the one that is + // being dropped. + if wake_next + || waiters + .front() + .get() + .map(|front| std::ptr::eq(front, waiter)) + .unwrap_or(false) + { + clear |= LONG_WAIT; + } + + let waiting_for = waiter.is_waiting_for(); + + // Don't drop the old waiter while holding the spin lock. + let old_waiter = if waiter.is_linked() && waiting_for == WaitingFor::Mutex { + // We know that the waiter is still linked and is waiting for the mutex, which + // guarantees that it is still linked into `self.waiters`. + let mut cursor = unsafe { waiters.cursor_mut_from_ptr(waiter as *const Waiter) }; + cursor.remove() + } else { + None + }; + + let (wake_list, set_on_release) = if wake_next || waiting_for == WaitingFor::None { + // Either the waiter was already woken or it's been removed from the mutex's waiter + // list and is going to be woken. Either way, we need to wake up another thread. + get_wake_list(waiters) + } else { + (WaiterList::new(WaiterAdapter::new()), 0) + }; + + if waiters.is_empty() { + clear |= HAS_WAITERS; + } + + if wake_list.is_empty() { + // We're not waking any other threads so clear the DESIGNATED_WAKER bit. In the worst + // case this leads to an additional thread being woken up but we risk a deadlock if we + // don't clear it. + clear |= DESIGNATED_WAKER; + } + + if let WaiterKind::Exclusive = waiter.kind() { + // The waiter being dropped is a writer so clear the writer waiting bit for now. If we + // found more writers in the list while fetching waiters to wake up then this bit will + // be set again via `set_on_release`. + clear |= WRITER_WAITING; + } + + while self + .state + .compare_exchange_weak( + oldstate, + (oldstate & !clear) | set_on_release, + Ordering::Release, + Ordering::Relaxed, + ) + .is_err() + { + spin_loop_hint(); + oldstate = self.state.load(Ordering::Relaxed); + } + + for w in wake_list { + w.wake(); + } + + mem::drop(old_waiter); + } +} + +unsafe impl Send for RawMutex {} +unsafe impl Sync for RawMutex {} + +fn cancel_waiter(raw: usize, waiter: &Waiter, wake_next: bool) { + let raw_mutex = raw as *const RawMutex; + + // Safe because the thread that owns the waiter that is being canceled must + // also own a reference to the mutex, which ensures that this pointer is + // valid. + unsafe { (*raw_mutex).cancel_waiter(waiter, wake_next) } +} + +/// A high-level primitive that provides safe, mutable access to a shared resource. +/// +/// Unlike more traditional mutexes, `Mutex` can safely provide both shared, immutable access (via +/// `read_lock()`) as well as exclusive, mutable access (via `lock()`) to an underlying resource +/// with no loss of performance. +/// +/// # Poisoning +/// +/// `Mutex` does not support lock poisoning so if a thread panics while holding the lock, the +/// poisoned data will be accessible by other threads in your program. If you need to guarantee that +/// other threads cannot access poisoned data then you may wish to wrap this `Mutex` inside another +/// type that provides the poisoning feature. See the implementation of `std::sync::Mutex` for an +/// example of this. +/// +/// +/// # Fairness +/// +/// This `Mutex` implementation does not guarantee that threads will acquire the lock in the same +/// order that they call `lock()` or `read_lock()`. However it will attempt to prevent long-term +/// starvation: if a thread repeatedly fails to acquire the lock beyond a threshold then all other +/// threads will fail to acquire the lock until the starved thread has acquired it. +/// +/// Similarly, this `Mutex` will attempt to balance reader and writer threads: once there is a +/// writer thread waiting to acquire the lock no new reader threads will be allowed to acquire it. +/// However, any reader threads that were already waiting will still be allowed to acquire it. +/// +/// # Examples +/// +/// ```edition2018 +/// use std::sync::Arc; +/// use std::thread; +/// use std::sync::mpsc::channel; +/// +/// use cros_async::sync::{block_on, Mutex}; +/// +/// const N: usize = 10; +/// +/// // Spawn a few threads to increment a shared variable (non-atomically), and +/// // let the main thread know once all increments are done. +/// // +/// // Here we're using an Arc to share memory among threads, and the data inside +/// // the Arc is protected with a mutex. +/// let data = Arc::new(Mutex::new(0)); +/// +/// let (tx, rx) = channel(); +/// for _ in 0..N { +/// let (data, tx) = (Arc::clone(&data), tx.clone()); +/// thread::spawn(move || { +/// // The shared state can only be accessed once the lock is held. +/// // Our non-atomic increment is safe because we're the only thread +/// // which can access the shared state when the lock is held. +/// let mut data = block_on(data.lock()); +/// *data += 1; +/// if *data == N { +/// tx.send(()).unwrap(); +/// } +/// // the lock is unlocked here when `data` goes out of scope. +/// }); +/// } +/// +/// rx.recv().unwrap(); +/// ``` +#[repr(align(128))] +pub struct Mutex<T: ?Sized> { + raw: RawMutex, + value: UnsafeCell<T>, +} + +impl<T> Mutex<T> { + /// Create a new, unlocked `Mutex` ready for use. + pub fn new(v: T) -> Mutex<T> { + Mutex { + raw: RawMutex::new(), + value: UnsafeCell::new(v), + } + } + + /// Consume the `Mutex` and return the contained value. This method does not perform any locking + /// as the compiler will guarantee that there are no other references to `self` and the caller + /// owns the `Mutex`. + pub fn into_inner(self) -> T { + // Don't need to acquire the lock because the compiler guarantees that there are + // no references to `self`. + self.value.into_inner() + } +} + +impl<T: ?Sized> Mutex<T> { + /// Acquires exclusive, mutable access to the resource protected by the `Mutex`, blocking the + /// current thread until it is able to do so. Upon returning, the current thread will be the + /// only thread with access to the resource. The `Mutex` will be released when the returned + /// `MutexGuard` is dropped. + /// + /// Calling `lock()` while holding a `MutexGuard` or a `MutexReadGuard` will cause a deadlock. + /// + /// Callers that are not in an async context may wish to use the `block_on` method to block the + /// thread until the `Mutex` is acquired. + #[inline] + pub async fn lock(&self) -> MutexGuard<'_, T> { + self.raw.lock().await; + + // Safe because we have exclusive access to `self.value`. + MutexGuard { + mu: self, + value: unsafe { &mut *self.value.get() }, + } + } + + /// Acquires shared, immutable access to the resource protected by the `Mutex`, blocking the + /// current thread until it is able to do so. Upon returning there may be other threads that + /// also have immutable access to the resource but there will not be any threads that have + /// mutable access to the resource. When the returned `MutexReadGuard` is dropped the thread + /// releases its access to the resource. + /// + /// Calling `read_lock()` while holding a `MutexReadGuard` may deadlock. Calling `read_lock()` + /// while holding a `MutexGuard` will deadlock. + /// + /// Callers that are not in an async context may wish to use the `block_on` method to block the + /// thread until the `Mutex` is acquired. + #[inline] + pub async fn read_lock(&self) -> MutexReadGuard<'_, T> { + self.raw.read_lock().await; + + // Safe because we have shared read-only access to `self.value`. + MutexReadGuard { + mu: self, + value: unsafe { &*self.value.get() }, + } + } + + // Called from `Condvar::wait` when the thread wants to reacquire the lock. + #[inline] + pub(crate) async fn lock_from_cv(&self) -> MutexGuard<'_, T> { + self.raw.lock_slow::<Exclusive>(DESIGNATED_WAKER, 0).await; + + // Safe because we have exclusive access to `self.value`. + MutexGuard { + mu: self, + value: unsafe { &mut *self.value.get() }, + } + } + + // Like `lock_from_cv` but for acquiring a shared lock. + #[inline] + pub(crate) async fn read_lock_from_cv(&self) -> MutexReadGuard<'_, T> { + // Threads that have waited in the Condvar's waiter list don't have to care if there is a + // writer waiting since they have already waited once. + self.raw + .lock_slow::<Shared>(DESIGNATED_WAKER, WRITER_WAITING) + .await; + + // Safe because we have exclusive access to `self.value`. + MutexReadGuard { + mu: self, + value: unsafe { &*self.value.get() }, + } + } + + #[inline] + fn unlock(&self) { + self.raw.unlock(); + } + + #[inline] + fn read_unlock(&self) { + self.raw.read_unlock(); + } + + pub fn get_mut(&mut self) -> &mut T { + // Safe because the compiler statically guarantees that are no other references to `self`. + // This is also why we don't need to acquire the lock first. + unsafe { &mut *self.value.get() } + } +} + +unsafe impl<T: ?Sized + Send> Send for Mutex<T> {} +unsafe impl<T: ?Sized + Send> Sync for Mutex<T> {} + +impl<T: ?Sized + Default> Default for Mutex<T> { + fn default() -> Self { + Self::new(Default::default()) + } +} + +impl<T> From<T> for Mutex<T> { + fn from(source: T) -> Self { + Self::new(source) + } +} + +/// An RAII implementation of a "scoped exclusive lock" for a `Mutex`. When this structure is +/// dropped, the lock will be released. The resource protected by the `Mutex` can be accessed via +/// the `Deref` and `DerefMut` implementations of this structure. +pub struct MutexGuard<'a, T: ?Sized + 'a> { + mu: &'a Mutex<T>, + value: &'a mut T, +} + +impl<'a, T: ?Sized> MutexGuard<'a, T> { + pub(crate) fn into_inner(self) -> &'a Mutex<T> { + self.mu + } + + pub(crate) fn as_raw_mutex(&self) -> &RawMutex { + &self.mu.raw + } +} + +impl<'a, T: ?Sized> Deref for MutexGuard<'a, T> { + type Target = T; + + fn deref(&self) -> &Self::Target { + self.value + } +} + +impl<'a, T: ?Sized> DerefMut for MutexGuard<'a, T> { + fn deref_mut(&mut self) -> &mut Self::Target { + self.value + } +} + +impl<'a, T: ?Sized> Drop for MutexGuard<'a, T> { + fn drop(&mut self) { + self.mu.unlock() + } +} + +/// An RAII implementation of a "scoped shared lock" for a `Mutex`. When this structure is dropped, +/// the lock will be released. The resource protected by the `Mutex` can be accessed via the `Deref` +/// implementation of this structure. +pub struct MutexReadGuard<'a, T: ?Sized + 'a> { + mu: &'a Mutex<T>, + value: &'a T, +} + +impl<'a, T: ?Sized> MutexReadGuard<'a, T> { + pub(crate) fn into_inner(self) -> &'a Mutex<T> { + self.mu + } + + pub(crate) fn as_raw_mutex(&self) -> &RawMutex { + &self.mu.raw + } +} + +impl<'a, T: ?Sized> Deref for MutexReadGuard<'a, T> { + type Target = T; + + fn deref(&self) -> &Self::Target { + self.value + } +} + +impl<'a, T: ?Sized> Drop for MutexReadGuard<'a, T> { + fn drop(&mut self) { + self.mu.read_unlock() + } +} + +#[cfg(test)] +mod test { + use super::*; + + use std::future::Future; + use std::mem; + use std::pin::Pin; + use std::rc::Rc; + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::sync::mpsc::{channel, Sender}; + use std::sync::Arc; + use std::task::{Context, Poll, Waker}; + use std::thread; + use std::time::Duration; + + use futures::channel::oneshot; + use futures::task::{waker_ref, ArcWake}; + use futures::{pending, select, FutureExt}; + use futures_executor::{LocalPool, ThreadPool}; + use futures_util::task::LocalSpawnExt; + + use crate::sync::{block_on, Condvar, SpinLock}; + + #[derive(Debug, Eq, PartialEq)] + struct NonCopy(u32); + + // Dummy waker used when we want to manually drive futures. + struct TestWaker; + impl ArcWake for TestWaker { + fn wake_by_ref(_arc_self: &Arc<Self>) {} + } + + #[test] + fn it_works() { + let mu = Mutex::new(NonCopy(13)); + + assert_eq!(*block_on(mu.lock()), NonCopy(13)); + } + + #[test] + fn smoke() { + let mu = Mutex::new(NonCopy(7)); + + mem::drop(block_on(mu.lock())); + mem::drop(block_on(mu.lock())); + } + + #[test] + fn rw_smoke() { + let mu = Mutex::new(NonCopy(7)); + + mem::drop(block_on(mu.lock())); + mem::drop(block_on(mu.read_lock())); + mem::drop((block_on(mu.read_lock()), block_on(mu.read_lock()))); + mem::drop(block_on(mu.lock())); + } + + #[test] + fn async_smoke() { + async fn lock(mu: Rc<Mutex<NonCopy>>) { + mu.lock().await; + } + + async fn read_lock(mu: Rc<Mutex<NonCopy>>) { + mu.read_lock().await; + } + + async fn double_read_lock(mu: Rc<Mutex<NonCopy>>) { + let first = mu.read_lock().await; + mu.read_lock().await; + + // Make sure first lives past the second read lock. + first.as_raw_mutex(); + } + + let mu = Rc::new(Mutex::new(NonCopy(7))); + + let mut ex = LocalPool::new(); + let spawner = ex.spawner(); + + spawner + .spawn_local(lock(Rc::clone(&mu))) + .expect("Failed to spawn future"); + spawner + .spawn_local(read_lock(Rc::clone(&mu))) + .expect("Failed to spawn future"); + spawner + .spawn_local(double_read_lock(Rc::clone(&mu))) + .expect("Failed to spawn future"); + spawner + .spawn_local(lock(Rc::clone(&mu))) + .expect("Failed to spawn future"); + + ex.run(); + } + + #[test] + fn send() { + let mu = Mutex::new(NonCopy(19)); + + thread::spawn(move || { + let value = block_on(mu.lock()); + assert_eq!(*value, NonCopy(19)); + }) + .join() + .unwrap(); + } + + #[test] + fn arc_nested() { + // Tests nested mutexes and access to underlying data. + let mu = Mutex::new(1); + let arc = Arc::new(Mutex::new(mu)); + thread::spawn(move || { + let nested = block_on(arc.lock()); + let lock2 = block_on(nested.lock()); + assert_eq!(*lock2, 1); + }) + .join() + .unwrap(); + } + + #[test] + fn arc_access_in_unwind() { + let arc = Arc::new(Mutex::new(1)); + let arc2 = arc.clone(); + thread::spawn(move || { + struct Unwinder { + i: Arc<Mutex<i32>>, + } + impl Drop for Unwinder { + fn drop(&mut self) { + *block_on(self.i.lock()) += 1; + } + } + let _u = Unwinder { i: arc2 }; + panic!(); + }) + .join() + .expect_err("thread did not panic"); + let lock = block_on(arc.lock()); + assert_eq!(*lock, 2); + } + + #[test] + fn unsized_value() { + let mutex: &Mutex<[i32]> = &Mutex::new([1, 2, 3]); + { + let b = &mut *block_on(mutex.lock()); + b[0] = 4; + b[2] = 5; + } + let expected: &[i32] = &[4, 2, 5]; + assert_eq!(&*block_on(mutex.lock()), expected); + } + #[test] + fn high_contention() { + const THREADS: usize = 17; + const ITERATIONS: usize = 103; + + let mut threads = Vec::with_capacity(THREADS); + + let mu = Arc::new(Mutex::new(0usize)); + for _ in 0..THREADS { + let mu2 = mu.clone(); + threads.push(thread::spawn(move || { + for _ in 0..ITERATIONS { + *block_on(mu2.lock()) += 1; + } + })); + } + + for t in threads.into_iter() { + t.join().unwrap(); + } + + assert_eq!(*block_on(mu.read_lock()), THREADS * ITERATIONS); + assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0); + } + + #[test] + fn high_contention_with_cancel() { + const TASKS: usize = 17; + const ITERATIONS: usize = 103; + + async fn increment(mu: Arc<Mutex<usize>>, alt_mu: Arc<Mutex<usize>>, tx: Sender<()>) { + for _ in 0..ITERATIONS { + select! { + mut count = mu.lock().fuse() => *count += 1, + mut count = alt_mu.lock().fuse() => *count += 1, + } + } + tx.send(()).expect("Failed to send completion signal"); + } + + let ex = ThreadPool::new().expect("Failed to create ThreadPool"); + + let mu = Arc::new(Mutex::new(0usize)); + let alt_mu = Arc::new(Mutex::new(0usize)); + + let (tx, rx) = channel(); + for _ in 0..TASKS { + ex.spawn_ok(increment(Arc::clone(&mu), Arc::clone(&alt_mu), tx.clone())); + } + + for _ in 0..TASKS { + if let Err(e) = rx.recv_timeout(Duration::from_secs(10)) { + panic!("Error while waiting for threads to complete: {}", e); + } + } + + assert_eq!( + *block_on(mu.read_lock()) + *block_on(alt_mu.read_lock()), + TASKS * ITERATIONS + ); + assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0); + assert_eq!(alt_mu.raw.state.load(Ordering::Relaxed), 0); + } + + #[test] + fn single_thread_async() { + const TASKS: usize = 17; + const ITERATIONS: usize = 103; + + // Async closures are unstable. + async fn increment(mu: Rc<Mutex<usize>>) { + for _ in 0..ITERATIONS { + *mu.lock().await += 1; + } + } + + let mut ex = LocalPool::new(); + let spawner = ex.spawner(); + + let mu = Rc::new(Mutex::new(0usize)); + for _ in 0..TASKS { + spawner + .spawn_local(increment(Rc::clone(&mu))) + .expect("Failed to spawn task"); + } + + ex.run(); + + assert_eq!(*block_on(mu.read_lock()), TASKS * ITERATIONS); + assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0); + } + + #[test] + fn multi_thread_async() { + const TASKS: usize = 17; + const ITERATIONS: usize = 103; + + // Async closures are unstable. + async fn increment(mu: Arc<Mutex<usize>>, tx: Sender<()>) { + for _ in 0..ITERATIONS { + *mu.lock().await += 1; + } + tx.send(()).expect("Failed to send completion signal"); + } + + let ex = ThreadPool::new().expect("Failed to create ThreadPool"); + + let mu = Arc::new(Mutex::new(0usize)); + let (tx, rx) = channel(); + for _ in 0..TASKS { + ex.spawn_ok(increment(Arc::clone(&mu), tx.clone())); + } + + for _ in 0..TASKS { + rx.recv_timeout(Duration::from_secs(5)) + .expect("Failed to receive completion signal"); + } + assert_eq!(*block_on(mu.read_lock()), TASKS * ITERATIONS); + assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0); + } + + #[test] + fn get_mut() { + let mut mu = Mutex::new(NonCopy(13)); + *mu.get_mut() = NonCopy(17); + + assert_eq!(mu.into_inner(), NonCopy(17)); + } + + #[test] + fn into_inner() { + let mu = Mutex::new(NonCopy(29)); + assert_eq!(mu.into_inner(), NonCopy(29)); + } + + #[test] + fn into_inner_drop() { + struct NeedsDrop(Arc<AtomicUsize>); + impl Drop for NeedsDrop { + fn drop(&mut self) { + self.0.fetch_add(1, Ordering::AcqRel); + } + } + + let value = Arc::new(AtomicUsize::new(0)); + let needs_drop = Mutex::new(NeedsDrop(value.clone())); + assert_eq!(value.load(Ordering::Acquire), 0); + + { + let inner = needs_drop.into_inner(); + assert_eq!(inner.0.load(Ordering::Acquire), 0); + } + + assert_eq!(value.load(Ordering::Acquire), 1); + } + + #[test] + fn rw_arc() { + const THREADS: isize = 7; + const ITERATIONS: isize = 13; + + let mu = Arc::new(Mutex::new(0isize)); + let mu2 = mu.clone(); + + let (tx, rx) = channel(); + thread::spawn(move || { + let mut guard = block_on(mu2.lock()); + for _ in 0..ITERATIONS { + let tmp = *guard; + *guard = -1; + thread::yield_now(); + *guard = tmp + 1; + } + tx.send(()).unwrap(); + }); + + let mut readers = Vec::with_capacity(10); + for _ in 0..THREADS { + let mu3 = mu.clone(); + let handle = thread::spawn(move || { + let guard = block_on(mu3.read_lock()); + assert!(*guard >= 0); + }); + + readers.push(handle); + } + + // Wait for the readers to finish their checks. + for r in readers { + r.join().expect("One or more readers saw a negative value"); + } + + // Wait for the writer to finish. + rx.recv_timeout(Duration::from_secs(5)).unwrap(); + assert_eq!(*block_on(mu.read_lock()), ITERATIONS); + assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0); + } + + #[test] + fn rw_single_thread_async() { + // A Future that returns `Poll::pending` the first time it is polled and `Poll::Ready` every + // time after that. + struct TestFuture { + polled: bool, + waker: Arc<SpinLock<Option<Waker>>>, + } + + impl Future for TestFuture { + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> { + if self.polled { + Poll::Ready(()) + } else { + self.polled = true; + *self.waker.lock() = Some(cx.waker().clone()); + Poll::Pending + } + } + } + + fn wake_future(waker: Arc<SpinLock<Option<Waker>>>) { + loop { + if let Some(w) = waker.lock().take() { + w.wake(); + return; + } + + // This sleep cannot be moved into an else branch because we would end up holding + // the lock while sleeping due to rust's drop ordering rules. + thread::sleep(Duration::from_millis(10)); + } + } + + async fn writer(mu: Rc<Mutex<isize>>) { + let mut guard = mu.lock().await; + for _ in 0..ITERATIONS { + let tmp = *guard; + *guard = -1; + let waker = Arc::new(SpinLock::new(None)); + let waker2 = Arc::clone(&waker); + thread::spawn(move || wake_future(waker2)); + let fut = TestFuture { + polled: false, + waker, + }; + fut.await; + *guard = tmp + 1; + } + } + + async fn reader(mu: Rc<Mutex<isize>>) { + let guard = mu.read_lock().await; + assert!(*guard >= 0); + } + + const TASKS: isize = 7; + const ITERATIONS: isize = 13; + + let mu = Rc::new(Mutex::new(0isize)); + let mut ex = LocalPool::new(); + let spawner = ex.spawner(); + + spawner + .spawn_local(writer(Rc::clone(&mu))) + .expect("Failed to spawn writer"); + + for _ in 0..TASKS { + spawner + .spawn_local(reader(Rc::clone(&mu))) + .expect("Failed to spawn reader"); + } + + ex.run(); + + assert_eq!(*block_on(mu.read_lock()), ITERATIONS); + assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0); + } + + #[test] + fn rw_multi_thread_async() { + async fn writer(mu: Arc<Mutex<isize>>, tx: Sender<()>) { + let mut guard = mu.lock().await; + for _ in 0..ITERATIONS { + let tmp = *guard; + *guard = -1; + thread::yield_now(); + *guard = tmp + 1; + } + + mem::drop(guard); + tx.send(()).unwrap(); + } + + async fn reader(mu: Arc<Mutex<isize>>, tx: Sender<()>) { + let guard = mu.read_lock().await; + assert!(*guard >= 0); + + mem::drop(guard); + tx.send(()).expect("Failed to send completion message"); + } + + const TASKS: isize = 7; + const ITERATIONS: isize = 13; + + let mu = Arc::new(Mutex::new(0isize)); + let ex = ThreadPool::new().expect("Failed to create ThreadPool"); + + let (txw, rxw) = channel(); + ex.spawn_ok(writer(Arc::clone(&mu), txw)); + + let (txr, rxr) = channel(); + for _ in 0..TASKS { + ex.spawn_ok(reader(Arc::clone(&mu), txr.clone())); + } + + // Wait for the readers to finish their checks. + for _ in 0..TASKS { + rxr.recv_timeout(Duration::from_secs(5)) + .expect("Failed to receive completion message from reader"); + } + + // Wait for the writer to finish. + rxw.recv_timeout(Duration::from_secs(5)) + .expect("Failed to receive completion message from writer"); + + assert_eq!(*block_on(mu.read_lock()), ITERATIONS); + assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0); + } + + #[test] + fn wake_all_readers() { + async fn read(mu: Arc<Mutex<()>>) { + let g = mu.read_lock().await; + pending!(); + mem::drop(g); + } + + async fn write(mu: Arc<Mutex<()>>) { + mu.lock().await; + } + + let mu = Arc::new(Mutex::new(())); + let mut futures: [Pin<Box<dyn Future<Output = ()>>>; 5] = [ + Box::pin(read(mu.clone())), + Box::pin(read(mu.clone())), + Box::pin(read(mu.clone())), + Box::pin(write(mu.clone())), + Box::pin(read(mu.clone())), + ]; + const NUM_READERS: usize = 4; + + let arc_waker = Arc::new(TestWaker); + let waker = waker_ref(&arc_waker); + let mut cx = Context::from_waker(&waker); + + // Acquire the lock so that the futures cannot get it. + let g = block_on(mu.lock()); + + for r in &mut futures { + if let Poll::Ready(()) = r.as_mut().poll(&mut cx) { + panic!("future unexpectedly ready"); + } + } + + assert_eq!( + mu.raw.state.load(Ordering::Relaxed) & HAS_WAITERS, + HAS_WAITERS + ); + + assert_eq!( + mu.raw.state.load(Ordering::Relaxed) & WRITER_WAITING, + WRITER_WAITING + ); + + // Drop the lock. This should allow all readers to make progress. Since they already waited + // once they should ignore the WRITER_WAITING bit that is currently set. + mem::drop(g); + for r in &mut futures { + if let Poll::Ready(()) = r.as_mut().poll(&mut cx) { + panic!("future unexpectedly ready"); + } + } + + // Check that all readers were able to acquire the lock. + assert_eq!( + mu.raw.state.load(Ordering::Relaxed) & READ_MASK, + READ_LOCK * NUM_READERS + ); + assert_eq!( + mu.raw.state.load(Ordering::Relaxed) & WRITER_WAITING, + WRITER_WAITING + ); + + let mut needs_poll = None; + + // All the readers can now finish but the writer needs to be polled again. + for (i, r) in futures.iter_mut().enumerate() { + match r.as_mut().poll(&mut cx) { + Poll::Ready(()) => {} + Poll::Pending => { + if needs_poll.is_some() { + panic!("More than one future unable to complete"); + } + needs_poll = Some(i); + } + } + } + + if futures[needs_poll.expect("Writer unexpectedly able to complete")] + .as_mut() + .poll(&mut cx) + .is_pending() + { + panic!("Writer unable to complete"); + } + + assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0); + } + + #[test] + fn long_wait() { + async fn tight_loop(mu: Arc<Mutex<bool>>) { + loop { + let ready = mu.lock().await; + if *ready { + break; + } + pending!(); + } + } + + async fn mark_ready(mu: Arc<Mutex<bool>>) { + *mu.lock().await = true; + } + + let mu = Arc::new(Mutex::new(false)); + let mut tl = Box::pin(tight_loop(mu.clone())); + let mut mark = Box::pin(mark_ready(mu.clone())); + + let arc_waker = Arc::new(TestWaker); + let waker = waker_ref(&arc_waker); + let mut cx = Context::from_waker(&waker); + + for _ in 0..=LONG_WAIT_THRESHOLD { + if let Poll::Ready(()) = tl.as_mut().poll(&mut cx) { + panic!("tight_loop unexpectedly ready"); + } + + if let Poll::Ready(()) = mark.as_mut().poll(&mut cx) { + panic!("mark_ready unexpectedly ready"); + } + } + + assert_eq!( + mu.raw.state.load(Ordering::Relaxed), + LOCKED | HAS_WAITERS | WRITER_WAITING | LONG_WAIT + ); + + // This time the tight loop will fail to acquire the lock. + if let Poll::Ready(()) = tl.as_mut().poll(&mut cx) { + panic!("tight_loop unexpectedly ready"); + } + + // Which will finally allow the mark_ready function to make progress. + if mark.as_mut().poll(&mut cx).is_pending() { + panic!("mark_ready not able to make progress"); + } + + // Now the tight loop will finish. + if tl.as_mut().poll(&mut cx).is_pending() { + panic!("tight_loop not able to finish"); + } + + assert!(*block_on(mu.lock())); + assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0); + } + + #[test] + fn cancel_long_wait_before_wake() { + async fn tight_loop(mu: Arc<Mutex<bool>>) { + loop { + let ready = mu.lock().await; + if *ready { + break; + } + pending!(); + } + } + + async fn mark_ready(mu: Arc<Mutex<bool>>) { + *mu.lock().await = true; + } + + let mu = Arc::new(Mutex::new(false)); + let mut tl = Box::pin(tight_loop(mu.clone())); + let mut mark = Box::pin(mark_ready(mu.clone())); + + let arc_waker = Arc::new(TestWaker); + let waker = waker_ref(&arc_waker); + let mut cx = Context::from_waker(&waker); + + for _ in 0..=LONG_WAIT_THRESHOLD { + if let Poll::Ready(()) = tl.as_mut().poll(&mut cx) { + panic!("tight_loop unexpectedly ready"); + } + + if let Poll::Ready(()) = mark.as_mut().poll(&mut cx) { + panic!("mark_ready unexpectedly ready"); + } + } + + assert_eq!( + mu.raw.state.load(Ordering::Relaxed), + LOCKED | HAS_WAITERS | WRITER_WAITING | LONG_WAIT + ); + + // Now drop the mark_ready future, which should clear the LONG_WAIT bit. + mem::drop(mark); + assert_eq!(mu.raw.state.load(Ordering::Relaxed), LOCKED); + + mem::drop(tl); + assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0); + } + + #[test] + fn cancel_long_wait_after_wake() { + async fn tight_loop(mu: Arc<Mutex<bool>>) { + loop { + let ready = mu.lock().await; + if *ready { + break; + } + pending!(); + } + } + + async fn mark_ready(mu: Arc<Mutex<bool>>) { + *mu.lock().await = true; + } + + let mu = Arc::new(Mutex::new(false)); + let mut tl = Box::pin(tight_loop(mu.clone())); + let mut mark = Box::pin(mark_ready(mu.clone())); + + let arc_waker = Arc::new(TestWaker); + let waker = waker_ref(&arc_waker); + let mut cx = Context::from_waker(&waker); + + for _ in 0..=LONG_WAIT_THRESHOLD { + if let Poll::Ready(()) = tl.as_mut().poll(&mut cx) { + panic!("tight_loop unexpectedly ready"); + } + + if let Poll::Ready(()) = mark.as_mut().poll(&mut cx) { + panic!("mark_ready unexpectedly ready"); + } + } + + assert_eq!( + mu.raw.state.load(Ordering::Relaxed), + LOCKED | HAS_WAITERS | WRITER_WAITING | LONG_WAIT + ); + + // This time the tight loop will fail to acquire the lock. + if let Poll::Ready(()) = tl.as_mut().poll(&mut cx) { + panic!("tight_loop unexpectedly ready"); + } + + // Now drop the mark_ready future, which should clear the LONG_WAIT bit. + mem::drop(mark); + assert_eq!(mu.raw.state.load(Ordering::Relaxed) & LONG_WAIT, 0); + + // Since the lock is not held, we should be able to spawn a future to set the ready flag. + block_on(mark_ready(mu.clone())); + + // Now the tight loop will finish. + if tl.as_mut().poll(&mut cx).is_pending() { + panic!("tight_loop not able to finish"); + } + + assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0); + } + + #[test] + fn designated_waker() { + async fn inc(mu: Arc<Mutex<usize>>) { + *mu.lock().await += 1; + } + + let mu = Arc::new(Mutex::new(0)); + + let mut futures = [ + Box::pin(inc(mu.clone())), + Box::pin(inc(mu.clone())), + Box::pin(inc(mu.clone())), + ]; + + let arc_waker = Arc::new(TestWaker); + let waker = waker_ref(&arc_waker); + let mut cx = Context::from_waker(&waker); + + let count = block_on(mu.lock()); + + // Poll 2 futures. Since neither will be able to acquire the lock, they should get added to + // the waiter list. + if let Poll::Ready(()) = futures[0].as_mut().poll(&mut cx) { + panic!("future unexpectedly ready"); + } + if let Poll::Ready(()) = futures[1].as_mut().poll(&mut cx) { + panic!("future unexpectedly ready"); + } + + assert_eq!( + mu.raw.state.load(Ordering::Relaxed), + LOCKED | HAS_WAITERS | WRITER_WAITING, + ); + + // Now drop the lock. This should set the DESIGNATED_WAKER bit and wake up the first future + // in the wait list. + mem::drop(count); + + assert_eq!( + mu.raw.state.load(Ordering::Relaxed), + DESIGNATED_WAKER | HAS_WAITERS | WRITER_WAITING, + ); + + // Now poll the third future. It should be able to acquire the lock immediately. + if futures[2].as_mut().poll(&mut cx).is_pending() { + panic!("future unable to complete"); + } + assert_eq!(*block_on(mu.lock()), 1); + + // There should still be a waiter in the wait list and the DESIGNATED_WAKER bit should still + // be set. + assert_eq!( + mu.raw.state.load(Ordering::Relaxed) & DESIGNATED_WAKER, + DESIGNATED_WAKER + ); + assert_eq!( + mu.raw.state.load(Ordering::Relaxed) & HAS_WAITERS, + HAS_WAITERS + ); + + // Now let the future that was woken up run. + if futures[0].as_mut().poll(&mut cx).is_pending() { + panic!("future unable to complete"); + } + assert_eq!(*block_on(mu.lock()), 2); + + if futures[1].as_mut().poll(&mut cx).is_pending() { + panic!("future unable to complete"); + } + assert_eq!(*block_on(mu.lock()), 3); + + assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0); + } + + #[test] + fn cancel_designated_waker() { + async fn inc(mu: Arc<Mutex<usize>>) { + *mu.lock().await += 1; + } + + let mu = Arc::new(Mutex::new(0)); + + let mut fut = Box::pin(inc(mu.clone())); + + let arc_waker = Arc::new(TestWaker); + let waker = waker_ref(&arc_waker); + let mut cx = Context::from_waker(&waker); + + let count = block_on(mu.lock()); + + if let Poll::Ready(()) = fut.as_mut().poll(&mut cx) { + panic!("Future unexpectedly ready when lock is held"); + } + + // Drop the lock. This will wake up the future. + mem::drop(count); + + // Now drop the future without polling. This should clear all the state in the mutex. + mem::drop(fut); + + assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0); + } + + #[test] + fn cancel_before_wake() { + async fn inc(mu: Arc<Mutex<usize>>) { + *mu.lock().await += 1; + } + + let mu = Arc::new(Mutex::new(0)); + + let mut fut1 = Box::pin(inc(mu.clone())); + + let mut fut2 = Box::pin(inc(mu.clone())); + + let arc_waker = Arc::new(TestWaker); + let waker = waker_ref(&arc_waker); + let mut cx = Context::from_waker(&waker); + + // First acquire the lock. + let count = block_on(mu.lock()); + + // Now poll the futures. Since the lock is acquired they will both get queued in the waiter + // list. + match fut1.as_mut().poll(&mut cx) { + Poll::Pending => {} + Poll::Ready(()) => panic!("Future is unexpectedly ready"), + } + + match fut2.as_mut().poll(&mut cx) { + Poll::Pending => {} + Poll::Ready(()) => panic!("Future is unexpectedly ready"), + } + + assert_eq!( + mu.raw.state.load(Ordering::Relaxed) & WRITER_WAITING, + WRITER_WAITING + ); + + // Drop fut1. This should remove it from the waiter list but shouldn't wake fut2. + mem::drop(fut1); + + // There should be no designated waker. + assert_eq!(mu.raw.state.load(Ordering::Relaxed) & DESIGNATED_WAKER, 0); + + // Since the waiter was a writer, we should clear the WRITER_WAITING bit. + assert_eq!(mu.raw.state.load(Ordering::Relaxed) & WRITER_WAITING, 0); + + match fut2.as_mut().poll(&mut cx) { + Poll::Pending => {} + Poll::Ready(()) => panic!("Future is unexpectedly ready"), + } + + // Now drop the lock. This should mark fut2 as ready to make progress. + mem::drop(count); + + match fut2.as_mut().poll(&mut cx) { + Poll::Pending => panic!("Future is not ready to make progress"), + Poll::Ready(()) => {} + } + + // Verify that we only incremented the count once. + assert_eq!(*block_on(mu.lock()), 1); + assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0); + } + + #[test] + fn cancel_after_wake() { + async fn inc(mu: Arc<Mutex<usize>>) { + *mu.lock().await += 1; + } + + let mu = Arc::new(Mutex::new(0)); + + let mut fut1 = Box::pin(inc(mu.clone())); + + let mut fut2 = Box::pin(inc(mu.clone())); + + let arc_waker = Arc::new(TestWaker); + let waker = waker_ref(&arc_waker); + let mut cx = Context::from_waker(&waker); + + // First acquire the lock. + let count = block_on(mu.lock()); + + // Now poll the futures. Since the lock is acquired they will both get queued in the waiter + // list. + match fut1.as_mut().poll(&mut cx) { + Poll::Pending => {} + Poll::Ready(()) => panic!("Future is unexpectedly ready"), + } + + match fut2.as_mut().poll(&mut cx) { + Poll::Pending => {} + Poll::Ready(()) => panic!("Future is unexpectedly ready"), + } + + assert_eq!( + mu.raw.state.load(Ordering::Relaxed) & WRITER_WAITING, + WRITER_WAITING + ); + + // Drop the lock. This should mark fut1 as ready to make progress. + mem::drop(count); + + // Now drop fut1. This should make fut2 ready to make progress. + mem::drop(fut1); + + // Since there was still another waiter in the list we shouldn't have cleared the + // DESIGNATED_WAKER bit. + assert_eq!( + mu.raw.state.load(Ordering::Relaxed) & DESIGNATED_WAKER, + DESIGNATED_WAKER + ); + + // Since the waiter was a writer, we should clear the WRITER_WAITING bit. + assert_eq!(mu.raw.state.load(Ordering::Relaxed) & WRITER_WAITING, 0); + + match fut2.as_mut().poll(&mut cx) { + Poll::Pending => panic!("Future is not ready to make progress"), + Poll::Ready(()) => {} + } + + // Verify that we only incremented the count once. + assert_eq!(*block_on(mu.lock()), 1); + assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0); + } + + #[test] + fn timeout() { + async fn timed_lock(timer: oneshot::Receiver<()>, mu: Arc<Mutex<()>>) { + select! { + res = timer.fuse() => { + match res { + Ok(()) => {}, + Err(e) => panic!("Timer unexpectedly canceled: {}", e), + } + } + _ = mu.lock().fuse() => panic!("Successfuly acquired lock"), + } + } + + let mu = Arc::new(Mutex::new(())); + let (tx, rx) = oneshot::channel(); + + let mut timeout = Box::pin(timed_lock(rx, mu.clone())); + + let arc_waker = Arc::new(TestWaker); + let waker = waker_ref(&arc_waker); + let mut cx = Context::from_waker(&waker); + + // Acquire the lock. + let g = block_on(mu.lock()); + + // Poll the future. + if let Poll::Ready(()) = timeout.as_mut().poll(&mut cx) { + panic!("timed_lock unexpectedly ready"); + } + + assert_eq!( + mu.raw.state.load(Ordering::Relaxed) & HAS_WAITERS, + HAS_WAITERS + ); + + // Signal the channel, which should cancel the lock. + tx.send(()).expect("Failed to send wakeup"); + + // Now the future should have completed without acquiring the lock. + if timeout.as_mut().poll(&mut cx).is_pending() { + panic!("timed_lock not ready after timeout"); + } + + // The mutex state should not show any waiters. + assert_eq!(mu.raw.state.load(Ordering::Relaxed) & HAS_WAITERS, 0); + + mem::drop(g); + + assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0); + } + + #[test] + fn writer_waiting() { + async fn read_zero(mu: Arc<Mutex<usize>>) { + let val = mu.read_lock().await; + pending!(); + + assert_eq!(*val, 0); + } + + async fn inc(mu: Arc<Mutex<usize>>) { + *mu.lock().await += 1; + } + + async fn read_one(mu: Arc<Mutex<usize>>) { + let val = mu.read_lock().await; + + assert_eq!(*val, 1); + } + + let mu = Arc::new(Mutex::new(0)); + + let mut r1 = Box::pin(read_zero(mu.clone())); + let mut r2 = Box::pin(read_zero(mu.clone())); + + let mut w = Box::pin(inc(mu.clone())); + let mut r3 = Box::pin(read_one(mu.clone())); + + let arc_waker = Arc::new(TestWaker); + let waker = waker_ref(&arc_waker); + let mut cx = Context::from_waker(&waker); + + if let Poll::Ready(()) = r1.as_mut().poll(&mut cx) { + panic!("read_zero unexpectedly ready"); + } + if let Poll::Ready(()) = r2.as_mut().poll(&mut cx) { + panic!("read_zero unexpectedly ready"); + } + assert_eq!( + mu.raw.state.load(Ordering::Relaxed) & READ_MASK, + 2 * READ_LOCK + ); + + if let Poll::Ready(()) = w.as_mut().poll(&mut cx) { + panic!("inc unexpectedly ready"); + } + assert_eq!( + mu.raw.state.load(Ordering::Relaxed) & WRITER_WAITING, + WRITER_WAITING + ); + + // The WRITER_WAITING bit should prevent the next reader from acquiring the lock. + if let Poll::Ready(()) = r3.as_mut().poll(&mut cx) { + panic!("read_one unexpectedly ready"); + } + assert_eq!( + mu.raw.state.load(Ordering::Relaxed) & READ_MASK, + 2 * READ_LOCK + ); + + if r1.as_mut().poll(&mut cx).is_pending() { + panic!("read_zero unable to complete"); + } + if r2.as_mut().poll(&mut cx).is_pending() { + panic!("read_zero unable to complete"); + } + if w.as_mut().poll(&mut cx).is_pending() { + panic!("inc unable to complete"); + } + if r3.as_mut().poll(&mut cx).is_pending() { + panic!("read_one unable to complete"); + } + + assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0); + } + + #[test] + fn notify_one() { + async fn read(mu: Arc<Mutex<usize>>, cv: Arc<Condvar>) { + let mut count = mu.read_lock().await; + while *count == 0 { + count = cv.wait_read(count).await; + } + } + + async fn write(mu: Arc<Mutex<usize>>, cv: Arc<Condvar>) { + let mut count = mu.lock().await; + while *count == 0 { + count = cv.wait(count).await; + } + + *count -= 1; + } + + let mu = Arc::new(Mutex::new(0)); + let cv = Arc::new(Condvar::new()); + + let arc_waker = Arc::new(TestWaker); + let waker = waker_ref(&arc_waker); + let mut cx = Context::from_waker(&waker); + + let mut readers = [ + Box::pin(read(mu.clone(), cv.clone())), + Box::pin(read(mu.clone(), cv.clone())), + Box::pin(read(mu.clone(), cv.clone())), + Box::pin(read(mu.clone(), cv.clone())), + ]; + let mut writer = Box::pin(write(mu.clone(), cv.clone())); + + for r in &mut readers { + if let Poll::Ready(()) = r.as_mut().poll(&mut cx) { + panic!("reader unexpectedly ready"); + } + } + if let Poll::Ready(()) = writer.as_mut().poll(&mut cx) { + panic!("writer unexpectedly ready"); + } + + let mut count = block_on(mu.lock()); + *count = 1; + + // This should wake all readers + one writer. + cv.notify_one(); + + // Poll the readers and the writer so they add themselves to the mutex's waiter list. + for r in &mut readers { + if r.as_mut().poll(&mut cx).is_ready() { + panic!("reader unexpectedly ready"); + } + } + + if writer.as_mut().poll(&mut cx).is_ready() { + panic!("writer unexpectedly ready"); + } + + assert_eq!( + mu.raw.state.load(Ordering::Relaxed) & HAS_WAITERS, + HAS_WAITERS + ); + assert_eq!( + mu.raw.state.load(Ordering::Relaxed) & WRITER_WAITING, + WRITER_WAITING + ); + + mem::drop(count); + + assert_eq!( + mu.raw.state.load(Ordering::Relaxed) & (HAS_WAITERS | WRITER_WAITING), + HAS_WAITERS | WRITER_WAITING + ); + + for r in &mut readers { + if r.as_mut().poll(&mut cx).is_pending() { + panic!("reader unable to complete"); + } + } + + if writer.as_mut().poll(&mut cx).is_pending() { + panic!("writer unable to complete"); + } + + assert_eq!(*block_on(mu.read_lock()), 0); + } + + #[test] + fn notify_when_unlocked() { + async fn dec(mu: Arc<Mutex<usize>>, cv: Arc<Condvar>) { + let mut count = mu.lock().await; + + while *count == 0 { + count = cv.wait(count).await; + } + + *count -= 1; + } + + let mu = Arc::new(Mutex::new(0)); + let cv = Arc::new(Condvar::new()); + + let arc_waker = Arc::new(TestWaker); + let waker = waker_ref(&arc_waker); + let mut cx = Context::from_waker(&waker); + + let mut futures = [ + Box::pin(dec(mu.clone(), cv.clone())), + Box::pin(dec(mu.clone(), cv.clone())), + Box::pin(dec(mu.clone(), cv.clone())), + Box::pin(dec(mu.clone(), cv.clone())), + ]; + + for f in &mut futures { + if let Poll::Ready(()) = f.as_mut().poll(&mut cx) { + panic!("future unexpectedly ready"); + } + } + + *block_on(mu.lock()) = futures.len(); + cv.notify_all(); + + // Since we haven't polled `futures` yet, the mutex should not have any waiters. + assert_eq!(mu.raw.state.load(Ordering::Relaxed) & HAS_WAITERS, 0); + + for f in &mut futures { + if f.as_mut().poll(&mut cx).is_pending() { + panic!("future unexpectedly ready"); + } + } + assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0); + } + + #[test] + fn notify_reader_writer() { + async fn read(mu: Arc<Mutex<usize>>, cv: Arc<Condvar>) { + let mut count = mu.read_lock().await; + while *count == 0 { + count = cv.wait_read(count).await; + } + + // Yield once while holding the read lock, which should prevent the writer from waking + // up. + pending!(); + } + + async fn write(mu: Arc<Mutex<usize>>, cv: Arc<Condvar>) { + let mut count = mu.lock().await; + while *count == 0 { + count = cv.wait(count).await; + } + + *count -= 1; + } + + async fn lock(mu: Arc<Mutex<usize>>) { + mem::drop(mu.lock().await); + } + + let mu = Arc::new(Mutex::new(0)); + let cv = Arc::new(Condvar::new()); + + let arc_waker = Arc::new(TestWaker); + let waker = waker_ref(&arc_waker); + let mut cx = Context::from_waker(&waker); + + let mut futures: [Pin<Box<dyn Future<Output = ()>>>; 5] = [ + Box::pin(read(mu.clone(), cv.clone())), + Box::pin(read(mu.clone(), cv.clone())), + Box::pin(read(mu.clone(), cv.clone())), + Box::pin(write(mu.clone(), cv.clone())), + Box::pin(read(mu.clone(), cv.clone())), + ]; + const NUM_READERS: usize = 4; + + let mut l = Box::pin(lock(mu.clone())); + + for f in &mut futures { + if let Poll::Ready(()) = f.as_mut().poll(&mut cx) { + panic!("future unexpectedly ready"); + } + } + + assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0); + + let mut count = block_on(mu.lock()); + *count = 1; + + // Now poll the lock function. Since the lock is held by us, it will get queued on the + // waiter list. + if let Poll::Ready(()) = l.as_mut().poll(&mut cx) { + panic!("lock() unexpectedly ready"); + } + + assert_eq!( + mu.raw.state.load(Ordering::Relaxed) & (HAS_WAITERS | WRITER_WAITING), + HAS_WAITERS | WRITER_WAITING + ); + + // Wake up waiters while holding the lock. + cv.notify_all(); + + // Drop the lock. This should wake up the lock function. + mem::drop(count); + + if l.as_mut().poll(&mut cx).is_pending() { + panic!("lock() unable to complete"); + } + + // Since we haven't polled `futures` yet, the mutex state should now be empty. + assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0); + + // Poll everything again. The readers should be able to make progress (but not complete) but + // the writer should be blocked. + for f in &mut futures { + if let Poll::Ready(()) = f.as_mut().poll(&mut cx) { + panic!("future unexpectedly ready"); + } + } + + assert_eq!( + mu.raw.state.load(Ordering::Relaxed) & READ_MASK, + READ_LOCK * NUM_READERS + ); + + // All the readers can now finish but the writer needs to be polled again. + let mut needs_poll = None; + for (i, r) in futures.iter_mut().enumerate() { + match r.as_mut().poll(&mut cx) { + Poll::Ready(()) => {} + Poll::Pending => { + if needs_poll.is_some() { + panic!("More than one future unable to complete"); + } + needs_poll = Some(i); + } + } + } + + if futures[needs_poll.expect("Writer unexpectedly able to complete")] + .as_mut() + .poll(&mut cx) + .is_pending() + { + panic!("Writer unable to complete"); + } + + assert_eq!(*block_on(mu.lock()), 0); + assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0); + } + + #[test] + fn notify_readers_with_read_lock() { + async fn read(mu: Arc<Mutex<usize>>, cv: Arc<Condvar>) { + let mut count = mu.read_lock().await; + while *count == 0 { + count = cv.wait_read(count).await; + } + + // Yield once while holding the read lock. + pending!(); + } + + let mu = Arc::new(Mutex::new(0)); + let cv = Arc::new(Condvar::new()); + + let arc_waker = Arc::new(TestWaker); + let waker = waker_ref(&arc_waker); + let mut cx = Context::from_waker(&waker); + + let mut futures = [ + Box::pin(read(mu.clone(), cv.clone())), + Box::pin(read(mu.clone(), cv.clone())), + Box::pin(read(mu.clone(), cv.clone())), + Box::pin(read(mu.clone(), cv.clone())), + ]; + + for f in &mut futures { + if let Poll::Ready(()) = f.as_mut().poll(&mut cx) { + panic!("future unexpectedly ready"); + } + } + + // Increment the count and then grab a read lock. + *block_on(mu.lock()) = 1; + + let g = block_on(mu.read_lock()); + + // Notify the condvar while holding the read lock. This should wake up all the waiters. + cv.notify_all(); + + // Since the lock is held in shared mode, all the readers should immediately be able to + // acquire the read lock. + for f in &mut futures { + if let Poll::Ready(()) = f.as_mut().poll(&mut cx) { + panic!("future unexpectedly ready"); + } + } + assert_eq!(mu.raw.state.load(Ordering::Relaxed) & HAS_WAITERS, 0); + assert_eq!( + mu.raw.state.load(Ordering::Relaxed) & READ_MASK, + READ_LOCK * (futures.len() + 1) + ); + + mem::drop(g); + + for f in &mut futures { + if f.as_mut().poll(&mut cx).is_pending() { + panic!("future unable to complete"); + } + } + + assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0); + } +} diff --git a/cros_async/src/sync/spin.rs b/cros_async/src/sync/spin.rs new file mode 100644 index 000000000..8b7b81193 --- /dev/null +++ b/cros_async/src/sync/spin.rs @@ -0,0 +1,277 @@ +// Copyright 2020 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +use std::cell::UnsafeCell; +use std::ops::{Deref, DerefMut}; +use std::sync::atomic::{spin_loop_hint, AtomicBool, Ordering}; + +const UNLOCKED: bool = false; +const LOCKED: bool = true; + +/// A primitive that provides safe, mutable access to a shared resource. +/// +/// Unlike `Mutex`, a `SpinLock` will not voluntarily yield its CPU time until the resource is +/// available and will instead keep spinning until the resource is acquired. For the vast majority +/// of cases, `Mutex` is a better choice than `SpinLock`. If a `SpinLock` must be used then users +/// should try to do as little work as possible while holding the `SpinLock` and avoid any sort of +/// blocking at all costs as it can severely penalize performance. +/// +/// # Poisoning +/// +/// This `SpinLock` does not implement lock poisoning so it is possible for threads to access +/// poisoned data if a thread panics while holding the lock. If lock poisoning is needed, it can be +/// implemented by wrapping the `SpinLock` in a new type that implements poisoning. See the +/// implementation of `std::sync::Mutex` for an example of how to do this. +#[repr(align(128))] +pub struct SpinLock<T: ?Sized> { + lock: AtomicBool, + value: UnsafeCell<T>, +} + +impl<T> SpinLock<T> { + /// Creates a new, unlocked `SpinLock` that's ready for use. + pub fn new(value: T) -> SpinLock<T> { + SpinLock { + lock: AtomicBool::new(UNLOCKED), + value: UnsafeCell::new(value), + } + } + + /// Consumes the `SpinLock` and returns the value guarded by it. This method doesn't perform any + /// locking as the compiler guarantees that there are no references to `self`. + pub fn into_inner(self) -> T { + // No need to take the lock because the compiler can statically guarantee + // that there are no references to the SpinLock. + self.value.into_inner() + } +} + +impl<T: ?Sized> SpinLock<T> { + /// Acquires exclusive, mutable access to the resource protected by the `SpinLock`, blocking the + /// current thread until it is able to do so. Upon returning, the current thread will be the + /// only thread with access to the resource. The `SpinLock` will be released when the returned + /// `SpinLockGuard` is dropped. Attempting to call `lock` while already holding the `SpinLock` + /// will cause a deadlock. + pub fn lock(&self) -> SpinLockGuard<T> { + loop { + let state = self.lock.load(Ordering::Relaxed); + if state == UNLOCKED + && self + .lock + .compare_exchange_weak(UNLOCKED, LOCKED, Ordering::Acquire, Ordering::Relaxed) + .is_ok() + { + break; + } + spin_loop_hint(); + } + + SpinLockGuard { + lock: self, + value: unsafe { &mut *self.value.get() }, + } + } + + fn unlock(&self) { + // Don't need to compare and swap because we exclusively hold the lock. + self.lock.store(UNLOCKED, Ordering::Release); + } + + /// Returns a mutable reference to the contained value. This method doesn't perform any locking + /// as the compiler will statically guarantee that there are no other references to `self`. + pub fn get_mut(&mut self) -> &mut T { + // Safe because the compiler can statically guarantee that there are no other references to + // `self`. This is also why we don't need to acquire the lock. + unsafe { &mut *self.value.get() } + } +} + +unsafe impl<T: ?Sized + Send> Send for SpinLock<T> {} +unsafe impl<T: ?Sized + Send> Sync for SpinLock<T> {} + +impl<T: ?Sized + Default> Default for SpinLock<T> { + fn default() -> Self { + Self::new(Default::default()) + } +} + +impl<T> From<T> for SpinLock<T> { + fn from(source: T) -> Self { + Self::new(source) + } +} + +/// An RAII implementation of a "scoped lock" for a `SpinLock`. When this structure is dropped, the +/// lock will be released. The resource protected by the `SpinLock` can be accessed via the `Deref` +/// and `DerefMut` implementations of this structure. +pub struct SpinLockGuard<'a, T: 'a + ?Sized> { + lock: &'a SpinLock<T>, + value: &'a mut T, +} + +impl<'a, T: ?Sized> Deref for SpinLockGuard<'a, T> { + type Target = T; + fn deref(&self) -> &T { + self.value + } +} + +impl<'a, T: ?Sized> DerefMut for SpinLockGuard<'a, T> { + fn deref_mut(&mut self) -> &mut T { + self.value + } +} + +impl<'a, T: ?Sized> Drop for SpinLockGuard<'a, T> { + fn drop(&mut self) { + self.lock.unlock(); + } +} + +#[cfg(test)] +mod test { + use super::*; + + use std::mem; + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::sync::Arc; + use std::thread; + + #[derive(PartialEq, Eq, Debug)] + struct NonCopy(u32); + + #[test] + fn it_works() { + let sl = SpinLock::new(NonCopy(13)); + + assert_eq!(*sl.lock(), NonCopy(13)); + } + + #[test] + fn smoke() { + let sl = SpinLock::new(NonCopy(7)); + + mem::drop(sl.lock()); + mem::drop(sl.lock()); + } + + #[test] + fn send() { + let sl = SpinLock::new(NonCopy(19)); + + thread::spawn(move || { + let value = sl.lock(); + assert_eq!(*value, NonCopy(19)); + }) + .join() + .unwrap(); + } + + #[test] + fn high_contention() { + const THREADS: usize = 23; + const ITERATIONS: usize = 101; + + let mut threads = Vec::with_capacity(THREADS); + + let sl = Arc::new(SpinLock::new(0usize)); + for _ in 0..THREADS { + let sl2 = sl.clone(); + threads.push(thread::spawn(move || { + for _ in 0..ITERATIONS { + *sl2.lock() += 1; + } + })); + } + + for t in threads.into_iter() { + t.join().unwrap(); + } + + assert_eq!(*sl.lock(), THREADS * ITERATIONS); + } + + #[test] + fn get_mut() { + let mut sl = SpinLock::new(NonCopy(13)); + *sl.get_mut() = NonCopy(17); + + assert_eq!(sl.into_inner(), NonCopy(17)); + } + + #[test] + fn into_inner() { + let sl = SpinLock::new(NonCopy(29)); + assert_eq!(sl.into_inner(), NonCopy(29)); + } + + #[test] + fn into_inner_drop() { + struct NeedsDrop(Arc<AtomicUsize>); + impl Drop for NeedsDrop { + fn drop(&mut self) { + self.0.fetch_add(1, Ordering::AcqRel); + } + } + + let value = Arc::new(AtomicUsize::new(0)); + let needs_drop = SpinLock::new(NeedsDrop(value.clone())); + assert_eq!(value.load(Ordering::Acquire), 0); + + { + let inner = needs_drop.into_inner(); + assert_eq!(inner.0.load(Ordering::Acquire), 0); + } + + assert_eq!(value.load(Ordering::Acquire), 1); + } + + #[test] + fn arc_nested() { + // Tests nested sltexes and access to underlying data. + let sl = SpinLock::new(1); + let arc = Arc::new(SpinLock::new(sl)); + thread::spawn(move || { + let nested = arc.lock(); + let lock2 = nested.lock(); + assert_eq!(*lock2, 1); + }) + .join() + .unwrap(); + } + + #[test] + fn arc_access_in_unwind() { + let arc = Arc::new(SpinLock::new(1)); + let arc2 = arc.clone(); + thread::spawn(move || { + struct Unwinder { + i: Arc<SpinLock<i32>>, + } + impl Drop for Unwinder { + fn drop(&mut self) { + *self.i.lock() += 1; + } + } + let _u = Unwinder { i: arc2 }; + panic!(); + }) + .join() + .expect_err("thread did not panic"); + let lock = arc.lock(); + assert_eq!(*lock, 2); + } + + #[test] + fn unsized_value() { + let sltex: &SpinLock<[i32]> = &SpinLock::new([1, 2, 3]); + { + let b = &mut *sltex.lock(); + b[0] = 4; + b[2] = 5; + } + let expected: &[i32] = &[4, 2, 5]; + assert_eq!(&*sltex.lock(), expected); + } +} diff --git a/cros_async/src/sync/waiter.rs b/cros_async/src/sync/waiter.rs new file mode 100644 index 000000000..072a0f506 --- /dev/null +++ b/cros_async/src/sync/waiter.rs @@ -0,0 +1,281 @@ +// Copyright 2020 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +use std::cell::UnsafeCell; +use std::future::Future; +use std::mem; +use std::pin::Pin; +use std::ptr::NonNull; +use std::sync::atomic::{AtomicBool, AtomicU8, Ordering}; +use std::sync::Arc; +use std::task::{Context, Poll, Waker}; + +use intrusive_collections::linked_list::{LinkedList, LinkedListOps}; +use intrusive_collections::{intrusive_adapter, DefaultLinkOps, LinkOps}; + +use crate::sync::SpinLock; + +// An atomic version of a LinkedListLink. See https://github.com/Amanieu/intrusive-rs/issues/47 for +// more details. +#[repr(align(128))] +pub struct AtomicLink { + prev: UnsafeCell<Option<NonNull<AtomicLink>>>, + next: UnsafeCell<Option<NonNull<AtomicLink>>>, + linked: AtomicBool, +} + +impl AtomicLink { + fn new() -> AtomicLink { + AtomicLink { + linked: AtomicBool::new(false), + prev: UnsafeCell::new(None), + next: UnsafeCell::new(None), + } + } + + fn is_linked(&self) -> bool { + self.linked.load(Ordering::Relaxed) + } +} + +impl DefaultLinkOps for AtomicLink { + type Ops = AtomicLinkOps; + + const NEW: Self::Ops = AtomicLinkOps; +} + +// Safe because the only way to mutate `AtomicLink` is via the `LinkedListOps` trait whose methods +// are all unsafe and require that the caller has first called `acquire_link` (and had it return +// true) to use them safely. +unsafe impl Send for AtomicLink {} +unsafe impl Sync for AtomicLink {} + +#[derive(Copy, Clone, Default)] +pub struct AtomicLinkOps; + +unsafe impl LinkOps for AtomicLinkOps { + type LinkPtr = NonNull<AtomicLink>; + + unsafe fn acquire_link(&mut self, ptr: Self::LinkPtr) -> bool { + !ptr.as_ref().linked.swap(true, Ordering::Acquire) + } + + unsafe fn release_link(&mut self, ptr: Self::LinkPtr) { + ptr.as_ref().linked.store(false, Ordering::Release) + } +} + +unsafe impl LinkedListOps for AtomicLinkOps { + unsafe fn next(&self, ptr: Self::LinkPtr) -> Option<Self::LinkPtr> { + *ptr.as_ref().next.get() + } + + unsafe fn prev(&self, ptr: Self::LinkPtr) -> Option<Self::LinkPtr> { + *ptr.as_ref().prev.get() + } + + unsafe fn set_next(&mut self, ptr: Self::LinkPtr, next: Option<Self::LinkPtr>) { + *ptr.as_ref().next.get() = next; + } + + unsafe fn set_prev(&mut self, ptr: Self::LinkPtr, prev: Option<Self::LinkPtr>) { + *ptr.as_ref().prev.get() = prev; + } +} + +#[derive(Clone, Copy)] +pub enum Kind { + Shared, + Exclusive, +} + +enum State { + Init, + Waiting(Waker), + Woken, + Finished, + Processing, +} + +// Indicates the queue to which the waiter belongs. It is the responsibility of the Mutex and +// Condvar implementations to update this value when adding/removing a Waiter from their respective +// waiter lists. +#[repr(u8)] +#[derive(Debug, Eq, PartialEq)] +pub enum WaitingFor { + // The waiter is either not linked into a waiter list or it is linked into a temporary list. + None = 0, + // The waiter is linked into the Mutex's waiter list. + Mutex = 1, + // The waiter is linked into the Condvar's waiter list. + Condvar = 2, +} + +// Represents a thread currently blocked on a Condvar or on acquiring a Mutex. +pub struct Waiter { + link: AtomicLink, + state: SpinLock<State>, + cancel: fn(usize, &Waiter, bool), + cancel_data: usize, + kind: Kind, + waiting_for: AtomicU8, +} + +impl Waiter { + // Create a new, initialized Waiter. + // + // `kind` should indicate whether this waiter represent a thread that is waiting for a shared + // lock or an exclusive lock. + // + // `cancel` is the function that is called when a `WaitFuture` (returned by the `wait()` + // function) is dropped before it can complete. `cancel_data` is used as the first parameter of + // the `cancel` function. The second parameter is the `Waiter` that was canceled and the third + // parameter indicates whether the `WaitFuture` was dropped after it was woken (but before it + // was polled to completion). A value of `false` for the third parameter may already be stale + // by the time the cancel function runs and so does not guarantee that the waiter was not woken. + // In this case, implementations should still check if the Waiter was woken. However, a value of + // `true` guarantees that the waiter was already woken up so no additional checks are necessary. + // In this case, the cancel implementation should wake up the next waiter in its wait list, if + // any. + // + // `waiting_for` indicates the waiter list to which this `Waiter` will be added. See the + // documentation of the `WaitingFor` enum for the meaning of the different values. + pub fn new( + kind: Kind, + cancel: fn(usize, &Waiter, bool), + cancel_data: usize, + waiting_for: WaitingFor, + ) -> Waiter { + Waiter { + link: AtomicLink::new(), + state: SpinLock::new(State::Init), + cancel, + cancel_data, + kind, + waiting_for: AtomicU8::new(waiting_for as u8), + } + } + + // The kind of lock that this `Waiter` is waiting to acquire. + pub fn kind(&self) -> Kind { + self.kind + } + + // Returns true if this `Waiter` is currently linked into a waiter list. + pub fn is_linked(&self) -> bool { + self.link.is_linked() + } + + // Indicates the waiter list to which this `Waiter` belongs. + pub fn is_waiting_for(&self) -> WaitingFor { + match self.waiting_for.load(Ordering::Acquire) { + 0 => WaitingFor::None, + 1 => WaitingFor::Mutex, + 2 => WaitingFor::Condvar, + v => panic!("Unknown value for `WaitingFor`: {}", v), + } + } + + // Change the waiter list to which this `Waiter` belongs. This will panic if called when the + // `Waiter` is still linked into a waiter list. + pub fn set_waiting_for(&self, waiting_for: WaitingFor) { + self.waiting_for.store(waiting_for as u8, Ordering::Release); + } + + // Reset the Waiter back to its initial state. Panics if this `Waiter` is still linked into a + // waiter list. + pub fn reset(&self, waiting_for: WaitingFor) { + debug_assert!(!self.is_linked(), "Cannot reset `Waiter` while linked"); + self.set_waiting_for(waiting_for); + + let mut state = self.state.lock(); + if let State::Waiting(waker) = mem::replace(&mut *state, State::Init) { + mem::drop(state); + mem::drop(waker); + } + } + + // Wait until woken up by another thread. + pub fn wait(&self) -> WaitFuture<'_> { + WaitFuture { waiter: self } + } + + // Wake up the thread associated with this `Waiter`. Panics if `waiting_for()` does not return + // `WaitingFor::None` or if `is_linked()` returns true. + pub fn wake(&self) { + debug_assert!(!self.is_linked(), "Cannot wake `Waiter` while linked"); + debug_assert_eq!(self.is_waiting_for(), WaitingFor::None); + + let mut state = self.state.lock(); + + if let State::Waiting(waker) = mem::replace(&mut *state, State::Woken) { + mem::drop(state); + waker.wake(); + } + } +} + +pub struct WaitFuture<'w> { + waiter: &'w Waiter, +} + +impl<'w> Future for WaitFuture<'w> { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> { + let mut state = self.waiter.state.lock(); + + match mem::replace(&mut *state, State::Processing) { + State::Init => { + *state = State::Waiting(cx.waker().clone()); + + Poll::Pending + } + State::Waiting(old_waker) => { + *state = State::Waiting(cx.waker().clone()); + mem::drop(state); + mem::drop(old_waker); + + Poll::Pending + } + State::Woken => { + *state = State::Finished; + Poll::Ready(()) + } + State::Finished => { + panic!("Future polled after returning Poll::Ready"); + } + State::Processing => { + panic!("Unexpected waker state"); + } + } + } +} + +impl<'w> Drop for WaitFuture<'w> { + fn drop(&mut self) { + let state = self.waiter.state.lock(); + + match *state { + State::Finished => {} + State::Processing => panic!("Unexpected waker state"), + State::Woken => { + mem::drop(state); + + // We were woken but not polled. Wake up the next waiter. + (self.waiter.cancel)(self.waiter.cancel_data, self.waiter, true); + } + _ => { + mem::drop(state); + + // Not woken. No need to wake up any waiters. + (self.waiter.cancel)(self.waiter.cancel_data, self.waiter, false); + } + } + } +} + +intrusive_adapter!(pub WaiterAdapter = Arc<Waiter>: Waiter { link: AtomicLink }); + +pub type WaiterList = LinkedList<WaiterAdapter>; diff --git a/cros_async/src/uring_executor.rs b/cros_async/src/uring_executor.rs index c6af38c27..39bce9581 100644 --- a/cros_async/src/uring_executor.rs +++ b/cros_async/src/uring_executor.rs @@ -307,11 +307,12 @@ impl RawExecutor { let raw = Arc::downgrade(self); let schedule = move |runnable| { if let Some(r) = raw.upgrade() { - r.queue.schedule(runnable); + r.queue.push_back(runnable); + r.wake(); } }; let (runnable, task) = async_task::spawn(f, schedule); - self.queue.schedule(runnable); + runnable.schedule(); task } @@ -323,11 +324,12 @@ impl RawExecutor { let raw = Arc::downgrade(self); let schedule = move |runnable| { if let Some(r) = raw.upgrade() { - r.queue.schedule(runnable); + r.queue.push_back(runnable); + r.wake(); } }; let (runnable, task) = async_task::spawn_local(f, schedule); - self.queue.schedule(runnable); + runnable.schedule(); task } @@ -351,7 +353,6 @@ impl RawExecutor { pin_mut!(done); loop { self.state.store(PROCESSING, Ordering::Release); - self.queue.set_waker(cx.waker().clone()); for runnable in self.queue.iter() { runnable.run(); } @@ -360,10 +361,13 @@ impl RawExecutor { return Ok(val); } - let oldstate = self - .state - .compare_and_swap(PROCESSING, WAITING, Ordering::Acquire); - if oldstate != PROCESSING { + let oldstate = self.state.compare_exchange( + PROCESSING, + WAITING, + Ordering::Acquire, + Ordering::Acquire, + ); + if let Err(oldstate) = oldstate { debug_assert_eq!(oldstate, WOKEN); // One or more futures have become runnable. continue; diff --git a/crosvm_plugin/src/lib.rs b/crosvm_plugin/src/lib.rs index 2f8f975f5..86827a681 100644 --- a/crosvm_plugin/src/lib.rs +++ b/crosvm_plugin/src/lib.rs @@ -152,8 +152,12 @@ impl IdAllocator { } fn free(&self, id: u32) { - self.0 - .compare_and_swap(id as usize + 1, id as usize, Ordering::Relaxed); + let _ = self.0.compare_exchange( + id as usize + 1, + id as usize, + Ordering::Relaxed, + Ordering::Relaxed, + ); } } diff --git a/data_model/src/endian.rs b/data_model/src/endian.rs index 686b8a182..6d3645cce 100644 --- a/data_model/src/endian.rs +++ b/data_model/src/endian.rs @@ -39,7 +39,7 @@ use crate::DataInit; macro_rules! endian_type { ($old_type:ident, $new_type:ident, $to_new:ident, $from_new:ident) => { - /// An unsigned integer type of with an explicit endianness. + /// An integer type of with an explicit endianness. /// /// See module level documentation for examples. #[derive(Copy, Clone, Eq, PartialEq, Debug, Default)] @@ -71,9 +71,9 @@ macro_rules! endian_type { } } - impl Into<$old_type> for $new_type { - fn into(self) -> $old_type { - $old_type::$from_new(self.0) + impl From<$new_type> for $old_type { + fn from(v: $new_type) -> $old_type { + $old_type::$from_new(v.0) } } @@ -104,13 +104,21 @@ macro_rules! endian_type { } endian_type!(u16, Le16, to_le, from_le); +endian_type!(i16, SLe16, to_le, from_le); endian_type!(u32, Le32, to_le, from_le); +endian_type!(i32, SLe32, to_le, from_le); endian_type!(u64, Le64, to_le, from_le); +endian_type!(i64, SLe64, to_le, from_le); endian_type!(usize, LeSize, to_le, from_le); +endian_type!(isize, SLeSize, to_le, from_le); endian_type!(u16, Be16, to_be, from_be); +endian_type!(i16, SBe16, to_be, from_be); endian_type!(u32, Be32, to_be, from_be); +endian_type!(i32, SBe32, to_be, from_be); endian_type!(u64, Be64, to_be, from_be); +endian_type!(i64, SBe64, to_be, from_be); endian_type!(usize, BeSize, to_be, from_be); +endian_type!(isize, SBeSize, to_be, from_be); #[cfg(test)] mod tests { @@ -153,11 +161,19 @@ mod tests { } endian_test!(u16, Le16, test_le16, NATIVE_LITTLE); + endian_test!(i16, SLe16, test_sle16, NATIVE_LITTLE); endian_test!(u32, Le32, test_le32, NATIVE_LITTLE); + endian_test!(i32, SLe32, test_sle32, NATIVE_LITTLE); endian_test!(u64, Le64, test_le64, NATIVE_LITTLE); + endian_test!(i64, SLe64, test_sle64, NATIVE_LITTLE); endian_test!(usize, LeSize, test_le_size, NATIVE_LITTLE); + endian_test!(isize, SLeSize, test_sle_size, NATIVE_LITTLE); endian_test!(u16, Be16, test_be16, NATIVE_BIG); + endian_test!(i16, SBe16, test_sbe16, NATIVE_BIG); endian_test!(u32, Be32, test_be32, NATIVE_BIG); + endian_test!(i32, SBe32, test_sbe32, NATIVE_BIG); endian_test!(u64, Be64, test_be64, NATIVE_BIG); + endian_test!(i64, SBe64, test_sbe64, NATIVE_BIG); endian_test!(usize, BeSize, test_be_size, NATIVE_BIG); + endian_test!(isize, SBeSize, test_sbe_size, NATIVE_BIG); } diff --git a/devices/Cargo.toml b/devices/Cargo.toml index 3438f0bdb..f77035c57 100644 --- a/devices/Cargo.toml +++ b/devices/Cargo.toml @@ -6,6 +6,7 @@ edition = "2018" [features] audio = [] +direct = [] gpu = ["gpu_display","rutabaga_gfx"] tpm = ["protos/trunks", "tpm2"] video-decoder = [] @@ -18,6 +19,7 @@ gfxstream = ["gpu", "rutabaga_gfx/gfxstream"] [dependencies] acpi_tables = {path = "../acpi_tables" } audio_streams = { path = "../../adhd/audio_streams" } # ignored by ebuild +base = { path = "../base" } bit_field = { path = "../bit_field" } cros_async = { path = "../cros_async" } data_model = { path = "../data_model" } @@ -33,8 +35,6 @@ libchromeos = { path = "../../libchromeos-rs" } # ignored by ebuild libcras = { path = "../../adhd/cras/client/libcras" } # ignored by ebuild linux_input_sys = { path = "../linux_input_sys" } minijail = { path = "../../minijail/rust/minijail" } # ignored by ebuild -msg_on_socket_derive = { path = "../msg_socket/msg_on_socket_derive" } -msg_socket = { path = "../msg_socket" } net_sys = { path = "../net_sys" } net_util = { path = "../net_util" } p9 = { path = "../../vm_tools/p9" } @@ -43,21 +43,23 @@ protos = { path = "../protos", optional = true } rand_ish = { path = "../rand_ish" } remain = "*" resources = { path = "../resources" } +serde = { version = "1", features = [ "derive" ] } +smallvec = "1.6.1" sync = { path = "../sync" } sys_util = { path = "../sys_util" } -base = { path = "../base" } -syscall_defines = { path = "../syscall_defines" } thiserror = "1.0.20" tpm2 = { path = "../tpm2", optional = true } usb_util = { path = "../usb_util" } vfio_sys = { path = "../vfio_sys" } vhost = { path = "../vhost" } +vmm_vhost = { version = "*", features = ["vhost-user-master"] } virtio_sys = { path = "../virtio_sys" } vm_control = { path = "../vm_control" } vm_memory = { path = "../vm_memory" } [dependencies.futures] version = "*" +features = ["std"] default-features = false [dev-dependencies] diff --git a/devices/src/bat.rs b/devices/src/bat.rs index 68c125a94..95f3cd70a 100644 --- a/devices/src/bat.rs +++ b/devices/src/bat.rs @@ -5,15 +5,14 @@ use crate::{BusAccessInfo, BusDevice}; use acpi_tables::{aml, aml::Aml}; use base::{ - error, warn, AsRawDescriptor, Descriptor, Event, PollToken, RawDescriptor, WaitContext, + error, warn, AsRawDescriptor, Descriptor, Event, PollToken, RawDescriptor, Tube, WaitContext, }; -use msg_socket::{MsgReceiver, MsgSender}; use power_monitor::{BatteryStatus, CreatePowerMonitorFn}; use std::fmt::{self, Display}; use std::sync::Arc; use std::thread; use sync::Mutex; -use vm_control::{BatControlCommand, BatControlResponseSocket, BatControlResult}; +use vm_control::{BatControlCommand, BatControlResult}; /// Errors for battery devices. #[derive(Debug)] @@ -106,7 +105,7 @@ pub struct GoldfishBattery { activated: bool, monitor_thread: Option<thread::JoinHandle<()>>, kill_evt: Option<Event>, - socket: Option<BatControlResponseSocket>, + tube: Option<Tube>, create_power_monitor: Option<Box<dyn CreatePowerMonitorFn>>, } @@ -143,7 +142,7 @@ const BATTERY_STATUS_VAL_NOT_CHARGING: u32 = 3; const BATTERY_HEALTH_VAL_UNKNOWN: u32 = 0; fn command_monitor( - socket: BatControlResponseSocket, + tube: Tube, irq_evt: Event, irq_resample_evt: Event, kill_evt: Event, @@ -151,7 +150,7 @@ fn command_monitor( create_power_monitor: Option<Box<dyn CreatePowerMonitorFn>>, ) { let wait_ctx: WaitContext<Token> = match WaitContext::build_with(&[ - (&Descriptor(socket.as_raw_descriptor()), Token::Commands), + (&Descriptor(tube.as_raw_descriptor()), Token::Commands), ( &Descriptor(irq_resample_evt.as_raw_descriptor()), Token::Resample, @@ -202,7 +201,7 @@ fn command_monitor( for event in events.iter().filter(|e| e.is_readable) { match event.token { Token::Commands => { - let req = match socket.recv() { + let req = match tube.recv() { Ok(req) => req, Err(e) => { error!("failed to receive request: {}", e); @@ -232,7 +231,7 @@ fn command_monitor( let _ = irq_evt.write(1); } - if let Err(e) = socket.send(&BatControlResult::Ok) { + if let Err(e) = tube.send(&BatControlResult::Ok) { error!("failed to send response: {}", e); } } @@ -309,7 +308,7 @@ impl GoldfishBattery { irq_num: u32, irq_evt: Event, irq_resample_evt: Event, - socket: BatControlResponseSocket, + tube: Tube, create_power_monitor: Option<Box<dyn CreatePowerMonitorFn>>, ) -> Result<Self> { if mmio_base + GOLDFISHBAT_MMIO_LEN - 1 > u32::MAX as u64 { @@ -338,7 +337,7 @@ impl GoldfishBattery { activated: false, monitor_thread: None, kill_evt: None, - socket: Some(socket), + tube: Some(tube), create_power_monitor, }) } @@ -350,8 +349,8 @@ impl GoldfishBattery { self.irq_resample_evt.as_raw_descriptor(), ]; - if let Some(socket) = &self.socket { - rds.push(socket.as_raw_descriptor()); + if let Some(tube) = &self.tube { + rds.push(tube.as_raw_descriptor()); } rds @@ -375,7 +374,7 @@ impl GoldfishBattery { } }; - if let Some(socket) = self.socket.take() { + if let Some(tube) = self.tube.take() { let irq_evt = self.irq_evt.try_clone().unwrap(); let irq_resample_evt = self.irq_resample_evt.try_clone().unwrap(); let bat_state = self.state.clone(); @@ -385,7 +384,7 @@ impl GoldfishBattery { .name(self.debug_label()) .spawn(move || { command_monitor( - socket, + tube, irq_evt, irq_resample_evt, kill_evt, diff --git a/devices/src/bus.rs b/devices/src/bus.rs index 73093665b..90360cd9d 100644 --- a/devices/src/bus.rs +++ b/devices/src/bus.rs @@ -10,12 +10,11 @@ use std::fmt::{self, Display}; use std::result; use std::sync::Arc; -use base::RawDescriptor; -use msg_socket::MsgOnSocket; +use serde::{Deserialize, Serialize}; use sync::Mutex; /// Information about how a device was accessed. -#[derive(Copy, Clone, Eq, PartialEq, Debug, MsgOnSocket)] +#[derive(Copy, Clone, Eq, PartialEq, Debug, Serialize, Deserialize)] pub struct BusAccessInfo { /// Offset from base address that the device was accessed at. pub offset: u64, diff --git a/devices/src/direct_io.rs b/devices/src/direct_io.rs new file mode 100644 index 000000000..2f80d7eee --- /dev/null +++ b/devices/src/direct_io.rs @@ -0,0 +1,68 @@ +// Copyright 2021 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +use crate::{BusAccessInfo, BusDevice, BusDeviceSync}; +use std::fs::{File, OpenOptions}; +use std::io; +use std::os::unix::prelude::FileExt; +use std::path::Path; +use std::sync::Mutex; + +pub struct DirectIo { + dev: Mutex<File>, + read_only: bool, +} + +impl DirectIo { + /// Create simple direct I/O access device. + pub fn new(path: &Path, read_only: bool) -> Result<Self, io::Error> { + let dev = OpenOptions::new().read(true).write(!read_only).open(path)?; + Ok(DirectIo { + dev: Mutex::new(dev), + read_only, + }) + } + + fn iowr(&self, port: u64, data: &[u8]) { + if !self.read_only { + if let Ok(ref mut dev) = self.dev.lock() { + let _ = dev.write_all_at(data, port); + } + } + } + + fn iord(&self, port: u64, data: &mut [u8]) { + if let Ok(ref mut dev) = self.dev.lock() { + let _ = dev.read_exact_at(data, port); + } + } +} + +impl BusDevice for DirectIo { + fn debug_label(&self) -> String { + "direct-io".to_string() + } + + /// Reads at `offset` from this device + fn read(&mut self, ai: BusAccessInfo, data: &mut [u8]) { + self.iord(ai.address, data); + } + + /// Writes at `offset` into this device + fn write(&mut self, ai: BusAccessInfo, data: &[u8]) { + self.iowr(ai.address, data); + } +} + +impl BusDeviceSync for DirectIo { + /// Reads at `offset` from this device + fn read(&self, ai: BusAccessInfo, data: &mut [u8]) { + self.iord(ai.address, data); + } + + /// Writes at `offset` into this device + fn write(&self, ai: BusAccessInfo, data: &[u8]) { + self.iowr(ai.address, data); + } +} diff --git a/devices/src/direct_irq.rs b/devices/src/direct_irq.rs new file mode 100644 index 000000000..fe44c4c4d --- /dev/null +++ b/devices/src/direct_irq.rs @@ -0,0 +1,115 @@ +// Copyright 2021 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +use base::{ioctl_with_ref, AsRawDescriptor, Event, RawDescriptor}; +use data_model::vec_with_array_field; +use std::fmt; +use std::fs::{File, OpenOptions}; +use std::io; +use std::mem::size_of; + +use vfio_sys::*; + +#[derive(Debug)] +pub enum DirectIrqError { + Open(io::Error), + Enable, +} + +impl fmt::Display for DirectIrqError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + DirectIrqError::Open(e) => write!(f, "failed to open /dev/plat-irq-forward: {}", e), + DirectIrqError::Enable => write!(f, "failed to enable direct irq"), + } + } +} + +pub struct DirectIrq { + dev: File, + trigger: Event, + resample: Option<Event>, +} + +impl DirectIrq { + /// Create DirectIrq object to access hardware triggered interrupts. + pub fn new(trigger: Event, resample: Option<Event>) -> Result<Self, DirectIrqError> { + let dev = OpenOptions::new() + .read(true) + .write(true) + .open("/dev/plat-irq-forward") + .map_err(DirectIrqError::Open)?; + Ok(DirectIrq { + dev, + trigger, + resample, + }) + } + + /// Enable hardware triggered interrupt handling. + /// + /// Note: this feature is not part of VFIO, but provides + /// missing IRQ forwarding functionality. + /// + /// # Arguments + /// + /// * `irq_num` - host interrupt number (GSI). + /// + pub fn irq_enable(&self, irq_num: u32) -> Result<(), DirectIrqError> { + if let Some(resample) = &self.resample { + self.plat_irq_ioctl( + irq_num, + PLAT_IRQ_FORWARD_SET_LEVEL_TRIGGER_EVENTFD, + self.trigger.as_raw_descriptor(), + )?; + self.plat_irq_ioctl( + irq_num, + PLAT_IRQ_FORWARD_SET_LEVEL_UNMASK_EVENTFD, + resample.as_raw_descriptor(), + )?; + } else { + self.plat_irq_ioctl( + irq_num, + PLAT_IRQ_FORWARD_SET_EDGE_TRIGGER, + self.trigger.as_raw_descriptor(), + )?; + }; + + Ok(()) + } + + fn plat_irq_ioctl( + &self, + irq_num: u32, + action: u32, + fd: RawDescriptor, + ) -> Result<(), DirectIrqError> { + let count = 1; + let u32_size = size_of::<u32>(); + let mut irq_set = vec_with_array_field::<plat_irq_forward_set, u32>(count); + irq_set[0].argsz = (size_of::<plat_irq_forward_set>() + count * u32_size) as u32; + irq_set[0].action_flags = action; + irq_set[0].count = count as u32; + irq_set[0].irq_number_host = irq_num; + // Safe as we are the owner of irq_set and allocation provides enough space for + // eventfd array. + let data = unsafe { irq_set[0].eventfd.as_mut_slice(count * u32_size) }; + let (left, _right) = data.split_at_mut(u32_size); + left.copy_from_slice(&fd.to_ne_bytes()[..]); + + // Safe as we are the owner of plat_irq_forward and irq_set which are valid value + let ret = unsafe { ioctl_with_ref(self, PLAT_IRQ_FORWARD_SET(), &irq_set[0]) }; + if ret < 0 { + Err(DirectIrqError::Enable) + } else { + Ok(()) + } + } +} + +impl AsRawDescriptor for DirectIrq { + fn as_raw_descriptor(&self) -> i32 { + self.dev.as_raw_descriptor() + } +} diff --git a/devices/src/irqchip/ioapic.rs b/devices/src/irqchip/ioapic.rs index 2062bac10..a2117489f 100644 --- a/devices/src/irqchip/ioapic.rs +++ b/devices/src/irqchip/ioapic.rs @@ -2,20 +2,24 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -// Implementation of an intel 82093AA Input/Output Advanced Programmable Interrupt Controller -// See https://pdos.csail.mit.edu/6.828/2016/readings/ia32/ioapic.pdf for a specification. +// Implementation of an Intel ICH10 Input/Output Advanced Programmable Interrupt Controller +// See https://www.intel.com/content/dam/doc/datasheet/io-controller-hub-10-family-datasheet.pdf +// for a specification. use std::fmt::{self, Display}; use super::IrqEvent; use crate::bus::BusAccessInfo; use crate::BusDevice; -use base::{error, warn, AsRawDescriptor, Error, Event, Result}; -use hypervisor::{IoapicState, MsiAddressMessage, MsiDataMessage, TriggerMode, NUM_IOAPIC_PINS}; -use msg_socket::{MsgError, MsgReceiver, MsgSender}; -use vm_control::{MaybeOwnedDescriptor, VmIrqRequest, VmIrqRequestSocket, VmIrqResponse}; - -const IOAPIC_VERSION_ID: u32 = 0x00170011; +use base::{error, warn, Error, Event, Result, Tube, TubeError}; +use hypervisor::{ + IoapicRedirectionTableEntry, IoapicState, MsiAddressMessage, MsiDataMessage, TriggerMode, + NUM_IOAPIC_PINS, +}; +use vm_control::{VmIrqRequest, VmIrqResponse}; + +// ICH10 I/O APIC version: 0x20 +const IOAPIC_VERSION_ID: u32 = 0x00000020; pub const IOAPIC_BASE_ADDRESS: u64 = 0xfec00000; // The Intel manual does not specify this size, but KVM uses it. pub const IOAPIC_MEM_LENGTH_BYTES: u64 = 0x100; @@ -29,6 +33,8 @@ const IOAPIC_REG_ARBITRATION_ID: u8 = 0x02; const IOREGSEL_OFF: u8 = 0x0; const IOREGSEL_DUMMY_UPPER_32_BITS_OFF: u8 = 0x4; const IOWIN_OFF: u8 = 0x10; +const IOEOIR_OFF: u8 = 0x40; + const IOWIN_SCALE: u8 = 0x2; /// Given an IRQ and whether or not the selector should refer to the high bits, return a selector @@ -57,8 +63,12 @@ fn decode_irq_from_selector(selector: u8) -> (usize, bool) { const RTC_IRQ: usize = 0x8; pub struct Ioapic { - /// State of the ioapic registers - state: IoapicState, + /// Number of supported IO-APIC inputs / redirection entries. + num_pins: usize, + /// ioregsel register. Used for selecting which entry of the redirect table to read/write. + ioregsel: u8, + /// ioapicid register. Bits 24 - 27 contain the APIC ID for this device. + ioapicid: u32, /// Remote IRR for Edge Triggered Real Time Clock interrupts, which allows the CMOS to know when /// one of its interrupts is being coalesced. rtc_remote_irr: bool, @@ -67,8 +77,12 @@ pub struct Ioapic { /// Events that should be triggered on an EOI. The outer Vec is indexed by GSI, and the inner /// Vec is an unordered list of registered resample events for the GSI. resample_events: Vec<Vec<Event>>, - /// Socket used to route MSI irqs - irq_socket: VmIrqRequestSocket, + /// Redirection settings for each irq line. + redirect_table: Vec<IoapicRedirectionTableEntry>, + /// Interrupt activation state. + interrupt_level: Vec<bool>, + /// Tube used to route MSI irqs. + irq_tube: Tube, } impl BusDevice for Ioapic { @@ -85,9 +99,10 @@ impl BusDevice for Ioapic { warn!("IOAPIC: Bad read from {}", info); } let out = match info.offset as u8 { - IOREGSEL_OFF => self.state.ioregsel.into(), + IOREGSEL_OFF => self.ioregsel.into(), IOREGSEL_DUMMY_UPPER_32_BITS_OFF => 0, IOWIN_OFF => self.ioapic_read(), + IOEOIR_OFF => 0, _ => { warn!("IOAPIC: Bad read from {}", info); return; @@ -110,7 +125,7 @@ impl BusDevice for Ioapic { warn!("IOAPIC: Bad write to {}", info); } match info.offset as u8 { - IOREGSEL_OFF => self.state.ioregsel = data[0], + IOREGSEL_OFF => self.ioregsel = data[0], IOREGSEL_DUMMY_UPPER_32_BITS_OFF => {} // Ignored. IOWIN_OFF => { if data.len() != 4 { @@ -121,6 +136,7 @@ impl BusDevice for Ioapic { let val = u32::from_ne_bytes(data_arr); self.ioapic_write(val); } + IOEOIR_OFF => self.end_of_interrupt(data[0]), _ => { warn!("IOAPIC: Bad write to {}", info); } @@ -129,28 +145,66 @@ impl BusDevice for Ioapic { } impl Ioapic { - pub fn new(irq_socket: VmIrqRequestSocket) -> Result<Ioapic> { - let mut state = IoapicState::default(); - - for i in 0..NUM_IOAPIC_PINS { - state.redirect_table[i].set_interrupt_mask(true); - } - + pub fn new(irq_tube: Tube, num_pins: usize) -> Result<Ioapic> { + let num_pins = num_pins.max(NUM_IOAPIC_PINS as usize); + let mut entry = IoapicRedirectionTableEntry::new(); + entry.set_interrupt_mask(true); Ok(Ioapic { - state, + num_pins, + ioregsel: 0, + ioapicid: 0, rtc_remote_irr: false, - out_events: (0..NUM_IOAPIC_PINS).map(|_| None).collect(), + out_events: (0..num_pins).map(|_| None).collect(), resample_events: Vec::new(), - irq_socket, + redirect_table: (0..num_pins).map(|_| entry.clone()).collect(), + interrupt_level: (0..num_pins).map(|_| false).collect(), + irq_tube, }) } pub fn get_ioapic_state(&self) -> IoapicState { - self.state + // Convert vector of first NUM_IOAPIC_PINS active interrupts into an u32 value. + let level_bitmap = self + .interrupt_level + .iter() + .take(NUM_IOAPIC_PINS) + .rev() + .fold(0, |acc, &l| acc * 2 + l as u32); + let mut state = IoapicState { + base_address: IOAPIC_BASE_ADDRESS, + ioregsel: self.ioregsel, + ioapicid: self.ioapicid, + current_interrupt_level_bitmap: level_bitmap, + ..Default::default() + }; + for (dst, src) in state + .redirect_table + .iter_mut() + .zip(self.redirect_table.iter()) + { + *dst = *src; + } + state } pub fn set_ioapic_state(&mut self, state: &IoapicState) { - self.state = *state + self.ioregsel = state.ioregsel; + self.ioapicid = state.ioapicid & 0x0f00_0000; + for (src, dst) in state + .redirect_table + .iter() + .zip(self.redirect_table.iter_mut()) + { + *dst = *src; + } + for (i, level) in self + .interrupt_level + .iter_mut() + .take(NUM_IOAPIC_PINS) + .enumerate() + { + *level = state.current_interrupt_level_bitmap & (1 << i) != 0; + } } pub fn register_resample_events(&mut self, resample_events: Vec<Vec<Event>>) { @@ -160,14 +214,14 @@ impl Ioapic { // The ioapic must be informed about EOIs in order to avoid sending multiple interrupts of the // same type at the same time. pub fn end_of_interrupt(&mut self, vector: u8) { - if self.state.redirect_table[RTC_IRQ].get_vector() == vector && self.rtc_remote_irr { + if self.redirect_table[RTC_IRQ].get_vector() == vector && self.rtc_remote_irr { // Specifically clear RTC IRQ field self.rtc_remote_irr = false; } - for i in 0..NUM_IOAPIC_PINS { - if self.state.redirect_table[i].get_vector() == vector - && self.state.redirect_table[i].get_trigger_mode() == TriggerMode::Level + for i in 0..self.num_pins { + if self.redirect_table[i].get_vector() == vector + && self.redirect_table[i].get_trigger_mode() == TriggerMode::Level { if self .resample_events @@ -182,34 +236,32 @@ impl Ioapic { resample_evt.write(1).unwrap(); } } - self.state.redirect_table[i].set_remote_irr(false); + self.redirect_table[i].set_remote_irr(false); } // There is an inherent race condition in hardware if the OS is finished processing an // interrupt and a new interrupt is delivered between issuing an EOI and the EOI being // completed. When that happens the ioapic is supposed to re-inject the interrupt. - if self.state.current_interrupt_level_bitmap & (1 << i) != 0 { + if self.interrupt_level[i] { self.service_irq(i, true); } } } pub fn service_irq(&mut self, irq: usize, level: bool) -> bool { - let entry = &mut self.state.redirect_table[irq]; + let entry = &mut self.redirect_table[irq]; // De-assert the interrupt. if !level { - self.state.current_interrupt_level_bitmap &= !(1 << irq); + self.interrupt_level[irq] = false; return true; } // If it's an edge-triggered interrupt that's already high we ignore it. - if entry.get_trigger_mode() == TriggerMode::Edge - && self.state.current_interrupt_level_bitmap & (1 << irq) != 0 - { + if entry.get_trigger_mode() == TriggerMode::Edge && self.interrupt_level[irq] { return false; } - self.state.current_interrupt_level_bitmap |= 1 << irq; + self.interrupt_level[irq] = true; // Interrupts are masked, so don't inject. if entry.get_interrupt_mask() { @@ -242,22 +294,22 @@ impl Ioapic { } fn ioapic_write(&mut self, val: u32) { - match self.state.ioregsel { + match self.ioregsel { IOAPIC_REG_VERSION => { /* read-only register */ } - IOAPIC_REG_ID => self.state.ioapicid = (val >> 24) & 0xf, + IOAPIC_REG_ID => self.ioapicid = val & 0x0f00_0000, IOAPIC_REG_ARBITRATION_ID => { /* read-only register */ } _ => { - if self.state.ioregsel < IOWIN_OFF { + if self.ioregsel < IOWIN_OFF { // Invalid write; ignore. return; } - let (index, is_high_bits) = decode_irq_from_selector(self.state.ioregsel); - if index >= hypervisor::NUM_IOAPIC_PINS { + let (index, is_high_bits) = decode_irq_from_selector(self.ioregsel); + if index >= self.num_pins { // Invalid write; ignore. return; } - let entry = &mut self.state.redirect_table[index]; + let entry = &mut self.redirect_table[index]; if is_high_bits { entry.set(32, 32, val.into()); } else { @@ -278,16 +330,16 @@ impl Ioapic { // is the fix for this. } - if self.state.redirect_table[index].get_trigger_mode() == TriggerMode::Level - && self.state.current_interrupt_level_bitmap & (1 << index) != 0 - && !self.state.redirect_table[index].get_interrupt_mask() + if self.redirect_table[index].get_trigger_mode() == TriggerMode::Level + && self.interrupt_level[index] + && !self.redirect_table[index].get_interrupt_mask() { self.service_irq(index, true); } let mut address = MsiAddressMessage::new(); let mut data = MsiDataMessage::new(); - let entry = &self.state.redirect_table[index]; + let entry = &self.redirect_table[index]; address.set_destination_mode(entry.get_dest_mode()); address.set_destination_id(entry.get_dest_id()); address.set_always_0xfee(0xfee); @@ -329,21 +381,22 @@ impl Ioapic { evt.gsi } else { let event = Event::new().map_err(IoapicError::CreateEvent)?; - let request = VmIrqRequest::AllocateOneMsi { - irqfd: MaybeOwnedDescriptor::Borrowed(event.as_raw_descriptor()), - }; - self.irq_socket + let request = VmIrqRequest::AllocateOneMsi { irqfd: event }; + self.irq_tube .send(&request) .map_err(IoapicError::AllocateOneMsiSend)?; match self - .irq_socket + .irq_tube .recv() .map_err(IoapicError::AllocateOneMsiRecv)? { VmIrqResponse::AllocateOneMsi { gsi, .. } => { self.out_events[index] = Some(IrqEvent { gsi, - event, + event: match request { + VmIrqRequest::AllocateOneMsi { irqfd } => irqfd, + _ => unreachable!(), + }, resample_event: None, }); gsi @@ -361,32 +414,28 @@ impl Ioapic { msi_address, msi_data, }; - self.irq_socket + self.irq_tube .send(&request) .map_err(IoapicError::AddMsiRouteSend)?; - if let VmIrqResponse::Err(e) = self - .irq_socket - .recv() - .map_err(IoapicError::AddMsiRouteRecv)? - { + if let VmIrqResponse::Err(e) = self.irq_tube.recv().map_err(IoapicError::AddMsiRouteRecv)? { return Err(IoapicError::AddMsiRoute(e)); } Ok(()) } fn ioapic_read(&mut self) -> u32 { - match self.state.ioregsel { - IOAPIC_REG_VERSION => IOAPIC_VERSION_ID, - IOAPIC_REG_ID | IOAPIC_REG_ARBITRATION_ID => (self.state.ioapicid & 0xf) << 24, + match self.ioregsel { + IOAPIC_REG_VERSION => ((self.num_pins - 1) as u32) << 16 | IOAPIC_VERSION_ID, + IOAPIC_REG_ID | IOAPIC_REG_ARBITRATION_ID => self.ioapicid, _ => { - if self.state.ioregsel < IOWIN_OFF { + if self.ioregsel < IOWIN_OFF { // Invalid read; ignore and return 0. 0 } else { - let (index, is_high_bits) = decode_irq_from_selector(self.state.ioregsel); - if index < NUM_IOAPIC_PINS { + let (index, is_high_bits) = decode_irq_from_selector(self.ioregsel); + if index < self.num_pins { let offset = if is_high_bits { 32 } else { 0 }; - self.state.redirect_table[index].get(offset, 32) as u32 + self.redirect_table[index].get(offset, 32) as u32 } else { !0 // Invalid index - return all 1s } @@ -399,11 +448,11 @@ impl Ioapic { #[derive(Debug)] enum IoapicError { AddMsiRoute(Error), - AddMsiRouteRecv(MsgError), - AddMsiRouteSend(MsgError), + AddMsiRouteRecv(TubeError), + AddMsiRouteSend(TubeError), AllocateOneMsi(Error), - AllocateOneMsiRecv(MsgError), - AllocateOneMsiSend(MsgError), + AllocateOneMsiRecv(TubeError), + AllocateOneMsiSend(TubeError), CreateEvent(Error), } @@ -428,14 +477,14 @@ impl Display for IoapicError { #[cfg(test)] mod tests { use super::*; - use hypervisor::{DeliveryMode, DeliveryStatus, DestinationMode, IoapicRedirectionTableEntry}; + use hypervisor::{DeliveryMode, DeliveryStatus, DestinationMode}; const DEFAULT_VECTOR: u8 = 0x3a; const DEFAULT_DESTINATION_ID: u8 = 0x5f; fn new() -> Ioapic { - let (_, device_socket) = msg_socket::pair::<VmIrqResponse, VmIrqRequest>().unwrap(); - Ioapic::new(device_socket).unwrap() + let (_, irq_tube) = Tube::pair().unwrap(); + Ioapic::new(irq_tube, NUM_IOAPIC_PINS).unwrap() } fn ioapic_bus_address(offset: u8) -> BusAccessInfo { @@ -739,7 +788,7 @@ mod tests { fn remote_irr_read_only() { let (mut ioapic, irq) = set_up(TriggerMode::Level); - ioapic.state.redirect_table[irq].set_remote_irr(true); + ioapic.redirect_table[irq].set_remote_irr(true); let mut entry = read_entry(&mut ioapic, irq); entry.set_remote_irr(false); @@ -752,7 +801,7 @@ mod tests { fn delivery_status_read_only() { let (mut ioapic, irq) = set_up(TriggerMode::Level); - ioapic.state.redirect_table[irq].set_delivery_status(DeliveryStatus::Pending); + ioapic.redirect_table[irq].set_delivery_status(DeliveryStatus::Pending); let mut entry = read_entry(&mut ioapic, irq); entry.set_delivery_status(DeliveryStatus::Idle); @@ -768,7 +817,7 @@ mod tests { fn level_to_edge_transition_clears_remote_irr() { let (mut ioapic, irq) = set_up(TriggerMode::Level); - ioapic.state.redirect_table[irq].set_remote_irr(true); + ioapic.redirect_table[irq].set_remote_irr(true); let mut entry = read_entry(&mut ioapic, irq); entry.set_trigger_mode(TriggerMode::Edge); @@ -781,7 +830,7 @@ mod tests { fn masking_preserves_remote_irr() { let (mut ioapic, irq) = set_up(TriggerMode::Level); - ioapic.state.redirect_table[irq].set_remote_irr(true); + ioapic.redirect_table[irq].set_remote_irr(true); set_mask(&mut ioapic, irq, true); set_mask(&mut ioapic, irq, false); diff --git a/devices/src/irqchip/kvm/x86_64.rs b/devices/src/irqchip/kvm/x86_64.rs index 8b6d54627..613e1c4fd 100644 --- a/devices/src/irqchip/kvm/x86_64.rs +++ b/devices/src/irqchip/kvm/x86_64.rs @@ -18,8 +18,7 @@ use hypervisor::{ use kvm_sys::*; use resources::SystemAllocator; -use base::{error, Error, Event, Result}; -use vm_control::VmIrqRequestSocket; +use base::{error, Error, Event, Result, Tube}; use crate::irqchip::{ Ioapic, IrqEvent, IrqEventIndex, Pic, VcpuRunState, IOAPIC_BASE_ADDRESS, @@ -27,12 +26,12 @@ use crate::irqchip::{ }; use crate::{Bus, IrqChip, IrqChipCap, IrqChipX86_64, Pit, PitError}; -/// PIT channel 0 timer is connected to IRQ 0 +/// PIT tube 0 timer is connected to IRQ 0 const PIT_CHANNEL0_IRQ: u32 = 0; /// Default x86 routing table. Pins 0-7 go to primary pic and ioapic, pins 8-15 go to secondary /// pic and ioapic, and pins 16-23 go only to the ioapic. -fn kvm_default_irq_routing_table() -> Vec<IrqRoute> { +fn kvm_default_irq_routing_table(ioapic_pins: usize) -> Vec<IrqRoute> { let mut routes: Vec<IrqRoute> = Vec::new(); for i in 0..8 { @@ -43,7 +42,7 @@ fn kvm_default_irq_routing_table() -> Vec<IrqRoute> { routes.push(IrqRoute::pic_irq_route(IrqSourceChip::PicSecondary, i)); routes.push(IrqRoute::ioapic_irq_route(i)); } - for i in 16..NUM_IOAPIC_PINS as u32 { + for i in 16..ioapic_pins as u32 { routes.push(IrqRoute::ioapic_irq_route(i)); } @@ -68,7 +67,7 @@ impl KvmKernelIrqChip { Ok(KvmKernelIrqChip { vm, vcpus: Arc::new(Mutex::new((0..num_vcpus).map(|_| None).collect())), - routes: Arc::new(Mutex::new(kvm_default_irq_routing_table())), + routes: Arc::new(Mutex::new(kvm_default_irq_routing_table(NUM_IOAPIC_PINS))), }) } /// Attempt to create a shallow clone of this x86_64 KvmKernelIrqChip instance. @@ -146,6 +145,7 @@ pub struct KvmSplitIrqChip { pit: Arc<Mutex<Pit>>, pic: Arc<Mutex<Pic>>, ioapic: Arc<Mutex<Ioapic>>, + ioapic_pins: usize, /// Vec of ioapic irq events that have been delayed because the ioapic was locked when /// service_irq was called on the irqchip. This prevents deadlocks when a Vcpu thread has /// locked the ioapic and the ioapic sends a AddMsiRoute signal to the main thread (which @@ -155,9 +155,9 @@ pub struct KvmSplitIrqChip { irq_events: Arc<Mutex<Vec<Option<IrqEvent>>>>, } -fn kvm_dummy_msi_routes() -> Vec<IrqRoute> { +fn kvm_dummy_msi_routes(ioapic_pins: usize) -> Vec<IrqRoute> { let mut routes: Vec<IrqRoute> = Vec::new(); - for i in 0..NUM_IOAPIC_PINS { + for i in 0..ioapic_pins { routes.push( // Add dummy MSI routes to replace the default IRQChip routes. IrqRoute { @@ -174,9 +174,14 @@ fn kvm_dummy_msi_routes() -> Vec<IrqRoute> { impl KvmSplitIrqChip { /// Construct a new KvmSplitIrqChip. - pub fn new(vm: KvmVm, num_vcpus: usize, irq_socket: VmIrqRequestSocket) -> Result<Self> { - vm.enable_split_irqchip()?; - + pub fn new( + vm: KvmVm, + num_vcpus: usize, + irq_tube: Tube, + ioapic_pins: Option<usize>, + ) -> Result<Self> { + let ioapic_pins = ioapic_pins.unwrap_or(hypervisor::NUM_IOAPIC_PINS); + vm.enable_split_irqchip(ioapic_pins)?; let pit_evt = Event::new()?; let pit = Arc::new(Mutex::new( Pit::new(pit_evt.try_clone()?, Arc::new(Mutex::new(Clock::new()))).map_err( @@ -197,15 +202,16 @@ impl KvmSplitIrqChip { routes: Arc::new(Mutex::new(Vec::new())), pit, pic: Arc::new(Mutex::new(Pic::new())), - ioapic: Arc::new(Mutex::new(Ioapic::new(irq_socket)?)), + ioapic: Arc::new(Mutex::new(Ioapic::new(irq_tube, ioapic_pins)?)), + ioapic_pins, delayed_ioapic_irq_events: Arc::new(Mutex::new(Vec::new())), irq_events: Arc::new(Mutex::new(Default::default())), }; // Setup standard x86 irq routes - let mut routes = kvm_default_irq_routing_table(); - // Add dummy MSI routes for the first 24 GSIs - routes.append(&mut kvm_dummy_msi_routes()); + let mut routes = kvm_default_irq_routing_table(ioapic_pins); + // Add dummy MSI routes for the first ioapic_pins GSIs + routes.append(&mut kvm_dummy_msi_routes(ioapic_pins)); // Set the routes so they get sent to KVM chip.set_irq_routes(&routes)?; @@ -252,16 +258,15 @@ impl KvmSplitIrqChip { /// Check if the specified vcpu has any pending interrupts. Returns None for no interrupts, /// otherwise Some(u32) should be the injected interrupt vector. For KvmSplitIrqChip /// this calls get_external_interrupt on the pic. - fn get_external_interrupt(&self, vcpu_id: usize) -> Result<Option<u32>> { + fn get_external_interrupt(&self, vcpu_id: usize) -> Option<u32> { // Pic interrupts for the split irqchip only go to vcpu 0 if vcpu_id != 0 { - return Ok(None); - } - if let Some(vector) = self.pic.lock().get_external_interrupt() { - Ok(Some(vector as u32)) - } else { - Ok(None) + return None; } + self.pic + .lock() + .get_external_interrupt() + .map(|vector| vector as u32) } } @@ -313,7 +318,7 @@ impl IrqChip for KvmSplitIrqChip { irq_event: &Event, resample_event: Option<&Event>, ) -> Result<Option<IrqEventIndex>> { - if irq < NUM_IOAPIC_PINS as u32 { + if irq < self.ioapic_pins as u32 { let mut evt = IrqEvent { gsi: irq, event: irq_event.try_clone()?, @@ -336,7 +341,7 @@ impl IrqChip for KvmSplitIrqChip { /// Unregister an event for a particular GSI. fn unregister_irq_event(&mut self, irq: u32, irq_event: &Event) -> Result<()> { - if irq < NUM_IOAPIC_PINS as u32 { + if irq < self.ioapic_pins as u32 { let mut irq_events = self.irq_events.lock(); for (index, evt) in irq_events.iter().enumerate() { if let Some(evt) = evt { @@ -470,8 +475,8 @@ impl IrqChip for KvmSplitIrqChip { return Ok(()); } - if let Some(vector) = self.get_external_interrupt(vcpu_id)? { - vcpu.interrupt(vector as u32)?; + if let Some(vector) = self.get_external_interrupt(vcpu_id) { + vcpu.interrupt(vector)?; } // The second interrupt request should be handled immediately, so ask vCPU to exit as soon as @@ -524,6 +529,7 @@ impl IrqChip for KvmSplitIrqChip { pit: self.pit.clone(), pic: self.pic.clone(), ioapic: self.ioapic.clone(), + ioapic_pins: self.ioapic_pins, delayed_ioapic_irq_events: self.delayed_ioapic_irq_events.clone(), irq_events: self.irq_events.clone(), }) @@ -558,13 +564,13 @@ impl IrqChip for KvmSplitIrqChip { // At this point, all of our devices have been created and they have registered their // irq events, so we can clone our resample events let mut ioapic_resample_events: Vec<Vec<Event>> = - (0..NUM_IOAPIC_PINS).map(|_| Vec::new()).collect(); + (0..self.ioapic_pins).map(|_| Vec::new()).collect(); let mut pic_resample_events: Vec<Vec<Event>> = - (0..NUM_IOAPIC_PINS).map(|_| Vec::new()).collect(); + (0..self.ioapic_pins).map(|_| Vec::new()).collect(); for evt in self.irq_events.lock().iter() { if let Some(evt) = evt { - if (evt.gsi as usize) >= NUM_IOAPIC_PINS { + if (evt.gsi as usize) >= self.ioapic_pins { continue; } if let Some(resample_evt) = &evt.resample_event { @@ -583,9 +589,9 @@ impl IrqChip for KvmSplitIrqChip { .lock() .register_resample_events(pic_resample_events); - // Make sure all future irq numbers are >= NUM_IOAPIC_PINS + // Make sure all future irq numbers are beyond IO-APIC range. let mut irq_num = resources.allocate_irq().unwrap(); - while irq_num < NUM_IOAPIC_PINS as u32 { + while irq_num < self.ioapic_pins as u32 { irq_num = resources.allocate_irq().unwrap(); } @@ -701,7 +707,6 @@ mod tests { use vm_memory::GuestMemory; use hypervisor::{IoapicRedirectionTableEntry, PitRWMode, TriggerMode, Vm, VmX86_64}; - use vm_control::{VmIrqRequest, VmIrqResponse}; use super::super::super::tests::*; use crate::IrqChip; @@ -728,13 +733,13 @@ mod tests { let mem = GuestMemory::new(&[]).unwrap(); let vm = KvmVm::new(&kvm, mem).expect("failed tso instantiate vm"); - let (_, device_socket) = - msg_socket::pair::<VmIrqResponse, VmIrqRequest>().expect("failed to create irq socket"); + let (_, device_tube) = Tube::pair().expect("failed to create irq tube"); let mut chip = KvmSplitIrqChip::new( vm.try_clone().expect("failed to clone vm"), 1, - device_socket, + device_tube, + None, ) .expect("failed to instantiate KvmKernelIrqChip"); @@ -940,7 +945,7 @@ mod tests { .expect("failed to get external interrupt"), // Vector is 9 because the interrupt vector base address is 0x08 and this is irq // line 1 and 8+1 = 9 - Some(0x9) + 0x9 ); assert_eq!( @@ -984,7 +989,7 @@ mod tests { assert_eq!( chip.get_external_interrupt(0) .expect("failed to get external interrupt"), - Some(0) + 0, ); // interrupt is not requested twice diff --git a/devices/src/lib.rs b/devices/src/lib.rs index 5299ef5a6..60b46934f 100644 --- a/devices/src/lib.rs +++ b/devices/src/lib.rs @@ -6,6 +6,10 @@ mod bus; mod cmos; +#[cfg(feature = "direct")] +pub mod direct_io; +#[cfg(feature = "direct")] +pub mod direct_irq; mod i8042; pub mod irqchip; mod pci; @@ -27,8 +31,12 @@ pub mod virtio; pub use self::acpi::ACPIPMResource; pub use self::bat::{BatteryError, GoldfishBattery}; pub use self::bus::Error as BusError; -pub use self::bus::{Bus, BusAccessInfo, BusDevice, BusRange, BusResumeDevice}; +pub use self::bus::{Bus, BusAccessInfo, BusDevice, BusDeviceSync, BusRange, BusResumeDevice}; pub use self::cmos::Cmos; +#[cfg(feature = "direct")] +pub use self::direct_io::DirectIo; +#[cfg(feature = "direct")] +pub use self::direct_irq::{DirectIrq, DirectIrqError}; pub use self::i8042::I8042Device; pub use self::irqchip::*; #[cfg(feature = "audio")] diff --git a/devices/src/pci/ac97.rs b/devices/src/pci/ac97.rs index b6517621d..c618b14d6 100644 --- a/devices/src/pci/ac97.rs +++ b/devices/src/pci/ac97.rs @@ -10,7 +10,7 @@ use std::str::FromStr; use audio_streams::shm_streams::{NullShmStreamSource, ShmStreamSource}; use base::{error, Event, RawDescriptor}; -use libcras::{CrasClient, CrasClientType, CrasSocketType}; +use libcras::{CrasClient, CrasClientType, CrasSocketType, CrasSysError}; use resources::{Alloc, MmioType, SystemAllocator}; use vm_memory::GuestMemory; @@ -22,9 +22,9 @@ use crate::pci::pci_configuration::{ }; use crate::pci::pci_device::{self, PciDevice, Result}; use crate::pci::{PciAddress, PciDeviceError, PciInterruptPin}; -#[cfg(not(target_os = "linux"))] +#[cfg(not(any(target_os = "linux", target_os = "android")))] use crate::virtio::snd::vios_backend::Error as VioSError; -#[cfg(target_os = "linux")] +#[cfg(any(target_os = "linux", target_os = "android"))] use crate::virtio::snd::vios_backend::VioSShmStreamSource; // Use 82801AA because it's what qemu does. @@ -84,6 +84,17 @@ pub struct Ac97Parameters { pub backend: Ac97Backend, pub capture: bool, pub vios_server_path: Option<PathBuf>, + client_type: Option<CrasClientType>, +} + +impl Ac97Parameters { + /// Set CRAS client type by given client type string. + /// + /// `client_type` - The client type string. + pub fn set_client_type(&mut self, client_type: &str) -> std::result::Result<(), CrasSysError> { + self.client_type = Some(client_type.parse()?); + Ok(()) + } } pub struct Ac97Dev { @@ -115,6 +126,7 @@ impl Ac97Dev { PciHeaderType::Device, 0x8086, // Subsystem Vendor ID 0x1, // Subsystem ID. + 0, // Revision ID. ); Self { @@ -137,10 +149,10 @@ impl Ac97Dev { "Ac97Dev: create_cras_audio_device: {}. Fallback to null audio device", e ); - Self::create_null_audio_device(mem) + Ok(Self::create_null_audio_device(mem)) }), Ac97Backend::VIOS => Self::create_vios_audio_device(mem, param), - Ac97Backend::NULL => Self::create_null_audio_device(mem), + Ac97Backend::NULL => Ok(Self::create_null_audio_device(mem)), } } @@ -158,7 +170,11 @@ impl Ac97Dev { CrasClient::with_type(CrasSocketType::Unified) .map_err(pci_device::Error::CreateCrasClientFailed)?, ); - server.set_client_type(CrasClientType::CRAS_CLIENT_TYPE_CROSVM); + server.set_client_type( + params + .client_type + .unwrap_or(CrasClientType::CRAS_CLIENT_TYPE_CROSVM), + ); if params.capture { server.enable_cras_capture(); } @@ -168,7 +184,7 @@ impl Ac97Dev { } fn create_vios_audio_device(mem: GuestMemory, param: Ac97Parameters) -> Result<Self> { - #[cfg(target_os = "linux")] + #[cfg(any(target_os = "linux", target_os = "android"))] { let server = Box::new( // The presence of vios_server_path is checked during argument parsing @@ -178,16 +194,15 @@ impl Ac97Dev { let vios_audio = Self::new(mem, Ac97Backend::VIOS, server); return Ok(vios_audio); } - #[cfg(not(target_os = "linux"))] + #[cfg(not(any(target_os = "linux", target_os = "android")))] Err(pci_device::Error::CreateViosClientFailed( VioSError::PlatformNotSupported, )) } - fn create_null_audio_device(mem: GuestMemory) -> Result<Self> { + fn create_null_audio_device(mem: GuestMemory) -> Self { let server = Box::new(NullShmStreamSource::new()); - let null_audio = Self::new(mem, Ac97Backend::NULL, server); - Ok(null_audio) + Self::new(mem, Ac97Backend::NULL, server) } fn read_mixer(&mut self, offset: u64, data: &mut [u8]) { diff --git a/devices/src/pci/ac97_bus_master.rs b/devices/src/pci/ac97_bus_master.rs index 321bf5cac..e121a6f7b 100644 --- a/devices/src/pci/ac97_bus_master.rs +++ b/devices/src/pci/ac97_bus_master.rs @@ -3,7 +3,6 @@ // found in the LICENSE file. use std::collections::VecDeque; -use std::convert::AsRef; use std::convert::TryInto; use std::fmt::{self, Display}; use std::sync::atomic::{AtomicBool, Ordering}; @@ -16,7 +15,8 @@ use audio_streams::{ BoxError, NoopStreamControl, SampleFormat, StreamControl, StreamDirection, StreamEffect, }; use base::{ - self, error, set_rt_prio_limit, set_rt_round_robin, warn, AsRawDescriptor, Event, RawDescriptor, + self, error, set_rt_prio_limit, set_rt_round_robin, warn, AsRawDescriptors, Event, + RawDescriptor, }; use sync::{Condvar, Mutex}; use vm_memory::{GuestAddress, GuestMemory}; @@ -70,22 +70,22 @@ impl Ac97BusMasterRegs { } } - fn channel_count(&self, func: Ac97Function) -> usize { - fn output_channel_count(glob_cnt: u32) -> usize { + fn tube_count(&self, func: Ac97Function) -> usize { + fn output_tube_count(glob_cnt: u32) -> usize { let val = (glob_cnt & GLOB_CNT_PCM_246_MASK) >> 20; match val { 0 => 2, 1 => 4, 2 => 6, _ => { - warn!("unknown channel_count: 0x{:x}", val); + warn!("unknown tube_count: 0x{:x}", val); 2 } } } match func { - Ac97Function::Output => output_channel_count(self.glob_cnt), + Ac97Function::Output => output_tube_count(self.glob_cnt), _ => DEVICE_INPUT_CHANNEL_COUNT, } } @@ -130,6 +130,8 @@ type GuestMemoryResult<T> = std::result::Result<T, GuestMemoryError>; enum AudioError { // Failed to create a new stream. CreateStream(BoxError), + // Failure to get regions from guest memory. + GuestRegion(GuestMemoryError), // Invalid buffer offset received from the audio server. InvalidBufferOffset, // Guest did not provide a buffer when needed. @@ -150,6 +152,7 @@ impl Display for AudioError { match self { CreateStream(e) => write!(f, "Failed to create audio stream: {}.", e), + GuestRegion(e) => write!(f, "Failed to get guest memory region: {}.", e), InvalidBufferOffset => write!(f, "Offset > max usize"), NoBufferAvailable => write!(f, "No buffer was available from the Guest"), ReadingGuestError(e) => write!(f, "Failed to read guest memory: {}.", e), @@ -255,7 +258,7 @@ impl Ac97BusMaster { /// Returns any file descriptors that need to be kept open when entering a jail. pub fn keep_rds(&self) -> Option<Vec<RawDescriptor>> { let mut rds = self.audio_server.keep_fds(); - rds.push(self.mem.as_raw_descriptor()); + rds.append(&mut self.mem.as_raw_descriptors()); Some(rds) } @@ -341,7 +344,7 @@ impl Ac97BusMaster { regs.po_regs.picb } else { // Estimate how many samples have been played since the last audio callback. - let num_channels = regs.channel_count(Ac97Function::Output) as u64; + let num_channels = regs.tube_count(Ac97Function::Output) as u64; let micros = regs.po_pointer_update_time.elapsed().subsec_micros(); // Round down to the next 10 millisecond boundary. The linux driver often // assumes that two rapid reads from picb will return the same value. @@ -562,7 +565,7 @@ impl Ac97BusMaster { let locked_regs = self.regs.lock(); let sample_rate = self.current_sample_rate(func, mixer); let buffer_samples = current_buffer_size(locked_regs.func_regs(func), &self.mem)?; - let num_channels = locked_regs.channel_count(func); + let num_channels = locked_regs.tube_count(func); let buffer_frames = buffer_samples / num_channels; let mut pending_buffers = VecDeque::with_capacity(2); @@ -588,7 +591,12 @@ impl Ac97BusMaster { sample_rate, buffer_frames, &Self::stream_effects(func), - self.mem.as_ref().inner(), + self.mem + .offset_region(starting_offsets[0]) + .map_err(|e| { + AudioError::GuestRegion(GuestMemoryError::ReadingGuestBufferAddress(e)) + })? + .inner(), starting_offsets, ) .map_err(AudioError::CreateStream)?; @@ -712,7 +720,7 @@ fn next_guest_buffer( // 0 h l n // +++++++++......++++ (low > high && (low <= value || value <= high)) - }; + } // Check if // * we're halted @@ -729,7 +737,7 @@ fn next_guest_buffer( let offset = get_buffer_offset(func_regs, mem, index)? .try_into() .map_err(|_| AudioError::InvalidBufferOffset)?; - let frames = get_buffer_samples(func_regs, mem, index)? / regs.channel_count(func); + let frames = get_buffer_samples(func_regs, mem, index)? / regs.tube_count(func); Ok(Some(GuestBuffer { index, @@ -1037,7 +1045,7 @@ mod test { } #[test] - fn run_multi_channel_playback() { + fn run_multi_tube_playback() { start_playback(2, 48000); start_playback(4, 48000); start_playback(6, 48000); @@ -1087,7 +1095,7 @@ mod test { bm.writeb(PO_LVI_15, LVI_MASK, &mixer); assert_eq!(bm.readb(PO_CIV_14), 0); - // Set channel count and sample rate. + // Set tube count and sample rate. let mut cnt = bm.readl(GLOB_CNT_2C); cnt &= !GLOB_CNT_PCM_246_MASK; mixer.writew(MIXER_PCM_FRONT_DAC_RATE_2C, rate); diff --git a/devices/src/pci/ac97_mixer.rs b/devices/src/pci/ac97_mixer.rs index 7ab918d5e..bf1fcda54 100644 --- a/devices/src/pci/ac97_mixer.rs +++ b/devices/src/pci/ac97_mixer.rs @@ -119,7 +119,7 @@ impl Ac97Mixer { /// Returns the front sample rate (reg 0x2c). pub fn get_sample_rate(&self) -> u16 { // MIXER_PCM_FRONT_DAC_RATE_2C, MIXER_PCM_SURR_DAC_RATE_2E, and MIXER_PCM_LFE_DAC_RATE_30 - // are updated to the same rate when playback with 2,4 and 6 channels. + // are updated to the same rate when playback with 2,4 and 6 tubes. self.pcm_front_dac_rate } diff --git a/devices/src/pci/ac97_regs.rs b/devices/src/pci/ac97_regs.rs index 26ac45a5c..afe06998f 100644 --- a/devices/src/pci/ac97_regs.rs +++ b/devices/src/pci/ac97_regs.rs @@ -58,7 +58,7 @@ pub const MIXER_EI_SDAC: u16 = 0x0080; // PCM Surround DAC is available. pub const MIXER_EI_LDAC: u16 = 0x0100; // PCM LFE DAC is available. // Basic capabilities for MIXER_RESET_00 -pub const BC_DEDICATED_MIC: u16 = 0x0001; /* Dedicated Mic PCM In Channel */ +pub const BC_DEDICATED_MIC: u16 = 0x0001; /* Dedicated Mic PCM In Tube */ // Bus Master regs from ICH spec: // 00h PI_BDBAR PCM In Buffer Descriptor list Base Address Register @@ -93,14 +93,14 @@ pub const GLOB_CNT_WARM_RESET: u32 = 0x0000_0004; pub const GLOB_CNT_STABLE_BITS: u32 = 0x0000_007f; // Bits not affected by reset. // PCM 4/6 Enable bits -pub const GLOB_CNT_PCM_2: u32 = 0x0000_0000; // 2 channels -pub const GLOB_CNT_PCM_4: u32 = 0x0010_0000; // 4 channels -pub const GLOB_CNT_PCM_6: u32 = 0x0020_0000; // 6 channels -pub const GLOB_CNT_PCM_246_MASK: u32 = GLOB_CNT_PCM_4 | GLOB_CNT_PCM_6; // channel mask +pub const GLOB_CNT_PCM_2: u32 = 0x0000_0000; // 2 tubes +pub const GLOB_CNT_PCM_4: u32 = 0x0010_0000; // 4 tubes +pub const GLOB_CNT_PCM_6: u32 = 0x0020_0000; // 6 tubes +pub const GLOB_CNT_PCM_246_MASK: u32 = GLOB_CNT_PCM_4 | GLOB_CNT_PCM_6; // tube mask // Global status pub const GLOB_STA_30: u64 = 0x30; -// Primary codec ready set and turn on D20:21 to support 4 and 6 channels on PCM out. +// Primary codec ready set and turn on D20:21 to support 4 and 6 tubes on PCM out. pub const GLOB_STA_RESET_VAL: u32 = 0x0030_0100; // glob_sta bits diff --git a/devices/src/pci/msix.rs b/devices/src/pci/msix.rs index 7240793c3..2852aefd6 100644 --- a/devices/src/pci/msix.rs +++ b/devices/src/pci/msix.rs @@ -3,11 +3,11 @@ // found in the LICENSE file. use crate::pci::{PciCapability, PciCapabilityID}; -use base::{error, AsRawDescriptor, Error as SysError, Event, RawDescriptor}; -use msg_socket::{MsgError, MsgReceiver, MsgSender}; +use base::{error, AsRawDescriptor, Error as SysError, Event, RawDescriptor, Tube, TubeError}; + use std::convert::TryInto; use std::fmt::{self, Display}; -use vm_control::{MaybeOwnedDescriptor, VmIrqRequest, VmIrqRequestSocket, VmIrqResponse}; +use vm_control::{VmIrqRequest, VmIrqResponse}; use data_model::DataInit; @@ -55,17 +55,17 @@ pub struct MsixConfig { irq_vec: Vec<IrqfdGsi>, masked: bool, enabled: bool, - msi_device_socket: VmIrqRequestSocket, + msi_device_socket: Tube, msix_num: u16, } enum MsixError { AddMsiRoute(SysError), - AddMsiRouteRecv(MsgError), - AddMsiRouteSend(MsgError), + AddMsiRouteRecv(TubeError), + AddMsiRouteSend(TubeError), AllocateOneMsi(SysError), - AllocateOneMsiRecv(MsgError), - AllocateOneMsiSend(MsgError), + AllocateOneMsiRecv(TubeError), + AllocateOneMsiSend(TubeError), } impl Display for MsixError { @@ -94,7 +94,7 @@ pub enum MsixStatus { } impl MsixConfig { - pub fn new(msix_vectors: u16, vm_socket: VmIrqRequestSocket) -> Self { + pub fn new(msix_vectors: u16, vm_socket: Tube) -> Self { assert!(msix_vectors <= MAX_MSIX_VECTORS_PER_DEVICE); let mut table_entries: Vec<MsixTableEntry> = Vec::new(); @@ -235,10 +235,9 @@ impl MsixConfig { self.irq_vec.clear(); for i in 0..self.msix_num { let irqfd = Event::new().unwrap(); + let request = VmIrqRequest::AllocateOneMsi { irqfd }; self.msi_device_socket - .send(&VmIrqRequest::AllocateOneMsi { - irqfd: MaybeOwnedDescriptor::Borrowed(irqfd.as_raw_descriptor()), - }) + .send(&request) .map_err(MsixError::AllocateOneMsiSend)?; let irq_num: u32; match self @@ -251,7 +250,10 @@ impl MsixConfig { _ => unreachable!(), } self.irq_vec.push(IrqfdGsi { - irqfd, + irqfd: match request { + VmIrqRequest::AllocateOneMsi { irqfd } => irqfd, + _ => unreachable!(), + }, gsi: irq_num, }); @@ -498,7 +500,7 @@ impl MsixConfig { /// Return the raw fd of the MSI device socket pub fn get_msi_socket(&self) -> RawDescriptor { - self.msi_device_socket.as_ref().as_raw_descriptor() + self.msi_device_socket.as_raw_descriptor() } /// Return irqfd of MSI-X Table entry diff --git a/devices/src/pci/pci_configuration.rs b/devices/src/pci/pci_configuration.rs index 0ae8f585d..8654b0171 100644 --- a/devices/src/pci/pci_configuration.rs +++ b/devices/src/pci/pci_configuration.rs @@ -24,7 +24,7 @@ const BAR_MEM_MIN_SIZE: u64 = 16; const NUM_BAR_REGS: usize = 6; const CAPABILITY_LIST_HEAD_OFFSET: usize = 0x34; const FIRST_CAPABILITY_OFFSET: usize = 0x40; -const CAPABILITY_MAX_OFFSET: usize = 192; +const CAPABILITY_MAX_OFFSET: usize = 255; const INTERRUPT_LINE_PIN_REG: usize = 15; @@ -292,6 +292,7 @@ impl PciConfiguration { header_type: PciHeaderType, subsystem_vendor_id: u16, subsystem_id: u16, + revision_id: u8, ) -> Self { let mut registers = [0u32; NUM_CONFIGURATION_REGISTERS]; let mut writable_bits = [0u32; NUM_CONFIGURATION_REGISTERS]; @@ -305,7 +306,8 @@ impl PciConfiguration { }; registers[2] = u32::from(class_code.get_register_value()) << 24 | u32::from(subclass.get_register_value()) << 16 - | u32::from(pi) << 8; + | u32::from(pi) << 8 + | u32::from(revision_id); writable_bits[3] = 0x0000_00ff; // Cacheline size (r/w) match header_type { PciHeaderType::Device => { @@ -471,11 +473,17 @@ impl PciConfiguration { } let (mask, lower_bits) = match config.region_type { - PciBarRegionType::Memory32BitRegion | PciBarRegionType::Memory64BitRegion => ( - BAR_MEM_ADDR_MASK, - config.prefetchable as u32 | config.region_type as u32, - ), - PciBarRegionType::IORegion => (BAR_IO_ADDR_MASK, config.region_type as u32), + PciBarRegionType::Memory32BitRegion | PciBarRegionType::Memory64BitRegion => { + self.registers[COMMAND_REG] |= COMMAND_REG_MEMORY_SPACE_MASK; + ( + BAR_MEM_ADDR_MASK, + config.prefetchable as u32 | config.region_type as u32, + ) + } + PciBarRegionType::IORegion => { + self.registers[COMMAND_REG] |= COMMAND_REG_IO_SPACE_MASK; + (BAR_IO_ADDR_MASK, config.region_type as u32) + } }; self.registers[bar_idx] = ((config.addr as u32) & mask) | lower_bits; @@ -673,6 +681,7 @@ mod tests { PciHeaderType::Device, 0xABCD, 0x2468, + 0, ); // Add two capabilities with different contents. @@ -734,6 +743,7 @@ mod tests { PciHeaderType::Device, 0xABCD, 0x2468, + 0, ); let class_reg = cfg.read_reg(2); @@ -756,6 +766,7 @@ mod tests { PciHeaderType::Device, 0xABCD, 0x2468, + 0, ); // Attempt to overwrite vendor ID and device ID, which are read-only @@ -775,6 +786,7 @@ mod tests { PciHeaderType::Device, 0xABCD, 0x2468, + 0, ); // No BAR 0 has been configured, so these should return None or 0 as appropriate. @@ -796,6 +808,7 @@ mod tests { PciHeaderType::Device, 0xABCD, 0x2468, + 0, ); cfg.add_pci_bar( @@ -842,6 +855,7 @@ mod tests { PciHeaderType::Device, 0xABCD, 0x2468, + 0, ); cfg.add_pci_bar( @@ -887,6 +901,7 @@ mod tests { PciHeaderType::Device, 0xABCD, 0x2468, + 0, ); cfg.add_pci_bar( @@ -929,6 +944,7 @@ mod tests { PciHeaderType::Device, 0xABCD, 0x2468, + 0, ); // bar_num 0-1: 64-bit memory @@ -1050,6 +1066,7 @@ mod tests { PciHeaderType::Device, 0xABCD, 0x2468, + 0, ); // I/O BAR with size 2 (too small) diff --git a/devices/src/pci/pci_device.rs b/devices/src/pci/pci_device.rs index 25e37cdc3..e2a1b67e2 100644 --- a/devices/src/pci/pci_device.rs +++ b/devices/src/pci/pci_device.rs @@ -282,6 +282,7 @@ mod tests { PciHeaderType::Device, 0x5678, 0xEF01, + 0, ), }; diff --git a/devices/src/pci/pci_root.rs b/devices/src/pci/pci_root.rs index d2f42cb10..4eb1b3407 100644 --- a/devices/src/pci/pci_root.rs +++ b/devices/src/pci/pci_root.rs @@ -145,6 +145,7 @@ impl PciRoot { PciHeaderType::Device, 0, 0, + 0, ), }, devices: BTreeMap::new(), diff --git a/devices/src/pci/vfio_pci.rs b/devices/src/pci/vfio_pci.rs index ae043d1bd..46ee5372b 100644 --- a/devices/src/pci/vfio_pci.rs +++ b/devices/src/pci/vfio_pci.rs @@ -6,17 +6,15 @@ use std::sync::Arc; use std::u32; use base::{ - error, AsRawDescriptor, Event, MappedRegion, MemoryMapping, MemoryMappingBuilder, RawDescriptor, + error, pagesize, AsRawDescriptor, Event, MappedRegion, MemoryMapping, MemoryMappingBuilder, + RawDescriptor, Tube, }; use hypervisor::Datamatch; -use msg_socket::{MsgReceiver, MsgSender}; + use resources::{Alloc, MmioType, SystemAllocator}; use vfio_sys::*; -use vm_control::{ - MaybeOwnedDescriptor, VmIrqRequest, VmIrqRequestSocket, VmIrqResponse, - VmMemoryControlRequestSocket, VmMemoryRequest, VmMemoryResponse, -}; +use vm_control::{VmIrqRequest, VmIrqResponse, VmMemoryRequest, VmMemoryResponse}; use crate::pci::msix::{ MsixConfig, BITS_PER_PBA_ENTRY, MSIX_PBA_ENTRIES_MODULO, MSIX_TABLE_ENTRIES_MODULO, @@ -128,13 +126,13 @@ struct VfioMsiCap { ctl: u16, address: u64, data: u16, - vm_socket_irq: VmIrqRequestSocket, + vm_socket_irq: Tube, irqfd: Option<Event>, gsi: Option<u32>, } impl VfioMsiCap { - fn new(config: &VfioPciConfig, msi_cap_start: u32, vm_socket_irq: VmIrqRequestSocket) -> Self { + fn new(config: &VfioPciConfig, msi_cap_start: u32, vm_socket_irq: Tube) -> Self { let msi_ctl = config.read_config_word(msi_cap_start + PCI_MSI_FLAGS); VfioMsiCap { @@ -253,19 +251,27 @@ impl VfioMsiCap { } fn allocate_one_msi(&mut self) { - if self.irqfd.is_none() { - match Event::new() { - Ok(fd) => self.irqfd = Some(fd), + let irqfd = match self.irqfd.take() { + Some(e) => e, + None => match Event::new() { + Ok(e) => e, Err(e) => { error!("failed to create event: {:?}", e); return; } - }; - } + }, + }; - if let Err(e) = self.vm_socket_irq.send(&VmIrqRequest::AllocateOneMsi { - irqfd: MaybeOwnedDescriptor::Borrowed(self.irqfd.as_ref().unwrap().as_raw_descriptor()), - }) { + let request = VmIrqRequest::AllocateOneMsi { irqfd }; + let request_result = self.vm_socket_irq.send(&request); + + // Stash the irqfd in self immediately because we used take above. + self.irqfd = match request { + VmIrqRequest::AllocateOneMsi { irqfd } => Some(irqfd), + _ => unreachable!(), + }; + + if let Err(e) = request_result { error!("failed to send AllocateOneMsi request: {:?}", e); return; } @@ -310,7 +316,7 @@ struct VfioMsixCap { } impl VfioMsixCap { - fn new(config: &VfioPciConfig, msix_cap_start: u32, vm_socket_irq: VmIrqRequestSocket) -> Self { + fn new(config: &VfioPciConfig, msix_cap_start: u32, vm_socket_irq: Tube) -> Self { let msix_ctl = config.read_config_word(msix_cap_start + PCI_MSIX_FLAGS); let table_size = (msix_ctl & PCI_MSIX_FLAGS_QSIZE) + 1; let table = config.read_config_dword(msix_cap_start + PCI_MSIX_TABLE); @@ -441,7 +447,7 @@ pub struct VfioPciDevice { msi_cap: Option<VfioMsiCap>, msix_cap: Option<VfioMsixCap>, irq_type: Option<VfioIrqType>, - vm_socket_mem: VmMemoryControlRequestSocket, + vm_socket_mem: Tube, device_data: Option<DeviceData>, // scratch MemoryMapping to avoid unmap beform vm exit @@ -452,9 +458,9 @@ impl VfioPciDevice { /// Constructs a new Vfio Pci device for the give Vfio device pub fn new( device: VfioDevice, - vfio_device_socket_msi: VmIrqRequestSocket, - vfio_device_socket_msix: VmIrqRequestSocket, - vfio_device_socket_mem: VmMemoryControlRequestSocket, + vfio_device_socket_msi: Tube, + vfio_device_socket_msix: Tube, + vfio_device_socket_mem: Tube, ) -> Self { let dev = Arc::new(device); let config = VfioPciConfig::new(Arc::clone(&dev)); @@ -543,22 +549,25 @@ impl VfioPciDevice { if let Some(ref interrupt_evt) = self.interrupt_evt { let mut fds = Vec::new(); fds.push(interrupt_evt); - if let Err(e) = self.device.irq_enable(fds, VfioIrqType::Intx) { + if let Err(e) = self.device.irq_enable(fds, VFIO_PCI_INTX_IRQ_INDEX) { error!("Intx enable failed: {}", e); return; } if let Some(ref irq_resample_evt) = self.interrupt_resample_evt { - if let Err(e) = self.device.irq_mask(VfioIrqType::Intx) { + if let Err(e) = self.device.irq_mask(VFIO_PCI_INTX_IRQ_INDEX) { error!("Intx mask failed: {}", e); self.disable_intx(); return; } - if let Err(e) = self.device.resample_virq_enable(irq_resample_evt) { + if let Err(e) = self + .device + .resample_virq_enable(irq_resample_evt, VFIO_PCI_INTX_IRQ_INDEX) + { error!("resample enable failed: {}", e); self.disable_intx(); return; } - if let Err(e) = self.device.irq_unmask(VfioIrqType::Intx) { + if let Err(e) = self.device.irq_unmask(VFIO_PCI_INTX_IRQ_INDEX) { error!("Intx unmask failed: {}", e); self.disable_intx(); return; @@ -570,7 +579,7 @@ impl VfioPciDevice { } fn disable_intx(&mut self) { - if let Err(e) = self.device.irq_disable(VfioIrqType::Intx) { + if let Err(e) = self.device.irq_disable(VFIO_PCI_INTX_IRQ_INDEX) { error!("Intx disable failed: {}", e); } self.irq_type = None; @@ -610,7 +619,7 @@ impl VfioPciDevice { let mut fds = Vec::new(); fds.push(irqfd); - if let Err(e) = self.device.irq_enable(fds, VfioIrqType::Msi) { + if let Err(e) = self.device.irq_enable(fds, VFIO_PCI_MSI_IRQ_INDEX) { error!("failed to enable msi: {}", e); self.enable_intx(); return; @@ -620,7 +629,7 @@ impl VfioPciDevice { } fn disable_msi(&mut self) { - if let Err(e) = self.device.irq_disable(VfioIrqType::Msi) { + if let Err(e) = self.device.irq_disable(VFIO_PCI_MSI_IRQ_INDEX) { error!("failed to disable msi: {}", e); return; } @@ -637,7 +646,7 @@ impl VfioPciDevice { }; if let Some(descriptors) = irqfds { - if let Err(e) = self.device.irq_enable(descriptors, VfioIrqType::Msix) { + if let Err(e) = self.device.irq_enable(descriptors, VFIO_PCI_MSIX_IRQ_INDEX) { error!("failed to enable msix: {}", e); self.enable_intx(); return; @@ -651,7 +660,7 @@ impl VfioPciDevice { } fn disable_msix(&mut self) { - if let Err(e) = self.device.irq_disable(VfioIrqType::Msix) { + if let Err(e) = self.device.irq_disable(VFIO_PCI_MSIX_IRQ_INDEX) { error!("failed to disable msix: {}", e); return; } @@ -681,10 +690,14 @@ impl VfioPciDevice { let guest_map_start = bar_addr + mmap_offset; let region_offset = self.device.get_region_offset(index); let offset = region_offset + mmap_offset; + let descriptor = match self.device.device_file().try_clone() { + Ok(device_file) => device_file.into(), + Err(_) => break, + }; if self .vm_socket_mem .send(&VmMemoryRequest::RegisterMmapMemory { - descriptor: MaybeOwnedDescriptor::Borrowed(self.device.as_raw_descriptor()), + descriptor, size: mmap_size as usize, offset, gpa: guest_map_start, @@ -694,7 +707,7 @@ impl VfioPciDevice { break; } - let response = match self.vm_socket_mem.recv() { + let response: VmMemoryResponse = match self.vm_socket_mem.recv() { Ok(res) => res, Err(_) => break, }; @@ -704,7 +717,7 @@ impl VfioPciDevice { // device process doesn't has this mapping, but vfio_dma_map() need it // in device process, so here map it again. let mmap = match MemoryMappingBuilder::new(mmap_size as usize) - .from_descriptor(self.device.as_ref()) + .from_file(self.device.device_file()) .offset(offset) .build() { @@ -712,10 +725,13 @@ impl VfioPciDevice { Err(_e) => break, }; let host = (&mmap).as_ptr() as u64; + let pgsz = pagesize() as u64; + let size = (mmap_size + pgsz - 1) / pgsz * pgsz; // Safe because the given guest_map_start is valid guest bar address. and // the host pointer is correct and valid guaranteed by MemoryMapping interface. - match unsafe { self.device.vfio_dma_map(guest_map_start, mmap_size, host) } - { + // The size will be extened to page size aligned if it is not which is also + // safe because VFIO actually maps the BAR with page size aligned size. + match unsafe { self.device.vfio_dma_map(guest_map_start, size, host) } { Ok(_) => mem_map.push(mmap), Err(e) => { error!( @@ -954,7 +970,7 @@ impl PciDevice for VfioPciDevice { let mut config = self.config.read_config_dword(reg); // Ignore IO bar - if reg >= 0x10 && reg <= 0x24 { + if (0x10..=0x24).contains(®) { for io_info in self.io_regions.iter() { if io_info.bar_index * 4 + 0x10 == reg { config = 0; diff --git a/devices/src/proxy.rs b/devices/src/proxy.rs index e044feb9b..5e22786d6 100644 --- a/devices/src/proxy.rs +++ b/devices/src/proxy.rs @@ -4,14 +4,14 @@ //! Runs hardware devices in child processes. +use std::ffi::CString; use std::fmt::{self, Display}; use std::time::Duration; -use std::{self, io}; -use base::{error, net::UnixSeqpacket, AsRawDescriptor, RawDescriptor}; +use base::{error, AsRawDescriptor, RawDescriptor, Tube, TubeError}; use libc::{self, pid_t}; use minijail::{self, Minijail}; -use msg_socket::{MsgOnSocket, MsgReceiver, MsgSender, MsgSocket}; +use serde::{Deserialize, Serialize}; use crate::bus::ConfigWriteResult; use crate::{BusAccessInfo, BusDevice}; @@ -20,7 +20,7 @@ use crate::{BusAccessInfo, BusDevice}; #[derive(Debug)] pub enum Error { ForkingJail(minijail::Error), - Io(io::Error), + Tube(TubeError), } pub type Result<T> = std::result::Result<T, Error>; @@ -30,14 +30,14 @@ impl Display for Error { match self { ForkingJail(e) => write!(f, "Failed to fork jail process: {}", e), - Io(e) => write!(f, "IO error configuring proxy device {}.", e), + Tube(e) => write!(f, "Failed to configure tube: {}.", e), } } } const SOCKET_TIMEOUT_MS: u64 = 2000; -#[derive(Debug, MsgOnSocket)] +#[derive(Debug, Serialize, Deserialize)] enum Command { Read { len: u32, @@ -57,8 +57,7 @@ enum Command { }, Shutdown, } - -#[derive(MsgOnSocket)] +#[derive(Debug, Serialize, Deserialize)] enum CommandResult { Ok, ReadResult([u8; 8]), @@ -69,12 +68,11 @@ enum CommandResult { }, } -fn child_proc<D: BusDevice>(sock: UnixSeqpacket, device: &mut D) { +fn child_proc<D: BusDevice>(tube: Tube, device: &mut D) { let mut running = true; - let sock = MsgSocket::<CommandResult, Command>::new(sock); while running { - let cmd = match sock.recv() { + let cmd = match tube.recv() { Ok(cmd) => cmd, Err(err) => { error!("child device process failed recv: {}", err); @@ -86,7 +84,7 @@ fn child_proc<D: BusDevice>(sock: UnixSeqpacket, device: &mut D) { Command::Read { len, info } => { let mut buffer = [0u8; 8]; device.read(info, &mut buffer[0..len as usize]); - sock.send(&CommandResult::ReadResult(buffer)) + tube.send(&CommandResult::ReadResult(buffer)) } Command::Write { len, info, data } => { let len = len as usize; @@ -96,7 +94,7 @@ fn child_proc<D: BusDevice>(sock: UnixSeqpacket, device: &mut D) { } Command::ReadConfig(idx) => { let val = device.config_register_read(idx as usize); - sock.send(&CommandResult::ReadConfigResult(val)) + tube.send(&CommandResult::ReadConfigResult(val)) } Command::WriteConfig { reg_idx, @@ -107,14 +105,14 @@ fn child_proc<D: BusDevice>(sock: UnixSeqpacket, device: &mut D) { let len = len as usize; let res = device.config_register_write(reg_idx as usize, offset as u64, &data[0..len]); - sock.send(&CommandResult::WriteConfigResult { + tube.send(&CommandResult::WriteConfigResult { mem_bus_new_state: res.mem_bus_new_state, io_bus_new_state: res.io_bus_new_state, }) } Command::Shutdown => { running = false; - sock.send(&CommandResult::Ok) + tube.send(&CommandResult::Ok) } }; if let Err(e) = res { @@ -128,7 +126,7 @@ fn child_proc<D: BusDevice>(sock: UnixSeqpacket, device: &mut D) { /// Because forks are very unfriendly to destructors and all memory mappings and file descriptors /// are inherited, this should be used as early as possible in the main process. pub struct ProxyDevice { - sock: MsgSocket<Command, CommandResult>, + tube: Tube, pid: pid_t, debug_label: String, } @@ -149,15 +147,23 @@ impl ProxyDevice { mut keep_rds: Vec<RawDescriptor>, ) -> Result<ProxyDevice> { let debug_label = device.debug_label(); - let (child_sock, parent_sock) = UnixSeqpacket::pair().map_err(Error::Io)?; + let (child_tube, parent_tube) = Tube::pair().map_err(Error::Tube)?; - keep_rds.push(child_sock.as_raw_descriptor()); + keep_rds.push(child_tube.as_raw_descriptor()); // Forking here is safe as long as the program is still single threaded. let pid = unsafe { match jail.fork(Some(&keep_rds)).map_err(Error::ForkingJail)? { 0 => { + let max_len = 15; // pthread_setname_np() limit on Linux + let debug_label_trimmed = + &debug_label.as_bytes()[..std::cmp::min(max_len, debug_label.len())]; + let thread_name = CString::new(debug_label_trimmed).unwrap(); + // TODO(crbug.com/1199487): remove this once libc provides the wrapper for all + // targets + #[cfg(all(target_os = "linux", target_env = "gnu"))] + let _ = libc::pthread_setname_np(libc::pthread_self(), thread_name.as_ptr()); device.on_sandboxed(); - child_proc(child_sock, &mut device); + child_proc(child_tube, &mut device); // We're explicitly not using std::process::exit here to avoid the cleanup of // stdout/stderr globals. This can cause cascading panics and SIGILL if a worker @@ -173,14 +179,14 @@ impl ProxyDevice { } }; - parent_sock - .set_write_timeout(Some(Duration::from_millis(SOCKET_TIMEOUT_MS))) - .map_err(Error::Io)?; - parent_sock - .set_read_timeout(Some(Duration::from_millis(SOCKET_TIMEOUT_MS))) - .map_err(Error::Io)?; + parent_tube + .set_send_timeout(Some(Duration::from_millis(SOCKET_TIMEOUT_MS))) + .map_err(Error::Tube)?; + parent_tube + .set_recv_timeout(Some(Duration::from_millis(SOCKET_TIMEOUT_MS))) + .map_err(Error::Tube)?; Ok(ProxyDevice { - sock: MsgSocket::<Command, CommandResult>::new(parent_sock), + tube: parent_tube, pid, debug_label, }) @@ -192,7 +198,7 @@ impl ProxyDevice { /// Send a command that does not expect a response from the child device process. fn send_no_result(&self, cmd: &Command) { - let res = self.sock.send(cmd); + let res = self.tube.send(cmd); if let Err(e) = res { error!( "failed write to child device process {}: {}", @@ -204,7 +210,7 @@ impl ProxyDevice { /// Send a command and read its response from the child device process. fn sync_send(&self, cmd: &Command) -> Option<CommandResult> { self.send_no_result(cmd); - match self.sock.recv() { + match self.tube.recv() { Err(e) => { error!( "failed to read result of {:?} from child device process {}: {}", diff --git a/devices/src/usb/host_backend/error.rs b/devices/src/usb/host_backend/error.rs index 3f1fed4e1..8d089233a 100644 --- a/devices/src/usb/host_backend/error.rs +++ b/devices/src/usb/host_backend/error.rs @@ -5,7 +5,8 @@ use crate::usb::xhci::scatter_gather_buffer::Error as BufferError; use crate::usb::xhci::xhci_transfer::Error as XhciTransferError; use crate::utils::Error as UtilsError; -use msg_socket::MsgError; + +use base::TubeError; use std::fmt::{self, Display}; use usb_util::Error as UsbUtilError; @@ -20,11 +21,12 @@ pub enum Error { SetInterfaceAltSetting(UsbUtilError), ClearHalt(UsbUtilError), CreateTransfer(UsbUtilError), + Reset(UsbUtilError), GetEndpointType, - CreateControlSock(std::io::Error), - SetupControlSock(std::io::Error), - ReadControlSock(MsgError), - WriteControlSock(MsgError), + CreateControlTube(TubeError), + SetupControlTube(TubeError), + ReadControlTube(TubeError), + WriteControlTube(TubeError), GetXhciTransferType(XhciTransferError), TransferComplete(XhciTransferError), ReadBuffer(BufferError), @@ -51,11 +53,12 @@ impl Display for Error { SetInterfaceAltSetting(e) => write!(f, "failed to set interface alt setting: {:?}", e), ClearHalt(e) => write!(f, "failed to clear halt: {:?}", e), CreateTransfer(e) => write!(f, "failed to create transfer: {:?}", e), + Reset(e) => write!(f, "failed to reset: {:?}", e), GetEndpointType => write!(f, "failed to get endpoint type"), - CreateControlSock(e) => write!(f, "failed to create contro sock: {}", e), - SetupControlSock(e) => write!(f, "failed to setup control sock: {}", e), - ReadControlSock(e) => write!(f, "failed to read control sock: {}", e), - WriteControlSock(e) => write!(f, "failed to write control sock: {}", e), + CreateControlTube(e) => write!(f, "failed to create contro tube: {}", e), + SetupControlTube(e) => write!(f, "failed to setup control tube: {}", e), + ReadControlTube(e) => write!(f, "failed to read control tube: {}", e), + WriteControlTube(e) => write!(f, "failed to write control tube: {}", e), GetXhciTransferType(e) => write!(f, "failed to get xhci transfer type: {}", e), TransferComplete(e) => write!(f, "xhci transfer completed: {}", e), ReadBuffer(e) => write!(f, "failed to read buffer: {}", e), diff --git a/devices/src/usb/host_backend/host_backend_device_provider.rs b/devices/src/usb/host_backend/host_backend_device_provider.rs index c33092a15..787bb8fe0 100644 --- a/devices/src/usb/host_backend/host_backend_device_provider.rs +++ b/devices/src/usb/host_backend/host_backend_device_provider.rs @@ -11,19 +11,14 @@ use crate::usb::xhci::usb_hub::UsbHub; use crate::usb::xhci::xhci_backend_device_provider::XhciBackendDeviceProvider; use crate::utils::AsyncJobQueue; use crate::utils::{EventHandler, EventLoop, FailHandle}; -use base::net::UnixSeqpacket; -use base::{ - error, AsRawDescriptor, FromRawDescriptor, IntoRawDescriptor, RawDescriptor, WatchingEvents, -}; -use msg_socket::{MsgReceiver, MsgSender, MsgSocket}; +use base::{error, AsRawDescriptor, Descriptor, RawDescriptor, Tube, WatchingEvents}; use std::collections::HashMap; use std::mem; use std::time::Duration; use sync::Mutex; use usb_util::Device; use vm_control::{ - MaybeOwnedDescriptor, UsbControlAttachedDevice, UsbControlCommand, UsbControlResult, - UsbControlSocket, USB_CONTROL_MAX_PORTS, + UsbControlAttachedDevice, UsbControlCommand, UsbControlResult, USB_CONTROL_MAX_PORTS, }; const SOCKET_TIMEOUT_MS: u64 = 2000; @@ -32,31 +27,27 @@ const SOCKET_TIMEOUT_MS: u64 = 2000; /// devices. pub enum HostBackendDeviceProvider { // The provider is created but not yet started. - Created { - sock: Mutex<MsgSocket<UsbControlResult, UsbControlCommand>>, - }, + Created { control_tube: Mutex<Tube> }, // The provider is started on an event loop. - Started { - inner: Arc<ProviderInner>, - }, + Started { inner: Arc<ProviderInner> }, // The provider has failed. Failed, } impl HostBackendDeviceProvider { - pub fn new() -> Result<(UsbControlSocket, HostBackendDeviceProvider)> { - let (child_sock, control_sock) = UnixSeqpacket::pair().map_err(Error::CreateControlSock)?; - control_sock - .set_write_timeout(Some(Duration::from_millis(SOCKET_TIMEOUT_MS))) - .map_err(Error::SetupControlSock)?; - control_sock - .set_read_timeout(Some(Duration::from_millis(SOCKET_TIMEOUT_MS))) - .map_err(Error::SetupControlSock)?; + pub fn new() -> Result<(Tube, HostBackendDeviceProvider)> { + let (child_tube, control_tube) = Tube::pair().map_err(Error::CreateControlTube)?; + control_tube + .set_send_timeout(Some(Duration::from_millis(SOCKET_TIMEOUT_MS))) + .map_err(Error::SetupControlTube)?; + control_tube + .set_recv_timeout(Some(Duration::from_millis(SOCKET_TIMEOUT_MS))) + .map_err(Error::SetupControlTube)?; let provider = HostBackendDeviceProvider::Created { - sock: Mutex::new(MsgSocket::new(child_sock)), + control_tube: Mutex::new(child_tube), }; - Ok((MsgSocket::new(control_sock), provider)) + Ok((control_tube, provider)) } fn start_helper( @@ -66,20 +57,20 @@ impl HostBackendDeviceProvider { hub: Arc<UsbHub>, ) -> Result<()> { match mem::replace(self, HostBackendDeviceProvider::Failed) { - HostBackendDeviceProvider::Created { sock } => { + HostBackendDeviceProvider::Created { control_tube } => { let job_queue = AsyncJobQueue::init(&event_loop).map_err(Error::StartAsyncJobQueue)?; let inner = Arc::new(ProviderInner::new( fail_handle, job_queue, event_loop.clone(), - sock, + control_tube, hub, )); let handler: Arc<dyn EventHandler> = inner.clone(); event_loop .add_event( - &*inner.sock.lock(), + &*inner.control_tube.lock(), WatchingEvents::empty().set_read(), Arc::downgrade(&handler), ) @@ -105,16 +96,15 @@ impl XhciBackendDeviceProvider for HostBackendDeviceProvider { fail_handle: Arc<dyn FailHandle>, event_loop: Arc<EventLoop>, hub: Arc<UsbHub>, - ) -> std::result::Result<(), ()> { + ) -> Result<()> { self.start_helper(fail_handle, event_loop, hub) - .map_err(|e| { - error!("failed to start host backend device provider: {}", e); - }) } fn keep_rds(&self) -> Vec<RawDescriptor> { match self { - HostBackendDeviceProvider::Created { sock } => vec![sock.lock().as_raw_descriptor()], + HostBackendDeviceProvider::Created { control_tube } => { + vec![control_tube.lock().as_raw_descriptor()] + } _ => { error!( "Trying to get keepfds when HostBackendDeviceProvider is not in created state" @@ -130,7 +120,7 @@ pub struct ProviderInner { fail_handle: Arc<dyn FailHandle>, job_queue: Arc<AsyncJobQueue>, event_loop: Arc<EventLoop>, - sock: Mutex<MsgSocket<UsbControlResult, UsbControlCommand>>, + control_tube: Mutex<Tube>, usb_hub: Arc<UsbHub>, // Map of USB hub port number to per-device context. @@ -147,14 +137,14 @@ impl ProviderInner { fail_handle: Arc<dyn FailHandle>, job_queue: Arc<AsyncJobQueue>, event_loop: Arc<EventLoop>, - sock: Mutex<MsgSocket<UsbControlResult, UsbControlCommand>>, + control_tube: Mutex<Tube>, usb_hub: Arc<UsbHub>, ) -> ProviderInner { ProviderInner { fail_handle, job_queue, event_loop, - sock, + control_tube, usb_hub, devices: Mutex::new(HashMap::new()), } @@ -162,18 +152,8 @@ impl ProviderInner { /// Open a usbdevfs file to create a host USB device object. /// `fd` should be an open file descriptor for a file in `/dev/bus/usb`. - fn handle_attach_device(&self, fd: Option<MaybeOwnedDescriptor>) -> UsbControlResult { - let usb_file = match fd { - Some(MaybeOwnedDescriptor::Owned(file)) => file, - _ => { - error!("missing fd in UsbControlCommand::AttachDevice message"); - return UsbControlResult::FailedToOpenDevice; - } - }; - - let raw_descriptor = usb_file.into_raw_descriptor(); - // Safe as it is valid to have multiple variables accessing the same fd. - let device = match Device::new(unsafe { File::from_raw_descriptor(raw_descriptor) }) { + fn handle_attach_device(&self, usb_file: File) -> UsbControlResult { + let device = match Device::new(usb_file) { Ok(d) => d, Err(e) => { error!("could not construct USB device from fd: {}", e); @@ -181,6 +161,8 @@ impl ProviderInner { } }; + let device_descriptor = Descriptor(device.as_raw_descriptor()); + let arc_mutex_device = Arc::new(Mutex::new(device)); let event_handler: Arc<dyn EventHandler> = Arc::new(UsbUtilEventHandler { @@ -188,7 +170,7 @@ impl ProviderInner { }); if let Err(e) = self.event_loop.add_event( - &MaybeOwnedDescriptor::Borrowed(raw_descriptor), + &device_descriptor, WatchingEvents::empty().set_read().set_write(), Arc::downgrade(&event_handler), ) { @@ -233,12 +215,7 @@ impl ProviderInner { let device = device_ctx.device.lock(); let fd = device.fd(); - if let Err(e) = - self.event_loop - .remove_event_for_fd(&MaybeOwnedDescriptor::Borrowed( - fd.as_raw_descriptor(), - )) - { + if let Err(e) = self.event_loop.remove_event_for_fd(&*fd) { error!( "failed to remove poll change handler from event loop: {}", e @@ -276,16 +253,14 @@ impl ProviderInner { } fn on_event_helper(&self) -> Result<()> { - let sock = self.sock.lock(); - let cmd = sock.recv().map_err(Error::ReadControlSock)?; + let tube = self.control_tube.lock(); + let cmd = tube.recv().map_err(Error::ReadControlTube)?; let result = match cmd { - UsbControlCommand::AttachDevice { descriptor, .. } => { - self.handle_attach_device(descriptor) - } + UsbControlCommand::AttachDevice { file, .. } => self.handle_attach_device(file), UsbControlCommand::DetachDevice { port } => self.handle_detach_device(port), UsbControlCommand::ListDevice { ports } => self.handle_list_devices(ports), }; - sock.send(&result).map_err(Error::WriteControlSock)?; + tube.send(&result).map_err(Error::WriteControlTube)?; Ok(()) } } diff --git a/devices/src/usb/host_backend/host_device.rs b/devices/src/usb/host_backend/host_device.rs index b3ff14b78..6a49ee5ea 100644 --- a/devices/src/usb/host_backend/host_device.rs +++ b/devices/src/usb/host_backend/host_device.rs @@ -468,10 +468,8 @@ impl XhciBackendDevice for HostDevice { } } - fn submit_transfer(&mut self, transfer: XhciTransfer) -> std::result::Result<(), ()> { - self.submit_transfer_helper(transfer).map_err(|e| { - error!("failed to submit transfer: {}", e); - }) + fn submit_transfer(&mut self, transfer: XhciTransfer) -> Result<()> { + self.submit_transfer_helper(transfer) } fn set_address(&mut self, _address: UsbDeviceAddress) { @@ -483,11 +481,8 @@ impl XhciBackendDevice for HostDevice { ); } - fn reset(&mut self) -> std::result::Result<(), ()> { + fn reset(&mut self) -> Result<()> { usb_debug!("resetting host device"); - self.device - .lock() - .reset() - .map_err(|e| error!("failed to reset device: {:?}", e)) + self.device.lock().reset().map_err(Error::Reset) } } diff --git a/devices/src/usb/xhci/device_slot.rs b/devices/src/usb/xhci/device_slot.rs index 5e12327d1..e24776339 100644 --- a/devices/src/usb/xhci/device_slot.rs +++ b/devices/src/usb/xhci/device_slot.rs @@ -12,6 +12,7 @@ use super::xhci_abi::{ }; use super::xhci_regs::{valid_slot_id, MAX_PORTS, MAX_SLOTS}; use crate::register_space::Register; +use crate::usb::host_backend::error::Error as HostBackendProviderError; use crate::usb::xhci::ring_buffer_stop_cb::{fallible_closure, RingBufferStopCallback}; use crate::utils::{EventLoop, FailHandle}; use base::error; @@ -37,6 +38,7 @@ pub enum Error { BadInputContextAddr(GuestAddress), BadDeviceContextAddr(GuestAddress), CreateTransferController(TransferRingControllerError), + ResetPort(HostBackendProviderError), } type Result<T> = std::result::Result<T, Error>; @@ -58,6 +60,7 @@ impl Display for Error { BadInputContextAddr(addr) => write!(f, "bad input context address: {}", addr), BadDeviceContextAddr(addr) => write!(f, "bad device context: {}", addr), CreateTransferController(e) => write!(f, "failed to create transfer controller: {}", e), + ResetPort(e) => write!(f, "failed to reset port: {}", e), } } } @@ -127,10 +130,10 @@ impl DeviceSlots { } /// Reset the device connected to a specific port. - pub fn reset_port(&self, port_id: u8) -> std::result::Result<(), ()> { + pub fn reset_port(&self, port_id: u8) -> Result<()> { if let Some(port) = self.hub.get_port(port_id) { if let Some(backend_device) = port.get_backend_device().as_mut() { - backend_device.reset()?; + backend_device.reset().map_err(Error::ResetPort)?; } } diff --git a/devices/src/usb/xhci/xhci.rs b/devices/src/usb/xhci/xhci.rs index e2035980f..85f6de8f0 100644 --- a/devices/src/usb/xhci/xhci.rs +++ b/devices/src/usb/xhci/xhci.rs @@ -10,7 +10,10 @@ use super::ring_buffer_stop_cb::RingBufferStopCallback; use super::usb_hub::UsbHub; use super::xhci_backend_device_provider::XhciBackendDeviceProvider; use super::xhci_regs::*; -use crate::usb::host_backend::host_backend_device_provider::HostBackendDeviceProvider; +use crate::usb::host_backend::{ + error::Error as HostBackendProviderError, + host_backend_device_provider::HostBackendDeviceProvider, +}; use crate::utils::{Error as UtilsError, EventLoop, FailHandle}; use base::{error, Event}; use std::fmt::{self, Display}; @@ -29,7 +32,7 @@ pub enum Error { SetModeration(InterrupterError), SetupEventRing(InterrupterError), SetEventHandlerBusy(InterrupterError), - StartProvider, + StartProvider(HostBackendProviderError), RingDoorbell(DeviceSlotError), CreateCommandRingController(CommandRingControllerError), ResetPort, @@ -50,7 +53,7 @@ impl Display for Error { SetModeration(e) => write!(f, "failed to set interrupter moderation: {}", e), SetupEventRing(e) => write!(f, "failed to setup event ring: {}", e), SetEventHandlerBusy(e) => write!(f, "failed to set event handler busy: {}", e), - StartProvider => write!(f, "failed to start backend provider"), + StartProvider(e) => write!(f, "failed to start backend provider: {}", e), RingDoorbell(e) => write!(f, "failed to ring doorbell: {}", e), CreateCommandRingController(e) => { write!(f, "failed to create command ring controller: {}", e) @@ -101,7 +104,7 @@ impl Xhci { let mut device_provider = device_provider; device_provider .start(fail_handle.clone(), event_loop.clone(), hub.clone()) - .map_err(|_| Error::StartProvider)?; + .map_err(Error::StartProvider)?; let device_slots = DeviceSlots::new( fail_handle.clone(), @@ -148,8 +151,7 @@ impl Xhci { let xhci_weak = Arc::downgrade(xhci); xhci.regs.crcr.set_write_cb(move |val: u64| { let xhci = xhci_weak.upgrade().unwrap(); - let r = xhci.crcr_callback(val); - xhci.handle_register_callback_result(r, 0) + xhci.crcr_callback(val) }); for i in 0..xhci.regs.portsc.len() { @@ -227,7 +229,7 @@ impl Xhci { fn usbcmd_callback(&self, value: u32) -> Result<u32> { if (value & USB_CMD_RESET) > 0 { usb_debug!("xhci_controller: reset controller"); - self.reset()?; + self.reset(); return Ok(value & (!USB_CMD_RESET)); } @@ -236,7 +238,7 @@ impl Xhci { self.regs.usbsts.clear_bits(USB_STS_HALTED); } else { usb_debug!("xhci_controller: halt device"); - self.halt()?; + self.halt(); self.regs.crcr.clear_bits(CRCR_COMMAND_RING_RUNNING); } @@ -252,9 +254,9 @@ impl Xhci { } // Callback for crcr register write. - fn crcr_callback(&self, value: u64) -> Result<u64> { + fn crcr_callback(&self, value: u64) -> u64 { usb_debug!("xhci_controller: write to crcr {:x}", value); - let value = if (self.regs.crcr.get_value() & CRCR_COMMAND_RING_RUNNING) == 0 { + if (self.regs.crcr.get_value() & CRCR_COMMAND_RING_RUNNING) == 0 { self.command_ring_controller .set_dequeue_pointer(GuestAddress(value & CRCR_COMMAND_RING_POINTER)); self.command_ring_controller @@ -263,8 +265,7 @@ impl Xhci { } else { error!("Write to crcr while command ring is running"); self.regs.crcr.get_value() - }; - Ok(value) + } } // Callback for portsc register write. @@ -378,22 +379,20 @@ impl Xhci { .map_err(Error::SetEventHandlerBusy) } - fn reset(&self) -> Result<()> { + fn reset(&self) { self.regs.usbsts.set_bits(USB_STS_CONTROLLER_NOT_READY); let usbsts = self.regs.usbsts.clone(); self.device_slots.stop_all_and_reset(move || { usbsts.clear_bits(USB_STS_CONTROLLER_NOT_READY); }); - Ok(()) } - fn halt(&self) -> Result<()> { + fn halt(&self) { let usbsts = self.regs.usbsts.clone(); self.device_slots .stop_all(RingBufferStopCallback::new(move || { usbsts.set_bits(USB_STS_HALTED); })); - Ok(()) } } diff --git a/devices/src/usb/xhci/xhci_backend_device.rs b/devices/src/usb/xhci/xhci_backend_device.rs index a3d9e66c5..2b0dedb0f 100644 --- a/devices/src/usb/xhci/xhci_backend_device.rs +++ b/devices/src/usb/xhci/xhci_backend_device.rs @@ -3,6 +3,7 @@ // found in the LICENSE file. use super::xhci_transfer::XhciTransfer; +use crate::usb::host_backend::error::Result; /// Address of this usb device, as in Set Address standard usb device request. pub type UsbDeviceAddress = u32; @@ -23,9 +24,9 @@ pub trait XhciBackendDevice: Send { /// Get product id of this device. fn get_pid(&self) -> u16; /// Submit a xhci transfer to backend. - fn submit_transfer(&mut self, transfer: XhciTransfer) -> std::result::Result<(), ()>; + fn submit_transfer(&mut self, transfer: XhciTransfer) -> Result<()>; /// Set address of this backend. fn set_address(&mut self, address: UsbDeviceAddress); /// Reset the backend device. - fn reset(&mut self) -> std::result::Result<(), ()>; + fn reset(&mut self) -> Result<()>; } diff --git a/devices/src/usb/xhci/xhci_backend_device_provider.rs b/devices/src/usb/xhci/xhci_backend_device_provider.rs index aefcf5393..ca416cc02 100644 --- a/devices/src/usb/xhci/xhci_backend_device_provider.rs +++ b/devices/src/usb/xhci/xhci_backend_device_provider.rs @@ -2,6 +2,7 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. +use super::super::host_backend::error::Result; use super::usb_hub::UsbHub; use crate::utils::{EventLoop, FailHandle}; use base::RawDescriptor; @@ -15,7 +16,7 @@ pub trait XhciBackendDeviceProvider: Send { fail_handle: Arc<dyn FailHandle>, event_loop: Arc<EventLoop>, hub: Arc<UsbHub>, - ) -> std::result::Result<(), ()>; + ) -> Result<()>; /// Keep raw descriptors that should be kept open. fn keep_rds(&self) -> Vec<RawDescriptor>; diff --git a/devices/src/usb/xhci/xhci_controller.rs b/devices/src/usb/xhci/xhci_controller.rs index 85c9d3240..0889b356a 100644 --- a/devices/src/usb/xhci/xhci_controller.rs +++ b/devices/src/usb/xhci/xhci_controller.rs @@ -111,6 +111,7 @@ impl XhciController { PciHeaderType::Device, 0, 0, + 0, ); XhciController { config_regs, diff --git a/devices/src/vfio.rs b/devices/src/vfio.rs index fcc306532..9fda0d65d 100644 --- a/devices/src/vfio.rs +++ b/devices/src/vfio.rs @@ -174,7 +174,7 @@ impl VfioContainer { // Add all guest memory regions into vfio container's iommu table, // then vfio kernel driver could access guest memory from gfn - guest_mem.with_regions(|_index, guest_addr, size, host_addr, _fd_offset| { + guest_mem.with_regions(|_index, guest_addr, size, host_addr, _mmap, _fd_offset| { // Safe because the guest regions are guaranteed not to overlap unsafe { self.vfio_dma_map(guest_addr.0, size as u64, host_addr as u64) } })?; @@ -391,21 +391,13 @@ impl VfioDevice { /// Enable vfio device's irq and associate Irqfd Event with device. /// When MSIx is enabled, multi vectors will be supported, so descriptors is vector and the vector /// length is the num of MSIx vectors - pub fn irq_enable( - &self, - descriptors: Vec<&Event>, - irq_type: VfioIrqType, - ) -> Result<(), VfioError> { + pub fn irq_enable(&self, descriptors: Vec<&Event>, index: u32) -> Result<(), VfioError> { let count = descriptors.len(); let u32_size = mem::size_of::<u32>(); let mut irq_set = vec_with_array_field::<vfio_irq_set, u32>(count); irq_set[0].argsz = (mem::size_of::<vfio_irq_set>() + count * u32_size) as u32; irq_set[0].flags = VFIO_IRQ_SET_DATA_EVENTFD | VFIO_IRQ_SET_ACTION_TRIGGER; - match irq_type { - VfioIrqType::Intx => irq_set[0].index = VFIO_PCI_INTX_IRQ_INDEX, - VfioIrqType::Msi => irq_set[0].index = VFIO_PCI_MSI_IRQ_INDEX, - VfioIrqType::Msix => irq_set[0].index = VFIO_PCI_MSIX_IRQ_INDEX, - } + irq_set[0].index = index; irq_set[0].start = 0; irq_set[0].count = count as u32; @@ -421,7 +413,7 @@ impl VfioDevice { } // Safe as we are the owner of self and irq_set which are valid value - let ret = unsafe { ioctl_with_ref(self, VFIO_DEVICE_SET_IRQS(), &irq_set[0]) }; + let ret = unsafe { ioctl_with_ref(&self.dev, VFIO_DEVICE_SET_IRQS(), &irq_set[0]) }; if ret < 0 { Err(VfioError::VfioIrqEnable(get_error())) } else { @@ -438,17 +430,17 @@ impl VfioDevice { /// This function enable resample irqfd and let vfio kernel could get EOI notification. /// /// descriptor: should be resample IrqFd. - pub fn resample_virq_enable(&self, descriptor: &Event) -> Result<(), VfioError> { + pub fn resample_virq_enable(&self, descriptor: &Event, index: u32) -> Result<(), VfioError> { let mut irq_set = vec_with_array_field::<vfio_irq_set, u32>(1); irq_set[0].argsz = (mem::size_of::<vfio_irq_set>() + mem::size_of::<u32>()) as u32; irq_set[0].flags = VFIO_IRQ_SET_DATA_EVENTFD | VFIO_IRQ_SET_ACTION_UNMASK; - irq_set[0].index = VFIO_PCI_INTX_IRQ_INDEX; + irq_set[0].index = index; irq_set[0].start = 0; irq_set[0].count = 1; { - // irq_set.data could be none, bool or descriptor according to flags, so irq_set.data - // is u8 default, here irq_set.data is descriptor as u32, so 4 default u8 are combined + // irq_set.data could be none, bool or descriptor according to flags, so irq_set.data is + // u8 default, here irq_set.data is descriptor as u32, so 4 default u8 are combined // together as u32. It is safe as enough space is reserved through // vec_with_array_field(u32)<1>. let descriptors = unsafe { irq_set[0].data.as_mut_slice(4) }; @@ -456,7 +448,7 @@ impl VfioDevice { } // Safe as we are the owner of self and irq_set which are valid value - let ret = unsafe { ioctl_with_ref(self, VFIO_DEVICE_SET_IRQS(), &irq_set[0]) }; + let ret = unsafe { ioctl_with_ref(&self.dev, VFIO_DEVICE_SET_IRQS(), &irq_set[0]) }; if ret < 0 { Err(VfioError::VfioIrqEnable(get_error())) } else { @@ -465,20 +457,16 @@ impl VfioDevice { } /// disable vfio device's irq and disconnect Irqfd Event with device - pub fn irq_disable(&self, irq_type: VfioIrqType) -> Result<(), VfioError> { + pub fn irq_disable(&self, index: u32) -> Result<(), VfioError> { let mut irq_set = vec_with_array_field::<vfio_irq_set, u32>(0); irq_set[0].argsz = mem::size_of::<vfio_irq_set>() as u32; irq_set[0].flags = VFIO_IRQ_SET_DATA_NONE | VFIO_IRQ_SET_ACTION_TRIGGER; - match irq_type { - VfioIrqType::Intx => irq_set[0].index = VFIO_PCI_INTX_IRQ_INDEX, - VfioIrqType::Msi => irq_set[0].index = VFIO_PCI_MSI_IRQ_INDEX, - VfioIrqType::Msix => irq_set[0].index = VFIO_PCI_MSIX_IRQ_INDEX, - } + irq_set[0].index = index; irq_set[0].start = 0; irq_set[0].count = 0; // Safe as we are the owner of self and irq_set which are valid value - let ret = unsafe { ioctl_with_ref(self, VFIO_DEVICE_SET_IRQS(), &irq_set[0]) }; + let ret = unsafe { ioctl_with_ref(&self.dev, VFIO_DEVICE_SET_IRQS(), &irq_set[0]) }; if ret < 0 { Err(VfioError::VfioIrqDisable(get_error())) } else { @@ -487,20 +475,16 @@ impl VfioDevice { } /// Unmask vfio device irq - pub fn irq_unmask(&self, irq_type: VfioIrqType) -> Result<(), VfioError> { + pub fn irq_unmask(&self, index: u32) -> Result<(), VfioError> { let mut irq_set = vec_with_array_field::<vfio_irq_set, u32>(0); irq_set[0].argsz = mem::size_of::<vfio_irq_set>() as u32; irq_set[0].flags = VFIO_IRQ_SET_DATA_NONE | VFIO_IRQ_SET_ACTION_UNMASK; - match irq_type { - VfioIrqType::Intx => irq_set[0].index = VFIO_PCI_INTX_IRQ_INDEX, - VfioIrqType::Msi => irq_set[0].index = VFIO_PCI_MSI_IRQ_INDEX, - VfioIrqType::Msix => irq_set[0].index = VFIO_PCI_MSIX_IRQ_INDEX, - } + irq_set[0].index = index; irq_set[0].start = 0; irq_set[0].count = 1; // Safe as we are the owner of self and irq_set which are valid value - let ret = unsafe { ioctl_with_ref(self, VFIO_DEVICE_SET_IRQS(), &irq_set[0]) }; + let ret = unsafe { ioctl_with_ref(&self.dev, VFIO_DEVICE_SET_IRQS(), &irq_set[0]) }; if ret < 0 { Err(VfioError::VfioIrqUnmask(get_error())) } else { @@ -509,20 +493,16 @@ impl VfioDevice { } /// Mask vfio device irq - pub fn irq_mask(&self, irq_type: VfioIrqType) -> Result<(), VfioError> { + pub fn irq_mask(&self, index: u32) -> Result<(), VfioError> { let mut irq_set = vec_with_array_field::<vfio_irq_set, u32>(0); irq_set[0].argsz = mem::size_of::<vfio_irq_set>() as u32; irq_set[0].flags = VFIO_IRQ_SET_DATA_NONE | VFIO_IRQ_SET_ACTION_MASK; - match irq_type { - VfioIrqType::Intx => irq_set[0].index = VFIO_PCI_INTX_IRQ_INDEX, - VfioIrqType::Msi => irq_set[0].index = VFIO_PCI_MSI_IRQ_INDEX, - VfioIrqType::Msix => irq_set[0].index = VFIO_PCI_MSIX_IRQ_INDEX, - } + irq_set[0].index = index; irq_set[0].start = 0; irq_set[0].count = 1; // Safe as we are the owner of self and irq_set which are valid value - let ret = unsafe { ioctl_with_ref(self, VFIO_DEVICE_SET_IRQS(), &irq_set[0]) }; + let ret = unsafe { ioctl_with_ref(&self.dev, VFIO_DEVICE_SET_IRQS(), &irq_set[0]) }; if ret < 0 { Err(VfioError::VfioIrqMask(get_error())) } else { @@ -802,7 +782,7 @@ impl VfioDevice { /// get vfio device's descriptors which are passed into minijail process pub fn keep_rds(&self) -> Vec<RawDescriptor> { let mut rds = Vec::new(); - rds.push(self.as_raw_descriptor()); + rds.push(self.dev.as_raw_descriptor()); rds.push(self.group_descriptor); rds.push(self.container.lock().as_raw_descriptor()); rds @@ -822,10 +802,9 @@ impl VfioDevice { pub fn vfio_dma_unmap(&self, iova: u64, size: u64) -> Result<(), VfioError> { self.container.lock().vfio_dma_unmap(iova, size) } -} -impl AsRawDescriptor for VfioDevice { - fn as_raw_descriptor(&self) -> RawDescriptor { - self.dev.as_raw_descriptor() + /// Gets the vfio device backing `File`. + pub fn device_file(&self) -> &File { + &self.dev } } diff --git a/devices/src/virtio/balloon.rs b/devices/src/virtio/balloon.rs index 05962e92f..a76765d47 100644 --- a/devices/src/virtio/balloon.rs +++ b/devices/src/virtio/balloon.rs @@ -12,18 +12,15 @@ use futures::{channel::mpsc, pin_mut, StreamExt}; use remain::sorted; use thiserror::Error as ThisError; -use base::{self, error, info, warn, AsRawDescriptor, Event, RawDescriptor}; +use base::{self, error, info, warn, AsRawDescriptor, AsyncTube, Event, RawDescriptor, Tube}; use cros_async::{select6, EventAsync, Executor}; use data_model::{DataInit, Le16, Le32, Le64}; -use msg_socket::MsgSender; -use vm_control::{ - BalloonControlCommand, BalloonControlResponseSocket, BalloonControlResult, BalloonStats, -}; +use vm_control::{BalloonControlCommand, BalloonControlResult, BalloonStats}; use vm_memory::{GuestAddress, GuestMemory}; use super::{ - copy_config, descriptor_utils, DescriptorChain, Interrupt, Queue, Reader, VirtioDevice, - TYPE_BALLOON, + copy_config, descriptor_utils, DescriptorChain, Interrupt, Queue, Reader, SignalableInterrupt, + VirtioDevice, TYPE_BALLOON, }; #[sorted] @@ -31,10 +28,10 @@ use super::{ pub enum BalloonError { /// Failed to create async message receiver. #[error("failed to create async message receiver: {0}")] - CreatingMessageReceiver(msg_socket::MsgError), + CreatingMessageReceiver(base::TubeError), /// Failed to receive command message. #[error("failed to receive command message: {0}")] - ReceivingCommand(msg_socket::MsgError), + ReceivingCommand(base::TubeError), /// Failed to write config event. #[error("failed to write config event: {0}")] WritingConfigEvent(base::Error), @@ -194,7 +191,7 @@ async fn handle_stats_queue( mut queue: Queue, mut queue_event: EventAsync, mut stats_rx: mpsc::Receiver<()>, - command_socket: &BalloonControlResponseSocket, + command_tube: &Tube, config: Arc<BalloonConfig>, interrupt: Rc<RefCell<Interrupt>>, ) { @@ -229,13 +226,13 @@ async fn handle_stats_queue( balloon_actual: actual_pages << VIRTIO_BALLOON_PFN_SHIFT, stats, }; - if let Err(e) = command_socket.send(&result) { + if let Err(e) = command_tube.send(&result) { error!("failed to send stats result: {}", e); } // Wait for a request to read the stats again. if stats_rx.next().await.is_none() { - error!("stats signal channel was closed"); + error!("stats signal tube was closed"); break; } @@ -247,18 +244,14 @@ async fn handle_stats_queue( // Async task that handles the command socket. The command socket handles messages from the host // requesting that the guest balloon be adjusted or to report guest memory statistics. -async fn handle_command_socket( - ex: &Executor, - command_socket: &BalloonControlResponseSocket, +async fn handle_command_tube( + command_tube: &AsyncTube, interrupt: Rc<RefCell<Interrupt>>, config: Arc<BalloonConfig>, mut stats_tx: mpsc::Sender<()>, ) -> Result<()> { - let mut async_messages = command_socket - .async_receiver(ex) - .map_err(BalloonError::CreatingMessageReceiver)?; loop { - match async_messages.next().await { + match command_tube.next().await { Ok(command) => match command { BalloonControlCommand::Adjust { num_bytes } => { let num_pages = (num_bytes >> VIRTIO_BALLOON_PFN_SHIFT) as usize; @@ -283,14 +276,20 @@ async fn handle_command_socket( // Async task that resamples the status of the interrupt when the guest sends a request by // signalling the resample event associated with the interrupt. async fn handle_irq_resample(ex: &Executor, interrupt: Rc<RefCell<Interrupt>>) { - let resample_evt = interrupt - .borrow_mut() - .get_resample_evt() - .try_clone() - .unwrap(); - let resample_evt = EventAsync::new(resample_evt.0, ex).unwrap(); - while resample_evt.next_val().await.is_ok() { - interrupt.borrow_mut().do_interrupt_resample(); + let resample_evt = if let Some(resample_evt) = interrupt.borrow_mut().get_resample_evt() { + let resample_evt = resample_evt.try_clone().unwrap(); + let resample_evt = EventAsync::new(resample_evt.0, ex).unwrap(); + Some(resample_evt) + } else { + None + }; + if let Some(resample_evt) = resample_evt { + while resample_evt.next_val().await.is_ok() { + interrupt.borrow_mut().do_interrupt_resample(); + } + } else { + // no resample event, park the future. + let () = futures::future::pending().await; } } @@ -306,95 +305,98 @@ async fn wait_kill(kill_evt: EventAsync) { fn run_worker( mut queue_evts: Vec<Event>, mut queues: Vec<Queue>, - command_socket: &BalloonControlResponseSocket, + command_tube: Tube, interrupt: Interrupt, kill_evt: Event, mem: GuestMemory, config: Arc<BalloonConfig>, -) { +) -> Tube { // Wrap the interrupt in a `RefCell` so it can be shared between async functions. let interrupt = Rc::new(RefCell::new(interrupt)); let ex = Executor::new().unwrap(); + let command_tube = command_tube.into_async_tube(&ex).unwrap(); + + // We need a block to release all references to command_tube at the end before returning it. + { + // The first queue is used for inflate messages + let inflate_event = EventAsync::new(queue_evts.remove(0).0, &ex) + .expect("failed to set up the inflate event"); + let inflate = handle_queue( + &mem, + queues.remove(0), + inflate_event, + interrupt.clone(), + |guest_address, len| { + if let Err(e) = mem.remove_range(guest_address, len) { + warn!("Marking pages unused failed: {}, addr={}", e, guest_address); + } + }, + ); + pin_mut!(inflate); + + // The second queue is used for deflate messages + let deflate_event = EventAsync::new(queue_evts.remove(0).0, &ex) + .expect("failed to set up the deflate event"); + let deflate = handle_queue( + &mem, + queues.remove(0), + deflate_event, + interrupt.clone(), + |_, _| {}, // Ignore these. + ); + pin_mut!(deflate); + + // The third queue is used for stats messages + let (stats_tx, stats_rx) = mpsc::channel::<()>(1); + let stats_event = + EventAsync::new(queue_evts.remove(0).0, &ex).expect("failed to set up the stats event"); + let stats = handle_stats_queue( + &mem, + queues.remove(0), + stats_event, + stats_rx, + &command_tube, + config.clone(), + interrupt.clone(), + ); + pin_mut!(stats); - // The first queue is used for inflate messages - let inflate_event = - EventAsync::new(queue_evts.remove(0).0, &ex).expect("failed to set up the inflate event"); - let inflate = handle_queue( - &mem, - queues.remove(0), - inflate_event, - interrupt.clone(), - |guest_address, len| { - if let Err(e) = mem.remove_range(guest_address, len) { - warn!("Marking pages unused failed: {}, addr={}", e, guest_address); - } - }, - ); - pin_mut!(inflate); - - // The second queue is used for deflate messages - let deflate_event = - EventAsync::new(queue_evts.remove(0).0, &ex).expect("failed to set up the deflate event"); - let deflate = handle_queue( - &mem, - queues.remove(0), - deflate_event, - interrupt.clone(), - |_, _| {}, // Ignore these. - ); - pin_mut!(deflate); - - // The third queue is used for stats messages - let (stats_tx, stats_rx) = mpsc::channel::<()>(1); - let stats_event = - EventAsync::new(queue_evts.remove(0).0, &ex).expect("failed to set up the stats event"); - let stats = handle_stats_queue( - &mem, - queues.remove(0), - stats_event, - stats_rx, - command_socket, - config.clone(), - interrupt.clone(), - ); - pin_mut!(stats); - - // Future to handle command messages that resize the balloon. - let command = handle_command_socket(&ex, command_socket, interrupt.clone(), config, stats_tx); - pin_mut!(command); - - // Process any requests to resample the irq value. - let resample = handle_irq_resample(&ex, interrupt.clone()); - pin_mut!(resample); - - // Exit if the kill event is triggered. - let kill_evt = EventAsync::new(kill_evt.0, &ex).expect("failed to set up the kill event"); - let kill = wait_kill(kill_evt); - pin_mut!(kill); - - if let Err(e) = ex.run_until(select6(inflate, deflate, stats, command, resample, kill)) { - error!("error happened in executor: {}", e); + // Future to handle command messages that resize the balloon. + let command = handle_command_tube(&command_tube, interrupt.clone(), config, stats_tx); + pin_mut!(command); + + // Process any requests to resample the irq value. + let resample = handle_irq_resample(&ex, interrupt.clone()); + pin_mut!(resample); + + // Exit if the kill event is triggered. + let kill_evt = EventAsync::new(kill_evt.0, &ex).expect("failed to set up the kill event"); + let kill = wait_kill(kill_evt); + pin_mut!(kill); + + if let Err(e) = ex.run_until(select6(inflate, deflate, stats, command, resample, kill)) { + error!("error happened in executor: {}", e); + } } + + command_tube.into() } /// Virtio device for memory balloon inflation/deflation. pub struct Balloon { - command_socket: Option<BalloonControlResponseSocket>, + command_tube: Option<Tube>, config: Arc<BalloonConfig>, features: u64, kill_evt: Option<Event>, - worker_thread: Option<thread::JoinHandle<BalloonControlResponseSocket>>, + worker_thread: Option<thread::JoinHandle<Tube>>, } impl Balloon { /// Creates a new virtio balloon device. - pub fn new( - base_features: u64, - command_socket: BalloonControlResponseSocket, - ) -> Result<Balloon> { + pub fn new(base_features: u64, command_tube: Tube) -> Result<Balloon> { Ok(Balloon { - command_socket: Some(command_socket), + command_tube: Some(command_tube), config: Arc::new(BalloonConfig { num_pages: AtomicUsize::new(0), actual_pages: AtomicUsize::new(0), @@ -433,7 +435,7 @@ impl Drop for Balloon { impl VirtioDevice for Balloon { fn keep_rds(&self) -> Vec<RawDescriptor> { - vec![self.command_socket.as_ref().unwrap().as_raw_descriptor()] + vec![self.command_tube.as_ref().unwrap().as_raw_descriptor()] } fn device_type(&self) -> u32 { @@ -485,20 +487,19 @@ impl VirtioDevice for Balloon { self.kill_evt = Some(self_kill_evt); let config = self.config.clone(); - let command_socket = self.command_socket.take().unwrap(); + let command_tube = self.command_tube.take().unwrap(); let worker_result = thread::Builder::new() .name("virtio_balloon".to_string()) .spawn(move || { run_worker( queue_evts, queues, - &command_socket, + command_tube, interrupt, kill_evt, mem, config, - ); - command_socket // Return the command socket so it can be re-used. + ) }); match worker_result { @@ -525,8 +526,8 @@ impl VirtioDevice for Balloon { error!("{}: failed to get back resources", self.debug_label()); return false; } - Ok(command_socket) => { - self.command_socket = Some(command_socket); + Ok(command_tube) => { + self.command_tube = Some(command_tube); return true; } } diff --git a/devices/src/virtio/block.rs b/devices/src/virtio/block.rs index eba3b0126..e9a202427 100644 --- a/devices/src/virtio/block.rs +++ b/devices/src/virtio/block.rs @@ -15,19 +15,19 @@ use std::u32; use base::Error as SysError; use base::Result as SysResult; use base::{ - error, info, iov_max, warn, AsRawDescriptor, Event, PollToken, RawDescriptor, Timer, + error, info, iov_max, warn, AsRawDescriptor, Event, PollToken, RawDescriptor, Timer, Tube, WaitContext, }; use data_model::{DataInit, Le16, Le32, Le64}; use disk::DiskFile; -use msg_socket::{MsgReceiver, MsgSender}; + use sync::Mutex; -use vm_control::{DiskControlCommand, DiskControlResponseSocket, DiskControlResult}; +use vm_control::{DiskControlCommand, DiskControlResult}; use vm_memory::GuestMemory; use super::{ - copy_config, DescriptorChain, DescriptorError, Interrupt, Queue, Reader, VirtioDevice, Writer, - TYPE_BLOCK, + copy_config, DescriptorChain, DescriptorError, Interrupt, Queue, Reader, SignalableInterrupt, + VirtioDevice, Writer, TYPE_BLOCK, }; const QUEUE_SIZE: u16 = 256; @@ -257,7 +257,7 @@ struct Worker { read_only: bool, sparse: bool, id: Option<BlockId>, - control_socket: Option<DiskControlResponseSocket>, + control_tube: Option<Tube>, } impl Worker { @@ -397,12 +397,17 @@ impl Worker { let wait_ctx: WaitContext<Token> = match WaitContext::build_with(&[ (&flush_timer, Token::FlushTimer), (&queue_evt, Token::QueueAvailable), - (self.interrupt.get_resample_evt(), Token::InterruptResample), (&kill_evt, Token::Kill), ]) + .and_then(|wc| { + if let Some(resample_evt) = self.interrupt.get_resample_evt() { + wc.add(resample_evt, Token::InterruptResample)?; + } + Ok(wc) + }) .and_then(|pc| { - if let Some(control_socket) = self.control_socket.as_ref() { - pc.add(control_socket, Token::ControlRequest)? + if let Some(control_tube) = self.control_tube.as_ref() { + pc.add(control_tube, Token::ControlRequest)? } Ok(pc) }) { @@ -443,14 +448,14 @@ impl Worker { self.process_queue(0, &mut flush_timer, &mut flush_timer_armed); } Token::ControlRequest => { - let control_socket = match self.control_socket.as_ref() { + let control_tube = match self.control_tube.as_ref() { Some(cs) => cs, None => { error!("received control socket request with no control socket"); break 'wait; } }; - let req = match control_socket.recv() { + let req = match control_tube.recv() { Ok(req) => req, Err(e) => { error!("control socket failed recv: {}", e); @@ -468,8 +473,8 @@ impl Worker { } }; - // We already know there is Some control_socket used to recv a request. - if let Err(e) = self.control_socket.as_ref().unwrap().send(&resp) { + // We already know there is Some control_tube used to recv a request. + if let Err(e) = self.control_tube.as_ref().unwrap().send(&resp) { error!("control socket failed send: {}", e); break 'wait; } @@ -499,7 +504,7 @@ pub struct Block { seg_max: u32, block_size: u32, id: Option<BlockId>, - control_socket: Option<DiskControlResponseSocket>, + control_tube: Option<Tube>, } fn build_config_space(disk_size: u64, seg_max: u32, block_size: u32) -> virtio_blk_config { @@ -527,7 +532,7 @@ impl Block { sparse: bool, block_size: u32, id: Option<BlockId>, - control_socket: Option<DiskControlResponseSocket>, + control_tube: Option<Tube>, ) -> SysResult<Block> { if block_size % SECTOR_SIZE as u32 != 0 { error!( @@ -576,7 +581,7 @@ impl Block { seg_max, block_size, id, - control_socket, + control_tube, }) } @@ -751,8 +756,8 @@ impl VirtioDevice for Block { keep_rds.extend(disk_image.as_raw_descriptors()); } - if let Some(control_socket) = &self.control_socket { - keep_rds.push(control_socket.as_raw_descriptor()); + if let Some(control_tube) = &self.control_tube { + keep_rds.push(control_tube.as_raw_descriptor()); } keep_rds @@ -803,7 +808,7 @@ impl VirtioDevice for Block { let disk_size = self.disk_size.clone(); let id = self.id.take(); if let Some(disk_image) = self.disk_image.take() { - let control_socket = self.control_socket.take(); + let control_tube = self.control_tube.take(); let worker_result = thread::Builder::new() .name("virtio_blk".to_string()) @@ -817,7 +822,7 @@ impl VirtioDevice for Block { read_only, sparse, id, - control_socket, + control_tube, }; worker.run(queue_evts.remove(0), kill_evt); worker @@ -851,7 +856,7 @@ impl VirtioDevice for Block { } Ok(worker) => { self.disk_image = Some(worker.disk_image); - self.control_socket = worker.control_socket; + self.control_tube = worker.control_tube; return true; } } diff --git a/devices/src/virtio/block_async.rs b/devices/src/virtio/block_async.rs index 0ea2e0265..dc04227c7 100644 --- a/devices/src/virtio/block_async.rs +++ b/devices/src/virtio/block_async.rs @@ -15,23 +15,26 @@ use std::u32; use futures::pin_mut; use futures::stream::{FuturesUnordered, StreamExt}; -use libchromeos::sync::Mutex as AsyncMutex; use remain::sorted; use thiserror::Error as ThisError; use base::Error as SysError; use base::Result as SysResult; -use base::{error, info, iov_max, warn, AsRawDescriptor, Event, RawDescriptor, Timer}; -use cros_async::{select5, AsyncError, EventAsync, Executor, SelectResult, TimerAsync}; +use base::{ + error, info, iov_max, warn, AsRawDescriptor, AsyncTube, Event, RawDescriptor, Timer, Tube, + TubeError, +}; +use cros_async::{ + select5, sync::Mutex as AsyncMutex, AsyncError, EventAsync, Executor, SelectResult, TimerAsync, +}; use data_model::{DataInit, Le16, Le32, Le64}; use disk::{AsyncDisk, ToAsyncDisk}; -use msg_socket::{MsgError, MsgSender}; -use vm_control::{DiskControlCommand, DiskControlResponseSocket, DiskControlResult}; +use vm_control::{DiskControlCommand, DiskControlResult}; use vm_memory::GuestMemory; use super::{ - copy_config, DescriptorChain, DescriptorError, Interrupt, Queue, Reader, VirtioDevice, Writer, - TYPE_BLOCK, + copy_config, DescriptorChain, DescriptorError, Interrupt, Queue, Reader, SignalableInterrupt, + VirtioDevice, Writer, TYPE_BLOCK, }; const QUEUE_SIZE: u16 = 256; @@ -48,9 +51,17 @@ const MAX_WRITE_ZEROES_SEG: u32 = 32; // but this should probably be based on cluster size for qcow. const DISCARD_SECTOR_ALIGNMENT: u32 = 128; +const ID_LEN: usize = 20; + +/// Virtio block device identifier. +/// This is an ASCII string terminated by a \0, unless all 20 bytes are used, +/// in which case the \0 terminator is omitted. +pub type BlockId = [u8; ID_LEN]; + const VIRTIO_BLK_T_IN: u32 = 0; const VIRTIO_BLK_T_OUT: u32 = 1; const VIRTIO_BLK_T_FLUSH: u32 = 4; +const VIRTIO_BLK_T_GET_ID: u32 = 8; const VIRTIO_BLK_T_DISCARD: u32 = 11; const VIRTIO_BLK_T_WRITE_ZEROES: u32 = 13; @@ -91,7 +102,7 @@ unsafe impl DataInit for virtio_blk_topology {} #[derive(Copy, Clone, Debug, Default)] #[repr(C, packed)] -struct virtio_blk_config { +pub(crate) struct virtio_blk_config { capacity: Le64, size_max: Le32, seg_max: Le32, @@ -100,7 +111,7 @@ struct virtio_blk_config { topology: virtio_blk_topology, writeback: u8, unused0: u8, - num_queues: Le16, + pub num_queues: Le16, max_discard_sectors: Le32, max_discard_seg: Le32, discard_sector_alignment: Le32, @@ -140,8 +151,8 @@ unsafe impl DataInit for virtio_blk_discard_write_zeroes {} #[sorted] #[derive(ThisError, Debug)] enum ExecuteError { - #[error("couldn't create a message receiver: {0}")] - CreatingMessageReceiver(MsgError), + #[error("failed to copy ID string: {0}")] + CopyId(io::Error), #[error("virtio descriptor error: {0}")] Descriptor(DescriptorError), #[error("failed to perform discard or write zeroes; sector={sector} num_sectors={num_sectors} flags={flags}; {ioerr:?}")] @@ -167,10 +178,10 @@ enum ExecuteError { }, #[error("read only; request_type={request_type}")] ReadOnly { request_type: u32 }, - #[error("failed to read command message: {0}")] - ReceivingCommand(MsgError), + #[error("failed to recieve command message: {0}")] + ReceivingCommand(TubeError), #[error("failed to send command response: {0}")] - SendingResponse(MsgError), + SendingResponse(TubeError), #[error("couldn't reset the timer: {0}")] TimerReset(base::Error), #[error("unsupported ({0})")] @@ -188,7 +199,7 @@ enum ExecuteError { impl ExecuteError { fn status(&self) -> u8 { match self { - ExecuteError::CreatingMessageReceiver(_) => VIRTIO_BLK_S_IOERR, + ExecuteError::CopyId(_) => VIRTIO_BLK_S_IOERR, ExecuteError::Descriptor(_) => VIRTIO_BLK_S_IOERR, ExecuteError::DiscardWriteZeroes { .. } => VIRTIO_BLK_S_IOERR, ExecuteError::Flush(_) => VIRTIO_BLK_S_IOERR, @@ -227,6 +238,7 @@ struct DiskState { disk_size: Arc<AtomicU64>, read_only: bool, sparse: bool, + id: Option<BlockId>, } async fn process_one_request( @@ -294,7 +306,7 @@ async fn process_one_request_task( let mut queue = queue.borrow_mut(); queue.add_used(&mem, descriptor_index, len as u32); - queue.trigger_interrupt(&mem, &interrupt.borrow()); + queue.trigger_interrupt(&mem, &*interrupt.borrow()); queue.update_int_required(&mem); } @@ -335,19 +347,28 @@ async fn handle_irq_resample( ex: &Executor, interrupt: Rc<RefCell<Interrupt>>, ) -> result::Result<(), OtherError> { - let resample_evt = interrupt - .borrow_mut() - .get_resample_evt() - .try_clone() - .map_err(OtherError::CloneResampleEvent)?; - let resample_evt = - EventAsync::new(resample_evt.0, ex).map_err(OtherError::AsyncResampleCreate)?; - loop { - let _ = resample_evt - .next_val() - .await - .map_err(OtherError::ReadResampleEvent)?; - interrupt.borrow_mut().do_interrupt_resample(); + let resample_evt = if let Some(resample_evt) = interrupt.borrow().get_resample_evt() { + let resample_evt = resample_evt + .try_clone() + .map_err(OtherError::CloneResampleEvent)?; + let resample_evt = + EventAsync::new(resample_evt.0, ex).map_err(OtherError::AsyncResampleCreate)?; + Some(resample_evt) + } else { + None + }; + if let Some(resample_evt) = resample_evt { + loop { + let _ = resample_evt + .next_val() + .await + .map_err(OtherError::ReadResampleEvent)?; + interrupt.borrow().do_interrupt_resample(); + } + } else { + // no resample event, park the future. + let () = futures::future::pending().await; + Ok(()) } } @@ -357,24 +378,20 @@ async fn wait_kill(kill_evt: EventAsync) { let _ = kill_evt.next_val().await; } -async fn handle_command_socket( - ex: &Executor, - command_socket: &Option<DiskControlResponseSocket>, +async fn handle_command_tube( + command_tube: &Option<AsyncTube>, interrupt: Rc<RefCell<Interrupt>>, disk_state: Rc<AsyncMutex<DiskState>>, ) -> Result<(), ExecuteError> { - let command_socket = match command_socket { + let command_tube = match command_tube { Some(c) => c, None => { let () = futures::future::pending().await; return Ok(()); } }; - let mut async_messages = command_socket - .async_receiver(ex) - .map_err(ExecuteError::CreatingMessageReceiver)?; loop { - match async_messages.next().await { + match command_tube.next().await { Ok(command) => { let resp = match command { DiskControlCommand::Resize { new_size } => { @@ -382,7 +399,7 @@ async fn handle_command_socket( } }; - command_socket + command_tube .send(&resp) .map_err(ExecuteError::SendingResponse)?; if let DiskControlResult::Ok = resp { @@ -462,7 +479,7 @@ fn run_worker( queues: Vec<Queue>, mem: GuestMemory, disk_state: &Rc<AsyncMutex<DiskState>>, - control_socket: &Option<DiskControlResponseSocket>, + control_tube: &Option<AsyncTube>, queue_evts: Vec<Event>, kill_evt: Event, ) -> Result<(), String> { @@ -514,7 +531,7 @@ fn run_worker( pin_mut!(disk_flush); // Handles control requests. - let control = handle_command_socket(&ex, control_socket, interrupt.clone(), disk_state.clone()); + let control = handle_command_tube(control_tube, interrupt.clone(), disk_state.clone()); pin_mut!(control); // Process any requests to resample the irq value. @@ -546,8 +563,7 @@ fn run_worker( /// Virtio device for exposing block level read/write operations on a host file. pub struct BlockAsync { kill_evt: Option<Event>, - worker_thread: - Option<thread::JoinHandle<(Box<dyn ToAsyncDisk>, Option<DiskControlResponseSocket>)>>, + worker_thread: Option<thread::JoinHandle<(Box<dyn ToAsyncDisk>, Option<Tube>)>>, disk_image: Option<Box<dyn ToAsyncDisk>>, disk_size: Arc<AtomicU64>, avail_features: u64, @@ -555,7 +571,8 @@ pub struct BlockAsync { sparse: bool, seg_max: u32, block_size: u32, - control_socket: Option<DiskControlResponseSocket>, + id: Option<BlockId>, + control_tube: Option<Tube>, } fn build_config_space(disk_size: u64, seg_max: u32, block_size: u32) -> virtio_blk_config { @@ -583,7 +600,8 @@ impl BlockAsync { read_only: bool, sparse: bool, block_size: u32, - control_socket: Option<DiskControlResponseSocket>, + id: Option<BlockId>, + control_tube: Option<Tube>, ) -> SysResult<BlockAsync> { if block_size % SECTOR_SIZE as u32 != 0 { error!( @@ -632,7 +650,8 @@ impl BlockAsync { sparse, seg_max, block_size, - control_socket, + id, + control_tube, }) } @@ -655,7 +674,7 @@ impl BlockAsync { let req_type = req_header.req_type.to_native(); let sector = req_header.sector.to_native(); - if disk_state.read_only && req_type != VIRTIO_BLK_T_IN { + if disk_state.read_only && req_type != VIRTIO_BLK_T_IN && req_type != VIRTIO_BLK_T_GET_ID { return Err(ExecuteError::ReadOnly { request_type: req_type, }); @@ -790,6 +809,13 @@ impl BlockAsync { .await .map_err(ExecuteError::Flush)?; } + VIRTIO_BLK_T_GET_ID => { + if let Some(id) = disk_state.id { + writer.write_all(&id).map_err(ExecuteError::CopyId)?; + } else { + return Err(ExecuteError::Unsupported(req_type)); + } + } t => return Err(ExecuteError::Unsupported(t)), }; Ok(()) @@ -817,8 +843,8 @@ impl VirtioDevice for BlockAsync { keep_rds.extend(disk_image.as_raw_descriptors()); } - if let Some(control_socket) = &self.control_socket { - keep_rds.push(control_socket.as_raw_descriptor()); + if let Some(control_tube) = &self.control_tube { + keep_rds.push(control_tube.as_raw_descriptor()); } keep_rds @@ -863,13 +889,16 @@ impl VirtioDevice for BlockAsync { let read_only = self.read_only; let sparse = self.sparse; let disk_size = self.disk_size.clone(); + let id = self.id.take(); if let Some(disk_image) = self.disk_image.take() { - let control_socket = self.control_socket.take(); + let control_tube = self.control_tube.take(); let worker_result = thread::Builder::new() .name("virtio_blk".to_string()) .spawn(move || { let ex = Executor::new().expect("Failed to create an executor"); + let async_control = control_tube + .map(|c| c.into_async_tube(&ex).expect("failed to create async tube")); let async_image = match disk_image.to_async_disk(&ex) { Ok(d) => d, Err(e) => panic!("Failed to create async disk {}", e), @@ -879,6 +908,7 @@ impl VirtioDevice for BlockAsync { disk_size, read_only, sparse, + id, })); if let Err(err_string) = run_worker( ex, @@ -886,7 +916,7 @@ impl VirtioDevice for BlockAsync { queues, mem, &disk_state, - &control_socket, + &async_control, queue_evts, kill_evt, ) { @@ -897,7 +927,10 @@ impl VirtioDevice for BlockAsync { Ok(d) => d.into_inner(), Err(_) => panic!("too many refs to the disk"), }; - (disk_state.disk_image.into_inner(), control_socket) + ( + disk_state.disk_image.into_inner(), + async_control.map(|c| c.into()), + ) }); match worker_result { @@ -926,9 +959,9 @@ impl VirtioDevice for BlockAsync { error!("{}: failed to get back resources", self.debug_label()); return false; } - Ok((disk_image, control_socket)) => { + Ok((disk_image, control_tube)) => { self.disk_image = Some(disk_image); - self.control_socket = control_socket; + self.control_tube = control_tube; return true; } } @@ -962,7 +995,7 @@ mod tests { f.set_len(0x1000).unwrap(); let features = base_features(ProtectionType::Unprotected); - let b = BlockAsync::new(features, Box::new(f), true, false, 512, None).unwrap(); + let b = BlockAsync::new(features, Box::new(f), true, false, 512, None, None).unwrap(); let mut num_sectors = [0u8; 4]; b.read_config(0, &mut num_sectors); // size is 0x1000, so num_sectors is 8 (4096/512). @@ -982,7 +1015,7 @@ mod tests { f.set_len(0x1000).unwrap(); let features = base_features(ProtectionType::Unprotected); - let b = BlockAsync::new(features, Box::new(f), true, false, 4096, None).unwrap(); + let b = BlockAsync::new(features, Box::new(f), true, false, 4096, None, None).unwrap(); let mut blk_size = [0u8; 4]; b.read_config(20, &mut blk_size); // blk_size should be 4096 (0x1000). @@ -999,7 +1032,7 @@ mod tests { { let f = File::create(&path).unwrap(); let features = base_features(ProtectionType::Unprotected); - let b = BlockAsync::new(features, Box::new(f), false, true, 512, None).unwrap(); + let b = BlockAsync::new(features, Box::new(f), false, true, 512, None, None).unwrap(); // writable device should set VIRTIO_BLK_F_FLUSH + VIRTIO_BLK_F_DISCARD // + VIRTIO_BLK_F_WRITE_ZEROES + VIRTIO_F_VERSION_1 + VIRTIO_BLK_F_BLK_SIZE // + VIRTIO_BLK_F_SEG_MAX + VIRTIO_BLK_F_MQ @@ -1010,7 +1043,7 @@ mod tests { { let f = File::create(&path).unwrap(); let features = base_features(ProtectionType::Unprotected); - let b = BlockAsync::new(features, Box::new(f), false, false, 512, None).unwrap(); + let b = BlockAsync::new(features, Box::new(f), false, false, 512, None, None).unwrap(); // read-only device should set VIRTIO_BLK_F_FLUSH and VIRTIO_BLK_F_RO // + VIRTIO_F_VERSION_1 + VIRTIO_BLK_F_BLK_SIZE + VIRTIO_BLK_F_SEG_MAX // + VIRTIO_BLK_F_MQ @@ -1021,7 +1054,7 @@ mod tests { { let f = File::create(&path).unwrap(); let features = base_features(ProtectionType::Unprotected); - let b = BlockAsync::new(features, Box::new(f), true, true, 512, None).unwrap(); + let b = BlockAsync::new(features, Box::new(f), true, true, 512, None, None).unwrap(); // read-only device should set VIRTIO_BLK_F_FLUSH and VIRTIO_BLK_F_RO // + VIRTIO_F_VERSION_1 + VIRTIO_BLK_F_BLK_SIZE + VIRTIO_BLK_F_SEG_MAX // + VIRTIO_BLK_F_MQ @@ -1086,6 +1119,7 @@ mod tests { disk_size: Arc::new(AtomicU64::new(disk_size)), read_only: false, sparse: true, + id: None, })); let fut = process_one_request(avail_desc, disk_state, flush_timer, flush_timer_armed, &mem); @@ -1154,6 +1188,7 @@ mod tests { disk_size: Arc::new(AtomicU64::new(disk_size)), read_only: false, sparse: true, + id: None, })); let fut = process_one_request(avail_desc, disk_state, flush_timer, flush_timer_armed, &mem); @@ -1166,4 +1201,79 @@ mod tests { let status = mem.read_obj_from_addr::<u8>(status_offset).unwrap(); assert_eq!(status, VIRTIO_BLK_S_IOERR); } + + #[test] + fn get_id() { + let ex = Executor::new().expect("creating an executor failed"); + + let tempdir = TempDir::new().unwrap(); + let mut path = tempdir.path().to_owned(); + path.push("disk_image"); + let f = OpenOptions::new() + .read(true) + .write(true) + .create(true) + .open(&path) + .unwrap(); + let disk_size = 0x1000; + f.set_len(disk_size).unwrap(); + + let mem = GuestMemory::new(&[(GuestAddress(0u64), 4 * 1024 * 1024)]) + .expect("Creating guest memory failed."); + + let req_hdr = virtio_blk_req_header { + req_type: Le32::from(VIRTIO_BLK_T_GET_ID), + reserved: Le32::from(0), + sector: Le64::from(0), + }; + mem.write_obj_at_addr(req_hdr, GuestAddress(0x1000)) + .expect("writing req failed"); + + let avail_desc = create_descriptor_chain( + &mem, + GuestAddress(0x100), // Place descriptor chain at 0x100. + GuestAddress(0x1000), // Describe buffer at 0x1000. + vec![ + // Request header + (DescriptorType::Readable, size_of_val(&req_hdr) as u32), + // I/O buffer (20 bytes for serial) + (DescriptorType::Writable, 20), + // Request status + (DescriptorType::Writable, 1), + ], + 0, + ) + .expect("create_descriptor_chain failed"); + + let af = SingleFileDisk::new(f, &ex).expect("Failed to create SFD"); + let timer = Timer::new().expect("Failed to create a timer"); + let flush_timer = Rc::new(RefCell::new( + TimerAsync::new(timer.0, &ex).expect("Failed to create an async timer"), + )); + let flush_timer_armed = Rc::new(RefCell::new(false)); + + let id = b"a20-byteserialnumber"; + + let disk_state = Rc::new(AsyncMutex::new(DiskState { + disk_image: Box::new(af), + disk_size: Arc::new(AtomicU64::new(disk_size)), + read_only: false, + sparse: true, + id: Some(*id), + })); + + let fut = process_one_request(avail_desc, disk_state, flush_timer, flush_timer_armed, &mem); + + ex.run_until(fut) + .expect("running executor failed") + .expect("execute failed"); + + let status_offset = GuestAddress((0x1000 + size_of_val(&req_hdr) + 512) as u64); + let status = mem.read_obj_from_addr::<u8>(status_offset).unwrap(); + assert_eq!(status, VIRTIO_BLK_S_OK); + + let id_offset = GuestAddress(0x1000 + size_of_val(&req_hdr) as u64); + let returned_id = mem.read_obj_from_addr::<[u8; 20]>(id_offset).unwrap(); + assert_eq!(returned_id, *id); + } } diff --git a/devices/src/virtio/console.rs b/devices/src/virtio/console.rs index 5eb7cbcf4..9aa789a5d 100644 --- a/devices/src/virtio/console.rs +++ b/devices/src/virtio/console.rs @@ -11,7 +11,8 @@ use data_model::{DataInit, Le16, Le32}; use vm_memory::GuestMemory; use super::{ - base_features, copy_config, Interrupt, Queue, Reader, VirtioDevice, Writer, TYPE_CONSOLE, + base_features, copy_config, Interrupt, Queue, Reader, SignalableInterrupt, VirtioDevice, + Writer, TYPE_CONSOLE, }; use crate::{ProtectionType, SerialDevice}; @@ -239,7 +240,6 @@ impl Worker { (&transmit_evt, Token::TransmitQueueAvailable), (&receive_evt, Token::ReceiveQueueAvailable), (&in_avail_evt, Token::InputAvailable), - (self.interrupt.get_resample_evt(), Token::InterruptResample), (&kill_evt, Token::Kill), ]) { Ok(pc) => pc, @@ -248,6 +248,15 @@ impl Worker { return; } }; + if let Some(resample_evt) = self.interrupt.get_resample_evt() { + if wait_ctx + .add(resample_evt, Token::InterruptResample) + .is_err() + { + error!("failed adding resample event to WaitContext."); + return; + } + } let mut output: Box<dyn io::Write> = match self.output.take() { Some(o) => o, diff --git a/devices/src/virtio/descriptor_utils.rs b/devices/src/virtio/descriptor_utils.rs index a8be0e779..fc839b610 100644 --- a/devices/src/virtio/descriptor_utils.rs +++ b/devices/src/virtio/descriptor_utils.rs @@ -18,6 +18,7 @@ use base::{FileReadWriteAtVolatile, FileReadWriteVolatile}; use cros_async::MemRegion; use data_model::{DataInit, Le16, Le32, Le64, VolatileMemoryError, VolatileSlice}; use disk::AsyncDisk; +use smallvec::SmallVec; use vm_memory::{GuestAddress, GuestMemory}; use super::DescriptorChain; @@ -56,7 +57,7 @@ impl std::error::Error for Error {} #[derive(Clone)] struct DescriptorChainRegions { - regions: Vec<MemRegion>, + regions: SmallVec<[MemRegion; 16]>, current: usize, bytes_consumed: usize, } @@ -87,7 +88,7 @@ impl DescriptorChainRegions { /// `GuestMemory`. Calling this function does not consume any bytes from the `DescriptorChain`. /// Instead callers should use the `consume` method to advance the `DescriptorChain`. Multiple /// calls to `get` with no intervening calls to `consume` will return the same data. - fn get_remaining<'mem>(&self, mem: &'mem GuestMemory) -> Vec<VolatileSlice<'mem>> { + fn get_remaining<'mem>(&self, mem: &'mem GuestMemory) -> SmallVec<[VolatileSlice<'mem>; 16]> { self.get_remaining_regions() .iter() .filter_map(|region| { @@ -134,7 +135,7 @@ impl DescriptorChainRegions { &self, mem: &'mem GuestMemory, count: usize, - ) -> Vec<VolatileSlice<'mem>> { + ) -> SmallVec<[VolatileSlice<'mem>; 16]> { self.get_remaining_regions_with_count(count) .iter() .filter_map(|region| { @@ -259,7 +260,7 @@ impl Reader { len: desc.len.try_into().expect("u32 doesn't fit in usize"), }) }) - .collect::<Result<Vec<MemRegion>>>()?; + .collect::<Result<SmallVec<[MemRegion; 16]>>>()?; Ok(Reader { mem, regions: DescriptorChainRegions { @@ -442,7 +443,7 @@ impl Reader { /// Returns a `&[VolatileSlice]` that represents all the remaining data in this `Reader`. /// Calling this method does not actually consume any data from the `Reader` and callers should /// call `consume` to advance the `Reader`. - pub fn get_remaining(&self) -> Vec<VolatileSlice> { + pub fn get_remaining(&self) -> SmallVec<[VolatileSlice; 16]> { self.regions.get_remaining(&self.mem) } @@ -527,7 +528,7 @@ impl Writer { len: desc.len.try_into().expect("u32 doesn't fit in usize"), }) }) - .collect::<Result<Vec<MemRegion>>>()?; + .collect::<Result<SmallVec<[MemRegion; 16]>>>()?; Ok(Writer { mem, regions: DescriptorChainRegions { diff --git a/devices/src/virtio/fs/caps.rs b/devices/src/virtio/fs/caps.rs new file mode 100644 index 000000000..f4964c6a6 --- /dev/null +++ b/devices/src/virtio/fs/caps.rs @@ -0,0 +1,163 @@ +// Copyright 2021 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +use std::ffi::c_void; +use std::io; +use std::os::raw::c_int; + +#[allow(non_camel_case_types)] +type cap_t = *mut c_void; + +#[allow(non_camel_case_types)] +pub type cap_value_t = u32; + +#[allow(non_camel_case_types)] +type cap_flag_t = u32; + +#[allow(non_camel_case_types)] +type cap_flag_value_t = i32; + +#[link(name = "cap")] +extern "C" { + fn cap_free(ptr: *mut c_void) -> c_int; + + fn cap_set_flag( + c: cap_t, + f: cap_flag_t, + ncap: c_int, + caps: *const cap_value_t, + val: cap_flag_value_t, + ) -> c_int; + + fn cap_get_proc() -> cap_t; + fn cap_set_proc(cap: cap_t) -> c_int; +} + +#[repr(u32)] +pub enum Capability { + Chown = 0, + DacOverride = 1, + DacReadSearch = 2, + Fowner = 3, + Fsetid = 4, + Kill = 5, + Setgid = 6, + Setuid = 7, + Setpcap = 8, + LinuxImmutable = 9, + NetBindService = 10, + NetBroadcast = 11, + NetAdmin = 12, + NetRaw = 13, + IpcLock = 14, + IpcOwner = 15, + SysModule = 16, + SysRawio = 17, + SysChroot = 18, + SysPtrace = 19, + SysPacct = 20, + SysAdmin = 21, + SysBoot = 22, + SysNice = 23, + SysResource = 24, + SysTime = 25, + SysTtyConfig = 26, + Mknod = 27, + Lease = 28, + AuditWrite = 29, + AuditControl = 30, + Setfcap = 31, + MacOverride = 32, + MacAdmin = 33, + Syslog = 34, + WakeAlarm = 35, + BlockSuspend = 36, + AuditRead = 37, + Last, +} + +impl From<Capability> for cap_value_t { + fn from(c: Capability) -> cap_value_t { + c as cap_value_t + } +} + +#[repr(u32)] +pub enum Set { + Effective = 0, + Permitted = 1, + Inheritable = 2, +} + +impl From<Set> for cap_flag_t { + fn from(s: Set) -> cap_flag_t { + s as cap_flag_t + } +} + +#[repr(i32)] +pub enum Value { + Clear = 0, + Set = 1, +} + +impl From<Value> for cap_flag_value_t { + fn from(v: Value) -> cap_flag_value_t { + v as cap_flag_value_t + } +} + +pub struct Caps(cap_t); + +impl Caps { + /// Get the capabilities for the current thread. + pub fn for_current_thread() -> io::Result<Caps> { + // Safe because this doesn't modify any memory and we check the return value. + let caps = unsafe { cap_get_proc() }; + if caps.is_null() { + Err(io::Error::last_os_error()) + } else { + Ok(Caps(caps)) + } + } + + /// Update the capabilities described by `self` by setting or clearing `caps` in `set`. + pub fn update(&mut self, caps: &[Capability], set: Set, value: Value) -> io::Result<()> { + // Safe because this only modifies the memory pointed to by `self.0` and we check the return + // value. + let ret = unsafe { + cap_set_flag( + self.0, + set.into(), + caps.len() as c_int, + // It's safe to cast this pointer because `Capability` is #[repr(u32)] + caps.as_ptr() as *const cap_value_t, + value.into(), + ) + }; + + if ret == 0 { + Ok(()) + } else { + Err(io::Error::last_os_error()) + } + } + + /// Apply the capabilities described by `self` to the current thread. + pub fn apply(&self) -> io::Result<()> { + if unsafe { cap_set_proc(self.0) } == 0 { + Ok(()) + } else { + Err(io::Error::last_os_error()) + } + } +} + +impl Drop for Caps { + fn drop(&mut self) { + unsafe { + cap_free(self.0); + } + } +} diff --git a/devices/src/virtio/fs/mod.rs b/devices/src/virtio/fs/mod.rs index 7ec90cbf4..bc8c28f01 100644 --- a/devices/src/virtio/fs/mod.rs +++ b/devices/src/virtio/fs/mod.rs @@ -8,11 +8,10 @@ use std::mem; use std::sync::{Arc, Mutex}; use std::thread; -use base::{error, warn, AsRawDescriptor, Error as SysError, Event, RawDescriptor}; +use base::{error, warn, AsRawDescriptor, Error as SysError, Event, RawDescriptor, Tube}; use data_model::{DataInit, Le32}; -use msg_socket::{MsgReceiver, MsgSender}; use resources::Alloc; -use vm_control::{FsMappingRequest, FsMappingRequestSocket, VmResponse}; +use vm_control::{FsMappingRequest, VmResponse}; use vm_memory::GuestMemory; use crate::pci::{ @@ -23,6 +22,7 @@ use crate::virtio::{ VirtioPciShmCap, TYPE_FS, }; +mod caps; mod multikey; pub mod passthrough; mod read_dir; @@ -33,7 +33,7 @@ use passthrough::PassthroughFs; use worker::Worker; // The fs device does not have a fixed number of queues. -const QUEUE_SIZE: u16 = 1024; +pub const QUEUE_SIZE: u16 = 1024; const FS_BAR_NUM: u8 = 4; const FS_BAR_OFFSET: u64 = 0; @@ -48,15 +48,15 @@ pub const FS_MAX_TAG_LEN: usize = 36; /// kernel/include/uapi/linux/virtio_fs.h #[repr(C, packed)] #[derive(Clone, Copy)] -struct Config { +pub(crate) struct virtio_fs_config { /// Filesystem name (UTF-8, not NUL-terminated, padded with NULs) - tag: [u8; FS_MAX_TAG_LEN], + pub tag: [u8; FS_MAX_TAG_LEN], /// Number of request queues - num_queues: Le32, + pub num_request_queues: Le32, } // Safe because all members are plain old data and any value is valid. -unsafe impl DataInit for Config {} +unsafe impl DataInit for virtio_fs_config {} /// Errors that may occur during the creation or operation of an Fs device. #[derive(Debug)] @@ -81,6 +81,10 @@ pub enum Error { InvalidDescriptorChain(DescriptorError), /// Error happened in FUSE. FuseError(fuse::Error), + /// Failed to get the securebits for the worker thread. + GetSecurebits(io::Error), + /// Failed to set the securebits for the worker thread. + SetSecurebits(io::Error), } impl ::std::error::Error for Error {} @@ -109,6 +113,12 @@ impl fmt::Display for Error { SignalUsedQueue(err) => write!(f, "failed to signal used queue: {}", err), InvalidDescriptorChain(err) => write!(f, "DescriptorChain is invalid: {}", err), FuseError(err) => write!(f, "fuse error: {}", err), + GetSecurebits(err) => { + write!(f, "failed to get securebits for the worker thread: {}", err) + } + SetSecurebits(err) => { + write!(f, "failed to set securebits for the worker thread: {}", err) + } } } } @@ -116,13 +126,13 @@ impl fmt::Display for Error { pub type Result<T> = ::std::result::Result<T, Error>; pub struct Fs { - cfg: Config, + cfg: virtio_fs_config, fs: Option<PassthroughFs>, queue_sizes: Box<[u16]>, avail_features: u64, acked_features: u64, pci_bar: Option<Alloc>, - socket: Option<FsMappingRequestSocket>, + tube: Option<Tube>, workers: Vec<(Event, thread::JoinHandle<Result<()>>)>, } @@ -132,7 +142,7 @@ impl Fs { tag: &str, num_workers: usize, fs_cfg: passthrough::Config, - socket: FsMappingRequestSocket, + tube: Tube, ) -> Result<Fs> { if tag.len() > FS_MAX_TAG_LEN { return Err(Error::TagTooLong(tag.len())); @@ -141,9 +151,9 @@ impl Fs { let mut cfg_tag = [0u8; FS_MAX_TAG_LEN]; cfg_tag[..tag.len()].copy_from_slice(tag.as_bytes()); - let cfg = Config { + let cfg = virtio_fs_config { tag: cfg_tag, - num_queues: Le32::from(num_workers as u32), + num_request_queues: Le32::from(num_workers as u32), }; let fs = PassthroughFs::new(fs_cfg).map_err(Error::CreateFs)?; @@ -158,7 +168,7 @@ impl Fs { avail_features: base_features, acked_features: 0, pci_bar: None, - socket: Some(socket), + tube: Some(tube), workers: Vec::with_capacity(num_workers + 1), }) } @@ -190,7 +200,7 @@ impl VirtioDevice for Fs { .as_ref() .map(PassthroughFs::keep_rds) .unwrap_or_else(Vec::new); - if let Some(rd) = self.socket.as_ref().map(|s| s.as_raw_descriptor()) { + if let Some(rd) = self.tube.as_ref().map(|s| s.as_raw_descriptor()) { fds.push(rd); } @@ -240,7 +250,7 @@ impl VirtioDevice for Fs { let server = Arc::new(Server::new(fs)); let irq = Arc::new(interrupt); - let socket = self.socket.take().expect("missing mapping socket"); + let socket = self.tube.take().expect("missing mapping socket"); let mut slot = 0; // Set up shared memory for DAX. diff --git a/devices/src/virtio/fs/passthrough.rs b/devices/src/virtio/fs/passthrough.rs index 71415e101..dfeee627a 100644 --- a/devices/src/virtio/fs/passthrough.rs +++ b/devices/src/virtio/fs/passthrough.rs @@ -6,12 +6,11 @@ use std::borrow::Cow; use std::cmp; use std::collections::btree_map; use std::collections::BTreeMap; -use std::ffi::{c_void, CStr, CString}; +use std::ffi::{CStr, CString}; use std::fs::File; use std::io; use std::mem::{self, size_of, MaybeUninit}; use std::os::raw::{c_int, c_long}; -use std::ptr; use std::str::FromStr; use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; use std::sync::Arc; @@ -27,10 +26,11 @@ use fuse::filesystem::{ IoctlReply, ListxattrReply, OpenOptions, RemoveMappingOne, SetattrValid, ZeroCopyReader, ZeroCopyWriter, ROOT_ID, }; +use fuse::sys::WRITE_KILL_PRIV; use fuse::Mapper; -use rand_ish::SimpleRng; use sync::Mutex; +use crate::virtio::fs::caps::{Capability, Caps, Set as CapSet, Value as CapValue}; use crate::virtio::fs::multikey::MultikeyBTreeMap; use crate::virtio::fs::read_dir::ReadDir; @@ -254,6 +254,62 @@ fn set_creds( ScopedGid::new(gid, oldgid).and_then(|gid| Ok((ScopedUid::new(uid, olduid)?, gid))) } +struct ScopedUmask<'a> { + old: libc::mode_t, + mask: libc::mode_t, + _factory: &'a mut Umask, +} + +impl<'a> Drop for ScopedUmask<'a> { + fn drop(&mut self) { + // Safe because this doesn't modify any memory and always succeeds. + let previous = unsafe { libc::umask(self.old) }; + debug_assert_eq!( + previous, self.mask, + "umask changed while holding ScopedUmask" + ); + } +} + +struct Umask; + +impl Umask { + fn set(&mut self, mask: libc::mode_t) -> ScopedUmask { + ScopedUmask { + // Safe because this doesn't modify any memory and always succeeds. + old: unsafe { libc::umask(mask) }, + mask, + _factory: self, + } + } +} + +struct ScopedFsetid(Caps); +impl Drop for ScopedFsetid { + fn drop(&mut self) { + if let Err(e) = raise_cap_fsetid(&mut self.0) { + error!( + "Failed to restore CAP_FSETID: {}. Some operations may be broken.", + e + ) + } + } +} + +fn raise_cap_fsetid(c: &mut Caps) -> io::Result<()> { + c.update(&[Capability::Fsetid], CapSet::Effective, CapValue::Set)?; + c.apply() +} + +// Drops CAP_FSETID from the effective set for the current thread and returns an RAII guard that +// adds the capability back when it is dropped. +fn drop_cap_fsetid() -> io::Result<ScopedFsetid> { + let mut caps = Caps::for_current_thread()?; + caps.update(&[Capability::Fsetid], CapSet::Effective, CapValue::Clear)?; + caps.apply()?; + Ok(ScopedFsetid(caps)) +} + fn ebadf() -> io::Error { io::Error::from_raw_os_error(libc::EBADF) } @@ -442,6 +498,10 @@ pub struct PassthroughFs { // process-wide CWD, we cannot allow more than one thread to do it at the same time. chdir_mutex: Mutex<()>, + // Used when creating files / directories / nodes. Since the umask is process-wide, we can only + // allow one thread at a time to change it. + umask: Mutex<Umask>, + cfg: Config, } @@ -479,6 +539,7 @@ impl PassthroughFs { zero_message_opendir: AtomicBool::new(false), chdir_mutex: Mutex::new(()), + umask: Mutex::new(Umask), cfg, }) } @@ -570,7 +631,7 @@ impl PassthroughFs { } // Creates a new entry for `f` or increases the refcount of the existing entry for `f`. - fn add_entry(&self, f: File, st: libc::stat64, open_flags: libc::c_int) -> io::Result<Entry> { + fn add_entry(&self, f: File, st: libc::stat64, open_flags: libc::c_int) -> Entry { let altkey = InodeAltKey { ino: st.st_ino, dev: st.st_dev, @@ -603,13 +664,13 @@ impl PassthroughFs { inode }; - Ok(Entry { + Entry { inode, generation: 0, attr: st, attr_timeout: self.cfg.attr_timeout, entry_timeout: self.cfg.entry_timeout, - }) + } } // Performs an ascii case insensitive lookup. @@ -651,7 +712,7 @@ impl PassthroughFs { // Safe because we just opened this fd. let f = unsafe { File::from_raw_descriptor(fd) }; - self.add_entry(f, st, flags) + Ok(self.add_entry(f, st, flags)) } fn do_open(&self, inode: Inode, flags: u32) -> io::Result<(Option<Handle>, OpenOptions)> { @@ -684,68 +745,6 @@ impl PassthroughFs { Ok((Some(handle), opts)) } - fn do_tmpfile( - &self, - ctx: &Context, - dir: &InodeData, - flags: u32, - mut mode: u32, - umask: u32, - ) -> io::Result<(File, libc::c_int)> { - // We don't want to use `O_EXCL` with `O_TMPFILE` as it has a different meaning when used in - // that combination. - let mut tmpflags = (flags as i32 | libc::O_TMPFILE | libc::O_CLOEXEC | libc::O_NOFOLLOW) - & !(libc::O_EXCL | libc::O_CREAT); - - // O_TMPFILE requires that we use O_RDWR or O_WRONLY. - if flags as i32 & libc::O_ACCMODE == libc::O_RDONLY { - tmpflags &= !libc::O_ACCMODE; - tmpflags |= libc::O_RDWR; - } - - // The presence of a default posix acl xattr in the parent directory completely changes the - // meaning of the mode parameter so only apply the umask if it doesn't have one. - if !self.has_default_posix_acl(&dir)? { - mode &= !umask; - } - - // Safe because this is a valid c string. - let current_dir = unsafe { CStr::from_bytes_with_nul_unchecked(b".\0") }; - - // Safe because this doesn't modify any memory and we check the return value. - let fd = unsafe { - libc::openat( - dir.as_raw_descriptor(), - current_dir.as_ptr(), - tmpflags, - mode, - ) - }; - if fd < 0 { - return Err(io::Error::last_os_error()); - } - - // Safe because we just opened this fd. - let tmpfile = unsafe { File::from_raw_descriptor(fd) }; - - // We need to respect the setgid bit in the parent directory if it is set. - let st = stat(dir)?; - let gid = if st.st_mode & libc::S_ISGID != 0 { - st.st_gid - } else { - ctx.gid - }; - - // Now set the uid and gid for the file. Safe because this doesn't modify any memory and we - // check the return value. - let ret = unsafe { libc::fchown(tmpfile.as_raw_descriptor(), ctx.uid, gid) }; - if ret < 0 { - return Err(io::Error::last_os_error()); - } - - Ok((tmpfile, tmpflags)) - } - fn do_release(&self, inode: Inode, handle: Handle) -> io::Result<()> { let mut handles = self.handles.lock(); @@ -868,21 +867,6 @@ impl PassthroughFs { } } - // Checks whether `inode` has a default posix acl xattr. - fn has_default_posix_acl(&self, inode: &InodeData) -> io::Result<bool> { - // Safe because this is a valid c string with no interior nul-bytes. - let acl = unsafe { CStr::from_bytes_with_nul_unchecked(b"system.posix_acl_default\0") }; - - if let Err(e) = self.do_getxattr(inode, acl, &mut []) { - match e.raw_os_error() { - Some(libc::ENODATA) | Some(libc::EOPNOTSUPP) => Ok(false), - _ => Err(e), - } - } else { - Ok(true) - } - } - fn get_encryption_policy_ex<R: io::Read>( &self, inode: Inode, @@ -1015,8 +999,8 @@ fn forget_one( // Synchronizes with the acquire load in `do_lookup`. if data .refcount - .compare_and_swap(refcount, new_count, Ordering::Release) - == refcount + .compare_exchange_weak(refcount, new_count, Ordering::Release, Ordering::Relaxed) + .is_ok() { if new_count == 0 { // We just removed the last refcount for this inode. There's no need for an @@ -1062,133 +1046,6 @@ fn strip_xattr_prefix(buf: &mut Vec<u8>) { } } -// Like mkdtemp but also takes a mode parameter rather than always using 0o700. This is needed -// because if the parent has a default posix acl set then the meaning of the mode parameter in the -// mkdir call completely changes: the actual mode is inherited from the default acls set in the -// parent and the mode is treated like a umask (the real umask is ignored in this case). -// Additionally, this only happens when the inode is first created and not on subsequent fchmod -// calls so we really need to use the requested mode from the very beginning and not the default -// 0o700 mode that mkdtemp uses. -fn create_temp_dir<D: AsRawDescriptor>(parent: &D, mode: libc::mode_t) -> io::Result<CString> { - const MAX_ATTEMPTS: usize = 64; - let mut seed = 0u64.to_ne_bytes(); - // Safe because this will only modify `seed` and we check the return value. - let ret = unsafe { - libc::syscall( - libc::SYS_getrandom, - seed.as_mut_ptr() as *mut c_void, - seed.len(), - 0, - ) - }; - if ret < 0 { - return Err(io::Error::last_os_error()); - } - - let mut rng = SimpleRng::new(u64::from_ne_bytes(seed)); - - // Set an upper bound so that we don't end up spinning here forever. - for _ in 0..MAX_ATTEMPTS { - let mut name = String::from("."); - name.push_str(&rng.str(6)); - let name = CString::new(name).expect("SimpleRng produced string with nul-bytes"); - - // Safe because this doesn't modify any memory and we check the return value. - let ret = unsafe { libc::mkdirat(parent.as_raw_descriptor(), name.as_ptr(), mode) }; - if ret == 0 { - return Ok(name); - } - - let e = io::Error::last_os_error(); - if let Some(libc::EEXIST) = e.raw_os_error() { - continue; - } else { - return Err(e); - } - } - - Err(io::Error::from_raw_os_error(libc::EAGAIN)) -} - -// A temporary directory that is automatically deleted when dropped unless `into_inner()` is called. -// This isn't a general-purpose temporary directory and is only intended to be used to ensure that -// there are no leaks when initializing a newly created directory with the correct metadata (see the -// implementation of `mkdir()` below). The directory is removed via a call to `unlinkat` so callers -// are not allowed to actually populate this temporary directory with any entries (or else deleting -// the directory will fail). -struct TempDir<'a, D: AsRawDescriptor> { - parent: &'a D, - name: CString, - file: File, -} - -impl<'a, D: AsRawDescriptor> TempDir<'a, D> { - // Creates a new temporary directory in `parent` with a randomly generated name. `parent` must - // be a directory. - fn new(parent: &'a D, mode: libc::mode_t) -> io::Result<Self> { - let name = create_temp_dir(parent, mode)?; - - // Safe because this doesn't modify any memory and we check the return value. - let raw_descriptor = unsafe { - libc::openat( - parent.as_raw_descriptor(), - name.as_ptr(), - libc::O_DIRECTORY | libc::O_CLOEXEC, - ) - }; - if raw_descriptor < 0 { - return Err(io::Error::last_os_error()); - } - - Ok(TempDir { - parent, - name, - // Safe because we just opened this descriptor. - file: unsafe { File::from_raw_descriptor(raw_descriptor) }, - }) - } - - fn basename(&self) -> &CStr { - &self.name - } - - // Consumes the `TempDir`, returning the inner `File` without deleting the temporary - // directory. - fn into_inner(self) -> (CString, File) { - // Safe because this is a valid pointer and we are going to call `mem::forget` on `self` so - // we will not be aliasing memory. - let _parent = unsafe { ptr::read(&self.parent) }; - let name = unsafe { ptr::read(&self.name) }; - let file = unsafe { ptr::read(&self.file) }; - mem::forget(self); - - (name, file) - } -} - -impl<'a, D: AsRawDescriptor> AsRawDescriptor for TempDir<'a, D> { - fn as_raw_descriptor(&self) -> RawDescriptor { - self.file.as_raw_descriptor() - } -} - -impl<'a, D: AsRawDescriptor> Drop for TempDir<'a, D> { - fn drop(&mut self) { - // Safe because this doesn't modify any memory and we check the return value. - let ret = unsafe { - libc::unlinkat( - self.parent.as_raw_descriptor(), - self.name.as_ptr(), - libc::AT_REMOVEDIR, - ) - }; - if ret < 0 { - println!("Failed to remove tempdir: {}", io::Error::last_os_error()); - error!("Failed to remove tempdir: {}", io::Error::last_os_error()); - } - } -} - impl FileSystem for PassthroughFs { type Inode = Inode; type Handle = Handle; @@ -1331,65 +1188,24 @@ impl FileSystem for PassthroughFs { ctx: Context, parent: Inode, name: &CStr, - mut mode: u32, + mode: u32, umask: u32, ) -> io::Result<Entry> { - // This method has the same issues as `create()`: namely that the kernel may have allowed a - // process to make a directory due to one of its supplementary groups but that information - // is not forwarded to us. However, there is no `O_TMPDIR` equivalent for directories so - // instead we create a "hidden" directory with a randomly generated name in the parent - // directory, modify the uid/gid and mode to the proper values, and then rename it to the - // requested name. This ensures that even in the case of a power loss the directory is not - // visible in the filesystem with the requested name but incorrect metadata. The only thing - // left would be a empty hidden directory with a random name. let data = self.find_inode(parent)?; - // The presence of a default posix acl xattr in the parent directory completely changes the - // meaning of the mode parameter so only apply the umask if it doesn't have one. - if !self.has_default_posix_acl(&data)? { - mode &= !umask; - } - - let tmpdir = TempDir::new(&*data, mode)?; - - // We need to respect the setgid bit in the parent directory if it is set. - let st = stat(&data.file.lock().0)?; - let gid = if st.st_mode & libc::S_ISGID != 0 { - st.st_gid - } else { - ctx.gid - }; - - // Set the uid and gid for the directory. Safe because this doesn't modify any memory and we - // check the return value. - let ret = unsafe { libc::fchown(tmpdir.as_raw_descriptor(), ctx.uid, gid) }; - if ret < 0 { - return Err(io::Error::last_os_error()); - } + let (_uid, _gid) = set_creds(ctx.uid, ctx.gid)?; + let res = { + let mut um = self.umask.lock(); + let _scoped_umask = um.set(umask); - // Now rename it into place. Safe because this doesn't modify any memory and we check the - // return value. TODO: Switch to libc::renameat2 once - // https://github.com/rust-lang/libc/pull/1508 lands and we have glibc 2.28. - let ret = unsafe { - libc::syscall( - libc::SYS_renameat2, - data.as_raw_descriptor(), - tmpdir.basename().as_ptr(), - data.as_raw_descriptor(), - name.as_ptr(), - libc::RENAME_NOREPLACE, - ) + // Safe because this doesn't modify any memory and we check the return value. + unsafe { libc::mkdirat(data.as_raw_descriptor(), name.as_ptr(), mode) } }; - if ret < 0 { - return Err(io::Error::last_os_error()); + if res == 0 { + self.do_lookup(&data, name) + } else { + Err(io::Error::last_os_error()) } - - // Now that we've moved the directory make sure we don't try to delete the now non-existent - // `tmpdir`. - let (_, dir) = tmpdir.into_inner(); - - let st = stat(&dir)?; - self.add_entry(dir, st, libc::O_DIRECTORY | libc::O_CLOEXEC) } fn rmdir(&self, _ctx: Context, parent: Inode, name: &CStr) -> io::Result<()> { @@ -1458,10 +1274,36 @@ impl FileSystem for PassthroughFs { ) -> io::Result<Entry> { let data = self.find_inode(parent)?; - let (tmpfile, flags) = self.do_tmpfile(&ctx, &data, 0, mode, umask)?; + let (_uid, _gid) = set_creds(ctx.uid, ctx.gid)?; + + let tmpflags = libc::O_RDWR | libc::O_TMPFILE | libc::O_CLOEXEC | libc::O_NOFOLLOW; + + // Safe because this is a valid c string. + let current_dir = unsafe { CStr::from_bytes_with_nul_unchecked(b".\0") }; + + let fd = { + let mut um = self.umask.lock(); + let _scoped_umask = um.set(umask); + + // Safe because this doesn't modify any memory and we check the return value. + unsafe { + libc::openat( + data.as_raw_descriptor(), + current_dir.as_ptr(), + tmpflags, + mode, + ) + } + }; + if fd < 0 { + return Err(io::Error::last_os_error()); + } + + // Safe because we just opened this fd. + let tmpfile = unsafe { File::from_raw_descriptor(fd) }; let st = stat(&tmpfile)?; - self.add_entry(tmpfile, st, flags) + Ok(self.add_entry(tmpfile, st, tmpflags)) } fn create( @@ -1473,40 +1315,31 @@ impl FileSystem for PassthroughFs { flags: u32, umask: u32, ) -> io::Result<(Entry, Option<Handle>, OpenOptions)> { - // The `Context` may not contain all the information we need to create the file here. For - // example, a process may be part of several groups, one of which gives it permission to - // create a file in `parent`, but is not the gid of the process. This information is not - // forwarded to the server so we don't know when this is happening. Instead, we just rely on - // the access checks in the kernel driver: if we received this request then the kernel has - // determined that the process is allowed to create the file and we shouldn't reject it now - // based on acls. - // - // To ensure that the file is created atomically with the proper uid/gid we use `O_TMPFILE` - // + `linkat` as described in the `open(2)` manpage. let data = self.find_inode(parent)?; - let (tmpfile, tmpflags) = self.do_tmpfile(&ctx, &data, flags, mode, umask)?; + let (_uid, _gid) = set_creds(ctx.uid, ctx.gid)?; - let proc_path = CString::new(format!("self/fd/{}", tmpfile.as_raw_descriptor())) - .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + let create_flags = + (flags as i32 | libc::O_CREAT | libc::O_CLOEXEC | libc::O_NOFOLLOW) & !libc::O_DIRECT; - // Finally link it into the file system tree so that it's visible to other processes. Safe - // because this doesn't modify any memory and we check the return value. - let ret = unsafe { - libc::linkat( - self.proc.as_raw_descriptor(), - proc_path.as_ptr(), - data.as_raw_descriptor(), - name.as_ptr(), - libc::AT_SYMLINK_FOLLOW, - ) + let fd = { + let mut um = self.umask.lock(); + let _scoped_umask = um.set(umask); + + // Safe because this doesn't modify any memory and we check the return value. We don't + // really check `flags` because if the kernel can't handle poorly specified flags then + // we have much bigger problems. + unsafe { libc::openat(data.as_raw_descriptor(), name.as_ptr(), create_flags, mode) } }; - if ret < 0 { + if fd < 0 { return Err(io::Error::last_os_error()); } - let st = stat(&tmpfile)?; - let entry = self.add_entry(tmpfile, st, tmpflags)?; + // Safe because we just opened this fd. + let file = unsafe { File::from_raw_descriptor(fd) }; + + let st = stat(&file)?; + let entry = self.add_entry(file, st, create_flags); let (handle, opts) = if self.zero_message_open.load(Ordering::Relaxed) { (None, OpenOptions::KEEP_CACHE) @@ -1570,7 +1403,7 @@ impl FileSystem for PassthroughFs { fn write<R: io::Read + ZeroCopyReader>( &self, - ctx: Context, + _ctx: Context, inode: Inode, handle: Handle, mut r: R, @@ -1578,11 +1411,15 @@ impl FileSystem for PassthroughFs { offset: u64, _lock_owner: Option<u64>, _delayed_write: bool, - _flags: u32, + flags: u32, ) -> io::Result<usize> { - // We need to change credentials during a write so that the kernel will remove setuid or - // setgid bits from the file if it was written to by someone other than the owner. - let (_uid, _gid) = set_creds(ctx.uid, ctx.gid)?; + // When the WRITE_KILL_PRIV flag is set, drop CAP_FSETID so that the kernel will + // automatically clear the setuid and setgid bits for us. + let _fsetid = if flags & WRITE_KILL_PRIV != 0 { + Some(drop_cap_fsetid()?) + } else { + None + }; if self.zero_message_open.load(Ordering::Relaxed) { let data = self.find_inode(inode)?; @@ -1788,28 +1625,27 @@ impl FileSystem for PassthroughFs { ctx: Context, parent: Inode, name: &CStr, - mut mode: u32, + mode: u32, rdev: u32, umask: u32, ) -> io::Result<Entry> { - let (_uid, _gid) = set_creds(ctx.uid, ctx.gid)?; - let data = self.find_inode(parent)?; - // The presence of a default posix acl xattr in the parent directory completely changes the - // meaning of the mode parameter so only apply the umask if it doesn't have one. - if !self.has_default_posix_acl(&data)? { - mode &= !umask; - } + let (_uid, _gid) = set_creds(ctx.uid, ctx.gid)?; - // Safe because this doesn't modify any memory and we check the return value. - let res = unsafe { - libc::mknodat( - data.as_raw_descriptor(), - name.as_ptr(), - mode as libc::mode_t, - rdev as libc::dev_t, - ) + let res = { + let mut um = self.umask.lock(); + let _scoped_umask = um.set(umask); + + // Safe because this doesn't modify any memory and we check the return value. + unsafe { + libc::mknodat( + data.as_raw_descriptor(), + name.as_ptr(), + mode as libc::mode_t, + rdev as libc::dev_t, + ) + } }; if res < 0 { @@ -1912,7 +1748,7 @@ impl FileSystem for PassthroughFs { // behavior by doing the same thing (dup-ing the fd and then immediately closing it). Safe // because this doesn't modify any memory and we check the return values. unsafe { - let newfd = libc::dup(data.as_raw_descriptor()); + let newfd = libc::fcntl(data.as_raw_descriptor(), libc::F_DUPFD_CLOEXEC, 0); if newfd < 0 { return Err(io::Error::last_os_error()); @@ -2380,72 +2216,6 @@ impl FileSystem for PassthroughFs { mod tests { use super::*; - use std::env; - use std::os::unix::ffi::OsStringExt; - - #[test] - fn create_temp_dir() { - let testdir = CString::new(env::temp_dir().into_os_string().into_vec()) - .expect("env::temp_dir() is not a valid c-string"); - let fd = unsafe { - libc::openat( - libc::AT_FDCWD, - testdir.as_ptr(), - libc::O_PATH | libc::O_CLOEXEC, - ) - }; - assert!(fd >= 0, "Failed to open env::temp_dir()"); - let parent = unsafe { File::from_raw_descriptor(fd) }; - let t = TempDir::new(&parent, 0o755).expect("Failed to create temporary directory"); - - let basename = t.basename().to_string_lossy(); - let path = env::temp_dir().join(&*basename); - assert!(path.exists()); - assert!(path.is_dir()); - } - - #[test] - fn remove_temp_dir() { - let testdir = CString::new(env::temp_dir().into_os_string().into_vec()) - .expect("env::temp_dir() is not a valid c-string"); - let fd = unsafe { - libc::openat( - libc::AT_FDCWD, - testdir.as_ptr(), - libc::O_PATH | libc::O_CLOEXEC, - ) - }; - assert!(fd >= 0, "Failed to open env::temp_dir()"); - let parent = unsafe { File::from_raw_descriptor(fd) }; - let t = TempDir::new(&parent, 0o755).expect("Failed to create temporary directory"); - - let basename = t.basename().to_string_lossy(); - let path = env::temp_dir().join(&*basename); - mem::drop(t); - assert!(!path.exists()); - } - - #[test] - fn temp_dir_into_inner() { - let testdir = CString::new(env::temp_dir().into_os_string().into_vec()) - .expect("env::temp_dir() is not a valid c-string"); - let fd = unsafe { - libc::openat( - libc::AT_FDCWD, - testdir.as_ptr(), - libc::O_PATH | libc::O_CLOEXEC, - ) - }; - assert!(fd >= 0, "Failed to open env::temp_dir()"); - let parent = unsafe { File::from_raw_descriptor(fd) }; - let t = TempDir::new(&parent, 0o755).expect("Failed to create temporary directory"); - - let (basename_cstr, _) = t.into_inner(); - let basename = basename_cstr.to_string_lossy(); - let path = env::temp_dir().join(&*basename); - assert!(path.exists()); - } - #[test] fn rewrite_xattr_names() { let cfg = Config { diff --git a/devices/src/virtio/fs/worker.rs b/devices/src/virtio/fs/worker.rs index ded93b2c7..65a47a9c6 100644 --- a/devices/src/virtio/fs/worker.rs +++ b/devices/src/virtio/fs/worker.rs @@ -2,20 +2,19 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -use std::convert::TryInto; +use std::convert::{TryFrom, TryInto}; use std::fs::File; use std::io; use std::os::unix::io::AsRawFd; use std::sync::{Arc, Mutex}; -use base::{error, Event, PollToken, WaitContext}; +use base::{error, Event, PollToken, SafeDescriptor, Tube, WaitContext}; use fuse::filesystem::{FileSystem, ZeroCopyReader, ZeroCopyWriter}; -use msg_socket::{MsgReceiver, MsgSender}; -use vm_control::{FsMappingRequest, FsMappingRequestSocket, MaybeOwnedDescriptor, VmResponse}; +use vm_control::{FsMappingRequest, VmResponse}; use vm_memory::GuestMemory; use crate::virtio::fs::{Error, Result}; -use crate::virtio::{Interrupt, Queue, Reader, Writer}; +use crate::virtio::{Interrupt, Queue, Reader, SignalableInterrupt, Writer}; impl fuse::Reader for Reader {} @@ -46,27 +45,27 @@ impl ZeroCopyWriter for Writer { } struct Mapper { - socket: Arc<Mutex<FsMappingRequestSocket>>, + tube: Arc<Mutex<Tube>>, slot: u32, } impl Mapper { - fn new(socket: Arc<Mutex<FsMappingRequestSocket>>, slot: u32) -> Self { - Self { socket, slot } + fn new(tube: Arc<Mutex<Tube>>, slot: u32) -> Self { + Self { tube, slot } } fn process_request(&self, request: &FsMappingRequest) -> io::Result<()> { - let socket = self.socket.lock().map_err(|e| { - error!("failed to lock socket: {}", e); + let tube = self.tube.lock().map_err(|e| { + error!("failed to lock tube: {}", e); io::Error::from_raw_os_error(libc::EINVAL) })?; - socket.send(request).map_err(|e| { + tube.send(request).map_err(|e| { error!("failed to send request {:?}: {}", request, e); io::Error::from_raw_os_error(libc::EINVAL) })?; - match socket.recv() { + match tube.recv() { Ok(VmResponse::Ok) => Ok(()), Ok(VmResponse::Err(e)) => Err(e.into()), r => { @@ -91,9 +90,11 @@ impl fuse::Mapper for Mapper { io::Error::from_raw_os_error(libc::EINVAL) })?; + let fd = SafeDescriptor::try_from(fd)?; + let request = FsMappingRequest::CreateMemoryMapping { slot: self.slot, - fd: MaybeOwnedDescriptor::Borrowed(fd.as_raw_fd()), + fd, size, file_offset, prot, @@ -128,7 +129,7 @@ pub struct Worker<F: FileSystem + Sync> { queue: Queue, server: Arc<fuse::Server<F>>, irq: Arc<Interrupt>, - socket: Arc<Mutex<FsMappingRequestSocket>>, + tube: Arc<Mutex<Tube>>, slot: u32, } @@ -138,7 +139,7 @@ impl<F: FileSystem + Sync> Worker<F> { queue: Queue, server: Arc<fuse::Server<F>>, irq: Arc<Interrupt>, - socket: Arc<Mutex<FsMappingRequestSocket>>, + tube: Arc<Mutex<Tube>>, slot: u32, ) -> Worker<F> { Worker { @@ -146,7 +147,7 @@ impl<F: FileSystem + Sync> Worker<F> { queue, server, irq, - socket, + tube, slot, } } @@ -154,7 +155,7 @@ impl<F: FileSystem + Sync> Worker<F> { fn process_queue(&mut self) -> Result<()> { let mut needs_interrupt = false; - let mapper = Mapper::new(Arc::clone(&self.socket), self.slot); + let mapper = Mapper::new(Arc::clone(&self.tube), self.slot); while let Some(avail_desc) = self.queue.pop(&self.mem) { let reader = Reader::new(self.mem.clone(), avail_desc.clone()) .map_err(Error::InvalidDescriptorChain)?; @@ -182,6 +183,29 @@ impl<F: FileSystem + Sync> Worker<F> { kill_evt: Event, watch_resample_event: bool, ) -> Result<()> { + // We need to set the no setuid fixup secure bit so that we don't drop capabilities when + // changing the thread uid/gid. Without this, creating new entries can fail in some corner + // cases. + const SECBIT_NO_SETUID_FIXUP: i32 = 1 << 2; + + // TODO(crbug.com/1199487): Remove this once libc provides the wrapper for all targets. + #[cfg(target_os = "linux")] + { + // Safe because this doesn't modify any memory and we check the return value. + let mut securebits = unsafe { libc::prctl(libc::PR_GET_SECUREBITS) }; + if securebits < 0 { + return Err(Error::GetSecurebits(io::Error::last_os_error())); + } + + securebits |= SECBIT_NO_SETUID_FIXUP; + + // Safe because this doesn't modify any memory and we check the return value. + let ret = unsafe { libc::prctl(libc::PR_SET_SECUREBITS, securebits) }; + if ret < 0 { + return Err(Error::SetSecurebits(io::Error::last_os_error())); + } + } + #[derive(PollToken)] enum Token { // A request is ready on the queue. @@ -197,9 +221,11 @@ impl<F: FileSystem + Sync> Worker<F> { .map_err(Error::CreateWaitContext)?; if watch_resample_event { - wait_ctx - .add(self.irq.get_resample_evt(), Token::InterruptResample) - .map_err(Error::CreateWaitContext)?; + if let Some(resample_evt) = self.irq.get_resample_evt() { + wait_ctx + .add(resample_evt, Token::InterruptResample) + .map_err(Error::CreateWaitContext)?; + } } loop { diff --git a/devices/src/virtio/gpu/mod.rs b/devices/src/virtio/gpu/mod.rs index 9646612e8..1d77e4ff0 100644 --- a/devices/src/virtio/gpu/mod.rs +++ b/devices/src/virtio/gpu/mod.rs @@ -3,6 +3,8 @@ // found in the LICENSE file. mod protocol; +mod udmabuf; +mod udmabuf_bindings; mod virtio_gpu; use std::cell::RefCell; @@ -19,22 +21,15 @@ use std::thread; use std::time::Duration; use base::{ - debug, error, warn, AsRawDescriptor, Event, ExternalMapping, PollToken, RawDescriptor, - WaitContext, + debug, error, warn, AsRawDescriptor, AsRawDescriptors, Event, ExternalMapping, PollToken, + RawDescriptor, Tube, WaitContext, }; use data_model::*; pub use gpu_display::EventDevice; use gpu_display::*; -use rutabaga_gfx::{ - DrmFormat, GfxstreamFlags, ResourceCreate3D, ResourceCreateBlob, RutabagaBuilder, - RutabagaChannel, RutabagaComponentType, RutabagaFenceData, Transfer3D, VirglRendererFlags, - RUTABAGA_CHANNEL_TYPE_CAMERA, RUTABAGA_CHANNEL_TYPE_WAYLAND, RUTABAGA_PIPE_BIND_RENDER_TARGET, - RUTABAGA_PIPE_TEXTURE_2D, -}; - -use msg_socket::{MsgReceiver, MsgSender}; +use rutabaga_gfx::*; use resources::Alloc; @@ -42,8 +37,8 @@ use sync::Mutex; use vm_memory::{GuestAddress, GuestMemory}; use super::{ - copy_config, resource_bridge::*, DescriptorChain, Interrupt, Queue, Reader, VirtioDevice, - Writer, TYPE_GPU, + copy_config, resource_bridge::*, DescriptorChain, Interrupt, Queue, Reader, + SignalableInterrupt, VirtioDevice, Writer, TYPE_GPU, }; use super::{PciCapabilityType, VirtioPciShmCap}; @@ -55,15 +50,13 @@ use crate::pci::{ PciAddress, PciBarConfiguration, PciBarPrefetchable, PciBarRegionType, PciCapability, }; -use vm_control::VmMemoryControlRequestSocket; - pub const DEFAULT_DISPLAY_WIDTH: u32 = 1280; pub const DEFAULT_DISPLAY_HEIGHT: u32 = 1024; #[derive(Copy, Clone, Debug, PartialEq)] pub enum GpuMode { Mode2D, - Mode3D, + ModeVirglRenderer, ModeGfxstream, } @@ -77,7 +70,8 @@ pub struct GpuParameters { pub renderer_use_surfaceless: bool, pub gfxstream_use_guest_angle: bool, pub gfxstream_use_syncfd: bool, - pub gfxstream_support_vulkan: bool, + pub use_vulkan: bool, + pub udmabuf: bool, pub mode: GpuMode, pub cache_path: Option<String>, pub cache_size: Option<String>, @@ -103,10 +97,11 @@ impl Default for GpuParameters { renderer_use_surfaceless: true, gfxstream_use_guest_angle: false, gfxstream_use_syncfd: true, - gfxstream_support_vulkan: true, - mode: GpuMode::Mode3D, + use_vulkan: false, + mode: GpuMode::ModeVirglRenderer, cache_path: None, cache_size: None, + udmabuf: false, } } } @@ -127,10 +122,11 @@ fn build( display_height: u32, rutabaga_builder: RutabagaBuilder, event_devices: Vec<EventDevice>, - gpu_device_socket: VmMemoryControlRequestSocket, + gpu_device_tube: Tube, pci_bar: Alloc, map_request: Arc<Mutex<Option<ExternalMapping>>>, external_blob: bool, + udmabuf: bool, ) -> Option<VirtioGpu> { let mut display_opt = None; for display in possible_displays { @@ -157,10 +153,11 @@ fn build( display_height, rutabaga_builder, event_devices, - gpu_device_socket, + gpu_device_tube, pci_bar, map_request, external_blob, + udmabuf, ) } @@ -228,7 +225,7 @@ impl Frontend { self.virtio_gpu.process_display() } - fn process_resource_bridge(&mut self, resource_bridge: &ResourceResponseSocket) { + fn process_resource_bridge(&mut self, resource_bridge: &Tube) { let response = match resource_bridge.recv() { Ok(ResourceRequest::GetBuffer { id }) => self.virtio_gpu.export_resource(id), Ok(ResourceRequest::GetFence { seqno }) => { @@ -685,7 +682,7 @@ struct Worker { ctrl_evt: Event, cursor_queue: Queue, cursor_evt: Event, - resource_bridges: Vec<ResourceResponseSocket>, + resource_bridges: Vec<Tube>, kill_evt: Event, state: Frontend, } @@ -706,7 +703,6 @@ impl Worker { (&self.ctrl_evt, Token::CtrlQueue), (&self.cursor_evt, Token::CursorQueue), (&*self.state.display().borrow(), Token::Display), - (self.interrupt.get_resample_evt(), Token::InterruptResample), (&self.kill_evt, Token::Kill), ]) { Ok(pc) => pc, @@ -715,6 +711,15 @@ impl Worker { return; } }; + if let Some(resample_evt) = self.interrupt.get_resample_evt() { + if wait_ctx + .add(resample_evt, Token::InterruptResample) + .is_err() + { + error!("failed creating WaitContext"); + return; + } + } for (index, bridge) in self.resource_bridges.iter().enumerate() { if let Err(e) = wait_ctx.add(bridge, Token::ResourceBridge { index }) { @@ -864,8 +869,8 @@ impl DisplayBackend { pub struct Gpu { exit_evt: Event, - gpu_device_socket: Option<VmMemoryControlRequestSocket>, - resource_bridges: Vec<ResourceResponseSocket>, + gpu_device_tube: Option<Tube>, + resource_bridges: Vec<Tube>, event_devices: Vec<EventDevice>, kill_evt: Option<Event>, config_event: bool, @@ -880,14 +885,16 @@ pub struct Gpu { external_blob: bool, rutabaga_component: RutabagaComponentType, base_features: u64, + mem: GuestMemory, + udmabuf: bool, } impl Gpu { pub fn new( exit_evt: Event, - gpu_device_socket: Option<VmMemoryControlRequestSocket>, + gpu_device_tube: Option<Tube>, num_scanouts: NonZeroU8, - resource_bridges: Vec<ResourceResponseSocket>, + resource_bridges: Vec<Tube>, display_backends: Vec<DisplayBackend>, gpu_parameters: &GpuParameters, event_devices: Vec<EventDevice>, @@ -895,13 +902,15 @@ impl Gpu { external_blob: bool, base_features: u64, channels: BTreeMap<String, PathBuf>, + mem: GuestMemory, ) -> Gpu { let virglrenderer_flags = VirglRendererFlags::new() .use_egl(gpu_parameters.renderer_use_egl) .use_gles(gpu_parameters.renderer_use_gles) .use_glx(gpu_parameters.renderer_use_glx) .use_surfaceless(gpu_parameters.renderer_use_surfaceless) - .use_external_blob(external_blob); + .use_external_blob(external_blob) + .use_venus(gpu_parameters.use_vulkan); let gfxstream_flags = GfxstreamFlags::new() .use_egl(gpu_parameters.renderer_use_egl) .use_gles(gpu_parameters.renderer_use_gles) @@ -909,7 +918,7 @@ impl Gpu { .use_surfaceless(gpu_parameters.renderer_use_surfaceless) .use_guest_angle(gpu_parameters.gfxstream_use_guest_angle) .use_syncfd(gpu_parameters.gfxstream_use_syncfd) - .support_vulkan(gpu_parameters.gfxstream_support_vulkan); + .use_vulkan(gpu_parameters.use_vulkan); let mut rutabaga_channels: Vec<RutabagaChannel> = Vec::new(); for (channel_name, path) in &channels { @@ -929,7 +938,7 @@ impl Gpu { let rutabaga_channels_opt = Some(rutabaga_channels); let component = match gpu_parameters.mode { GpuMode::Mode2D => RutabagaComponentType::Rutabaga2D, - GpuMode::Mode3D => RutabagaComponentType::VirglRenderer, + GpuMode::ModeVirglRenderer => RutabagaComponentType::VirglRenderer, GpuMode::ModeGfxstream => RutabagaComponentType::Gfxstream, }; @@ -942,7 +951,7 @@ impl Gpu { Gpu { exit_evt, - gpu_device_socket, + gpu_device_tube, num_scanouts, resource_bridges, event_devices, @@ -958,6 +967,8 @@ impl Gpu { external_blob, rutabaga_component: component, base_features, + mem, + udmabuf: gpu_parameters.udmabuf, } } @@ -970,8 +981,10 @@ impl Gpu { let num_capsets = match self.rutabaga_component { RutabagaComponentType::Rutabaga2D => 0, _ => { + let mut num_capsets = 0; + // Cross-domain (like virtio_wl with llvmpipe) is always available. - let mut num_capsets = 1; + num_capsets += 1; // Three capsets for virgl_renderer #[cfg(feature = "virgl_renderer")] @@ -1021,14 +1034,19 @@ impl VirtioDevice for Gpu { keep_rds.push(libc::STDERR_FILENO); } - if let Some(ref gpu_device_socket) = self.gpu_device_socket { - keep_rds.push(gpu_device_socket.as_raw_descriptor()); + if self.udmabuf { + keep_rds.append(&mut self.mem.as_raw_descriptors()); + } + + if let Some(ref gpu_device_tube) = self.gpu_device_tube { + keep_rds.push(gpu_device_tube.as_raw_descriptor()); } keep_rds.push(self.exit_evt.as_raw_descriptor()); for bridge in &self.resource_bridges { keep_rds.push(bridge.as_raw_descriptor()); } + keep_rds } @@ -1044,10 +1062,19 @@ impl VirtioDevice for Gpu { let rutabaga_features = match self.rutabaga_component { RutabagaComponentType::Rutabaga2D => 0, _ => { - 1 << VIRTIO_GPU_F_VIRGL + let mut features_3d = 0; + + features_3d |= 1 << VIRTIO_GPU_F_VIRGL | 1 << VIRTIO_GPU_F_RESOURCE_UUID | 1 << VIRTIO_GPU_F_RESOURCE_BLOB | 1 << VIRTIO_GPU_F_CONTEXT_INIT + | 1 << VIRTIO_GPU_F_RESOURCE_SYNC; + + if self.udmabuf { + features_3d |= 1 << VIRTIO_GPU_F_CREATE_GUEST_HANDLE; + } + + features_3d } }; @@ -1110,8 +1137,9 @@ impl VirtioDevice for Gpu { let event_devices = self.event_devices.split_off(0); let map_request = Arc::clone(&self.map_request); let external_blob = self.external_blob; - if let (Some(gpu_device_socket), Some(pci_bar), Some(rutabaga_builder)) = ( - self.gpu_device_socket.take(), + let udmabuf = self.udmabuf; + if let (Some(gpu_device_tube), Some(pci_bar), Some(rutabaga_builder)) = ( + self.gpu_device_tube.take(), self.pci_bar.take(), self.rutabaga_builder.take(), ) { @@ -1125,10 +1153,11 @@ impl VirtioDevice for Gpu { display_height, rutabaga_builder, event_devices, - gpu_device_socket, + gpu_device_tube, pci_bar, map_request, external_blob, + udmabuf, ) { Some(backend) => backend, None => return, diff --git a/devices/src/virtio/gpu/protocol.rs b/devices/src/virtio/gpu/protocol.rs index 77ab43a64..c2f9b3eb8 100644 --- a/devices/src/virtio/gpu/protocol.rs +++ b/devices/src/virtio/gpu/protocol.rs @@ -16,18 +16,21 @@ use std::str::from_utf8; use super::super::DescriptorError; use super::{Reader, Writer}; use base::Error as SysError; -use base::ExternalMappingError; +use base::{ExternalMappingError, TubeError}; use data_model::{DataInit, Le32, Le64}; use gpu_display::GpuDisplayError; -use msg_socket::MsgError; use rutabaga_gfx::RutabagaError; +use crate::virtio::gpu::udmabuf::UdmabufError; + pub const VIRTIO_GPU_F_VIRGL: u32 = 0; pub const VIRTIO_GPU_F_EDID: u32 = 1; pub const VIRTIO_GPU_F_RESOURCE_UUID: u32 = 2; pub const VIRTIO_GPU_F_RESOURCE_BLOB: u32 = 3; /* The following capabilities are not upstreamed. */ pub const VIRTIO_GPU_F_CONTEXT_INIT: u32 = 4; +pub const VIRTIO_GPU_F_CREATE_GUEST_HANDLE: u32 = 5; +pub const VIRTIO_GPU_F_RESOURCE_SYNC: u32 = 6; pub const VIRTIO_GPU_UNDEFINED: u32 = 0x0; @@ -88,6 +91,8 @@ pub const VIRTIO_GPU_BLOB_MEM_HOST3D_GUEST: u32 = 0x0003; pub const VIRTIO_GPU_BLOB_FLAG_USE_MAPPABLE: u32 = 0x0001; pub const VIRTIO_GPU_BLOB_FLAG_USE_SHAREABLE: u32 = 0x0002; pub const VIRTIO_GPU_BLOB_FLAG_USE_CROSS_DEVICE: u32 = 0x0004; +/* Create a OS-specific handle from guest memory (not upstreamed). */ +pub const VIRTIO_GPU_BLOB_FLAG_CREATE_GUEST_HANDLE: u32 = 0x0008; pub const VIRTIO_GPU_SHM_ID_NONE: u8 = 0x0000; pub const VIRTIO_GPU_SHM_ID_HOST_VISIBLE: u8 = 0x0001; @@ -800,7 +805,7 @@ pub enum GpuResponse { map_info: u32, }, ErrUnspec, - ErrMsg(MsgError), + ErrMsg(TubeError), ErrSys(SysError), ErrRutabaga(RutabagaError), ErrDisplay(GpuDisplayError), @@ -813,10 +818,11 @@ pub enum GpuResponse { ErrInvalidResourceId, ErrInvalidContextId, ErrInvalidParameter, + ErrUdmabuf(UdmabufError), } -impl From<MsgError> for GpuResponse { - fn from(e: MsgError) -> GpuResponse { +impl From<TubeError> for GpuResponse { + fn from(e: TubeError) -> GpuResponse { GpuResponse::ErrMsg(e) } } @@ -839,6 +845,12 @@ impl From<ExternalMappingError> for GpuResponse { } } +impl From<UdmabufError> for GpuResponse { + fn from(e: UdmabufError) -> GpuResponse { + GpuResponse::ErrUdmabuf(e) + } +} + impl Display for GpuResponse { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { use self::GpuResponse::*; @@ -848,6 +860,7 @@ impl Display for GpuResponse { ErrRutabaga(e) => write!(f, "renderer error: {}", e), ErrDisplay(e) => write!(f, "display error: {}", e), ErrScanout { num_scanouts } => write!(f, "non-zero scanout: {}", num_scanouts), + ErrUdmabuf(e) => write!(f, "udmabuf error: {}", e), _ => Ok(()), } } @@ -1024,6 +1037,7 @@ impl GpuResponse { GpuResponse::ErrRutabaga(_) => VIRTIO_GPU_RESP_ERR_UNSPEC, GpuResponse::ErrDisplay(_) => VIRTIO_GPU_RESP_ERR_UNSPEC, GpuResponse::ErrMapping(_) => VIRTIO_GPU_RESP_ERR_UNSPEC, + GpuResponse::ErrUdmabuf(_) => VIRTIO_GPU_RESP_ERR_UNSPEC, GpuResponse::ErrScanout { num_scanouts: _ } => VIRTIO_GPU_RESP_ERR_UNSPEC, GpuResponse::ErrOutOfMemory => VIRTIO_GPU_RESP_ERR_OUT_OF_MEMORY, GpuResponse::ErrInvalidScanoutId => VIRTIO_GPU_RESP_ERR_INVALID_SCANOUT_ID, diff --git a/devices/src/virtio/gpu/udmabuf.rs b/devices/src/virtio/gpu/udmabuf.rs new file mode 100644 index 000000000..812080bcc --- /dev/null +++ b/devices/src/virtio/gpu/udmabuf.rs @@ -0,0 +1,266 @@ +// Copyright 2021 The Chromium OS Authors. All rights reservsize. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#![allow(dead_code)] + +use std::fs::{File, OpenOptions}; +use std::os::raw::c_uint; + +use std::path::Path; +use std::{fmt, io}; + +use base::{ + ioctl_iow_nr, ioctl_with_ptr, pagesize, AsRawDescriptor, FromRawDescriptor, MappedRegion, + SafeDescriptor, +}; + +use data_model::{FlexibleArray, FlexibleArrayWrapper}; + +use rutabaga_gfx::{RutabagaHandle, RUTABAGA_MEM_HANDLE_TYPE_DMABUF}; + +use super::udmabuf_bindings::*; + +use vm_memory::{GuestAddress, GuestMemory, GuestMemoryError}; + +const UDMABUF_IOCTL_BASE: c_uint = 0x75; + +ioctl_iow_nr!(UDMABUF_CREATE, UDMABUF_IOCTL_BASE, 0x42, udmabuf_create); +ioctl_iow_nr!( + UDMABUF_CREATE_LIST, + UDMABUF_IOCTL_BASE, + 0x43, + udmabuf_create_list +); + +// It's possible to make the flexible array trait implementation a macro one day... +impl FlexibleArray<udmabuf_create_item> for udmabuf_create_list { + fn set_len(&mut self, len: usize) { + self.count = len as u32; + } + + fn get_len(&self) -> usize { + self.count as usize + } + + fn get_slice(&self, len: usize) -> &[udmabuf_create_item] { + unsafe { self.list.as_slice(len) } + } + + fn get_mut_slice(&mut self, len: usize) -> &mut [udmabuf_create_item] { + unsafe { self.list.as_mut_slice(len) } + } +} + +type UdmabufCreateList = FlexibleArrayWrapper<udmabuf_create_list, udmabuf_create_item>; + +#[derive(Debug)] +pub enum UdmabufError { + DriverOpenFailed(io::Error), + NotPageAligned, + InvalidOffset(GuestMemoryError), + DmabufCreationFail(io::Error), +} + +impl fmt::Display for UdmabufError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + use self::UdmabufError::*; + match self { + DriverOpenFailed(e) => write!(f, "failed to open udmabuf driver: {:?}", e), + NotPageAligned => write!(f, "All guest addresses must aligned to 4KiB"), + InvalidOffset(e) => write!(f, "failed to get region offset: {:?}", e), + DmabufCreationFail(e) => write!(f, "failed to create buffer: {:?}", e), + } + } +} + +/// The result of an operation in this file. +pub type UdmabufResult<T> = std::result::Result<T, UdmabufError>; + +// Returns absolute offset within the memory corresponding to a particular guest address. +// This offset is not relative to a particular mapping. + +// # Examples +// +// # fn test_memory_offsets() { +// # let start_addr1 = GuestAddress(0x100) +// # let start_addr2 = GuestAddress(0x1100); +// # let mem = GuestMemory::new(&vec![(start_addr1, 0x1000),(start_addr2, 0x1000)])?; +// # assert_eq!(memory_offset(&mem, GuestAddress(0x1100), 0x1000).unwrap(),0x1000); +// #} +fn memory_offset(mem: &GuestMemory, guest_addr: GuestAddress, len: u64) -> UdmabufResult<u64> { + mem.do_in_region(guest_addr, move |mapping, map_offset, memfd_offset| { + let map_offset = map_offset as u64; + if map_offset + .checked_add(len) + .map_or(true, |a| a > mapping.size() as u64) + { + return Err(GuestMemoryError::InvalidGuestAddress(guest_addr)); + } + + return Ok(memfd_offset + map_offset); + }) + .map_err(UdmabufError::InvalidOffset) +} + +/// A convenience wrapper for the Linux kernel's udmabuf driver. +/// +/// udmabuf is a kernel driver that turns memfd pages into dmabufs. It can be used for +/// zero-copy buffer sharing between the guest and host when guest memory is backed by +/// memfd pages. +pub struct UdmabufDriver { + driver_fd: File, +} + +impl UdmabufDriver { + /// Opens the udmabuf device on success. + pub fn new() -> UdmabufResult<UdmabufDriver> { + const UDMABUF_PATH: &str = "/dev/udmabuf"; + let path = Path::new(UDMABUF_PATH); + let fd = OpenOptions::new() + .read(true) + .write(true) + .open(path) + .map_err(UdmabufError::DriverOpenFailed)?; + + Ok(UdmabufDriver { driver_fd: fd }) + } + + /// Creates a dma-buf fd for the given scatter-gather list of guest memory pages (`iovecs`). + pub fn create_udmabuf( + &self, + mem: &GuestMemory, + iovecs: &[(GuestAddress, usize)], + ) -> UdmabufResult<RutabagaHandle> { + let pgsize = pagesize(); + + let mut list = UdmabufCreateList::new(iovecs.len() as usize); + let mut items = list.mut_entries_slice(); + for (i, &(addr, len)) in iovecs.iter().enumerate() { + let offset = memory_offset(mem, addr, len as u64)?; + + if offset as usize % pgsize != 0 || len % pgsize != 0 { + return Err(UdmabufError::NotPageAligned); + } + + // `unwrap` can't panic if `memory_offset obove succeeds. + items[i].memfd = mem.shm_region(addr).unwrap().as_raw_descriptor() as u32; + items[i].__pad = 0; + items[i].offset = offset; + items[i].size = len as u64; + } + + // Safe because we always allocate enough space for `udmabuf_create_list`. + let fd = unsafe { + let create_list = list.as_mut_ptr(); + (*create_list).flags = UDMABUF_FLAGS_CLOEXEC; + ioctl_with_ptr(&self.driver_fd, UDMABUF_CREATE_LIST(), create_list) + }; + + if fd < 0 { + return Err(UdmabufError::DmabufCreationFail(io::Error::last_os_error())); + } + + // Safe because we validated the file exists. + let os_handle = unsafe { SafeDescriptor::from_raw_descriptor(fd) }; + Ok(RutabagaHandle { + os_handle, + handle_type: RUTABAGA_MEM_HANDLE_TYPE_DMABUF, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use base::kernel_has_memfd; + use vm_memory::GuestAddress; + + #[test] + fn test_memory_offsets() { + if !kernel_has_memfd() { + return; + } + + let start_addr1 = GuestAddress(0x100); + let start_addr2 = GuestAddress(0x1100); + let start_addr3 = GuestAddress(0x2100); + + let mem = GuestMemory::new(&vec![ + (start_addr1, 0x1000), + (start_addr2, 0x1000), + (start_addr3, 0x1000), + ]) + .unwrap(); + + assert_eq!(memory_offset(&mem, GuestAddress(0x300), 1).unwrap(), 0x200); + assert_eq!( + memory_offset(&mem, GuestAddress(0x1200), 1).unwrap(), + 0x1100 + ); + assert_eq!( + memory_offset(&mem, GuestAddress(0x1100), 0x1000).unwrap(), + 0x1000 + ); + assert!(memory_offset(&mem, GuestAddress(0x1100), 0x1001).is_err()); + } + + #[test] + fn test_udmabuf_create() { + if !kernel_has_memfd() { + return; + } + + let driver_result = UdmabufDriver::new(); + + // Most kernels will not have udmabuf support. + if driver_result.is_err() { + return; + } + + let driver = driver_result.unwrap(); + + let start_addr1 = GuestAddress(0x100); + let start_addr2 = GuestAddress(0x1100); + let start_addr3 = GuestAddress(0x2100); + + let sg_list = vec![ + (start_addr1, 0x1000), + (start_addr2, 0x1000), + (start_addr3, 0x1000), + ]; + + let mem = GuestMemory::new(&sg_list[..]).unwrap(); + + let mut udmabuf_create_list = vec![ + (start_addr3, 0x1000 as usize), + (start_addr2, 0x1000 as usize), + (start_addr1, 0x1000 as usize), + (GuestAddress(0x4000), 0x1000 as usize), + ]; + + let result = driver.create_udmabuf(&mem, &udmabuf_create_list[..]); + assert_eq!(result.is_err(), true); + + udmabuf_create_list.pop(); + + let rutabaga_handle1 = driver + .create_udmabuf(&mem, &udmabuf_create_list[..]) + .unwrap(); + assert_eq!( + rutabaga_handle1.handle_type, + RUTABAGA_MEM_HANDLE_TYPE_DMABUF + ); + + udmabuf_create_list.pop(); + + // Multiple udmabufs with same memory backing is allowed. + let rutabaga_handle2 = driver + .create_udmabuf(&mem, &udmabuf_create_list[..]) + .unwrap(); + assert_eq!( + rutabaga_handle2.handle_type, + RUTABAGA_MEM_HANDLE_TYPE_DMABUF + ); + } +} diff --git a/devices/src/virtio/gpu/udmabuf_bindings.rs b/devices/src/virtio/gpu/udmabuf_bindings.rs new file mode 100644 index 000000000..dcd89a6f5 --- /dev/null +++ b/devices/src/virtio/gpu/udmabuf_bindings.rs @@ -0,0 +1,74 @@ +// Copyright 2021 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +/* automatically generated by rust-bindgen, though the exact command remains + * lost to history. Should be easy to duplicate, if needed. + */ + +#![allow(dead_code)] +#![allow(non_camel_case_types)] + +#[repr(C)] +#[derive(Default)] +pub struct __IncompleteArrayField<T>(::std::marker::PhantomData<T>, [T; 0]); +impl<T> __IncompleteArrayField<T> { + #[inline] + pub fn new() -> Self { + __IncompleteArrayField(::std::marker::PhantomData, []) + } + #[inline] + pub unsafe fn as_ptr(&self) -> *const T { + ::std::mem::transmute(self) + } + #[inline] + pub unsafe fn as_mut_ptr(&mut self) -> *mut T { + ::std::mem::transmute(self) + } + #[inline] + pub unsafe fn as_slice(&self, len: usize) -> &[T] { + ::std::slice::from_raw_parts(self.as_ptr(), len) + } + #[inline] + pub unsafe fn as_mut_slice(&mut self, len: usize) -> &mut [T] { + ::std::slice::from_raw_parts_mut(self.as_mut_ptr(), len) + } +} +impl<T> ::std::fmt::Debug for __IncompleteArrayField<T> { + fn fmt(&self, fmt: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + fmt.write_str("__IncompleteArrayField") + } +} +impl<T> ::std::clone::Clone for __IncompleteArrayField<T> { + #[inline] + fn clone(&self) -> Self { + Self::new() + } +} +pub const UDMABUF_FLAGS_CLOEXEC: u32 = 1; +pub type __u32 = ::std::os::raw::c_uint; +pub type __u64 = ::std::os::raw::c_ulonglong; +#[repr(C)] +#[derive(Debug, Default, Copy, Clone)] +pub struct udmabuf_create { + pub memfd: __u32, + pub flags: __u32, + pub offset: __u64, + pub size: __u64, +} +#[repr(C)] +#[derive(Debug, Default, Copy, Clone)] +pub struct udmabuf_create_item { + pub memfd: __u32, + pub __pad: __u32, + pub offset: __u64, + pub size: __u64, +} +#[repr(C)] +#[repr(align(8))] +#[derive(Debug, Default)] +pub struct udmabuf_create_list { + pub flags: __u32, + pub count: __u32, + pub list: __IncompleteArrayField<udmabuf_create_item>, +} diff --git a/devices/src/virtio/gpu/virtio_gpu.rs b/devices/src/virtio/gpu/virtio_gpu.rs index e1e29064d..00086903a 100644 --- a/devices/src/virtio/gpu/virtio_gpu.rs +++ b/devices/src/virtio/gpu/virtio_gpu.rs @@ -10,7 +10,7 @@ use std::result::Result; use std::sync::Arc; use crate::virtio::resource_bridge::{BufferInfo, PlaneInfo, ResourceInfo, ResourceResponse}; -use base::{error, AsRawDescriptor, ExternalMapping}; +use base::{error, AsRawDescriptor, ExternalMapping, Tube}; use data_model::VolatileSlice; @@ -20,21 +20,22 @@ use rutabaga_gfx::{ RutabagaIovec, Transfer3D, }; -use msg_socket::{MsgReceiver, MsgSender}; - use libc::c_void; use resources::Alloc; -use super::protocol::{GpuResponse::*, GpuResponsePlaneInfo, VirtioGpuResult}; +use super::protocol::{ + GpuResponse::{self, *}, + GpuResponsePlaneInfo, VirtioGpuResult, VIRTIO_GPU_BLOB_FLAG_CREATE_GUEST_HANDLE, + VIRTIO_GPU_BLOB_MEM_HOST3D, +}; +use super::udmabuf::UdmabufDriver; use super::VirtioScanoutBlobData; use sync::Mutex; use vm_memory::{GuestAddress, GuestMemory}; -use vm_control::{ - MaybeOwnedDescriptor, MemSlot, VmMemoryControlRequestSocket, VmMemoryRequest, VmMemoryResponse, -}; +use vm_control::{MemSlot, VmMemoryRequest, VmMemoryResponse}; struct VirtioGpuResource { resource_id: u32, @@ -78,12 +79,13 @@ pub struct VirtioGpu { cursor_surface_id: Option<u32>, // Maps event devices to scanout number. event_devices: Map<u32, u32>, - gpu_device_socket: VmMemoryControlRequestSocket, + gpu_device_tube: Tube, pci_bar: Alloc, map_request: Arc<Mutex<Option<ExternalMapping>>>, rutabaga: Rutabaga, resources: Map<u32, VirtioGpuResource>, external_blob: bool, + udmabuf_driver: Option<UdmabufDriver>, } fn sglist_to_rutabaga_iovecs( @@ -116,15 +118,26 @@ impl VirtioGpu { display_height: u32, rutabaga_builder: RutabagaBuilder, event_devices: Vec<EventDevice>, - gpu_device_socket: VmMemoryControlRequestSocket, + gpu_device_tube: Tube, pci_bar: Alloc, map_request: Arc<Mutex<Option<ExternalMapping>>>, external_blob: bool, + udmabuf: bool, ) -> Option<VirtioGpu> { let rutabaga = rutabaga_builder .build() .map_err(|e| error!("failed to build rutabaga {}", e)) .ok()?; + + let mut udmabuf_driver = None; + if udmabuf { + udmabuf_driver = Some( + UdmabufDriver::new() + .map_err(|e| error!("failed to initialize udmabuf: {}", e)) + .ok()?, + ); + } + let mut virtio_gpu = VirtioGpu { display: Rc::new(RefCell::new(display)), display_width, @@ -134,12 +147,13 @@ impl VirtioGpu { scanout_surface_id: None, cursor_resource_id: None, cursor_surface_id: None, - gpu_device_socket, + gpu_device_tube, pci_bar, map_request, rutabaga, resources: Default::default(), external_blob, + udmabuf_driver, }; for event_device in event_devices { @@ -424,7 +438,7 @@ impl VirtioGpu { /// If supported, export the resource with the given `resource_id` to a file. pub fn export_resource(&mut self, resource_id: u32) -> ResourceResponse { let file = match self.rutabaga.export_blob(resource_id) { - Ok(handle) => handle.os_handle, + Ok(handle) => handle.os_handle.into(), Err(_) => return ResourceResponse::Invalid, }; @@ -453,6 +467,7 @@ impl VirtioGpu { stride: q.strides[3], }, ], + modifier: q.modifier, })) } @@ -460,7 +475,7 @@ impl VirtioGpu { pub fn export_fence(&self, fence_id: u32) -> ResourceResponse { match self.rutabaga.export_fence(fence_id) { Ok(handle) => ResourceResponse::Resource(ResourceInfo::Fence { - file: handle.os_handle, + file: handle.os_handle.into(), }), Err(_) => ResourceResponse::Invalid, } @@ -517,7 +532,7 @@ impl VirtioGpu { // Rely on rutabaga to check for duplicate resource ids. self.resources.insert(resource_id, resource); - self.result_from_query(resource_id) + Ok(self.result_from_query(resource_id)) } /// Attaches backing memory to the given resource, represented by a `Vec` of `(address, size)` @@ -588,19 +603,32 @@ impl VirtioGpu { vecs: Vec<(GuestAddress, usize)>, mem: &GuestMemory, ) -> VirtioGpuResult { - let rutabaga_iovecs = sglist_to_rutabaga_iovecs(&vecs[..], mem).map_err(|_| ErrUnspec)?; + let mut rutabaga_handle = None; + let mut rutabaga_iovecs = None; + + if resource_create_blob.blob_flags & VIRTIO_GPU_BLOB_FLAG_CREATE_GUEST_HANDLE != 0 { + rutabaga_handle = match self.udmabuf_driver { + Some(ref driver) => Some(driver.create_udmabuf(mem, &vecs[..])?), + None => return Err(ErrUnspec), + } + } else if resource_create_blob.blob_mem != VIRTIO_GPU_BLOB_MEM_HOST3D { + rutabaga_iovecs = + Some(sglist_to_rutabaga_iovecs(&vecs[..], mem).map_err(|_| ErrUnspec)?); + } + self.rutabaga.resource_create_blob( ctx_id, resource_id, resource_create_blob, rutabaga_iovecs, + rutabaga_handle, )?; let resource = VirtioGpuResource::new(resource_id, 0, 0, resource_create_blob.size); // Rely on rutabaga to check for duplicate resource ids. self.resources.insert(resource_id, resource); - self.result_from_query(resource_id) + Ok(self.result_from_query(resource_id)) } /// Uses the hypervisor to map the rutabaga blob resource. @@ -611,15 +639,28 @@ impl VirtioGpu { .ok_or(ErrInvalidResourceId)?; let map_info = self.rutabaga.map_info(resource_id).map_err(|_| ErrUnspec)?; + let vulkan_info_opt = self.rutabaga.vulkan_info(resource_id).ok(); + let export = self.rutabaga.export_blob(resource_id); let request = match export { - Ok(ref export) => VmMemoryRequest::RegisterFdAtPciBarOffset( - self.pci_bar, - MaybeOwnedDescriptor::Borrowed(export.os_handle.as_raw_descriptor()), - resource.size as usize, - offset, - ), + Ok(export) => match vulkan_info_opt { + Some(vulkan_info) => VmMemoryRequest::RegisterVulkanMemoryAtPciBarOffset { + alloc: self.pci_bar, + descriptor: export.os_handle, + handle_type: export.handle_type, + memory_idx: vulkan_info.memory_idx, + physical_device_idx: vulkan_info.physical_device_idx, + offset, + size: resource.size, + }, + None => VmMemoryRequest::RegisterFdAtPciBarOffset( + self.pci_bar, + export.os_handle, + resource.size as usize, + offset, + ), + }, Err(_) => { if self.external_blob { return Err(ErrUnspec); @@ -638,8 +679,8 @@ impl VirtioGpu { } }; - self.gpu_device_socket.send(&request)?; - let response = self.gpu_device_socket.recv()?; + self.gpu_device_tube.send(&request)?; + let response = self.gpu_device_tube.recv()?; match response { VmMemoryResponse::RegisterMemory { pfn: _, slot } => { @@ -660,8 +701,8 @@ impl VirtioGpu { let slot = resource.slot.ok_or(ErrUnspec)?; let request = VmMemoryRequest::UnregisterMemory(slot); - self.gpu_device_socket.send(&request)?; - let response = self.gpu_device_socket.recv()?; + self.gpu_device_tube.send(&request)?; + let response = self.gpu_device_tube.recv()?; match response { VmMemoryResponse::Ok => { @@ -704,7 +745,7 @@ impl VirtioGpu { } // Non-public function -- no doc comment needed! - fn result_from_query(&mut self, resource_id: u32) -> VirtioGpuResult { + fn result_from_query(&mut self, resource_id: u32) -> GpuResponse { match self.rutabaga.query(resource_id) { Ok(query) => { let mut plane_info = Vec::with_capacity(4); @@ -715,12 +756,12 @@ impl VirtioGpu { }); } let format_modifier = query.modifier; - Ok(OkResourcePlaneInfo { + OkResourcePlaneInfo { format_modifier, plane_info, - }) + } } - Err(_) => Ok(OkNoData), + Err(_) => OkNoData, } } } diff --git a/devices/src/virtio/input/event_source.rs b/devices/src/virtio/input/event_source.rs index 650b19398..e17812298 100644 --- a/devices/src/virtio/input/event_source.rs +++ b/devices/src/virtio/input/event_source.rs @@ -242,7 +242,7 @@ mod tests { use std::cmp::min; use std::io::{Read, Write}; - use data_model::{DataInit, Le16, Le32}; + use data_model::{DataInit, Le16, SLe32}; use linux_input_sys::InputEventDecoder; use crate::virtio::input::event_source::{input_event, virtio_input_event, EventSourceImpl}; @@ -317,7 +317,11 @@ mod tests { timestamp_fields: [0, 0], type_: 3 * (idx as u16) + 1, code: 3 * (idx as u16) + 2, - value: 3 * (idx as u32) + 3, + value: if idx % 2 == 0 { + 3 * (idx as i32) + 3 + } else { + -3 * (idx as i32) - 3 + }, }); } ret @@ -326,7 +330,7 @@ mod tests { fn assert_events_match(e1: &virtio_input_event, e2: &input_event) { assert_eq!(e1.type_, Le16::from(e2.type_), "type should match"); assert_eq!(e1.code, Le16::from(e2.code), "code should match"); - assert_eq!(e1.value, Le32::from(e2.value), "value should match"); + assert_eq!(e1.value, SLe32::from(e2.value), "value should match"); } #[test] diff --git a/devices/src/virtio/input/mod.rs b/devices/src/virtio/input/mod.rs index 3848431cf..30cd724b2 100644 --- a/devices/src/virtio/input/mod.rs +++ b/devices/src/virtio/input/mod.rs @@ -16,8 +16,8 @@ use vm_memory::GuestMemory; use self::event_source::{EvdevEventSource, EventSource, SocketEventSource}; use super::{ - copy_config, DescriptorChain, DescriptorError, Interrupt, Queue, Reader, VirtioDevice, Writer, - TYPE_INPUT, + copy_config, DescriptorChain, DescriptorError, Interrupt, Queue, Reader, SignalableInterrupt, + VirtioDevice, Writer, TYPE_INPUT, }; use linux_input_sys::{virtio_input_event, InputEventDecoder}; use std::collections::BTreeMap; @@ -433,7 +433,7 @@ impl<T: EventSource> Worker<T> { Ok(count) => count, Err(e) => { error!("Input: failed to read events from virtqueue: {}", e); - break; + return Err(e); } }; @@ -463,7 +463,6 @@ impl<T: EventSource> Worker<T> { (&event_queue_evt, Token::EventQAvailable), (&status_queue_evt, Token::StatusQAvailable), (&self.event_source, Token::InputEventsAvailable), - (self.interrupt.get_resample_evt(), Token::InterruptResample), (&kill_evt, Token::Kill), ]) { Ok(wait_ctx) => wait_ctx, @@ -472,6 +471,15 @@ impl<T: EventSource> Worker<T> { return; } }; + if let Some(resample_evt) = self.interrupt.get_resample_evt() { + if wait_ctx + .add(resample_evt, Token::InterruptResample) + .is_err() + { + error!("failed adding resample event to WaitContext."); + return; + } + } 'wait: loop { let wait_events = match wait_ctx.wait() { diff --git a/devices/src/virtio/interrupt.rs b/devices/src/virtio/interrupt.rs index 2ddff08f6..5b52dc505 100644 --- a/devices/src/virtio/interrupt.rs +++ b/devices/src/virtio/interrupt.rs @@ -9,31 +9,35 @@ use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use sync::Mutex; +pub trait SignalableInterrupt { + /// Writes to the irqfd to VMM to deliver virtual interrupt to the guest. + fn signal(&self, vector: u16, interrupt_status_mask: u32); + + /// Notify the driver that buffers have been placed in the used queue. + fn signal_used_queue(&self, vector: u16) { + self.signal(vector, INTERRUPT_STATUS_USED_RING) + } + + /// Notify the driver that the device configuration has changed. + fn signal_config_changed(&self); + + /// Get the event to signal resampling is needed if it exists. + fn get_resample_evt(&self) -> Option<&Event>; + + /// Reads the status and writes to the interrupt event. Doesn't read the resample event, it + /// assumes the resample has been requested. + fn do_interrupt_resample(&self); +} + pub struct Interrupt { interrupt_status: Arc<AtomicUsize>, interrupt_evt: Event, interrupt_resample_evt: Event, - pub msix_config: Option<Arc<Mutex<MsixConfig>>>, + msix_config: Option<Arc<Mutex<MsixConfig>>>, config_msix_vector: u16, } -impl Interrupt { - pub fn new( - interrupt_status: Arc<AtomicUsize>, - interrupt_evt: Event, - interrupt_resample_evt: Event, - msix_config: Option<Arc<Mutex<MsixConfig>>>, - config_msix_vector: u16, - ) -> Interrupt { - Interrupt { - interrupt_status, - interrupt_evt, - interrupt_resample_evt, - msix_config, - config_msix_vector, - } - } - +impl SignalableInterrupt for Interrupt { /// Virtqueue Interrupts From The Device /// /// If MSI-X is enabled in this device, MSI-X interrupt is preferred. @@ -62,33 +66,46 @@ impl Interrupt { } } - /// Notify the driver that buffers have been placed in the used queue. - pub fn signal_used_queue(&self, vector: u16) { - self.signal(vector, INTERRUPT_STATUS_USED_RING) - } - - /// Notify the driver that the device configuration has changed. - pub fn signal_config_changed(&self) { + fn signal_config_changed(&self) { self.signal(self.config_msix_vector, INTERRUPT_STATUS_CONFIG_CHANGED) } - /// Handle interrupt resampling event, reading the value from the event and doing the resample. - pub fn interrupt_resample(&self) { - let _ = self.interrupt_resample_evt.read(); - self.do_interrupt_resample(); + fn get_resample_evt(&self) -> Option<&Event> { + Some(&self.interrupt_resample_evt) } - /// Read the status and write to the interrupt event. Don't read the resample event, assume the - /// resample has been requested. - pub fn do_interrupt_resample(&self) { + fn do_interrupt_resample(&self) { if self.interrupt_status.load(Ordering::SeqCst) != 0 { self.interrupt_evt.write(1).unwrap(); } } +} + +impl Interrupt { + pub fn new( + interrupt_status: Arc<AtomicUsize>, + interrupt_evt: Event, + interrupt_resample_evt: Event, + msix_config: Option<Arc<Mutex<MsixConfig>>>, + config_msix_vector: u16, + ) -> Interrupt { + Interrupt { + interrupt_status, + interrupt_evt, + interrupt_resample_evt, + msix_config, + config_msix_vector, + } + } + + /// Handle interrupt resampling event, reading the value from the event and doing the resample. + pub fn interrupt_resample(&self) { + let _ = self.interrupt_resample_evt.read(); + self.do_interrupt_resample(); + } - /// Return the reference of interrupt_resample_evt - /// To keep the interface clean, this member is private. - pub fn get_resample_evt(&self) -> &Event { - &self.interrupt_resample_evt + /// Get a reference to the msix configuration + pub fn get_msix_config(&self) -> &Option<Arc<Mutex<MsixConfig>>> { + &self.msix_config } } diff --git a/devices/src/virtio/mod.rs b/devices/src/virtio/mod.rs index a859e13bb..26966815b 100644 --- a/devices/src/virtio/mod.rs +++ b/devices/src/virtio/mod.rs @@ -92,7 +92,7 @@ const MAX_VIRTIO_DEVICE_ID: u32 = 63; const TYPE_WL: u32 = MAX_VIRTIO_DEVICE_ID; const TYPE_TPM: u32 = MAX_VIRTIO_DEVICE_ID - 1; -const VIRTIO_F_VERSION_1: u32 = 32; +pub const VIRTIO_F_VERSION_1: u32 = 32; const VIRTIO_F_ACCESS_PLATFORM: u32 = 33; const INTERRUPT_STATUS_USED_RING: u32 = 0x1; diff --git a/devices/src/virtio/net.rs b/devices/src/virtio/net.rs index 1c1c225cc..d1f35dca5 100644 --- a/devices/src/virtio/net.rs +++ b/devices/src/virtio/net.rs @@ -23,7 +23,8 @@ use virtio_sys::virtio_net::{ use vm_memory::GuestMemory; use super::{ - copy_config, DescriptorError, Interrupt, Queue, Reader, VirtioDevice, Writer, TYPE_NET, + copy_config, DescriptorError, Interrupt, Queue, Reader, SignalableInterrupt, VirtioDevice, + Writer, TYPE_NET, }; const QUEUE_SIZE: u16 = 256; @@ -134,7 +135,7 @@ fn virtio_features_to_tap_offload(features: u64) -> c_uint { #[derive(Debug, Clone, Copy, Default)] #[repr(C)] -struct VirtioNetConfig { +pub(crate) struct VirtioNetConfig { mac: [u8; 6], status: Le16, max_vq_pairs: Le16, @@ -351,9 +352,11 @@ where .add(ctrl_evt, Token::CtrlQueue) .map_err(NetError::CreateWaitContext)?; // Let CtrlQueue's thread handle InterruptResample also. - wait_ctx - .add(self.interrupt.get_resample_evt(), Token::InterruptResample) - .map_err(NetError::CreateWaitContext)?; + if let Some(resample_evt) = self.interrupt.get_resample_evt() { + wait_ctx + .add(resample_evt, Token::InterruptResample) + .map_err(NetError::CreateWaitContext)?; + } } let mut tap_polling_enabled = true; diff --git a/devices/src/virtio/p9.rs b/devices/src/virtio/p9.rs index a082afa06..670645d88 100644 --- a/devices/src/virtio/p9.rs +++ b/devices/src/virtio/p9.rs @@ -12,7 +12,8 @@ use base::{error, warn, Error as SysError, Event, PollToken, RawDescriptor, Wait use vm_memory::GuestMemory; use super::{ - copy_config, DescriptorError, Interrupt, Queue, Reader, VirtioDevice, Writer, TYPE_9P, + copy_config, DescriptorError, Interrupt, Queue, Reader, SignalableInterrupt, VirtioDevice, + Writer, TYPE_9P, }; const QUEUE_SIZE: u16 = 128; @@ -115,12 +116,14 @@ impl Worker { Kill, } - let wait_ctx: WaitContext<Token> = WaitContext::build_with(&[ - (&queue_evt, Token::QueueReady), - (self.interrupt.get_resample_evt(), Token::InterruptResample), - (&kill_evt, Token::Kill), - ]) - .map_err(P9Error::CreateWaitContext)?; + let wait_ctx: WaitContext<Token> = + WaitContext::build_with(&[(&queue_evt, Token::QueueReady), (&kill_evt, Token::Kill)]) + .map_err(P9Error::CreateWaitContext)?; + if let Some(resample_evt) = self.interrupt.get_resample_evt() { + wait_ctx + .add(resample_evt, Token::InterruptResample) + .map_err(P9Error::CreateWaitContext)?; + } loop { let events = wait_ctx.wait().map_err(P9Error::WaitError)?; diff --git a/devices/src/virtio/pmem.rs b/devices/src/virtio/pmem.rs index 6412268d7..9a3ca39cb 100644 --- a/devices/src/virtio/pmem.rs +++ b/devices/src/virtio/pmem.rs @@ -7,19 +7,15 @@ use std::fs::File; use std::io; use std::thread; -use base::{error, AsRawDescriptor, Event, PollToken, RawDescriptor, WaitContext}; +use base::{error, AsRawDescriptor, Event, PollToken, RawDescriptor, Tube, WaitContext}; use base::{Error as SysError, Result as SysResult}; -use vm_memory::{GuestAddress, GuestMemory}; - use data_model::{DataInit, Le32, Le64}; - -use msg_socket::{MsgReceiver, MsgSender}; - -use vm_control::{MemSlot, VmMsyncRequest, VmMsyncRequestSocket, VmMsyncResponse}; +use vm_control::{MemSlot, VmMsyncRequest, VmMsyncResponse}; +use vm_memory::{GuestAddress, GuestMemory}; use super::{ - copy_config, DescriptorChain, DescriptorError, Interrupt, Queue, Reader, VirtioDevice, Writer, - TYPE_PMEM, + copy_config, DescriptorChain, DescriptorError, Interrupt, Queue, Reader, SignalableInterrupt, + VirtioDevice, Writer, TYPE_PMEM, }; const QUEUE_SIZE: u16 = 256; @@ -87,7 +83,7 @@ struct Worker { interrupt: Interrupt, queue: Queue, memory: GuestMemory, - pmem_device_socket: VmMsyncRequestSocket, + pmem_device_tube: Tube, mapping_arena_slot: MemSlot, mapping_size: usize, } @@ -102,12 +98,12 @@ impl Worker { size: self.mapping_size, }; - if let Err(e) = self.pmem_device_socket.send(&request) { + if let Err(e) = self.pmem_device_tube.send(&request) { error!("failed to send request: {}", e); return VIRTIO_PMEM_RESP_TYPE_EIO; } - match self.pmem_device_socket.recv() { + match self.pmem_device_tube.recv() { Ok(response) => match response { VmMsyncResponse::Ok => VIRTIO_PMEM_RESP_TYPE_OK, VmMsyncResponse::Err(e) => { @@ -177,7 +173,6 @@ impl Worker { let wait_ctx: WaitContext<Token> = match WaitContext::build_with(&[ (&queue_evt, Token::QueueAvailable), - (self.interrupt.get_resample_evt(), Token::InterruptResample), (&kill_evt, Token::Kill), ]) { Ok(pc) => pc, @@ -186,6 +181,15 @@ impl Worker { return; } }; + if let Some(resample_evt) = self.interrupt.get_resample_evt() { + if wait_ctx + .add(resample_evt, Token::InterruptResample) + .is_err() + { + error!("failed adding resample event to WaitContext."); + return; + } + } 'wait: loop { let events = match wait_ctx.wait() { @@ -227,7 +231,7 @@ pub struct Pmem { mapping_address: GuestAddress, mapping_arena_slot: MemSlot, mapping_size: u64, - pmem_device_socket: Option<VmMsyncRequestSocket>, + pmem_device_tube: Option<Tube>, } impl Pmem { @@ -237,7 +241,7 @@ impl Pmem { mapping_address: GuestAddress, mapping_arena_slot: MemSlot, mapping_size: u64, - pmem_device_socket: Option<VmMsyncRequestSocket>, + pmem_device_tube: Option<Tube>, ) -> SysResult<Pmem> { if mapping_size > usize::max_value() as u64 { return Err(SysError::new(libc::EOVERFLOW)); @@ -251,7 +255,7 @@ impl Pmem { mapping_address, mapping_arena_slot, mapping_size, - pmem_device_socket, + pmem_device_tube, }) } } @@ -276,8 +280,8 @@ impl VirtioDevice for Pmem { keep_rds.push(disk_image.as_raw_descriptor()); } - if let Some(ref pmem_device_socket) = self.pmem_device_socket { - keep_rds.push(pmem_device_socket.as_raw_descriptor()); + if let Some(ref pmem_device_tube) = self.pmem_device_tube { + keep_rds.push(pmem_device_tube.as_raw_descriptor()); } keep_rds } @@ -320,7 +324,7 @@ impl VirtioDevice for Pmem { // We checked that this fits in a usize in `Pmem::new`. let mapping_size = self.mapping_size as usize; - if let Some(pmem_device_socket) = self.pmem_device_socket.take() { + if let Some(pmem_device_tube) = self.pmem_device_tube.take() { let (self_kill_event, kill_event) = match Event::new().and_then(|e| Ok((e.try_clone()?, e))) { Ok(v) => v, @@ -338,7 +342,7 @@ impl VirtioDevice for Pmem { interrupt, memory, queue, - pmem_device_socket, + pmem_device_tube, mapping_arena_slot, mapping_size, }; diff --git a/devices/src/virtio/queue.rs b/devices/src/virtio/queue.rs index 3f48ad64a..537abaeee 100644 --- a/devices/src/virtio/queue.rs +++ b/devices/src/virtio/queue.rs @@ -11,7 +11,7 @@ use cros_async::{AsyncError, EventAsync}; use virtio_sys::virtio_ring::VIRTIO_RING_F_EVENT_IDX; use vm_memory::{GuestAddress, GuestMemory}; -use super::{Interrupt, VIRTIO_MSI_NO_VECTOR}; +use super::{SignalableInterrupt, VIRTIO_MSI_NO_VECTOR}; const VIRTQ_DESC_F_NEXT: u16 = 0x1; const VIRTQ_DESC_F_WRITE: u16 = 0x2; @@ -536,7 +536,11 @@ impl Queue { /// inject interrupt into guest on this queue /// return true: interrupt is injected into guest for this queue /// false: interrupt isn't injected - pub fn trigger_interrupt(&mut self, mem: &GuestMemory, interrupt: &Interrupt) -> bool { + pub fn trigger_interrupt( + &mut self, + mem: &GuestMemory, + interrupt: &dyn SignalableInterrupt, + ) -> bool { if self.available_interrupt_enabled(mem) { self.last_used = self.next_used; interrupt.signal_used_queue(self.vector); @@ -554,6 +558,7 @@ impl Queue { #[cfg(test)] mod tests { + use super::super::Interrupt; use super::*; use base::Event; use data_model::{DataInit, Le16, Le32, Le64}; diff --git a/devices/src/virtio/resource_bridge.rs b/devices/src/virtio/resource_bridge.rs index ec2c2b590..a89843764 100644 --- a/devices/src/virtio/resource_bridge.rs +++ b/devices/src/virtio/resource_bridge.rs @@ -8,53 +8,51 @@ use std::fmt; use std::fs::File; -use base::RawDescriptor; -use msg_on_socket_derive::MsgOnSocket; -use msg_socket::{MsgError, MsgReceiver, MsgSender, MsgSocket}; +use serde::{Deserialize, Serialize}; -#[derive(MsgOnSocket, Debug)] +use base::{with_as_descriptor, Tube, TubeError}; + +#[derive(Debug, Serialize, Deserialize)] pub enum ResourceRequest { GetBuffer { id: u32 }, GetFence { seqno: u64 }, } -#[derive(MsgOnSocket, Clone, Copy, Default)] +#[derive(Serialize, Deserialize, Clone, Copy, Default)] pub struct PlaneInfo { pub offset: u32, pub stride: u32, } -#[derive(MsgOnSocket)] +#[derive(Serialize, Deserialize)] pub struct BufferInfo { + #[serde(with = "with_as_descriptor")] pub file: File, pub planes: [PlaneInfo; RESOURE_PLANE_NUM], + pub modifier: u64, } pub const RESOURE_PLANE_NUM: usize = 4; -#[derive(MsgOnSocket)] +#[derive(Serialize, Deserialize)] pub enum ResourceInfo { Buffer(BufferInfo), - Fence { file: File }, + Fence { + #[serde(with = "with_as_descriptor")] + file: File, + }, } -#[derive(MsgOnSocket)] +#[derive(Serialize, Deserialize)] pub enum ResourceResponse { Resource(ResourceInfo), Invalid, } -pub type ResourceRequestSocket = MsgSocket<ResourceRequest, ResourceResponse>; -pub type ResourceResponseSocket = MsgSocket<ResourceResponse, ResourceRequest>; - -pub fn pair() -> std::io::Result<(ResourceRequestSocket, ResourceResponseSocket)> { - msg_socket::pair() -} - #[derive(Debug)] pub enum ResourceBridgeError { InvalidResource(ResourceRequest), - SendFailure(ResourceRequest, MsgError), - RecieveFailure(ResourceRequest, MsgError), + SendFailure(ResourceRequest, TubeError), + RecieveFailure(ResourceRequest, TubeError), } impl fmt::Display for ResourceRequest { @@ -89,14 +87,14 @@ impl fmt::Display for ResourceBridgeError { impl std::error::Error for ResourceBridgeError {} pub fn get_resource_info( - sock: &ResourceRequestSocket, + tube: &Tube, request: ResourceRequest, ) -> std::result::Result<ResourceInfo, ResourceBridgeError> { - if let Err(e) = sock.send(&request) { + if let Err(e) = tube.send(&request) { return Err(ResourceBridgeError::SendFailure(request, e)); } - match sock.recv() { + match tube.recv() { Ok(ResourceResponse::Resource(info)) => Ok(info), Ok(ResourceResponse::Invalid) => Err(ResourceBridgeError::InvalidResource(request)), Err(e) => Err(ResourceBridgeError::RecieveFailure(request, e)), diff --git a/devices/src/virtio/rng.rs b/devices/src/virtio/rng.rs index 466ffd247..65d671e6a 100644 --- a/devices/src/virtio/rng.rs +++ b/devices/src/virtio/rng.rs @@ -10,7 +10,7 @@ use std::thread; use base::{error, warn, AsRawDescriptor, Event, PollToken, RawDescriptor, WaitContext}; use vm_memory::GuestMemory; -use super::{Interrupt, Queue, VirtioDevice, Writer, TYPE_RNG}; +use super::{Interrupt, Queue, SignalableInterrupt, VirtioDevice, Writer, TYPE_RNG}; const QUEUE_SIZE: u16 = 256; const QUEUE_SIZES: &[u16] = &[QUEUE_SIZE]; @@ -75,7 +75,6 @@ impl Worker { let wait_ctx: WaitContext<Token> = match WaitContext::build_with(&[ (&queue_evt, Token::QueueAvailable), - (self.interrupt.get_resample_evt(), Token::InterruptResample), (&kill_evt, Token::Kill), ]) { Ok(pc) => pc, @@ -84,6 +83,15 @@ impl Worker { return; } }; + if let Some(resample_evt) = self.interrupt.get_resample_evt() { + if wait_ctx + .add(resample_evt, Token::InterruptResample) + .is_err() + { + error!("failed adding resample event to WaitContext."); + return; + } + } 'wait: loop { let events = match wait_ctx.wait() { diff --git a/devices/src/virtio/snd/constants.rs b/devices/src/virtio/snd/constants.rs index 38b06c496..28a73a083 100644 --- a/devices/src/virtio/snd/constants.rs +++ b/devices/src/virtio/snd/constants.rs @@ -13,14 +13,17 @@ pub const STREAM_STOP: u32 = 0x0100 + 5; pub const CHANNEL_MAP_INFO: u32 = 0x0200; -pub const VIRTIO_SND_S_OK: u32 = 0x8000; -pub const VIRTIO_SND_S_BAD_MSG: u32 = 0x8001; -pub const VIRTIO_SND_S_NOT_SUPP: u32 = 0x8002; -pub const VIRTIO_SND_S_IO_ERR: u32 = 0x8003; - pub const VIRTIO_SND_D_OUTPUT: u8 = 0; pub const VIRTIO_SND_D_INPUT: u8 = 1; +/* supported PCM stream features */ +pub const VIRTIO_SND_PCM_F_SHMEM_HOST: u8 = 0; +pub const VIRTIO_SND_PCM_F_SHMEM_GUEST: u8 = 1; +pub const VIRTIO_SND_PCM_F_MSG_POLLING: u8 = 2; +pub const VIRTIO_SND_PCM_F_EVT_SHMEM_PERIODS: u8 = 3; +pub const VIRTIO_SND_PCM_F_EVT_XRUNS: u8 = 4; + +/* supported PCM sample formats */ pub const VIRTIO_SND_PCM_FMT_IMA_ADPCM: u8 = 0; pub const VIRTIO_SND_PCM_FMT_MU_LAW: u8 = 1; pub const VIRTIO_SND_PCM_FMT_A_LAW: u8 = 2; @@ -61,3 +64,76 @@ pub const VIRTIO_SND_PCM_RATE_96000: u8 = 10; pub const VIRTIO_SND_PCM_RATE_176400: u8 = 11; pub const VIRTIO_SND_PCM_RATE_192000: u8 = 12; pub const VIRTIO_SND_PCM_RATE_384000: u8 = 13; + +// From https://github.com/oasis-tcs/virtio-spec/blob/master/virtio-sound.tex +/* jack control request types */ +pub const VIRTIO_SND_R_JACK_INFO: u32 = 1; +pub const VIRTIO_SND_R_JACK_REMAP: u32 = 2; + +/* PCM control request types */ +pub const VIRTIO_SND_R_PCM_INFO: u32 = 0x0100; +pub const VIRTIO_SND_R_PCM_SET_PARAMS: u32 = 0x0101; +pub const VIRTIO_SND_R_PCM_PREPARE: u32 = 0x0102; +pub const VIRTIO_SND_R_PCM_RELEASE: u32 = 0x0103; +pub const VIRTIO_SND_R_PCM_START: u32 = 0x0104; +pub const VIRTIO_SND_R_PCM_STOP: u32 = 0x0105; + +/* channel map control request types */ +pub const VIRTIO_SND_R_CHMAP_INFO: u32 = 0x0200; + +/* jack event types */ +pub const VIRTIO_SND_EVT_JACK_CONNECTED: u32 = 0x1000; +pub const VIRTIO_SND_EVT_JACK_DISCONNECTED: u32 = 0x1001; + +/* PCM event types */ +pub const VIRTIO_SND_EVT_PCM_PERIOD_ELAPSED: u32 = 0x1100; +pub const VIRTIO_SND_EVT_PCM_XRUN: u32 = 0x1101; + +/* common status codes */ +pub const VIRTIO_SND_S_OK: u32 = 0x8000; +pub const VIRTIO_SND_S_BAD_MSG: u32 = 0x8001; +pub const VIRTIO_SND_S_NOT_SUPP: u32 = 0x8002; +pub const VIRTIO_SND_S_IO_ERR: u32 = 0x8003; + +pub const VIRTIO_SND_JACK_F_REMAP: u32 = 0; + +/* standard channel position definition */ +pub const VIRTIO_SND_CHMAP_NONE: u32 = 0; /* undefined */ +pub const VIRTIO_SND_CHMAP_NA: u32 = 1; /* silent */ +pub const VIRTIO_SND_CHMAP_MONO: u32 = 2; /* mono stream */ +pub const VIRTIO_SND_CHMAP_FL: u32 = 3; /* front left */ +pub const VIRTIO_SND_CHMAP_FR: u32 = 4; /* front right */ +pub const VIRTIO_SND_CHMAP_RL: u32 = 5; /* rear left */ +pub const VIRTIO_SND_CHMAP_RR: u32 = 6; /* rear right */ +pub const VIRTIO_SND_CHMAP_FC: u32 = 7; /* front center */ +pub const VIRTIO_SND_CHMAP_LFE: u32 = 8; /* low frequency (LFE) */ +pub const VIRTIO_SND_CHMAP_SL: u32 = 9; /* side left */ +pub const VIRTIO_SND_CHMAP_SR: u32 = 10; /* side right */ +pub const VIRTIO_SND_CHMAP_RC: u32 = 11; /* rear center */ +pub const VIRTIO_SND_CHMAP_FLC: u32 = 12; /* front left center */ +pub const VIRTIO_SND_CHMAP_FRC: u32 = 13; /* front right center */ +pub const VIRTIO_SND_CHMAP_RLC: u32 = 14; /* rear left center */ +pub const VIRTIO_SND_CHMAP_RRC: u32 = 15; /* rear right center */ +pub const VIRTIO_SND_CHMAP_FLW: u32 = 16; /* front left wide */ +pub const VIRTIO_SND_CHMAP_FRW: u32 = 17; /* front right wide */ +pub const VIRTIO_SND_CHMAP_FLH: u32 = 18; /* front left high */ +pub const VIRTIO_SND_CHMAP_FCH: u32 = 19; /* front center high */ +pub const VIRTIO_SND_CHMAP_FRH: u32 = 20; /* front right high */ +pub const VIRTIO_SND_CHMAP_TC: u32 = 21; /* top center */ +pub const VIRTIO_SND_CHMAP_TFL: u32 = 22; /* top front left */ +pub const VIRTIO_SND_CHMAP_TFR: u32 = 23; /* top front right */ +pub const VIRTIO_SND_CHMAP_TFC: u32 = 24; /* top front center */ +pub const VIRTIO_SND_CHMAP_TRL: u32 = 25; /* top rear left */ +pub const VIRTIO_SND_CHMAP_TRR: u32 = 26; /* top rear right */ +pub const VIRTIO_SND_CHMAP_TRC: u32 = 27; /* top rear center */ +pub const VIRTIO_SND_CHMAP_TFLC: u32 = 28; /* top front left center */ +pub const VIRTIO_SND_CHMAP_TFRC: u32 = 29; /* top front right center */ +pub const VIRTIO_SND_CHMAP_TSL: u32 = 34; /* top side left */ +pub const VIRTIO_SND_CHMAP_TSR: u32 = 35; /* top side right */ +pub const VIRTIO_SND_CHMAP_LLFE: u32 = 36; /* left LFE */ +pub const VIRTIO_SND_CHMAP_RLFE: u32 = 37; /* right LFE */ +pub const VIRTIO_SND_CHMAP_BC: u32 = 38; /* bottom center */ +pub const VIRTIO_SND_CHMAP_BLC: u32 = 39; /* bottom left center */ +pub const VIRTIO_SND_CHMAP_BRC: u32 = 40; /* bottom right center */ + +pub const VIRTIO_SND_CHMAP_MAX_SIZE: usize = 18; diff --git a/devices/src/virtio/snd/layout.rs b/devices/src/virtio/snd/layout.rs index c78a9d7a4..8881af0b4 100644 --- a/devices/src/virtio/snd/layout.rs +++ b/devices/src/virtio/snd/layout.rs @@ -2,9 +2,20 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. +use crate::virtio::snd::constants::VIRTIO_SND_CHMAP_MAX_SIZE; use data_model::{DataInit, Le32, Le64}; #[derive(Copy, Clone, Default)] +#[repr(C, packed)] +pub struct virtio_snd_config { + pub jacks: Le32, + pub streams: Le32, + pub chmaps: Le32, +} +// Safe because it only has data and has no implicit padding. +unsafe impl DataInit for virtio_snd_config {} + +#[derive(Copy, Clone, Default)] #[repr(C)] pub struct virtio_snd_hdr { pub code: Le32, @@ -12,6 +23,24 @@ pub struct virtio_snd_hdr { // Safe because it only has data and has no implicit padding. unsafe impl DataInit for virtio_snd_hdr {} +#[derive(Copy, Clone, Default)] +#[repr(C)] +pub struct virtio_snd_jack_hdr { + pub hdr: virtio_snd_hdr, + pub jack_id: Le32, +} +// Safe because it only has data and has no implicit padding. +unsafe impl DataInit for virtio_snd_jack_hdr {} + +#[derive(Copy, Clone, Default)] +#[repr(C)] +pub struct virtio_snd_event { + pub hdr: virtio_snd_hdr, + pub data: Le32, +} +// Safe because it only has data and has no implicit padding. +unsafe impl DataInit for virtio_snd_event {} + #[derive(Copy, Clone)] #[repr(C)] pub struct virtio_snd_query_info { @@ -35,9 +64,9 @@ unsafe impl DataInit for virtio_snd_info {} #[repr(C)] pub struct virtio_snd_pcm_info { pub hdr: virtio_snd_info, - pub features: Le32, - pub formats: Le64, - pub rates: Le64, + pub features: Le32, /* 1 << VIRTIO_SND_PCM_F_XXX */ + pub formats: Le64, /* 1 << VIRTIO_SND_PCM_FMT_XXX */ + pub rates: Le64, /* 1 << VIRTIO_SND_PCM_RATE_XXX */ pub direction: u8, pub channels_min: u8, pub channels_max: u8, @@ -62,7 +91,7 @@ pub struct virtio_snd_pcm_set_params { pub hdr: virtio_snd_pcm_hdr, pub buffer_bytes: Le32, pub period_bytes: Le32, - pub features: Le32, + pub features: Le32, /* 1 << VIRTIO_SND_PCM_F_XXX */ pub channels: u8, pub format: u8, pub rate: u8, @@ -79,7 +108,7 @@ pub struct virtio_snd_pcm_xfer { // Safe because it only has data and has no implicit padding. unsafe impl DataInit for virtio_snd_pcm_xfer {} -#[derive(Copy, Clone)] +#[derive(Copy, Clone, Default)] #[repr(C)] pub struct virtio_snd_pcm_status { pub status: Le32, @@ -87,3 +116,37 @@ pub struct virtio_snd_pcm_status { } // Safe because it only has data and has no implicit padding. unsafe impl DataInit for virtio_snd_pcm_status {} + +#[derive(Copy, Clone)] +#[repr(C)] +pub struct virtio_snd_jack_info { + pub hdr: virtio_snd_info, + pub features: Le32, /* 1 << VIRTIO_SND_JACK_F_XXX */ + pub hda_reg_defconf: Le32, + pub hda_reg_caps: Le32, + pub connected: u8, + pub padding: [u8; 7], +} +// Safe because it only has data and has no implicit padding. +unsafe impl DataInit for virtio_snd_jack_info {} + +#[derive(Copy, Clone)] +#[repr(C)] +pub struct virtio_snd_jack_remap { + pub hdr: virtio_snd_jack_hdr, /* .code = VIRTIO_SND_R_JACK_REMAP */ + pub association: Le32, + pub sequence: Le32, +} +// Safe because it only has data and has no implicit padding. +unsafe impl DataInit for virtio_snd_jack_remap {} + +#[derive(Copy, Clone)] +#[repr(C)] +pub struct virtio_snd_chmap_info { + pub hdr: virtio_snd_info, + pub direction: u8, + pub channels: u8, + pub positions: [u8; VIRTIO_SND_CHMAP_MAX_SIZE], +} +// Safe because it only has data and has no implicit padding. +unsafe impl DataInit for virtio_snd_chmap_info {} diff --git a/devices/src/virtio/snd/vios_backend/mod.rs b/devices/src/virtio/snd/vios_backend/mod.rs index 416fe7725..602ec4dcd 100644 --- a/devices/src/virtio/snd/vios_backend/mod.rs +++ b/devices/src/virtio/snd/vios_backend/mod.rs @@ -5,7 +5,7 @@ mod shm_streams; mod shm_vios; -#[cfg(target_os = "linux")] +#[cfg(any(target_os = "linux", target_os = "android"))] pub use self::shm_streams::*; pub use self::shm_vios::*; diff --git a/devices/src/virtio/snd/vios_backend/shm_streams.rs b/devices/src/virtio/snd/vios_backend/shm_streams.rs index c83ae86ef..5b5e0f641 100644 --- a/devices/src/virtio/snd/vios_backend/shm_streams.rs +++ b/devices/src/virtio/snd/vios_backend/shm_streams.rs @@ -22,7 +22,6 @@ use std::path::Path; use std::sync::Arc; use std::time::{Duration, Instant}; -use sync::Mutex; use sys_util::{Error as SysError, SharedMemory as SysSharedMemory}; use super::shm_vios::{Error, Result}; @@ -33,17 +32,14 @@ type GenericResult<T> = std::result::Result<T, BoxError>; /// Adapter that provides the ShmStreamSource trait around the VioS backend. pub struct VioSShmStreamSource { - // Reference counting is needed because the streams also need a reference to the client to push - // buffers and release the stream on drop and it has to implement Send because ShmStreamSource - // requires it, so the only possibility is Arc. - vios_client: Arc<Mutex<VioSClient>>, + vios_client: Arc<VioSClient>, } impl VioSShmStreamSource { /// Creates a new stream source given the path to the audio server's socket. pub fn new<P: AsRef<Path>>(server: P) -> Result<VioSShmStreamSource> { Ok(Self { - vios_client: Arc::new(Mutex::new(VioSClient::try_new(server)?)), + vios_client: Arc::new(VioSClient::try_new(server)?), }) } } @@ -93,39 +89,41 @@ impl ShmStreamSource for VioSShmStreamSource { client_shm: &SysSharedMemory, _buffer_offsets: [u64; 2], ) -> GenericResult<Box<dyn ShmStream>> { - let mut vios_client = self.vios_client.lock(); - match direction { - StreamDirection::Playback => { - match vios_client.get_unused_stream_id(VIRTIO_SND_D_OUTPUT) { - Some(stream_id) => { - vios_client.prepare_stream(stream_id)?; - let frame_size = num_channels * format.sample_bytes(); - let period_bytes = (frame_size * buffer_size) as u32; - let params = VioSStreamParams { - buffer_bytes: 2 * period_bytes, - period_bytes, - features: 0u32, - channels: num_channels as u8, - format: from_sample_format(format), - rate: virtio_frame_rate(frame_rate)?, - }; - vios_client.set_stream_parameters(stream_id, params)?; - vios_client.start_stream(stream_id)?; - VioSndShmStream::new( - buffer_size, - num_channels, - format, - frame_rate, - stream_id, - self.vios_client.clone(), - client_shm, - ) - } - None => Err(Box::new(Error::NoStreamsAvailable)), - } - } - StreamDirection::Capture => panic!("Capture not yet supported"), - } + self.vios_client.ensure_bg_thread_started()?; + let virtio_dir = match direction { + StreamDirection::Playback => VIRTIO_SND_D_OUTPUT, + StreamDirection::Capture => VIRTIO_SND_D_INPUT, + }; + let frame_size = num_channels * format.sample_bytes(); + let period_bytes = (frame_size * buffer_size) as u32; + let stream_id = self + .vios_client + .get_unused_stream_id(virtio_dir) + .ok_or(Box::new(Error::NoStreamsAvailable))?; + // Create the stream object before any errors can be returned to guarantee the stream will + // be released in all cases + let stream_box = VioSndShmStream::new( + buffer_size, + num_channels, + format, + frame_rate, + stream_id, + direction, + self.vios_client.clone(), + client_shm, + ); + self.vios_client.prepare_stream(stream_id)?; + let params = VioSStreamParams { + buffer_bytes: 2 * period_bytes, + period_bytes, + features: 0u32, + channels: num_channels as u8, + format: from_sample_format(format), + rate: virtio_frame_rate(frame_rate)?, + }; + self.vios_client.set_stream_parameters(stream_id, params)?; + self.vios_client.start_stream(stream_id)?; + stream_box } /// Get a list of file descriptors used by the implementation. @@ -134,7 +132,7 @@ impl ShmStreamSource for VioSShmStreamSource { /// This list helps users of the ShmStreamSource enter Linux jails without /// closing needed file descriptors. fn keep_fds(&self) -> Vec<RawFd> { - self.vios_client.lock().keep_fds() + self.vios_client.keep_fds() } } @@ -148,7 +146,8 @@ pub struct VioSndShmStream { next_frame: Duration, start_time: Instant, stream_id: u32, - vios_client: Arc<Mutex<VioSClient>>, + direction: StreamDirection, + vios_client: Arc<VioSClient>, client_shm: SharedMemory, } @@ -160,21 +159,22 @@ impl VioSndShmStream { format: SampleFormat, frame_rate: u32, stream_id: u32, - vios_client: Arc<Mutex<VioSClient>>, + direction: StreamDirection, + vios_client: Arc<VioSClient>, client_shm: &SysSharedMemory, ) -> GenericResult<Box<dyn ShmStream>> { let interval = Duration::from_millis(buffer_size as u64 * 1000 / frame_rate as u64); let dup_fd = unsafe { - // Safe because dup doesn't affect memory and client_shm should wrap a known valid file - // descriptor - libc::dup(client_shm.as_raw_fd()) + // Safe because fcntl doesn't affect memory and client_shm should wrap a known valid + // file descriptor. + libc::fcntl(client_shm.as_raw_fd(), libc::F_DUPFD_CLOEXEC, 0) }; if dup_fd < 0 { return Err(Box::new(Error::DupError(SysError::last()))); } let file = unsafe { - // safe because we checked the result of libc::dup() + // safe because we checked the result of libc::fcntl() File::from_raw_fd(dup_fd) }; let client_shm_clone = @@ -189,6 +189,7 @@ impl VioSndShmStream { next_frame: interval, start_time: Instant::now(), stream_id, + direction, vios_client, client_shm: client_shm_clone, })) @@ -211,10 +212,10 @@ impl ShmStream for VioSndShmStream { /// Waits until the next time a frame should be sent to the server. The server may release the /// previous buffer much sooner than it needs the next one, so this function may sleep to wait /// for the right time. - fn wait_for_next_action_with_timeout<'b>( - &'b mut self, + fn wait_for_next_action_with_timeout( + &mut self, timeout: Duration, - ) -> GenericResult<Option<ServerRequest<'b>>> { + ) -> GenericResult<Option<ServerRequest>> { let elapsed = self.start_time.elapsed(); if elapsed < self.next_frame { if timeout < self.next_frame - elapsed { @@ -231,12 +232,24 @@ impl ShmStream for VioSndShmStream { impl BufferSet for VioSndShmStream { fn callback(&mut self, offset: usize, frames: usize) -> GenericResult<()> { - self.vios_client.lock().inject_audio_data( - self.stream_id, - &mut self.client_shm, - offset, - frames * self.frame_size, - )?; + match self.direction { + StreamDirection::Playback => { + self.vios_client.inject_audio_data( + self.stream_id, + &mut self.client_shm, + offset, + frames * self.frame_size, + )?; + } + StreamDirection::Capture => { + self.vios_client.request_audio_data( + self.stream_id, + &mut self.client_shm, + offset, + frames * self.frame_size, + )?; + } + } Ok(()) } @@ -247,11 +260,11 @@ impl BufferSet for VioSndShmStream { impl Drop for VioSndShmStream { fn drop(&mut self) { - let mut client = self.vios_client.lock(); let stream_id = self.stream_id; - if let Err(e) = client + if let Err(e) = self + .vios_client .stop_stream(stream_id) - .and_then(|_| client.release_stream(stream_id)) + .and_then(|_| self.vios_client.release_stream(stream_id)) { error!("Failed to stop and release stream {}: {}", stream_id, e); } diff --git a/devices/src/virtio/snd/vios_backend/shm_vios.rs b/devices/src/virtio/snd/vios_backend/shm_vios.rs index d6372e68c..02ffa1ce5 100644 --- a/devices/src/virtio/snd/vios_backend/shm_vios.rs +++ b/devices/src/virtio/snd/vios_backend/shm_vios.rs @@ -6,15 +6,22 @@ use crate::virtio::snd::constants::*; use crate::virtio::snd::layout::*; use base::{ - net::UnixSeqpacket, Error as BaseError, FromRawDescriptor, IntoRawDescriptor, MemoryMapping, - MemoryMappingBuilder, MmapError, SafeDescriptor, ScmSocket, SharedMemory, + error, net::UnixSeqpacket, Error as BaseError, Event, FromRawDescriptor, IntoRawDescriptor, + MemoryMapping, MemoryMappingBuilder, MmapError, PollToken, SafeDescriptor, ScmSocket, + SharedMemory, WaitContext, }; use data_model::{DataInit, VolatileMemory, VolatileMemoryError}; +use std::collections::HashMap; use std::fs::File; use std::io::{Error as IOError, ErrorKind as IOErrorKind, Seek, SeekFrom}; use std::os::unix::io::{AsRawFd, FromRawFd, RawFd}; use std::path::Path; +use std::sync::mpsc::{channel, Receiver, RecvError, Sender}; +use std::sync::Arc; +use std::thread::JoinHandle; + +use sync::Mutex; use thiserror::Error as ThisError; @@ -58,6 +65,20 @@ pub enum Error { PlatformNotSupported, #[error("Command failed with status {0}")] CommandFailed(u32), + #[error("IO buffer operation failed: status = {0}")] + IOBufferError(u32), + #[error("Failed to duplicate UnixSeqpacket: {0}")] + UnixSeqpacketDupError(IOError), + #[error("Sender was dropped without sending buffer status, the recv thread may have exited")] + BufferStatusSenderLost(RecvError), + #[error("Failed to create Recv event: {0}")] + EventCreateError(BaseError), + #[error("Failed to dup Recv event: {0}")] + EventDupError(BaseError), + #[error("Failed to create Recv thread's WaitContext: {0}")] + WaitContextCreateError(BaseError), + #[error("Error waiting for events")] + WaitError(BaseError), } #[derive(ThisError, Debug)] @@ -75,14 +96,21 @@ pub enum ProtocolErrorKind { /// The client for the VioS backend /// /// Uses a protocol equivalent to virtio-snd over a shared memory file and a unix socket for -/// notifications. +/// notifications. It's thread safe, it can be encapsulated in an Arc smart pointer and shared +/// between threads. pub struct VioSClient { config: VioSConfig, - streams: Vec<VioSStreamInfo>, - control_socket: UnixSeqpacket, - event_socket: UnixSeqpacket, - tx: IoBufferQueue, - rx: IoBufferQueue, + // These mutexes should almost never be held simultaneously. If at some point they have to the + // locking order should match the order in which they are declared here. + streams: Mutex<Vec<VioSStreamInfo>>, + control_socket: Mutex<UnixSeqpacket>, + event_socket: Mutex<UnixSeqpacket>, + tx: Mutex<IoBufferQueue>, + rx: Mutex<IoBufferQueue>, + rx_subscribers: Arc<Mutex<HashMap<usize, Sender<(u32, usize)>>>>, + recv_running: Arc<Mutex<bool>>, + recv_event: Mutex<Event>, + recv_thread: Mutex<Option<JoinHandle<Result<()>>>>, } impl VioSClient { @@ -93,14 +121,14 @@ impl VioSClient { let mut config: VioSConfig = Default::default(); let mut fds: Vec<RawFd> = Vec::new(); const NUM_FDS: usize = 5; - fds.resize(NUM_FDS, 0 as RawFd); + fds.resize(NUM_FDS, 0); let (recv_size, fd_count) = client_socket .recv_with_fds(config.as_mut_slice(), &mut fds) .map_err(|e| Error::ServerError(e))?; // Resize the vector to the actual number of file descriptors received and wrap them in // SafeDescriptors to prevent leaks - fds.resize(fd_count, -1 as RawFd); + fds.resize(fd_count, -1); let mut safe_fds: Vec<SafeDescriptor> = fds .into_iter() .map(|fd| unsafe { @@ -153,30 +181,61 @@ impl VioSClient { )); } - let control_socket = client_socket; + let rx_subscribers: Arc<Mutex<HashMap<usize, Sender<(u32, usize)>>>> = + Arc::new(Mutex::new(HashMap::new())); + let recv_running = Arc::new(Mutex::new(true)); + let recv_event = Event::new().map_err(|e| Error::EventCreateError(e))?; let mut client = VioSClient { config, - streams: Vec::new(), - control_socket, - event_socket, - tx: IoBufferQueue::new(tx_socket, tx_shm_file)?, - rx: IoBufferQueue::new(rx_socket, rx_shm_file)?, + streams: Mutex::new(Vec::new()), + control_socket: Mutex::new(client_socket), + event_socket: Mutex::new(event_socket), + tx: Mutex::new(IoBufferQueue::new(tx_socket, tx_shm_file)?), + rx: Mutex::new(IoBufferQueue::new(rx_socket, rx_shm_file)?), + rx_subscribers, + recv_running, + recv_event: Mutex::new(recv_event), + recv_thread: Mutex::new(None), }; client.request_and_cache_streams_info()?; - Ok(client) } - /// Get a description of the available sound streams. - pub fn streams(&self) -> &Vec<VioSStreamInfo> { - &self.streams + pub fn ensure_bg_thread_started(&self) -> Result<()> { + if self.recv_thread.lock().is_some() { + return Ok(()); + } + let event_socket = self + .recv_event + .lock() + .try_clone() + .map_err(|e| Error::EventDupError(e))?; + let rx_socket = self + .rx + .lock() + .socket + .try_clone() + .map_err(|e| Error::UnixSeqpacketDupError(e))?; + let mut opt = self.recv_thread.lock(); + // The lock on recv_thread was released above to avoid holding more than one lock at a time + // while duplicating the fds. So we have to check again the condition. + if opt.is_none() { + *opt = Some(spawn_recv_thread( + self.rx_subscribers.clone(), + event_socket, + self.recv_running.clone(), + rx_socket, + )); + } + Ok(()) } /// Gets an unused stream id of the specified direction. `direction` must be one of /// VIRTIO_SND_D_INPUT OR VIRTIO_SND_D_OUTPUT. pub fn get_unused_stream_id(&self, direction: u8) -> Option<u32> { self.streams + .lock() .iter() .filter(|s| s.state == StreamState::Available && s.direction == direction as u8) .map(|s| s.id) @@ -184,21 +243,20 @@ impl VioSClient { } /// Configures a stream with the given parameters. - pub fn set_stream_parameters( - &mut self, - stream_id: u32, - params: VioSStreamParams, - ) -> Result<()> { - self.validate_stream_id(stream_id, &[StreamState::Available, StreamState::Acquired])?; + pub fn set_stream_parameters(&self, stream_id: u32, params: VioSStreamParams) -> Result<()> { + self.validate_stream_id( + stream_id, + &[StreamState::Available, StreamState::Acquired], + None, + )?; let raw_params: virtio_snd_pcm_set_params = (stream_id, params).into(); - seq_socket_send(&self.control_socket, raw_params)?; - self.recv_cmd_status()?; - self.streams[stream_id as usize].state = StreamState::Acquired; + self.send_cmd(raw_params)?; + self.streams.lock()[stream_id as usize].state = StreamState::Acquired; Ok(()) } /// Send the PREPARE_STREAM command to the server. - pub fn prepare_stream(&mut self, stream_id: u32) -> Result<()> { + pub fn prepare_stream(&self, stream_id: u32) -> Result<()> { self.common_stream_op( stream_id, &[StreamState::Available, StreamState::Acquired], @@ -208,7 +266,7 @@ impl VioSClient { } /// Send the RELEASE_STREAM command to the server. - pub fn release_stream(&mut self, stream_id: u32) -> Result<()> { + pub fn release_stream(&self, stream_id: u32) -> Result<()> { self.common_stream_op( stream_id, &[StreamState::Acquired], @@ -218,7 +276,7 @@ impl VioSClient { } /// Send the START_STREAM command to the server. - pub fn start_stream(&mut self, stream_id: u32) -> Result<()> { + pub fn start_stream(&self, stream_id: u32) -> Result<()> { self.common_stream_op( stream_id, &[StreamState::Acquired], @@ -228,7 +286,7 @@ impl VioSClient { } /// Send the STOP_STREAM command to the server. - pub fn stop_stream(&mut self, stream_id: u32) -> Result<()> { + pub fn stop_stream(&self, stream_id: u32) -> Result<()> { self.common_stream_op( stream_id, &[StreamState::Active], @@ -239,77 +297,121 @@ impl VioSClient { /// Send audio frames to the server. The audio data is taken from a shared memory resource. pub fn inject_audio_data( - &mut self, + &self, stream_id: u32, buffer: &mut SharedMemory, src_offset: usize, size: usize, ) -> Result<()> { - self.validate_stream_id(stream_id, &[StreamState::Active])?; - if self.streams[stream_id as usize].direction != VIRTIO_SND_D_OUTPUT { - return Err(Error::WrongDirection( - self.streams[stream_id as usize].direction, - )); - } - let dst_offset = self.tx.push_buffer(buffer, src_offset, size)?; + self.validate_stream_id(stream_id, &[StreamState::Active], Some(VIRTIO_SND_D_OUTPUT))?; + let mut tx_lock = self.tx.lock(); + let tx = &mut *tx_lock; + let dst_offset = tx.push_buffer(buffer, src_offset, size)?; let msg = IoTransferMsg::new(stream_id, dst_offset, size); - seq_socket_send(&self.tx.socket, msg) + seq_socket_send(&tx.socket, msg) + } + + pub fn request_audio_data( + &self, + stream_id: u32, + buffer: &mut SharedMemory, + dst_offset: usize, + size: usize, + ) -> Result<usize> { + self.validate_stream_id(stream_id, &[StreamState::Active], Some(VIRTIO_SND_D_INPUT))?; + let (src_offset, status_promise) = { + let mut rx_lock = self.rx.lock(); + let rx = &mut *rx_lock; + let src_offset = rx.allocate_buffer(size)?; + // Register to receive the status before sending the buffer to the server + let (sender, receiver): (Sender<(u32, usize)>, Receiver<(u32, usize)>) = channel(); + // It's OK to acquire rx_subscriber's lock after rx_lock + self.rx_subscribers.lock().insert(src_offset, sender); + let msg = IoTransferMsg::new(stream_id, src_offset, size); + seq_socket_send(&rx.socket, msg)?; + (src_offset, receiver) + }; + // Make sure no mutexes are held while awaiting for the buffer to be written to + let recv_size = await_status(status_promise)?; + { + let mut rx_lock = self.rx.lock(); + rx_lock + .pop_buffer(buffer, dst_offset, recv_size, src_offset) + .map(|()| recv_size) + } } /// Get a list of file descriptors used by the implementation. pub fn keep_fds(&self) -> Vec<RawFd> { + let control_fd = self.control_socket.lock().as_raw_fd(); + let event_fd = self.event_socket.lock().as_raw_fd(); + let (tx_socket_fd, tx_shm_fd) = { + let lock = self.tx.lock(); + (lock.socket.as_raw_fd(), lock.file.as_raw_fd()) + }; + let (rx_socket_fd, rx_shm_fd) = { + let lock = self.rx.lock(); + (lock.socket.as_raw_fd(), lock.file.as_raw_fd()) + }; vec![ - self.control_socket.as_raw_fd(), - self.event_socket.as_raw_fd(), - self.tx.socket.as_raw_fd(), - self.rx.socket.as_raw_fd(), - self.tx.file.as_raw_fd(), - self.rx.file.as_raw_fd(), + control_fd, + event_fd, + tx_socket_fd, + tx_shm_fd, + rx_socket_fd, + rx_shm_fd, ] } - fn validate_stream_id(&self, stream_id: u32, permitted_states: &[StreamState]) -> Result<()> { - if stream_id >= self.streams.len() as u32 { + fn send_cmd<T: DataInit>(&self, data: T) -> Result<()> { + let mut control_socket_lock = self.control_socket.lock(); + seq_socket_send(&mut *control_socket_lock, data)?; + recv_cmd_status(&mut *control_socket_lock) + } + + fn validate_stream_id( + &self, + stream_id: u32, + permitted_states: &[StreamState], + direction: Option<u8>, + ) -> Result<()> { + let streams_lock = self.streams.lock(); + let stream_idx = stream_id as usize; + if stream_idx >= streams_lock.len() { return Err(Error::InvalidStreamId(stream_id)); } - if !permitted_states.contains(&self.streams[stream_id as usize].state) { - return Err(Error::UnexpectedState( - self.streams[stream_id as usize].state, - )); + if !permitted_states.contains(&streams_lock[stream_idx].state) { + return Err(Error::UnexpectedState(streams_lock[stream_idx].state)); + } + match direction { + None => Ok(()), + Some(d) => { + if d == streams_lock[stream_idx].direction { + Ok(()) + } else { + Err(Error::WrongDirection(streams_lock[stream_idx].direction)) + } + } } - Ok(()) } fn common_stream_op( - &mut self, + &self, stream_id: u32, expected_states: &[StreamState], new_state: StreamState, op: u32, ) -> Result<()> { - self.validate_stream_id(stream_id, expected_states)?; + self.validate_stream_id(stream_id, expected_states, None)?; let msg = virtio_snd_pcm_hdr { hdr: virtio_snd_hdr { code: op.into() }, stream_id: stream_id.into(), }; - seq_socket_send(&self.control_socket, msg)?; - self.recv_cmd_status()?; - self.streams[stream_id as usize].state = new_state; + self.send_cmd(msg)?; + self.streams.lock()[stream_id as usize].state = new_state; Ok(()) } - fn recv_cmd_status(&mut self) -> Result<()> { - let mut status: virtio_snd_hdr = Default::default(); - self.control_socket - .recv(status.as_mut_slice()) - .map_err(|e| Error::ServerIOError(e))?; - if status.code.to_native() == VIRTIO_SND_S_OK { - Ok(()) - } else { - Err(Error::CommandFailed(status.code.to_native())) - } - } - fn request_and_cache_streams_info(&mut self) -> Result<()> { let num_streams = self.config.streams as usize; let info_size = std::mem::size_of::<virtio_snd_pcm_info>(); @@ -321,10 +423,9 @@ impl VioSClient { count: (num_streams as u32).into(), size: (std::mem::size_of::<virtio_snd_query_info>() as u32).into(), }; - seq_socket_send(&self.control_socket, req)?; - self.recv_cmd_status()?; - let info_vec = self - .control_socket + self.send_cmd(req)?; + let control_socket_lock = self.control_socket.lock(); + let info_vec = control_socket_lock .recv_as_vec() .map_err(|e| Error::ServerIOError(e))?; if info_vec.len() != num_streams * info_size { @@ -332,19 +433,123 @@ impl VioSClient { ProtocolErrorKind::UnexpectedMessageSize(num_streams * info_size, info_vec.len()), )); } - self.streams = info_vec - .chunks(info_size) - .enumerate() - .map(|(id, info_buffer)| { - // unwrap is safe because we checked the size of the vector above - let virtio_stream_info = virtio_snd_pcm_info::from_slice(&info_buffer).unwrap(); - VioSStreamInfo::new(id as u32, &virtio_stream_info) - }) - .collect(); + self.streams = Mutex::new( + info_vec + .chunks(info_size) + .enumerate() + .map(|(id, info_buffer)| { + // unwrap is safe because we checked the size of the vector + let virtio_stream_info = virtio_snd_pcm_info::from_slice(&info_buffer).unwrap(); + VioSStreamInfo::new(id as u32, &virtio_stream_info) + }) + .collect(), + ); Ok(()) } } +impl Drop for VioSClient { + fn drop(&mut self) { + // Stop the recv thread + *self.recv_running.lock() = false; + if let Err(e) = self.recv_event.lock().write(1u64) { + error!("Failed to notify recv thread: {:?}", e); + } + if let Some(handle) = self.recv_thread.lock().take() { + match handle.join() { + Ok(r) => { + if let Err(e) = r { + error!("Error detected on Recv Thread: {}", e); + } + } + Err(e) => error!("Recv thread panicked: {:?}", e), + }; + } + } +} + +#[derive(PollToken)] +enum Token { + Notification, + RxBufferMsg, +} + +fn spawn_recv_thread( + rx_subscribers: Arc<Mutex<HashMap<usize, Sender<(u32, usize)>>>>, + event: Event, + running: Arc<Mutex<bool>>, + rx_socket: UnixSeqpacket, +) -> JoinHandle<Result<()>> { + std::thread::spawn(move || { + let wait_ctx: WaitContext<Token> = WaitContext::build_with(&[ + (&rx_socket, Token::RxBufferMsg), + (&event, Token::Notification), + ]) + .map_err(|e| Error::WaitContextCreateError(e))?; + while *running.lock() { + let events = wait_ctx.wait().map_err(|e| Error::WaitError(e))?; + for evt in events { + match evt.token { + Token::RxBufferMsg => { + let mut msg: IoStatusMsg = Default::default(); + let size = rx_socket + .recv(msg.as_mut_slice()) + .map_err(|e| Error::ServerIOError(e))?; + if size != std::mem::size_of::<IoStatusMsg>() { + return Err(Error::ProtocolError( + ProtocolErrorKind::UnexpectedMessageSize( + std::mem::size_of::<IoStatusMsg>(), + size, + ), + )); + } + let mut status = msg.status.status.into(); + if status == u32::MAX { + // Anyone waiting for this would continue to wait for as long as status is + // u32::MAX + status -= 1; + } + let offset = msg.buffer_offset as usize; + let consumed_len = msg.consumed_len as usize; + // Acquire and immediately release the mutex protecting the hashmap + let promise_opt = rx_subscribers.lock().remove(&offset); + match promise_opt { + None => error!( + "Received an unexpected buffer status message: {}. This is a BUG!!", + offset + ), + Some(sender) => { + if let Err(e) = sender.send((status, consumed_len)) { + error!("Failed to notify waiting thread: {:?}", e); + } + } + } + } + Token::Notification => { + // Just consume the notification and check for termination on the next + // iteration + if let Err(e) = event.read() { + error!("Failed to consume notification from recv thread: {:?}", e); + } + } + } + } + } + Ok(()) + }) +} + +fn await_status(promise: Receiver<(u32, usize)>) -> Result<usize> { + let (status, consumed_len) = promise + .recv() + .map_err(|e| Error::BufferStatusSenderLost(e))?; + if status == VIRTIO_SND_S_OK { + Ok(consumed_len) + } else { + Err(Error::IOBufferError(status)) + } +} + struct IoBufferQueue { socket: UnixSeqpacket, file: File, @@ -360,7 +565,7 @@ impl IoBufferQueue { .map_err(|e| Error::FileSizeError(e))? as usize; let mmap = MemoryMappingBuilder::new(size) - .from_descriptor(&file) + .from_file(&file) .build() .map_err(|e| Error::ServerMmapError(e))?; @@ -373,17 +578,22 @@ impl IoBufferQueue { }) } - fn push_buffer(&mut self, src: &mut SharedMemory, offset: usize, size: usize) -> Result<usize> { + fn allocate_buffer(&mut self, size: usize) -> Result<usize> { if size > self.size { return Err(Error::OutOfSpace); } - let shm_offset = if size > self.size - self.next { + let offset = if size > self.size - self.next { // Can't fit the new buffer at the end of the area, so put it at the beginning 0 } else { self.next }; + self.next = offset + size; + Ok(offset) + } + fn push_buffer(&mut self, src: &mut SharedMemory, offset: usize, size: usize) -> Result<usize> { + let shm_offset = self.allocate_buffer(size)?; let (src_mmap, mmap_offset) = mmap_buffer(src, offset, size)?; let src_slice = src_mmap .get_slice(mmap_offset, size) @@ -393,9 +603,27 @@ impl IoBufferQueue { .get_slice(shm_offset, size) .map_err(|e| Error::VolatileMemoryError(e))?; src_slice.copy_to_volatile_slice(dst_slice); - self.next = shm_offset + size; Ok(shm_offset) } + + fn pop_buffer( + &mut self, + dst: &mut SharedMemory, + dst_offset: usize, + size: usize, + src_offset: usize, + ) -> Result<()> { + let (dst_mmap, mmap_offset) = mmap_buffer(dst, dst_offset, size)?; + let dst_slice = dst_mmap + .get_slice(mmap_offset, size) + .map_err(|e| Error::VolatileMemoryError(e))?; + let src_slice = self + .mmap + .get_slice(src_offset, size) + .map_err(|e| Error::VolatileMemoryError(e))?; + src_slice.copy_to_volatile_slice(dst_slice); + Ok(()) + } } /// Description of a stream made available by the server. @@ -479,13 +707,25 @@ fn mmap_buffer( let mmap = MemoryMappingBuilder::new(extended_size) .offset(aligned_offset as u64) - .from_descriptor(src) + .from_shared_memory(src) .build() .map_err(|e| Error::GuestMmapError(e))?; Ok((mmap, offset_from_mapping_start)) } +fn recv_cmd_status(control_socket: &mut UnixSeqpacket) -> Result<()> { + let mut status: virtio_snd_hdr = Default::default(); + control_socket + .recv(status.as_mut_slice()) + .map_err(|e| Error::ServerIOError(e))?; + if status.code.to_native() == VIRTIO_SND_S_OK { + Ok(()) + } else { + Err(Error::CommandFailed(status.code.to_native())) + } +} + fn seq_socket_send<T: DataInit>(socket: &UnixSeqpacket, data: T) -> Result<()> { loop { let send_res = socket.send(data.as_slice()); @@ -538,7 +778,7 @@ impl IoTransferMsg { } #[repr(C)] -#[derive(Copy, Clone)] +#[derive(Copy, Clone, Default)] struct IoStatusMsg { status: virtio_snd_pcm_status, buffer_offset: u32, diff --git a/devices/src/virtio/tpm.rs b/devices/src/virtio/tpm.rs index ee68ea1ea..e9e07c90a 100644 --- a/devices/src/virtio/tpm.rs +++ b/devices/src/virtio/tpm.rs @@ -14,7 +14,8 @@ use base::{error, Event, PollToken, RawDescriptor, WaitContext}; use vm_memory::GuestMemory; use super::{ - DescriptorChain, DescriptorError, Interrupt, Queue, Reader, VirtioDevice, Writer, TYPE_TPM, + DescriptorChain, DescriptorError, Interrupt, Queue, Reader, SignalableInterrupt, VirtioDevice, + Writer, TYPE_TPM, }; // A single queue of size 2. The guest kernel driver will enqueue a single @@ -112,9 +113,14 @@ impl Worker { let wait_ctx = match WaitContext::build_with(&[ (&self.queue_evt, Token::QueueAvailable), - (self.interrupt.get_resample_evt(), Token::InterruptResample), (&self.kill_evt, Token::Kill), - ]) { + ]) + .and_then(|wc| { + if let Some(resample_evt) = self.interrupt.get_resample_evt() { + wc.add(resample_evt, Token::InterruptResample)?; + } + Ok(wc) + }) { Ok(pc) => pc, Err(e) => { error!("vtpm failed creating WaitContext: {}", e); diff --git a/devices/src/virtio/vhost/control_socket.rs b/devices/src/virtio/vhost/control_socket.rs index fc1d0f302..18553885d 100644 --- a/devices/src/virtio/vhost/control_socket.rs +++ b/devices/src/virtio/vhost/control_socket.rs @@ -2,10 +2,11 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -use base::{Error as SysError, RawDescriptor}; -use msg_socket::{MsgOnSocket, MsgSocket}; +use serde::{Deserialize, Serialize}; -#[derive(MsgOnSocket, Debug)] +use base::Error as SysError; + +#[derive(Serialize, Deserialize, Debug)] pub enum VhostDevRequest { /// Mask or unmask all the MSI entries for a Virtio Vhost device. MsixChanged, @@ -13,24 +14,8 @@ pub enum VhostDevRequest { MsixEntryChanged(usize), } -#[derive(MsgOnSocket, Debug)] +#[derive(Serialize, Deserialize, Debug)] pub enum VhostDevResponse { Ok, Err(SysError), } - -pub type VhostDevRequestSocket = MsgSocket<VhostDevRequest, VhostDevResponse>; -pub type VhostDevResponseSocket = MsgSocket<VhostDevResponse, VhostDevRequest>; - -/// Create control socket pair. This pair is used to communicate with the -/// virtio device process. -/// Mainly between the virtio and activate thread. -pub fn create_control_sockets() -> ( - Option<VhostDevRequestSocket>, - Option<VhostDevResponseSocket>, -) { - match msg_socket::pair::<VhostDevRequest, VhostDevResponse>() { - Ok((request, response)) => (Some(request), Some(response)), - _ => (None, None), - } -} diff --git a/devices/src/virtio/vhost/mod.rs b/devices/src/virtio/vhost/mod.rs index e1cd06766..14bd8cad8 100644 --- a/devices/src/virtio/vhost/mod.rs +++ b/devices/src/virtio/vhost/mod.rs @@ -6,13 +6,14 @@ use std::fmt::{self, Display}; -use base::Error as SysError; +use base::{Error as SysError, TubeError}; use net_util::Error as TapError; use remain::sorted; use vhost::Error as VhostError; mod control_socket; mod net; +pub mod user; mod vsock; mod worker; @@ -27,6 +28,8 @@ pub enum Error { CloneKillEvent(SysError), /// Creating kill event failed. CreateKillEvent(SysError), + /// Creating tube failed. + CreateTube(TubeError), /// Creating wait context failed. CreateWaitContext(SysError), /// Enabling tap interface failed. @@ -88,6 +91,7 @@ impl Display for Error { match self { CloneKillEvent(e) => write!(f, "failed to clone kill event: {}", e), CreateKillEvent(e) => write!(f, "failed to create kill event: {}", e), + CreateTube(e) => write!(f, "failed to create tube: {}", e), CreateWaitContext(e) => write!(f, "failed to create poll context: {}", e), TapEnable(e) => write!(f, "failed to enable tap interface: {}", e), TapOpen(e) => write!(f, "failed to open tap device: {}", e), diff --git a/devices/src/virtio/vhost/net.rs b/devices/src/virtio/vhost/net.rs index 41dc6da4b..d68febe19 100644 --- a/devices/src/virtio/vhost/net.rs +++ b/devices/src/virtio/vhost/net.rs @@ -4,11 +4,12 @@ use std::mem; use std::net::Ipv4Addr; +use std::path::PathBuf; use std::thread; use net_util::{MacAddress, TapT}; -use base::{error, warn, AsRawDescriptor, Event, RawDescriptor}; +use base::{error, warn, AsRawDescriptor, Event, RawDescriptor, Tube}; use vhost::NetT as VhostNetT; use virtio_sys::virtio_net; use vm_memory::GuestMemory; @@ -18,7 +19,6 @@ use super::worker::Worker; use super::{Error, Result}; use crate::pci::MsixStatus; use crate::virtio::{Interrupt, Queue, VirtioDevice, TYPE_NET}; -use msg_socket::{MsgReceiver, MsgSender}; const QUEUE_SIZE: u16 = 256; const NUM_QUEUES: usize = 2; @@ -33,8 +33,8 @@ pub struct Net<T: TapT, U: VhostNetT<T>> { vhost_interrupt: Option<Vec<Event>>, avail_features: u64, acked_features: u64, - request_socket: Option<VhostDevRequestSocket>, - response_socket: Option<VhostDevResponseSocket>, + request_tube: Tube, + response_tube: Option<Tube>, } impl<T, U> Net<T, U> @@ -45,6 +45,7 @@ where /// Create a new virtio network device with the given IP address and /// netmask. pub fn new( + vhost_net_device_path: &PathBuf, base_features: u64, ip_addr: Ipv4Addr, netmask: Ipv4Addr, @@ -71,7 +72,7 @@ where .map_err(Error::TapSetVnetHdrSize)?; tap.enable().map_err(Error::TapEnable)?; - let vhost_net_handle = U::new(mem).map_err(Error::VhostOpen)?; + let vhost_net_handle = U::new(vhost_net_device_path, mem).map_err(Error::VhostOpen)?; let avail_features = base_features | 1 << virtio_net::VIRTIO_NET_F_GUEST_CSUM @@ -90,7 +91,7 @@ where vhost_interrupt.push(Event::new().map_err(Error::VhostIrqCreate)?); } - let (request_socket, response_socket) = create_control_sockets(); + let (request_tube, response_tube) = Tube::pair().map_err(Error::CreateTube)?; Ok(Net { workers_kill_evt: Some(kill_evt.try_clone().map_err(Error::CloneKillEvent)?), @@ -101,8 +102,8 @@ where vhost_interrupt: Some(vhost_interrupt), avail_features, acked_features: 0u64, - request_socket, - response_socket, + request_tube, + response_tube: Some(response_tube), }) } } @@ -152,12 +153,10 @@ where } keep_rds.push(self.kill_evt.as_raw_descriptor()); - if let Some(request_socket) = &self.request_socket { - keep_rds.push(request_socket.as_raw_descriptor()); - } + keep_rds.push(self.request_tube.as_raw_descriptor()); - if let Some(response_socket) = &self.response_socket { - keep_rds.push(response_socket.as_raw_descriptor()); + if let Some(response_tube) = &self.response_tube { + keep_rds.push(response_tube.as_raw_descriptor()); } keep_rds @@ -206,8 +205,8 @@ where if let Some(vhost_interrupt) = self.vhost_interrupt.take() { if let Some(kill_evt) = self.workers_kill_evt.take() { let acked_features = self.acked_features; - let socket = if self.response_socket.is_some() { - self.response_socket.take() + let socket = if self.response_tube.is_some() { + self.response_tube.take() } else { None }; @@ -275,44 +274,50 @@ where } fn control_notify(&self, behavior: MsixStatus) { - if self.worker_thread.is_none() || self.request_socket.is_none() { + if self.worker_thread.is_none() { return; } - if let Some(socket) = &self.request_socket { - match behavior { - MsixStatus::EntryChanged(index) => { - if let Err(e) = socket.send(&VhostDevRequest::MsixEntryChanged(index)) { - error!( - "{} failed to send VhostMsixEntryChanged request for entry {}: {:?}", - self.debug_label(), - index, - e - ); - return; - } - if let Err(e) = socket.recv() { - error!("{} failed to receive VhostMsixEntryChanged response for entry {}: {:?}", self.debug_label(), index, e); - } + match behavior { + MsixStatus::EntryChanged(index) => { + if let Err(e) = self + .request_tube + .send(&VhostDevRequest::MsixEntryChanged(index)) + { + error!( + "{} failed to send VhostMsixEntryChanged request for entry {}: {:?}", + self.debug_label(), + index, + e + ); + return; } - MsixStatus::Changed => { - if let Err(e) = socket.send(&VhostDevRequest::MsixChanged) { - error!( - "{} failed to send VhostMsixChanged request: {:?}", - self.debug_label(), - e - ); - return; - } - if let Err(e) = socket.recv() { - error!( - "{} failed to receive VhostMsixChanged response {:?}", - self.debug_label(), - e - ); - } + if let Err(e) = self.request_tube.recv::<VhostDevResponse>() { + error!( + "{} failed to receive VhostMsixEntryChanged response for entry {}: {:?}", + self.debug_label(), + index, + e + ); + } + } + MsixStatus::Changed => { + if let Err(e) = self.request_tube.send(&VhostDevRequest::MsixChanged) { + error!( + "{} failed to send VhostMsixChanged request: {:?}", + self.debug_label(), + e + ); + return; + } + if let Err(e) = self.request_tube.recv::<VhostDevResponse>() { + error!( + "{} failed to receive VhostMsixChanged response {:?}", + self.debug_label(), + e + ); } - _ => {} } + _ => {} } } @@ -334,7 +339,7 @@ where self.tap = Some(tap); self.vhost_interrupt = Some(worker.vhost_interrupt); self.workers_kill_evt = Some(worker.kill_evt); - self.response_socket = worker.response_socket; + self.response_tube = worker.response_tube; return true; } } @@ -366,6 +371,7 @@ pub mod tests { let guest_memory = create_guest_memory().unwrap(); let features = base_features(ProtectionType::Unprotected); Net::<FakeTap, FakeNet<FakeTap>>::new( + &PathBuf::from(""), features, Ipv4Addr::new(127, 0, 0, 1), Ipv4Addr::new(255, 255, 255, 0), diff --git a/devices/src/virtio/vhost/user/block.rs b/devices/src/virtio/vhost/user/block.rs new file mode 100644 index 000000000..f3b6abfbe --- /dev/null +++ b/devices/src/virtio/vhost/user/block.rs @@ -0,0 +1,180 @@ +// Copyright 2021 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +use std::cell::RefCell; +use std::os::unix::net::UnixStream; +use std::path::Path; +use std::thread; +use std::u32; + +use base::{error, Event, RawDescriptor}; +use cros_async::Executor; +use virtio_sys::virtio_ring::VIRTIO_RING_F_EVENT_IDX; +use vm_memory::GuestMemory; +use vmm_vhost::vhost_user::message::{VhostUserProtocolFeatures, VhostUserVirtioFeatures}; +use vmm_vhost::vhost_user::Master; + +use crate::virtio::vhost::user::handler::VhostUserHandler; +use crate::virtio::vhost::user::worker::Worker; +use crate::virtio::vhost::user::{Error, Result}; +use crate::virtio::{virtio_blk_config, Interrupt, Queue, VirtioDevice, TYPE_BLOCK}; + +const VIRTIO_BLK_F_SEG_MAX: u32 = 2; +const VIRTIO_BLK_F_RO: u32 = 5; +const VIRTIO_BLK_F_BLK_SIZE: u32 = 6; +const VIRTIO_BLK_F_FLUSH: u32 = 9; +const VIRTIO_BLK_F_DISCARD: u32 = 13; +const VIRTIO_BLK_F_WRITE_ZEROES: u32 = 14; + +const QUEUE_SIZE: u16 = 256; + +pub struct Block { + kill_evt: Option<Event>, + worker_thread: Option<thread::JoinHandle<Worker>>, + handler: RefCell<VhostUserHandler>, + queue_sizes: Vec<u16>, +} + +impl Block { + pub fn new<P: AsRef<Path>>(base_features: u64, socket_path: P) -> Result<Block> { + let socket = UnixStream::connect(&socket_path).map_err(Error::SocketConnect)?; + // TODO(b/181753022): Support multiple queues. + let vhost_user_blk = Master::from_stream(socket, 1 /* queues_num */); + + let allow_features = 1u64 << crate::virtio::VIRTIO_F_VERSION_1 + | 1 << VIRTIO_BLK_F_SEG_MAX + | 1 << VIRTIO_BLK_F_RO + | 1 << VIRTIO_BLK_F_BLK_SIZE + | 1 << VIRTIO_BLK_F_FLUSH + | 1 << VIRTIO_BLK_F_DISCARD + | 1 << VIRTIO_BLK_F_WRITE_ZEROES + | 1 << VIRTIO_RING_F_EVENT_IDX + | VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits(); + let init_features = base_features | VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits(); + let allow_protocol_features = VhostUserProtocolFeatures::CONFIG; + + let mut handler = VhostUserHandler::new( + vhost_user_blk, + allow_features, + init_features, + allow_protocol_features, + )?; + let queue_sizes = handler.queue_sizes(QUEUE_SIZE, 1)?; + + Ok(Block { + kill_evt: None, + worker_thread: None, + handler: RefCell::new(handler), + queue_sizes, + }) + } +} + +impl Drop for Block { + fn drop(&mut self) { + if let Some(kill_evt) = self.kill_evt.take() { + // Ignore the result because there is nothing we can do about it. + let _ = kill_evt.write(1); + } + + if let Some(worker_thread) = self.worker_thread.take() { + let _ = worker_thread.join(); + } + } +} + +impl VirtioDevice for Block { + fn keep_rds(&self) -> Vec<RawDescriptor> { + Vec::new() + } + + fn features(&self) -> u64 { + self.handler.borrow().avail_features + } + + fn ack_features(&mut self, features: u64) { + if let Err(e) = self.handler.borrow_mut().ack_features(features) { + error!("failed to enable features 0x{:x}: {}", features, e); + } + } + + fn device_type(&self) -> u32 { + TYPE_BLOCK + } + + fn queue_max_sizes(&self) -> &[u16] { + self.queue_sizes.as_slice() + } + + fn read_config(&self, offset: u64, data: &mut [u8]) { + if let Err(e) = self + .handler + .borrow_mut() + .read_config::<virtio_blk_config>(offset, data) + { + error!("failed to read config: {}", e); + } + } + + fn activate( + &mut self, + mem: GuestMemory, + interrupt: Interrupt, + queues: Vec<Queue>, + queue_evts: Vec<Event>, + ) { + if let Err(e) = self + .handler + .borrow_mut() + .activate(&mem, &interrupt, &queues, &queue_evts) + { + error!("failed to activate queues: {}", e); + return; + } + + let (self_kill_evt, kill_evt) = match Event::new().and_then(|e| Ok((e.try_clone()?, e))) { + Ok(v) => v, + Err(e) => { + error!("failed creating kill Event pair: {}", e); + return; + } + }; + self.kill_evt = Some(self_kill_evt); + + let worker_result = thread::Builder::new() + .name("vhost_user_virtio_blk".to_string()) + .spawn(move || { + let ex = Executor::new().expect("failed to create an executor"); + let mut worker = Worker { + queues, + mem, + kill_evt, + }; + + if let Err(e) = worker.run(&ex, interrupt) { + error!("failed to start a worker: {}", e); + } + worker + }); + + match worker_result { + Err(e) => { + error!("failed to spawn vhost-user virtio_blk worker: {}", e); + return; + } + Ok(join_handle) => { + self.worker_thread = Some(join_handle); + } + } + } + + fn reset(&mut self) -> bool { + if let Err(e) = self.handler.borrow_mut().reset(self.queue_sizes.len()) { + error!("Failed to reset block device: {}", e); + false + } else { + true + } + } +} diff --git a/devices/src/virtio/vhost/user/fs.rs b/devices/src/virtio/vhost/user/fs.rs new file mode 100644 index 000000000..2e212b97c --- /dev/null +++ b/devices/src/virtio/vhost/user/fs.rs @@ -0,0 +1,192 @@ +// Copyright 2021 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +use std::cell::RefCell; +use std::os::unix::net::UnixStream; +use std::path::Path; +use std::thread; + +use base::{error, Event, RawDescriptor}; +use cros_async::Executor; +use data_model::{DataInit, Le32}; +use vm_memory::GuestMemory; +use vmm_vhost::vhost_user::message::{VhostUserProtocolFeatures, VhostUserVirtioFeatures}; +use vmm_vhost::vhost_user::{Error as VhostUserError, Master}; +use vmm_vhost::Error as VhostError; + +use crate::virtio::fs::{virtio_fs_config, FS_MAX_TAG_LEN, QUEUE_SIZE}; +use crate::virtio::vhost::user::handler::VhostUserHandler; +use crate::virtio::vhost::user::worker::Worker; +use crate::virtio::vhost::user::{Error, Result}; +use crate::virtio::{copy_config, TYPE_FS}; +use crate::virtio::{Interrupt, Queue, VirtioDevice}; + +pub struct Fs { + cfg: virtio_fs_config, + kill_evt: Option<Event>, + worker_thread: Option<thread::JoinHandle<Worker>>, + handler: RefCell<VhostUserHandler>, + queue_sizes: Vec<u16>, +} +impl Fs { + pub fn new<P: AsRef<Path>>(base_features: u64, socket_path: P, tag: &str) -> Result<Fs> { + if tag.len() > FS_MAX_TAG_LEN { + return Err(Error::TagTooLong { + len: tag.len(), + max: FS_MAX_TAG_LEN, + }); + } + + // The spec requires a minimum of 2 queues: one worker queue and one high priority queue + let default_queue_size = 2; + + let mut cfg_tag = [0u8; FS_MAX_TAG_LEN]; + cfg_tag[..tag.len()].copy_from_slice(tag.as_bytes()); + + let cfg = virtio_fs_config { + tag: cfg_tag, + // Only count the worker queues, exclude the high prio queue + num_request_queues: Le32::from(default_queue_size - 1), + }; + + let socket = UnixStream::connect(&socket_path).map_err(Error::SocketConnect)?; + let master = Master::from_stream(socket, default_queue_size as u64); + + let allow_features = 1u64 << crate::virtio::VIRTIO_F_VERSION_1 + | VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits(); + let init_features = base_features | VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits(); + let allow_protocol_features = + VhostUserProtocolFeatures::MQ | VhostUserProtocolFeatures::CONFIG; + + let mut handler = VhostUserHandler::new( + master, + allow_features, + init_features, + allow_protocol_features, + )?; + let queue_sizes = handler.queue_sizes(QUEUE_SIZE, default_queue_size as usize)?; + + Ok(Fs { + cfg, + kill_evt: None, + worker_thread: None, + handler: RefCell::new(handler), + queue_sizes, + }) + } +} + +impl VirtioDevice for Fs { + fn keep_rds(&self) -> Vec<RawDescriptor> { + Vec::new() + } + + fn device_type(&self) -> u32 { + TYPE_FS + } + + fn queue_max_sizes(&self) -> &[u16] { + &self.queue_sizes + } + + fn features(&self) -> u64 { + self.handler.borrow().avail_features + } + + fn ack_features(&mut self, features: u64) { + if let Err(e) = self.handler.borrow_mut().ack_features(features) { + error!("failed to enable features 0x{:x}: {}", features, e); + } + } + + fn read_config(&self, offset: u64, data: &mut [u8]) { + match self + .handler + .borrow_mut() + .read_config::<virtio_fs_config>(offset, data) + { + Ok(()) => {} + // copy local config when VhostUserProtocolFeatures::CONFIG is not supported by the + // device + Err(Error::GetConfig(VhostError::VhostUserProtocol( + VhostUserError::InvalidOperation, + ))) => copy_config(data, 0, self.cfg.as_slice(), offset), + Err(e) => error!("Failed to fetch device config: {}", e), + } + } + + fn activate( + &mut self, + mem: GuestMemory, + interrupt: Interrupt, + queues: Vec<Queue>, + queue_evts: Vec<Event>, + ) { + if let Err(e) = self + .handler + .borrow_mut() + .activate(&mem, &interrupt, &queues, &queue_evts) + { + error!("failed to activate queues: {}", e); + return; + } + + let (self_kill_evt, kill_evt) = match Event::new().and_then(|e| Ok((e.try_clone()?, e))) { + Ok(v) => v, + Err(e) => { + error!("failed creating kill Event pair: {}", e); + return; + } + }; + self.kill_evt = Some(self_kill_evt); + + let worker_result = thread::Builder::new() + .name("vhost_user_virtio_fs".to_string()) + .spawn(move || { + let ex = Executor::new().expect("failed to create an executor"); + let mut worker = Worker { + queues, + mem, + kill_evt, + }; + + if let Err(e) = worker.run(&ex, interrupt) { + error!("failed to start a worker: {}", e); + } + worker + }); + + match worker_result { + Err(e) => { + error!("failed to spawn vhost-user virtio_fs worker: {}", e); + return; + } + Ok(join_handle) => { + self.worker_thread = Some(join_handle); + } + } + } + + fn reset(&mut self) -> bool { + if let Err(e) = self.handler.borrow_mut().reset(self.queue_sizes.len()) { + error!("Failed to reset fs device: {}", e); + false + } else { + true + } + } +} + +impl Drop for Fs { + fn drop(&mut self) { + if let Some(kill_evt) = self.kill_evt.take() { + // Ignore the result because there is nothing we can do about it. + let _ = kill_evt.write(1); + } + + if let Some(worker_thread) = self.worker_thread.take() { + let _ = worker_thread.join(); + } + } +} diff --git a/devices/src/virtio/vhost/user/handler.rs b/devices/src/virtio/vhost/user/handler.rs new file mode 100644 index 000000000..f3797ae8e --- /dev/null +++ b/devices/src/virtio/vhost/user/handler.rs @@ -0,0 +1,228 @@ +// Copyright 2021 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +use std::io::Write; + +use base::{AsRawDescriptor, Event}; +use vm_memory::GuestMemory; +use vmm_vhost::vhost_user::message::{ + VhostUserConfigFlags, VhostUserProtocolFeatures, VhostUserVirtioFeatures, + VHOST_USER_CONFIG_OFFSET, +}; +use vmm_vhost::vhost_user::{Master, VhostUserMaster}; +use vmm_vhost::{VhostBackend, VhostUserMemoryRegionInfo, VringConfigData}; + +use crate::virtio::vhost::user::{Error, Result}; +use crate::virtio::{Interrupt, Queue}; + +fn set_features(vu: &mut Master, avail_features: u64, ack_features: u64) -> Result<u64> { + let features = avail_features & ack_features; + vu.set_features(features).map_err(Error::SetFeatures)?; + Ok(features) +} + +pub struct VhostUserHandler { + vu: Master, + pub avail_features: u64, + acked_features: u64, + protocol_features: VhostUserProtocolFeatures, +} + +impl VhostUserHandler { + /// Creates a `VhostUserHandler` instance with features and protocol features initialized. + pub fn new( + mut vu: Master, + allow_features: u64, + init_features: u64, + allow_protocol_features: VhostUserProtocolFeatures, + ) -> Result<Self> { + vu.set_owner().map_err(Error::SetOwner)?; + + let avail_features = allow_features & vu.get_features().map_err(Error::GetFeatures)?; + let acked_features = set_features(&mut vu, avail_features, init_features)?; + + let mut protocol_features = VhostUserProtocolFeatures::empty(); + if acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() != 0 { + let avail_protocol_features = vu + .get_protocol_features() + .map_err(Error::GetProtocolFeatures)?; + protocol_features = allow_protocol_features & avail_protocol_features; + vu.set_protocol_features(protocol_features) + .map_err(Error::SetProtocolFeatures)?; + } + + Ok(VhostUserHandler { + vu, + avail_features, + acked_features, + protocol_features, + }) + } + + /// Returns a vector of sizes of each queue. + pub fn queue_sizes(&mut self, queue_size: u16, default_queues_num: usize) -> Result<Vec<u16>> { + let queues_num = if self + .protocol_features + .contains(VhostUserProtocolFeatures::MQ) + { + self.vu.get_queue_num().map_err(Error::GetQueueNum)? as usize + } else { + default_queues_num + }; + Ok(vec![queue_size; queues_num]) + } + + /// Enables a set of features. + pub fn ack_features(&mut self, ack_features: u64) -> Result<()> { + let features = set_features( + &mut self.vu, + self.avail_features, + self.acked_features | ack_features, + )?; + self.acked_features = features; + Ok(()) + } + + /// Gets the device configuration space at `offset` and writes it into `data`. + pub fn read_config<T>(&mut self, offset: u64, mut data: &mut [u8]) -> Result<()> { + let config_len = std::mem::size_of::<T>() as u64; + let data_len = data.len() as u64; + offset + .checked_add(data_len) + .and_then(|l| if l <= config_len { Some(()) } else { None }) + .ok_or(Error::InvalidConfigOffset { + data_len, + offset, + config_len, + })?; + + let buf = vec![0u8; config_len as usize]; + let (_, config) = self + .vu + .get_config( + VHOST_USER_CONFIG_OFFSET, + config_len as u32, + VhostUserConfigFlags::WRITABLE, + &buf, + ) + .map_err(Error::GetConfig)?; + + data.write_all( + &config[offset as usize..std::cmp::min(data_len + offset, config_len) as usize], + ) + .map_err(Error::CopyConfig) + } + + fn set_mem_table(&mut self, mem: &GuestMemory) -> Result<()> { + let mut regions: Vec<VhostUserMemoryRegionInfo> = Vec::new(); + mem.with_regions::<_, ()>( + |_idx, guest_phys_addr, memory_size, userspace_addr, mmap, mmap_offset| { + let region = VhostUserMemoryRegionInfo { + guest_phys_addr: guest_phys_addr.0, + memory_size: memory_size as u64, + userspace_addr: userspace_addr as u64, + mmap_offset, + mmap_handle: mmap.as_raw_descriptor(), + }; + regions.push(region); + Ok(()) + }, + ) + .unwrap(); // never fail + + self.vu + .set_mem_table(regions.as_slice()) + .map_err(Error::SetMemTable)?; + + Ok(()) + } + + fn activate_vring( + &mut self, + mem: &GuestMemory, + queue_index: usize, + queue: &Queue, + queue_evt: &Event, + interrupt: &Interrupt, + ) -> Result<()> { + self.vu + .set_vring_num(queue_index, queue.actual_size()) + .map_err(Error::SetVringNum)?; + + let config_data = VringConfigData { + queue_max_size: queue.max_size, + queue_size: queue.actual_size(), + flags: 0u32, + desc_table_addr: mem + .get_host_address(queue.desc_table) + .map_err(Error::GetHostAddress)? as u64, + used_ring_addr: mem + .get_host_address(queue.used_ring) + .map_err(Error::GetHostAddress)? as u64, + avail_ring_addr: mem + .get_host_address(queue.avail_ring) + .map_err(Error::GetHostAddress)? as u64, + log_addr: None, + }; + self.vu + .set_vring_addr(queue_index, &config_data) + .map_err(Error::SetVringAddr)?; + + self.vu + .set_vring_base(queue_index, 0) + .map_err(Error::SetVringBase)?; + + let msix_config_opt = interrupt + .get_msix_config() + .as_ref() + .ok_or(Error::MsixConfigUnavailable)?; + let msix_config = msix_config_opt.lock(); + let irqfd = msix_config + .get_irqfd(queue.vector as usize) + .ok_or(Error::MsixIrqfdUnavailable)?; + self.vu + .set_vring_call(queue_index, &irqfd.0) + .map_err(Error::SetVringCall)?; + + self.vu + .set_vring_kick(queue_index, &queue_evt.0) + .map_err(Error::SetVringKick)?; + self.vu + .set_vring_enable(queue_index, true) + .map_err(Error::SetVringEnable)?; + + Ok(()) + } + + /// Activates vrings. + pub fn activate( + &mut self, + mem: &GuestMemory, + interrupt: &Interrupt, + queues: &[Queue], + queue_evts: &[Event], + ) -> Result<()> { + self.set_mem_table(&mem)?; + + for (queue_index, queue) in queues.iter().enumerate() { + let queue_evt = &queue_evts[queue_index]; + self.activate_vring(&mem, queue_index, queue, queue_evt, &interrupt)?; + } + + Ok(()) + } + + /// Deactivates all vrings. + pub fn reset(&mut self, queues_num: usize) -> Result<()> { + for queue_index in 0..queues_num { + self.vu + .set_vring_enable(queue_index, false) + .map_err(Error::SetVringEnable)?; + self.vu + .get_vring_base(queue_index) + .map_err(Error::GetVringBase)?; + } + Ok(()) + } +} diff --git a/devices/src/virtio/vhost/user/mod.rs b/devices/src/virtio/vhost/user/mod.rs new file mode 100644 index 000000000..b408e1c18 --- /dev/null +++ b/devices/src/virtio/vhost/user/mod.rs @@ -0,0 +1,101 @@ +// Copyright 2021 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +mod block; +mod fs; +mod handler; +mod net; +mod worker; + +pub use self::block::*; +pub use self::fs::*; +pub use self::net::*; + +use remain::sorted; +use thiserror::Error as ThisError; +use vm_memory::GuestMemoryError; +use vmm_vhost::Error as VhostError; + +#[sorted] +#[derive(ThisError, Debug)] +pub enum Error { + /// Failed to copy config to a buffer. + #[error("failed to copy config to a buffer: {0}")] + CopyConfig(std::io::Error), + /// Failed to create `base::Event`. + #[error("failed to create Event: {0}")] + CreateEvent(base::Error), + /// Failed to get config. + #[error("failed to get config: {0}")] + GetConfig(VhostError), + /// Failed to get features. + #[error("failed to get features: {0}")] + GetFeatures(VhostError), + /// Failed to get host address. + #[error("failed to get host address: {0}")] + GetHostAddress(GuestMemoryError), + /// Failed to get protocol features. + #[error("failed to get protocol features: {0}")] + GetProtocolFeatures(VhostError), + /// Failed to get number of queues. + #[error("failed to get number of queues: {0}")] + GetQueueNum(VhostError), + /// Failed to get vring base offset. + #[error("failed to get vring base offset: {0}")] + GetVringBase(VhostError), + /// Invalid config offset is given. + #[error("invalid config offset is given: {data_len} + {offset} > {config_len}")] + InvalidConfigOffset { + data_len: u64, + offset: u64, + config_len: u64, + }, + /// MSI-X config is unavailable. + #[error("MSI-X config is unavailable")] + MsixConfigUnavailable, + /// MSI-X irqfd is unavailable. + #[error("MSI-X irqfd is unavailable")] + MsixIrqfdUnavailable, + /// Failed to reset owner. + #[error("failed to reset owner: {0}")] + ResetOwner(VhostError), + /// Failed to set features. + #[error("failed to set features: {0}")] + SetFeatures(VhostError), + /// Failed to set memory map regions. + #[error("failed to set memory map regions: {0}")] + SetMemTable(VhostError), + /// Failed to set owner. + #[error("failed to set owner: {0}")] + SetOwner(VhostError), + /// Failed to set protocol features. + #[error("failed to set protocol features: {0}")] + SetProtocolFeatures(VhostError), + /// Failed to set vring address. + #[error("failed to set vring address: {0}")] + SetVringAddr(VhostError), + /// Failed to set vring base offset. + #[error("failed to set vring base offset: {0}")] + SetVringBase(VhostError), + /// Failed to set eventfd to signal used vring buffers. + #[error("failed to set eventfd to signal used vring buffers: {0}")] + SetVringCall(VhostError), + /// Failed to enable or disable vring. + #[error("failed to enable or disable vring: {0}")] + SetVringEnable(VhostError), + /// Failed to set eventfd for adding buffers to vring. + #[error("failed to set eventfd for adding buffers to vring: {0}")] + SetVringKick(VhostError), + /// Failed to set the size of the queue. + #[error("failed to set the size of the queue: {0}")] + SetVringNum(VhostError), + /// Failed to connect socket. + #[error("failed to connect socket: {0}")] + SocketConnect(std::io::Error), + /// The tag for the Fs device was too long to fit in the config space. + #[error("tag is too long: {len} > {max}")] + TagTooLong { len: usize, max: usize }, +} + +pub type Result<T> = std::result::Result<T, Error>; diff --git a/devices/src/virtio/vhost/user/net.rs b/devices/src/virtio/vhost/user/net.rs new file mode 100644 index 000000000..fe3deccf4 --- /dev/null +++ b/devices/src/virtio/vhost/user/net.rs @@ -0,0 +1,185 @@ +// Copyright 2021 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +use std::cell::RefCell; +use std::os::unix::net::UnixStream; +use std::path::Path; +use std::thread; +use std::u32; + +use base::{error, Event, RawDescriptor}; +use cros_async::Executor; +use virtio_sys::virtio_net; +use virtio_sys::virtio_ring::VIRTIO_RING_F_EVENT_IDX; +use vm_memory::GuestMemory; +use vmm_vhost::vhost_user::message::{VhostUserProtocolFeatures, VhostUserVirtioFeatures}; +use vmm_vhost::vhost_user::Master; + +use crate::virtio::vhost::user::handler::VhostUserHandler; +use crate::virtio::vhost::user::worker::Worker; +use crate::virtio::vhost::user::Error; +use crate::virtio::{Interrupt, Queue, VirtioDevice, VirtioNetConfig, TYPE_NET}; + +type Result<T> = std::result::Result<T, Error>; + +const QUEUE_SIZE: u16 = 256; + +pub struct Net { + kill_evt: Option<Event>, + worker_thread: Option<thread::JoinHandle<Worker>>, + handler: RefCell<VhostUserHandler>, + queue_sizes: Vec<u16>, +} + +impl Net { + pub fn new<P: AsRef<Path>>(base_features: u64, socket_path: P) -> Result<Net> { + let socket = UnixStream::connect(&socket_path).map_err(Error::SocketConnect)?; + let vhost_user_net = Master::from_stream(socket, 16 /* # of queues */); + + // TODO(b/182430355): Support VIRTIO_NET_F_CTRL_VQ and VIRTIO_NET_F_CTRL_GUEST_OFFLOADS. + let allow_features = 1 << crate::virtio::VIRTIO_F_VERSION_1 + | 1 << virtio_net::VIRTIO_NET_F_CSUM + | 1 << virtio_net::VIRTIO_NET_F_GUEST_CSUM + | 1 << virtio_net::VIRTIO_NET_F_GUEST_TSO4 + | 1 << virtio_net::VIRTIO_NET_F_GUEST_UFO + | 1 << virtio_net::VIRTIO_NET_F_HOST_TSO4 + | 1 << virtio_net::VIRTIO_NET_F_HOST_UFO + | 1 << virtio_net::VIRTIO_NET_F_MQ + | 1 << VIRTIO_RING_F_EVENT_IDX + | VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits(); + let init_features = base_features | VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits(); + let allow_protocol_features = + VhostUserProtocolFeatures::MQ | VhostUserProtocolFeatures::CONFIG; + + let mut handler = VhostUserHandler::new( + vhost_user_net, + allow_features, + init_features, + allow_protocol_features, + )?; + let queue_sizes = handler.queue_sizes(QUEUE_SIZE, 2 /* 1 rx + 1 tx */)?; + + Ok(Net { + kill_evt: None, + worker_thread: None, + handler: RefCell::new(handler), + queue_sizes, + }) + } +} + +impl Drop for Net { + fn drop(&mut self) { + if let Some(kill_evt) = self.kill_evt.take() { + // Ignore the result because there is nothing we can do about it. + let _ = kill_evt.write(1); + } + + if let Some(worker_thread) = self.worker_thread.take() { + let _ = worker_thread.join(); + } + } +} + +impl VirtioDevice for Net { + fn keep_rds(&self) -> Vec<RawDescriptor> { + Vec::new() + } + + fn features(&self) -> u64 { + self.handler.borrow().avail_features + } + + fn ack_features(&mut self, features: u64) { + if let Err(e) = self.handler.borrow_mut().ack_features(features) { + error!("failed to enable features 0x{:x}: {}", features, e); + } + } + + fn device_type(&self) -> u32 { + TYPE_NET + } + + fn queue_max_sizes(&self) -> &[u16] { + self.queue_sizes.as_slice() + } + + fn read_config(&self, offset: u64, data: &mut [u8]) { + if let Err(e) = self + .handler + .borrow_mut() + .read_config::<VirtioNetConfig>(offset, data) + { + error!("failed to read config: {}", e); + } + } + + fn activate( + &mut self, + mem: GuestMemory, + interrupt: Interrupt, + queues: Vec<Queue>, + queue_evts: Vec<Event>, + ) { + // TODO(b/182430355): Remove this check once ctrlq is supported. + if queues.len() % 2 != 0 { + error!( + "The number of queues must be an even number but {}", + queues.len() + ); + } + + if let Err(e) = self + .handler + .borrow_mut() + .activate(&mem, &interrupt, &queues, &queue_evts) + { + error!("failed to activate queues: {}", e); + return; + } + + let (self_kill_evt, kill_evt) = match Event::new().and_then(|e| Ok((e.try_clone()?, e))) { + Ok(v) => v, + Err(e) => { + error!("failed creating kill Event pair: {}", e); + return; + } + }; + self.kill_evt = Some(self_kill_evt); + + let worker_result = thread::Builder::new() + .name("vhost_user_virtio_net".to_string()) + .spawn(move || { + let ex = Executor::new().expect("failed to create an executor"); + let mut worker = Worker { + queues, + mem, + kill_evt, + }; + if let Err(e) = worker.run(&ex, interrupt) { + error!("failed to start a worker: {}", e); + } + worker + }); + + match worker_result { + Err(e) => { + error!("failed to spawn virtio_net worker: {}", e); + return; + } + Ok(join_handle) => { + self.worker_thread = Some(join_handle); + } + } + } + + fn reset(&mut self) -> bool { + if let Err(e) = self.handler.borrow_mut().reset(self.queue_sizes.len()) { + error!("Failed to reset net device: {}", e); + false + } else { + true + } + } +} diff --git a/devices/src/virtio/vhost/user/worker.rs b/devices/src/virtio/vhost/user/worker.rs new file mode 100644 index 000000000..0015a158c --- /dev/null +++ b/devices/src/virtio/vhost/user/worker.rs @@ -0,0 +1,82 @@ +// Copyright 2021 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +use base::{error, Event}; +use cros_async::{select2, AsyncError, EventAsync, Executor, SelectResult}; +use futures::pin_mut; +use thiserror::Error as ThisError; +use vm_memory::GuestMemory; + +use crate::virtio::interrupt::SignalableInterrupt; +use crate::virtio::{Interrupt, Queue}; + +#[derive(ThisError, Debug)] +enum Error { + /// Failed to read the resample event. + #[error("failed to read the resample event: {0}")] + ReadResampleEvent(AsyncError), +} + +pub struct Worker { + pub queues: Vec<Queue>, + pub mem: GuestMemory, + pub kill_evt: Event, +} + +impl Worker { + // Processes any requests to resample the irq value. + async fn handle_irq_resample( + resample_evt: EventAsync, + interrupt: Interrupt, + ) -> Result<(), Error> { + loop { + let _ = resample_evt + .next_val() + .await + .map_err(Error::ReadResampleEvent)?; + interrupt.do_interrupt_resample(); + } + } + + // Waits until the kill event is triggered. + async fn wait_kill(kill_evt: EventAsync) { + // Once this event is readable, exit. Exiting this future will cause the main loop to + // break and the device process to exit. + let _ = kill_evt.next_val().await; + } + + // Runs asynchronous tasks. + pub fn run(&mut self, ex: &Executor, interrupt: Interrupt) -> Result<(), String> { + let resample_evt = interrupt + .get_resample_evt() + .expect("resample event required") + .try_clone() + .expect("failed to clone resample event"); + let async_resample_evt = + EventAsync::new(resample_evt.0, ex).expect("failed to create async resample event"); + let resample = Self::handle_irq_resample(async_resample_evt, interrupt); + pin_mut!(resample); + + let kill_evt = EventAsync::new( + self.kill_evt + .try_clone() + .expect("failed to clone kill_evt") + .0, + &ex, + ) + .expect("failed to create async kill event fd"); + let kill = Self::wait_kill(kill_evt); + pin_mut!(kill); + + match ex.run_until(select2(resample, kill)) { + Ok((resample_res, _)) => { + if let SelectResult::Finished(Err(e)) = resample_res { + return Err(format!("failed to resample a irq value: {:?}", e)); + } + Ok(()) + } + Err(e) => Err(e.to_string()), + } + } +} diff --git a/devices/src/virtio/vhost/vsock.rs b/devices/src/virtio/vhost/vsock.rs index 1885131cb..719aea48b 100644 --- a/devices/src/virtio/vhost/vsock.rs +++ b/devices/src/virtio/vhost/vsock.rs @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -use std::thread; +use std::{path::PathBuf, thread}; use data_model::{DataInit, Le64}; @@ -31,9 +31,15 @@ pub struct Vsock { impl Vsock { /// Create a new virtio-vsock device with the given VM cid. - pub fn new(base_features: u64, cid: u64, mem: &GuestMemory) -> Result<Vsock> { + pub fn new( + vhost_vsock_device_path: &PathBuf, + base_features: u64, + cid: u64, + mem: &GuestMemory, + ) -> Result<Vsock> { let kill_evt = Event::new().map_err(Error::CreateKillEvent)?; - let handle = VhostVsockHandle::new(mem).map_err(Error::VhostOpen)?; + let handle = + VhostVsockHandle::new(vhost_vsock_device_path, mem).map_err(Error::VhostOpen)?; let avail_features = base_features | 1 << virtio_sys::vhost::VIRTIO_F_NOTIFY_ON_EMPTY diff --git a/devices/src/virtio/vhost/worker.rs b/devices/src/virtio/vhost/worker.rs index 04c933a81..9d7822ebf 100644 --- a/devices/src/virtio/vhost/worker.rs +++ b/devices/src/virtio/vhost/worker.rs @@ -4,14 +4,13 @@ use std::os::raw::c_ulonglong; -use base::{error, Error as SysError, Event, PollToken, WaitContext}; +use base::{error, Error as SysError, Event, PollToken, Tube, WaitContext}; use vhost::Vhost; -use super::control_socket::{VhostDevRequest, VhostDevResponse, VhostDevResponseSocket}; +use super::control_socket::{VhostDevRequest, VhostDevResponse}; use super::{Error, Result}; -use crate::virtio::{Interrupt, Queue}; +use crate::virtio::{Interrupt, Queue, SignalableInterrupt}; use libc::EIO; -use msg_socket::{MsgReceiver, MsgSender}; /// Worker that takes care of running the vhost device. pub struct Worker<T: Vhost> { @@ -21,7 +20,7 @@ pub struct Worker<T: Vhost> { pub vhost_interrupt: Vec<Event>, acked_features: u64, pub kill_evt: Event, - pub response_socket: Option<VhostDevResponseSocket>, + pub response_tube: Option<Tube>, } impl<T: Vhost> Worker<T> { @@ -32,7 +31,7 @@ impl<T: Vhost> Worker<T> { interrupt: Interrupt, acked_features: u64, kill_evt: Event, - response_socket: Option<VhostDevResponseSocket>, + response_tube: Option<Tube>, ) -> Worker<T> { Worker { interrupt, @@ -41,7 +40,7 @@ impl<T: Vhost> Worker<T> { vhost_interrupt, acked_features, kill_evt, - response_socket, + response_tube, } } @@ -106,22 +105,25 @@ impl<T: Vhost> Worker<T> { ControlNotify, } - let wait_ctx: WaitContext<Token> = WaitContext::build_with(&[ - (self.interrupt.get_resample_evt(), Token::InterruptResample), - (&self.kill_evt, Token::Kill), - ]) - .map_err(Error::CreateWaitContext)?; + let wait_ctx: WaitContext<Token> = + WaitContext::build_with(&[(&self.kill_evt, Token::Kill)]) + .map_err(Error::CreateWaitContext)?; for (index, vhost_int) in self.vhost_interrupt.iter().enumerate() { wait_ctx .add(vhost_int, Token::VhostIrqi { index }) .map_err(Error::CreateWaitContext)?; } - if let Some(socket) = &self.response_socket { + if let Some(socket) = &self.response_tube { wait_ctx .add(socket, Token::ControlNotify) .map_err(Error::CreateWaitContext)?; } + if let Some(resample_evt) = self.interrupt.get_resample_evt() { + wait_ctx + .add(resample_evt, Token::InterruptResample) + .map_err(Error::CreateWaitContext)?; + } 'wait: loop { let events = wait_ctx.wait().map_err(Error::WaitError)?; @@ -142,7 +144,7 @@ impl<T: Vhost> Worker<T> { break 'wait; } Token::ControlNotify => { - if let Some(socket) = &self.response_socket { + if let Some(socket) = &self.response_tube { match socket.recv() { Ok(VhostDevRequest::MsixEntryChanged(index)) => { let mut qindex = 0; @@ -199,8 +201,8 @@ impl<T: Vhost> Worker<T> { // No response_socket means it doesn't have any control related // with the msix. Due to this, cannot use the direct irq fd but // should fall back to indirect irq fd. - if self.response_socket.is_some() { - if let Some(msix_config) = &self.interrupt.msix_config { + if self.response_tube.is_some() { + if let Some(msix_config) = self.interrupt.get_msix_config() { let msix_config = msix_config.lock(); let msix_masked = msix_config.masked(); if msix_masked { @@ -228,7 +230,7 @@ impl<T: Vhost> Worker<T> { } fn set_vring_calls(&self) -> Result<()> { - if let Some(msix_config) = &self.interrupt.msix_config { + if let Some(msix_config) = self.interrupt.get_msix_config() { let msix_config = msix_config.lock(); if msix_config.masked() { for (queue_index, _) in self.queues.iter().enumerate() { diff --git a/devices/src/virtio/video/decoder/backend/mod.rs b/devices/src/virtio/video/decoder/backend/mod.rs index a9c1411a1..0bd6b6bf0 100644 --- a/devices/src/virtio/video/decoder/backend/mod.rs +++ b/devices/src/virtio/video/decoder/backend/mod.rs @@ -82,6 +82,7 @@ pub trait DecoderSession { format: Format, output_buffer: RawDescriptor, planes: &[FramePlane], + modifier: u64, ) -> VideoResult<()>; /// Ask the device to reuse an output buffer previously passed to diff --git a/devices/src/virtio/video/decoder/backend/vda.rs b/devices/src/virtio/video/decoder/backend/vda.rs index 716347cd6..99ce7cbdf 100644 --- a/devices/src/virtio/video/decoder/backend/vda.rs +++ b/devices/src/virtio/video/decoder/backend/vda.rs @@ -155,6 +155,7 @@ impl<'a> DecoderSession for LibvdaSession<'a> { format: Format, output_buffer: RawDescriptor, planes: &[FramePlane], + modifier: u64, ) -> VideoResult<()> { let vda_planes: Vec<libvda::FramePlane> = planes.into_iter().map(Into::into).collect(); Ok(self.session.use_output_buffer( @@ -162,6 +163,7 @@ impl<'a> DecoderSession for LibvdaSession<'a> { libvda::PixelFormat::try_from(format)?, output_buffer, &vda_planes, + modifier, )?) } diff --git a/devices/src/virtio/video/decoder/mod.rs b/devices/src/virtio/video/decoder/mod.rs index 7314583e9..9545bb9cb 100644 --- a/devices/src/virtio/video/decoder/mod.rs +++ b/devices/src/virtio/video/decoder/mod.rs @@ -9,11 +9,9 @@ use std::collections::{BTreeMap, BTreeSet, VecDeque}; use std::convert::TryInto; use backend::*; -use base::{error, IntoRawDescriptor, WaitContext}; +use base::{error, IntoRawDescriptor, Tube, WaitContext}; -use crate::virtio::resource_bridge::{ - self, BufferInfo, ResourceInfo, ResourceRequest, ResourceRequestSocket, -}; +use crate::virtio::resource_bridge::{self, BufferInfo, ResourceInfo, ResourceRequest}; use crate::virtio::video::async_cmd_desc_map::AsyncCmdDescMap; use crate::virtio::video::command::{QueueType, VideoCmd}; use crate::virtio::video::control::{CtrlType, CtrlVal, QueryCtrlType}; @@ -179,9 +177,6 @@ struct Context<S: DecoderSession> { in_res: InputResources, out_res: OutputResources, - // Set the flag if we need to clear output resource when the output queue is cleared next time. - is_clear_out_res_needed: bool, - // Set the flag when we ask the decoder reset, and unset when the reset is done. is_resetting: bool, @@ -204,7 +199,6 @@ impl<S: DecoderSession> Context<S> { out_params: Default::default(), in_res: Default::default(), out_res: Default::default(), - is_clear_out_res_needed: false, is_resetting: false, pending_ready_pictures: Default::default(), session: None, @@ -262,7 +256,7 @@ impl<S: DecoderSession> Context<S> { fn get_resource_info( &self, queue_type: QueueType, - res_bridge: &ResourceRequestSocket, + res_bridge: &Tube, resource_id: u32, ) -> VideoResult<BufferInfo> { let res_id_to_res_handle = match queue_type { @@ -340,12 +334,6 @@ impl<S: DecoderSession> Context<S> { // No need to set `frame_rate`, as it's only for the encoder. ..Default::default() }; - - // That eos_resource_id has value means there are previous output resources. - // Clear the output resources when the output queue is cleared next time. - if self.out_res.eos_resource_id.is_some() { - self.is_clear_out_res_needed = true; - } } fn handle_notify_end_of_bitstream_buffer(&mut self, bitstream_id: i32) -> Option<ResourceId> { @@ -540,7 +528,7 @@ impl<'a, D: DecoderBackend> Decoder<D> { fn queue_input_resource( &mut self, - resource_bridge: &ResourceRequestSocket, + resource_bridge: &Tube, stream_id: StreamId, resource_id: ResourceId, timestamp: u64, @@ -604,7 +592,7 @@ impl<'a, D: DecoderBackend> Decoder<D> { fn queue_output_resource( &mut self, - resource_bridge: &ResourceRequestSocket, + resource_bridge: &Tube, stream_id: StreamId, resource_id: ResourceId, ) -> VideoResult<VideoCmdResponseType> { @@ -661,7 +649,13 @@ impl<'a, D: DecoderBackend> Decoder<D> { // Take ownership of this file by `into_raw_descriptor()` as this // file will be closed by libvda. let fd = resource_info.file.into_raw_descriptor(); - session.use_output_buffer(buffer_id as i32, Format::NV12, fd, &planes) + session.use_output_buffer( + buffer_id as i32, + Format::NV12, + fd, + &planes, + resource_info.modifier, + ) } }?; Ok(VideoCmdResponseType::Async(AsyncCmdTag::Queue { @@ -807,11 +801,7 @@ impl<'a, D: DecoderBackend> Decoder<D> { })) } QueueType::Output => { - if std::mem::replace(&mut ctx.is_clear_out_res_needed, false) { - ctx.out_res = Default::default(); - } else { - ctx.out_res.queued_res_ids.clear(); - } + ctx.out_res.queued_res_ids.clear(); Ok(VideoCmdResponseType::Sync(CmdResponse::NoData)) } } @@ -823,7 +813,7 @@ impl<D: DecoderBackend> Device for Decoder<D> { &mut self, cmd: VideoCmd, wait_ctx: &WaitContext<Token>, - resource_bridge: &ResourceRequestSocket, + resource_bridge: &Tube, ) -> ( VideoCmdResponseType, Option<(u32, Vec<VideoEvtResponseType>)>, diff --git a/devices/src/virtio/video/device.rs b/devices/src/virtio/video/device.rs index e5f0a0f6e..d56700cba 100644 --- a/devices/src/virtio/video/device.rs +++ b/devices/src/virtio/video/device.rs @@ -4,9 +4,8 @@ //! Definition of the trait `Device` that each backend video device must implement. -use base::{PollToken, WaitContext}; +use base::{PollToken, Tube, WaitContext}; -use crate::virtio::resource_bridge::ResourceRequestSocket; use crate::virtio::video::async_cmd_desc_map::AsyncCmdDescMap; use crate::virtio::video::command::{QueueType, VideoCmd}; use crate::virtio::video::error::*; @@ -102,7 +101,7 @@ pub trait Device { &mut self, cmd: VideoCmd, wait_ctx: &WaitContext<Token>, - resource_bridge: &ResourceRequestSocket, + resource_bridge: &Tube, ) -> ( VideoCmdResponseType, Option<(u32, Vec<VideoEvtResponseType>)>, diff --git a/devices/src/virtio/video/encoder/mod.rs b/devices/src/virtio/video/encoder/mod.rs index 34eec5d4b..7d32d378c 100644 --- a/devices/src/virtio/video/encoder/mod.rs +++ b/devices/src/virtio/video/encoder/mod.rs @@ -11,12 +11,10 @@ mod libvda_encoder; pub use encoder::EncoderError; pub use libvda_encoder::LibvdaEncoder; -use base::{error, warn, WaitContext}; +use base::{error, warn, Tube, WaitContext}; use std::collections::{BTreeMap, BTreeSet}; -use crate::virtio::resource_bridge::{ - self, BufferInfo, ResourceInfo, ResourceRequest, ResourceRequestSocket, -}; +use crate::virtio::resource_bridge::{self, BufferInfo, ResourceInfo, ResourceRequest}; use crate::virtio::video::async_cmd_desc_map::AsyncCmdDescMap; use crate::virtio::video::command::{QueueType, VideoCmd}; use crate::virtio::video::control::*; @@ -481,6 +479,7 @@ impl<T: EncoderSession> Stream<T> { } } + #[allow(clippy::unnecessary_wraps)] fn notify_error(&self, error: EncoderError) -> Option<Vec<VideoEvtResponseType>> { error!( "Received encoder error event for stream {}: {}", @@ -499,7 +498,7 @@ pub struct EncoderDevice<T: Encoder> { streams: BTreeMap<u32, Stream<T::Session>>, } -fn get_resource_info(res_bridge: &ResourceRequestSocket, uuid: u128) -> VideoResult<BufferInfo> { +fn get_resource_info(res_bridge: &Tube, uuid: u128) -> VideoResult<BufferInfo> { match resource_bridge::get_resource_info( res_bridge, ResourceRequest::GetBuffer { id: uuid as u32 }, @@ -519,6 +518,7 @@ impl<T: Encoder> EncoderDevice<T> { }) } + #[allow(clippy::unnecessary_wraps)] fn query_capabilities(&self, queue_type: QueueType) -> VideoResult<VideoCmdResponseType> { let descs = match queue_type { QueueType::Input => self.cros_capabilities.input_format_descs.clone(), @@ -597,7 +597,7 @@ impl<T: Encoder> EncoderDevice<T> { fn resource_create( &mut self, wait_ctx: &WaitContext<Token>, - resource_bridge: &ResourceRequestSocket, + resource_bridge: &Tube, stream_id: u32, queue_type: QueueType, resource_id: u32, @@ -672,7 +672,7 @@ impl<T: Encoder> EncoderDevice<T> { fn resource_queue( &mut self, - resource_bridge: &ResourceRequestSocket, + resource_bridge: &Tube, stream_id: u32, queue_type: QueueType, resource_id: u32, @@ -1200,7 +1200,7 @@ impl<T: Encoder> Device for EncoderDevice<T> { &mut self, req: VideoCmd, wait_ctx: &WaitContext<Token>, - resource_bridge: &ResourceRequestSocket, + resource_bridge: &Tube, ) -> ( VideoCmdResponseType, Option<(u32, Vec<VideoEvtResponseType>)>, diff --git a/devices/src/virtio/video/format.rs b/devices/src/virtio/video/format.rs index a7d61cebf..e52695b93 100644 --- a/devices/src/virtio/video/format.rs +++ b/devices/src/virtio/video/format.rs @@ -227,7 +227,7 @@ impl Response for FormatDesc { plane_align: Le32::from(0), num_frames: Le32::from(self.frame_formats.len() as u32), })?; - self.frame_formats.iter().map(|ff| ff.write(w)).collect() + self.frame_formats.iter().try_for_each(|ff| ff.write(w)) } } diff --git a/devices/src/virtio/video/mod.rs b/devices/src/virtio/video/mod.rs index 4a64ec1c3..b67c7adf2 100644 --- a/devices/src/virtio/video/mod.rs +++ b/devices/src/virtio/video/mod.rs @@ -10,11 +10,10 @@ use std::fmt::{self, Display}; use std::thread; -use base::{error, AsRawDescriptor, Error as SysError, Event, RawDescriptor}; +use base::{error, AsRawDescriptor, Error as SysError, Event, RawDescriptor, Tube}; use data_model::{DataInit, Le32}; use vm_memory::GuestMemory; -use crate::virtio::resource_bridge::ResourceRequestSocket; use crate::virtio::virtio_device::VirtioDevice; use crate::virtio::{self, copy_config, DescriptorError, Interrupt}; @@ -92,7 +91,7 @@ pub enum VideoDeviceType { pub struct VideoDevice { device_type: VideoDeviceType, kill_evt: Option<Event>, - resource_bridge: Option<ResourceRequestSocket>, + resource_bridge: Option<Tube>, base_features: u64, } @@ -100,7 +99,7 @@ impl VideoDevice { pub fn new( base_features: u64, device_type: VideoDeviceType, - resource_bridge: Option<ResourceRequestSocket>, + resource_bridge: Option<Tube>, ) -> VideoDevice { VideoDevice { device_type, diff --git a/devices/src/virtio/video/protocol.rs b/devices/src/virtio/video/protocol.rs index fb574a4f0..ec03f8c12 100644 --- a/devices/src/virtio/video/protocol.rs +++ b/devices/src/virtio/video/protocol.rs @@ -4,7 +4,7 @@ //! This file was generated by the following commands and modified manually. //! -//! ``` +//! ```shell //! $ bindgen virtio_video.h \ //! --whitelist-type "virtio_video.*" \ //! --whitelist-var "VIRTIO_VIDEO_.*" \ diff --git a/devices/src/virtio/video/response.rs b/devices/src/virtio/video/response.rs index bbc71ddf6..a32d3fc1e 100644 --- a/devices/src/virtio/video/response.rs +++ b/devices/src/virtio/video/response.rs @@ -104,7 +104,7 @@ impl Response for CmdResponse { num_descs: Le32::from(descs.len() as u32), ..Default::default() })?; - descs.iter().map(|d| d.write(w)).collect() + descs.iter().try_for_each(|d| d.write(w)) } ResourceQueue { timestamp, diff --git a/devices/src/virtio/video/worker.rs b/devices/src/virtio/video/worker.rs index 15da5971e..f7a3a3214 100644 --- a/devices/src/virtio/video/worker.rs +++ b/devices/src/virtio/video/worker.rs @@ -6,11 +6,10 @@ use std::collections::VecDeque; -use base::{error, info, Event, WaitContext}; +use base::{error, info, Event, Tube, WaitContext}; use vm_memory::GuestMemory; use crate::virtio::queue::{DescriptorChain, Queue}; -use crate::virtio::resource_bridge::ResourceRequestSocket; use crate::virtio::video::async_cmd_desc_map::AsyncCmdDescMap; use crate::virtio::video::command::{QueueType, VideoCmd}; use crate::virtio::video::device::{ @@ -19,7 +18,7 @@ use crate::virtio::video::device::{ use crate::virtio::video::event::{self, EvtType, VideoEvt}; use crate::virtio::video::response::{self, Response}; use crate::virtio::video::{Error, Result}; -use crate::virtio::{Interrupt, Reader, Writer}; +use crate::virtio::{Interrupt, Reader, SignalableInterrupt, Writer}; pub struct Worker { pub interrupt: Interrupt, @@ -27,7 +26,7 @@ pub struct Worker { pub cmd_evt: Event, pub event_evt: Event, pub kill_evt: Event, - pub resource_bridge: ResourceRequestSocket, + pub resource_bridge: Tube, } /// Pair of a descriptor chain and a response to be written. @@ -276,8 +275,13 @@ impl Worker { (&self.cmd_evt, Token::CmdQueue), (&self.event_evt, Token::EventQueue), (&self.kill_evt, Token::Kill), - (self.interrupt.get_resample_evt(), Token::InterruptResample), ]) + .and_then(|wc| { + if let Some(resample_evt) = self.interrupt.get_resample_evt() { + wc.add(resample_evt, Token::InterruptResample)?; + } + Ok(wc) + }) .map_err(Error::WaitContextCreationFailed)?; // Stores descriptors in which responses for asynchronous commands will be written. diff --git a/devices/src/virtio/virtio_pci_device.rs b/devices/src/virtio/virtio_pci_device.rs index e1d592991..b0f63d1d1 100644 --- a/devices/src/virtio/virtio_pci_device.rs +++ b/devices/src/virtio/virtio_pci_device.rs @@ -6,7 +6,7 @@ use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use sync::Mutex; -use base::{warn, AsRawDescriptor, Event, RawDescriptor, Result}; +use base::{warn, AsRawDescriptor, Event, RawDescriptor, Result, Tube}; use data_model::{DataInit, Le32}; use hypervisor::Datamatch; use libc::ERANGE; @@ -19,7 +19,6 @@ use crate::pci::{ PciClassCode, PciConfiguration, PciDevice, PciDeviceError, PciDisplaySubclass, PciHeaderType, PciInterruptPin, PciSubclass, }; -use vm_control::VmIrqRequestSocket; use self::virtio_pci_common_config::VirtioPciCommonConfig; @@ -193,6 +192,7 @@ const NOTIFY_OFF_MULTIPLIER: u32 = 4; // A dword per notification address. const VIRTIO_PCI_VENDOR_ID: u16 = 0x1af4; const VIRTIO_PCI_DEVICE_ID_BASE: u16 = 0x1040; // Add to device type to get device ID. +const VIRTIO_PCI_REVISION_ID: u8 = 1; /// Implements the /// [PCI](http://docs.oasis-open.org/virtio/virtio/v1.0/cs04/virtio-v1.0-cs04.html#x1-650001) @@ -221,7 +221,7 @@ impl VirtioPciDevice { pub fn new( mem: GuestMemory, device: Box<dyn VirtioDevice>, - msi_device_socket: VmIrqRequestSocket, + msi_device_tube: Tube, ) -> Result<Self> { let mut queue_evts = Vec::new(); for _ in device.queue_max_sizes() { @@ -250,7 +250,7 @@ impl VirtioPciDevice { // One MSI-X vector per queue plus one for configuration changes. let msix_num = u16::try_from(num_queues + 1).map_err(|_| base::Error::new(ERANGE))?; - let msix_config = Arc::new(Mutex::new(MsixConfig::new(msix_num, msi_device_socket))); + let msix_config = Arc::new(Mutex::new(MsixConfig::new(msix_num, msi_device_tube))); let config_regs = PciConfiguration::new( VIRTIO_PCI_VENDOR_ID, @@ -261,6 +261,7 @@ impl VirtioPciDevice { PciHeaderType::Device, VIRTIO_PCI_VENDOR_ID, pci_device_id, + VIRTIO_PCI_REVISION_ID, ); Ok(VirtioPciDevice { @@ -389,7 +390,7 @@ impl VirtioPciDevice { impl PciDevice for VirtioPciDevice { fn debug_label(&self) -> String { - format!("virtio-pci ({})", self.device.debug_label()) + format!("pci{}", self.device.debug_label()) } fn allocate_address( diff --git a/devices/src/virtio/wl.rs b/devices/src/virtio/wl.rs index 8359b442d..887ca2d19 100644 --- a/devices/src/virtio/wl.rs +++ b/devices/src/virtio/wl.rs @@ -28,7 +28,6 @@ //! the virtio queue, and routing messages in and out of `WlState`. Possible events include the kill //! event, available descriptors on the `in` or `out` queue, and incoming data on any vfd's socket. -use std::cell::RefCell; use std::collections::btree_map::Entry; use std::collections::{BTreeMap as Map, BTreeSet as Set, VecDeque}; use std::convert::From; @@ -49,29 +48,29 @@ use std::time::Duration; #[cfg(feature = "minigbm")] use libc::{EBADF, EINVAL}; -use data_model::VolatileMemoryError; use data_model::*; use base::{ error, pipe, round_up_to_page_size, warn, AsRawDescriptor, Error, Event, FileFlags, FromRawDescriptor, PollToken, RawDescriptor, Result, ScmSocket, SharedMemory, SharedMemoryUnix, - WaitContext, + Tube, TubeError, WaitContext, }; #[cfg(feature = "minigbm")] use base::{ioctl_iow_nr, ioctl_with_ref}; #[cfg(feature = "gpu")] use base::{IntoRawDescriptor, SafeDescriptor}; -use msg_socket::{MsgError, MsgReceiver, MsgSender}; use vm_memory::{GuestMemory, GuestMemoryError}; #[cfg(feature = "minigbm")] use vm_control::GpuMemoryDesc; -use super::resource_bridge::*; -use super::{DescriptorChain, Interrupt, Queue, Reader, VirtioDevice, Writer, TYPE_WL}; -use vm_control::{ - MaybeOwnedDescriptor, MemSlot, VmMemoryControlRequestSocket, VmMemoryRequest, VmMemoryResponse, +use super::resource_bridge::{ + get_resource_info, BufferInfo, ResourceBridgeError, ResourceInfo, ResourceRequest, }; +use super::{ + DescriptorChain, Interrupt, Queue, Reader, SignalableInterrupt, VirtioDevice, Writer, TYPE_WL, +}; +use vm_control::{MemSlot, VmMemoryRequest, VmMemoryResponse}; const VIRTWL_SEND_MAX_ALLOCS: usize = 28; const VIRTIO_WL_CMD_VFD_NEW: u32 = 256; @@ -266,7 +265,7 @@ enum WlError { NewPipe(Error), SocketConnect(io::Error), SocketNonBlock(io::Error), - VmControl(MsgError), + VmControl(TubeError), VmBadResponse, CheckedOffset, ParseDesc(io::Error), @@ -333,21 +332,28 @@ impl From<VolatileMemoryError> for WlError { #[derive(Clone)] struct VmRequester { - inner: Rc<RefCell<VmMemoryControlRequestSocket>>, + inner: Rc<Tube>, } impl VmRequester { - fn new(vm_socket: VmMemoryControlRequestSocket) -> VmRequester { + fn new(vm_socket: Tube) -> VmRequester { VmRequester { - inner: Rc::new(RefCell::new(vm_socket)), + inner: Rc::new(vm_socket), } } - fn request(&self, request: VmMemoryRequest) -> WlResult<VmMemoryResponse> { - let mut inner = self.inner.borrow_mut(); - let vm_socket = &mut *inner; - vm_socket.send(&request).map_err(WlError::VmControl)?; - vm_socket.recv().map_err(WlError::VmControl) + fn request(&self, request: &VmMemoryRequest) -> WlResult<VmMemoryResponse> { + self.inner.send(&request).map_err(WlError::VmControl)?; + self.inner.recv().map_err(WlError::VmControl) + } + + fn register_memory(&self, shm: SharedMemory) -> WlResult<(SharedMemory, VmMemoryResponse)> { + let request = VmMemoryRequest::RegisterMemory(shm); + let response = self.request(&request)?; + match request { + VmMemoryRequest::RegisterMemory(shm) => Ok((shm, response)), + _ => unreachable!(), + } } } @@ -554,7 +560,7 @@ impl<'a> WlResp<'a> { #[derive(Default)] struct WlVfd { socket: Option<UnixStream>, - guest_shared_memory: Option<(u64 /* size */, SharedMemory)>, + guest_shared_memory: Option<SharedMemory>, remote_pipe: Option<File>, local_pipe: Option<(u32 /* flags */, File)>, slot: Option<(MemSlot, u64 /* pfn */, VmRequester)>, @@ -594,14 +600,16 @@ impl WlVfd { let vfd_shm = SharedMemory::named("virtwl_alloc", size_page_aligned).map_err(WlError::NewAlloc)?; - let register_response = vm.request(VmMemoryRequest::RegisterMemory( - MaybeOwnedDescriptor::Borrowed(vfd_shm.as_raw_descriptor()), - vfd_shm.size() as usize, - ))?; + let register_request = VmMemoryRequest::RegisterMemory(vfd_shm); + let register_response = vm.request(®ister_request)?; match register_response { VmMemoryResponse::RegisterMemory { pfn, slot } => { let mut vfd = WlVfd::default(); - vfd.guest_shared_memory = Some((vfd_shm.size(), vfd_shm)); + let vfd_shm = match register_request { + VmMemoryRequest::RegisterMemory(shm) => shm, + _ => unreachable!(), + }; + vfd.guest_shared_memory = Some(vfd_shm); vfd.slot = Some((slot, pfn, vm)); Ok(vfd) } @@ -617,22 +625,22 @@ impl WlVfd { format: u32, ) -> WlResult<(WlVfd, GpuMemoryDesc)> { let allocate_and_register_gpu_memory_response = - vm.request(VmMemoryRequest::AllocateAndRegisterGpuMemory { + vm.request(&VmMemoryRequest::AllocateAndRegisterGpuMemory { width, height, format, })?; match allocate_and_register_gpu_memory_response { VmMemoryResponse::AllocateAndRegisterGpuMemory { - descriptor: MaybeOwnedDescriptor::Owned(file), + descriptor, pfn, slot, desc, } => { let mut vfd = WlVfd::default(); let vfd_shm = - SharedMemory::from_safe_descriptor(file).map_err(WlError::NewAlloc)?; - vfd.guest_shared_memory = Some((vfd_shm.size(), vfd_shm)); + SharedMemory::from_safe_descriptor(descriptor).map_err(WlError::NewAlloc)?; + vfd.guest_shared_memory = Some(vfd_shm); vfd.slot = Some((slot, pfn, vm)); vfd.is_dmabuf = true; Ok((vfd, desc)) @@ -648,7 +656,7 @@ impl WlVfd { } match &self.guest_shared_memory { - Some((_, descriptor)) => { + Some(descriptor) => { let sync = dma_buf_sync { flags: flags as u64, }; @@ -686,21 +694,14 @@ impl WlVfd { // for how big the shared memory chunk to map into guest memory is. If seeking to the end // fails, we assume it's a socket or pipe with read/write semantics. match descriptor.seek(SeekFrom::End(0)) { - Ok(fd_size) => { - let size = round_up_to_page_size(fd_size as usize) as u64; - let register_response = vm.request(VmMemoryRequest::RegisterMemory( - MaybeOwnedDescriptor::Borrowed(descriptor.as_raw_descriptor()), - size as usize, - ))?; + Ok(_) => { + let shm = SharedMemory::from_file(descriptor).map_err(WlError::FromSharedMemory)?; + let (shm, register_response) = vm.register_memory(shm)?; match register_response { VmMemoryResponse::RegisterMemory { pfn, slot } => { let mut vfd = WlVfd::default(); - vfd.guest_shared_memory = Some(( - size, - SharedMemory::from_file(descriptor) - .map_err(WlError::FromSharedMemory)?, - )); + vfd.guest_shared_memory = Some(shm); vfd.slot = Some((slot, pfn, vm)); Ok(vfd) } @@ -748,16 +749,16 @@ impl WlVfd { // Size in bytes of the shared memory VFD. fn size(&self) -> Option<u64> { - self.guest_shared_memory.as_ref().map(|&(size, _)| size) + self.guest_shared_memory.as_ref().map(|shm| shm.size()) } // The FD that gets sent if this VFD is sent over a socket. fn send_descriptor(&self) -> Option<RawDescriptor> { self.guest_shared_memory .as_ref() - .map(|(_, shm)| shm.as_raw_descriptor()) - .or_else(|| self.socket.as_ref().map(|s| s.as_raw_descriptor())) - .or_else(|| self.remote_pipe.as_ref().map(|p| p.as_raw_descriptor())) + .map(|shm| shm.as_raw_descriptor()) + .or(self.socket.as_ref().map(|s| s.as_raw_descriptor())) + .or(self.remote_pipe.as_ref().map(|p| p.as_raw_descriptor())) } // The FD that is used for polling for events on this VFD. @@ -842,7 +843,7 @@ impl WlVfd { fn close(&mut self) -> WlResult<()> { if let Some((slot, _, vm)) = self.slot.take() { - vm.request(VmMemoryRequest::UnregisterMemory(slot))?; + vm.request(&VmMemoryRequest::UnregisterMemory(slot))?; } self.socket = None; self.remote_pipe = None; @@ -867,7 +868,7 @@ enum WlRecv { struct WlState { wayland_paths: Map<String, PathBuf>, vm: VmRequester, - resource_bridge: Option<ResourceRequestSocket>, + resource_bridge: Option<Tube>, use_transition_flags: bool, wait_ctx: WaitContext<u32>, vfds: Map<u32, WlVfd>, @@ -884,14 +885,14 @@ struct WlState { impl WlState { fn new( wayland_paths: Map<String, PathBuf>, - vm_socket: VmMemoryControlRequestSocket, + vm_tube: Tube, use_transition_flags: bool, use_send_vfd_v2: bool, - resource_bridge: Option<ResourceRequestSocket>, + resource_bridge: Option<Tube>, ) -> WlState { WlState { wayland_paths, - vm: VmRequester::new(vm_socket), + vm: VmRequester::new(vm_tube), resource_bridge, wait_ctx: WaitContext::new().expect("failed to create WaitContext"), use_transition_flags, @@ -1105,7 +1106,7 @@ impl WlState { fn get_info(&mut self, request: ResourceRequest) -> Option<File> { let sock = self.resource_bridge.as_ref().unwrap(); match get_resource_info(sock, request) { - Ok(ResourceInfo::Buffer(BufferInfo { file, planes: _ })) => Some(file), + Ok(ResourceInfo::Buffer(BufferInfo { file, .. })) => Some(file), Ok(ResourceInfo::Fence { file }) => Some(file), Err(ResourceBridgeError::InvalidResource(req)) => { warn!("attempt to send non-existent gpu resource {}", req); @@ -1476,10 +1477,10 @@ impl Worker { in_queue: Queue, out_queue: Queue, wayland_paths: Map<String, PathBuf>, - vm_socket: VmMemoryControlRequestSocket, + vm_tube: Tube, use_transition_flags: bool, use_send_vfd_v2: bool, - resource_bridge: Option<ResourceRequestSocket>, + resource_bridge: Option<Tube>, ) -> Worker { Worker { interrupt, @@ -1488,7 +1489,7 @@ impl Worker { out_queue, state: WlState::new( wayland_paths, - vm_socket, + vm_tube, use_transition_flags, use_send_vfd_v2, resource_bridge, @@ -1515,7 +1516,6 @@ impl Worker { (&out_queue_evt, Token::OutQueue), (&kill_evt, Token::Kill), (&self.state.wait_ctx, Token::State), - (self.interrupt.get_resample_evt(), Token::InterruptResample), ]) { Ok(pc) => pc, Err(e) => { @@ -1523,6 +1523,15 @@ impl Worker { return; } }; + if let Some(resample_evt) = self.interrupt.get_resample_evt() { + if wait_ctx + .add(resample_evt, Token::InterruptResample) + .is_err() + { + error!("failed adding resample event to WaitContext."); + return; + } + } 'wait: loop { let mut signal_used_in = false; @@ -1660,8 +1669,8 @@ pub struct Wl { kill_evt: Option<Event>, worker_thread: Option<thread::JoinHandle<()>>, wayland_paths: Map<String, PathBuf>, - vm_socket: Option<VmMemoryControlRequestSocket>, - resource_bridge: Option<ResourceRequestSocket>, + vm_socket: Option<Tube>, + resource_bridge: Option<Tube>, use_transition_flags: bool, use_send_vfd_v2: bool, base_features: u64, @@ -1671,14 +1680,14 @@ impl Wl { pub fn new( base_features: u64, wayland_paths: Map<String, PathBuf>, - vm_socket: VmMemoryControlRequestSocket, - resource_bridge: Option<ResourceRequestSocket>, + vm_tube: Tube, + resource_bridge: Option<Tube>, ) -> Result<Wl> { Ok(Wl { kill_evt: None, worker_thread: None, wayland_paths, - vm_socket: Some(vm_socket), + vm_socket: Some(vm_tube), resource_bridge, use_transition_flags: false, use_send_vfd_v2: false, diff --git a/disk/src/composite.rs b/disk/src/composite.rs index a23eca0ca..efa5e1de8 100644 --- a/disk/src/composite.rs +++ b/disk/src/composite.rs @@ -194,7 +194,7 @@ impl CompositeDiskFile { } } - fn disk_at_offset<'a>(&'a mut self, offset: u64) -> io::Result<&'a mut ComponentDiskPart> { + fn disk_at_offset(&mut self, offset: u64) -> io::Result<&mut ComponentDiskPart> { self.component_disks .iter_mut() .find(|disk| disk.range().contains(&offset)) @@ -342,13 +342,14 @@ impl AsRawDescriptors for CompositeDiskFile { #[cfg(test)] mod tests { use super::*; - use base::{AsRawDescriptor, SharedMemory}; + use base::AsRawDescriptor; use data_model::VolatileMemory; + use tempfile::tempfile; #[test] fn block_duplicate_offset_disks() { - let file1: File = SharedMemory::new(None).unwrap().into(); - let file2: File = SharedMemory::new(None).unwrap().into(); + let file1 = tempfile().unwrap(); + let file2 = tempfile().unwrap(); let disk_part1 = ComponentDiskPart { file: Box::new(file1), offset: 0, @@ -364,8 +365,8 @@ mod tests { #[test] fn get_len() { - let file1: File = SharedMemory::new(None).unwrap().into(); - let file2: File = SharedMemory::new(None).unwrap().into(); + let file1 = tempfile().unwrap(); + let file2 = tempfile().unwrap(); let disk_part1 = ComponentDiskPart { file: Box::new(file1), offset: 0, @@ -383,7 +384,7 @@ mod tests { #[test] fn single_file_passthrough() { - let file: File = SharedMemory::new(None).unwrap().into(); + let file = tempfile().unwrap(); let disk_part = ComponentDiskPart { file: Box::new(file), offset: 0, @@ -405,9 +406,9 @@ mod tests { #[test] fn triple_file_fds() { - let file1: File = SharedMemory::new(None).unwrap().into(); - let file2: File = SharedMemory::new(None).unwrap().into(); - let file3: File = SharedMemory::new(None).unwrap().into(); + let file1 = tempfile().unwrap(); + let file2 = tempfile().unwrap(); + let file3 = tempfile().unwrap(); let mut in_fds = vec![ file1.as_raw_descriptor(), file2.as_raw_descriptor(), @@ -437,9 +438,9 @@ mod tests { #[test] fn triple_file_passthrough() { - let file1: File = SharedMemory::new(None).unwrap().into(); - let file2: File = SharedMemory::new(None).unwrap().into(); - let file3: File = SharedMemory::new(None).unwrap().into(); + let file1 = tempfile().unwrap(); + let file2 = tempfile().unwrap(); + let file3 = tempfile().unwrap(); let disk_part1 = ComponentDiskPart { file: Box::new(file1), offset: 0, @@ -467,14 +468,14 @@ mod tests { composite .read_exact_at_volatile(output_volatile_memory.get_slice(0, 200).unwrap(), 50) .unwrap(); - assert!(input_memory.into_iter().eq(output_memory.into_iter())); + assert!(input_memory.iter().eq(output_memory.iter())); } #[test] fn triple_file_punch_hole() { - let file1: File = SharedMemory::new(None).unwrap().into(); - let file2: File = SharedMemory::new(None).unwrap().into(); - let file3: File = SharedMemory::new(None).unwrap().into(); + let file1 = tempfile().unwrap(); + let file2 = tempfile().unwrap(); + let file3 = tempfile().unwrap(); let disk_part1 = ComponentDiskPart { file: Box::new(file1), offset: 0, @@ -507,14 +508,14 @@ mod tests { for i in 50..250 { input_memory[i] = 0; } - assert!(input_memory.into_iter().eq(output_memory.into_iter())); + assert!(input_memory.iter().eq(output_memory.iter())); } #[test] fn triple_file_write_zeroes() { - let file1: File = SharedMemory::new(None).unwrap().into(); - let file2: File = SharedMemory::new(None).unwrap().into(); - let file3: File = SharedMemory::new(None).unwrap().into(); + let file1 = tempfile().unwrap(); + let file2 = tempfile().unwrap(); + let file3 = tempfile().unwrap(); let disk_part1 = ComponentDiskPart { file: Box::new(file1), offset: 0, @@ -558,6 +559,6 @@ mod tests { i, input_memory[i], output_memory[i] ); } - assert!(input_memory.into_iter().eq(output_memory.into_iter())); + assert!(input_memory.iter().eq(output_memory.iter())); } } diff --git a/disk/src/disk.rs b/disk/src/disk.rs index 9f934135a..bd1945d30 100644 --- a/disk/src/disk.rs +++ b/disk/src/disk.rs @@ -264,36 +264,45 @@ pub fn convert(src_file: File, dst_file: File, dst_type: ImageType) -> Result<() } } -/// Detect the type of an image file by checking for a valid qcow2 header. +/// Detect the type of an image file by checking for a valid header of the supported formats. pub fn detect_image_type(file: &File) -> Result<ImageType> { let mut f = file; + let disk_size = f.get_len().map_err(Error::SeekingFile)?; let orig_seek = f.seek(SeekFrom::Current(0)).map_err(Error::SeekingFile)?; f.seek(SeekFrom::Start(0)).map_err(Error::SeekingFile)?; - let mut magic = [0u8; 4]; - f.read_exact(&mut magic).map_err(Error::ReadingHeader)?; - let magic = u32::from_be_bytes(magic); + + // Try to read the disk in a nicely-aligned block size unless the whole file is smaller. + const MAGIC_BLOCK_SIZE: usize = 4096; + let mut magic = [0u8; MAGIC_BLOCK_SIZE]; + let magic_read_len = if disk_size > MAGIC_BLOCK_SIZE as u64 { + MAGIC_BLOCK_SIZE + } else { + // This cast is safe since we know disk_size is less than MAGIC_BLOCK_SIZE (4096) and + // therefore is representable in usize. + disk_size as usize + }; + + f.read_exact(&mut magic[0..magic_read_len]) + .map_err(Error::ReadingHeader)?; + f.seek(SeekFrom::Start(orig_seek)) + .map_err(Error::SeekingFile)?; + #[cfg(feature = "composite-disk")] - { - f.seek(SeekFrom::Start(0)).map_err(Error::SeekingFile)?; - let mut cdisk_magic = [0u8; CDISK_MAGIC_LEN]; - f.read_exact(&mut cdisk_magic[..]) - .map_err(Error::ReadingHeader)?; + if let Some(cdisk_magic) = magic.get(0..CDISK_MAGIC_LEN) { if cdisk_magic == CDISK_MAGIC.as_bytes() { - f.seek(SeekFrom::Start(orig_seek)) - .map_err(Error::SeekingFile)?; return Ok(ImageType::CompositeDisk); } } - let image_type = if magic == QCOW_MAGIC { - ImageType::Qcow2 - } else if magic == SPARSE_HEADER_MAGIC.to_be() { - ImageType::AndroidSparse - } else { - ImageType::Raw - }; - f.seek(SeekFrom::Start(orig_seek)) - .map_err(Error::SeekingFile)?; - Ok(image_type) + + if let Some(magic4) = magic.get(0..4) { + if magic4 == QCOW_MAGIC.to_be_bytes() { + return Ok(ImageType::Qcow2); + } else if magic4 == SPARSE_HEADER_MAGIC.to_le_bytes() { + return Ok(ImageType::AndroidSparse); + } + } + + Ok(ImageType::Raw) } /// Check if the image file type can be used for async disk access. @@ -530,4 +539,62 @@ mod tests { let ex = Executor::new().unwrap(); ex.run_until(write_zeros_async(&ex)).unwrap(); } + + #[test] + fn detect_image_type_raw() { + let mut t = tempfile::tempfile().unwrap(); + // Fill the first block of the file with "random" data. + let buf = "ABCD".as_bytes().repeat(1024); + t.write_all(&buf).unwrap(); + let image_type = detect_image_type(&t).expect("failed to detect image type"); + assert_eq!(image_type, ImageType::Raw); + } + + #[test] + fn detect_image_type_qcow2() { + let mut t = tempfile::tempfile().unwrap(); + // Write the qcow2 magic signature. The rest of the header is not filled in, so if + // detect_image_type is ever updated to validate more of the header, this test would need + // to be updated. + let buf: &[u8] = &[0x51, 0x46, 0x49, 0xfb]; + t.write_all(&buf).unwrap(); + let image_type = detect_image_type(&t).expect("failed to detect image type"); + assert_eq!(image_type, ImageType::Qcow2); + } + + #[test] + fn detect_image_type_android_sparse() { + let mut t = tempfile::tempfile().unwrap(); + // Write the Android sparse magic signature. The rest of the header is not filled in, so if + // detect_image_type is ever updated to validate more of the header, this test would need + // to be updated. + let buf: &[u8] = &[0x3a, 0xff, 0x26, 0xed]; + t.write_all(&buf).unwrap(); + let image_type = detect_image_type(&t).expect("failed to detect image type"); + assert_eq!(image_type, ImageType::AndroidSparse); + } + + #[test] + #[cfg(feature = "composite-disk")] + fn detect_image_type_composite() { + let mut t = tempfile::tempfile().unwrap(); + // Write the composite disk magic signature. The rest of the header is not filled in, so if + // detect_image_type is ever updated to validate more of the header, this test would need + // to be updated. + let buf = "composite_disk\x1d".as_bytes(); + t.write_all(&buf).unwrap(); + let image_type = detect_image_type(&t).expect("failed to detect image type"); + assert_eq!(image_type, ImageType::CompositeDisk); + } + + #[test] + fn detect_image_type_small_file() { + let mut t = tempfile::tempfile().unwrap(); + // Write a file smaller than the four-byte qcow2/sparse magic to ensure the small file logic + // works correctly and handles it as a raw file. + let buf: &[u8] = &[0xAA, 0xBB]; + t.write_all(&buf).unwrap(); + let image_type = detect_image_type(&t).expect("failed to detect image type"); + assert_eq!(image_type, ImageType::Raw); + } } diff --git a/disk/src/qcow/mod.rs b/disk/src/qcow/mod.rs index a4a52349c..3a3ef9764 100644 --- a/disk/src/qcow/mod.rs +++ b/disk/src/qcow/mod.rs @@ -433,7 +433,7 @@ impl QcowFile { } let cluster_bits: u32 = header.cluster_bits; - if cluster_bits < MIN_CLUSTER_BITS || cluster_bits > MAX_CLUSTER_BITS { + if !(MIN_CLUSTER_BITS..=MAX_CLUSTER_BITS).contains(&cluster_bits) { return Err(Error::InvalidClusterSize); } let cluster_size = 0x01u64 << cluster_bits; diff --git a/docs/architecture.md b/docs/architecture.md index 5ff0567b2..064826c68 100644 --- a/docs/architecture.md +++ b/docs/architecture.md @@ -75,8 +75,8 @@ Most threads in crosvm will have a wait loop using a `PollContext`, which is a w Note that the limitations of `PollContext` are the same as the limitations of `epoll`. The same FD can not be inserted more than once, and the FD will be automatically removed if the process runs out of references to that FD. A `dup`/`fork` call will increment that reference count, so closing the original FD will not actually remove it from the `PollContext`. It is possible to receive tokens from `PollContext` for an FD that was closed because of a race condition in which an event was registered in the background before the `close` happened. Best practice is to remove an FD before closing it so that events associated with it can be reliably eliminated. -### MsgSocket +### `serde` with Descriptors. -Using raw sockets and pipes to communicate is very inconvenient for rich data types. To help make this easier and less error prone, crosvm has the `msg_socket` crate. Included is a trait for messages encodable on a Unix socket (`MsgOnSocket`), a set of traits for sending and receiving (`MsgSender`/`MsgReceiver`), and implementations of those traits over `UnixSeqpacket` (`MsgSocket`/`Sender`/`Receiver`). To make implementing `MsgOnSocket` very easy, a custom derive for that trait can be utilized with `#[derive(MsgOnSocket)]`. The custom derive will work for enums and structs with nested data, primitive types, and anything that implements `AsRawFd`. However, structures with no fixed upper limit in size, such as `Vec` or `BTreeMap`, are not supported. +Using raw sockets and pipes to communicate is very inconvenient for rich data types. To help make this easier and less error prone, crosvm uses the `serde` crate. To allow transmitting types with embedded descriptors (FDs on Linux or HANDLEs on Windows), a module is provided for sending and receiving descriptors alongside the plain old bytes that serde consumes. [minijail]: https://android.googlesource.com/platform/external/minijail diff --git a/fuse/src/server.rs b/fuse/src/server.rs index 0c62896b6..f28e9488a 100644 --- a/fuse/src/server.rs +++ b/fuse/src/server.rs @@ -1298,6 +1298,7 @@ impl<F: FileSystem + Sync> Server<F> { } } + #[allow(clippy::unnecessary_wraps)] fn interrupt(&self, _in_header: InHeader) -> Result<usize> { Ok(0) } @@ -1310,6 +1311,7 @@ impl<F: FileSystem + Sync> Server<F> { } } + #[allow(clippy::unnecessary_wraps)] fn destroy(&self) -> Result<usize> { // No reply to this function. self.fs.destroy(); diff --git a/fuzz/Cargo.toml b/fuzz/Cargo.toml index ab2e06558..d4c7a1e26 100644 --- a/fuzz/Cargo.toml +++ b/fuzz/Cargo.toml @@ -18,10 +18,6 @@ tempfile = { path = "../tempfile" } usb_util = { path = "../usb_util" } vm_memory = { path = "../vm_memory" } -# Prevent this from interfering with workspaces -[workspace] -members = ["."] - [[bin]] name = "crosvm_block_fuzzer" path = "block_fuzzer.rs" @@ -46,5 +42,3 @@ path = "virtqueue_fuzzer.rs" name = "crosvm_zimage_fuzzer" path = "zimage_fuzzer.rs" -[patch.crates-io] -base = { path = "../base" } diff --git a/gpu_display/build.rs b/gpu_display/build.rs index 8f134f942..5cdba2f7e 100644 --- a/gpu_display/build.rs +++ b/gpu_display/build.rs @@ -62,13 +62,13 @@ fn compile_protocol<P: AsRef<Path>>(name: &str, out: P) -> PathBuf { .arg(&in_protocol) .arg(&out_code) .output() - .unwrap(); + .expect("wayland-scanner code failed"); Command::new("wayland-scanner") .arg("client-header") .arg(&in_protocol) .arg(&out_header) .output() - .unwrap(); + .expect("wayland-scanner client-header failed"); out_code } diff --git a/gpu_display/src/gpu_display_stub.rs b/gpu_display/src/gpu_display_stub.rs index 63cc81104..11b96c59c 100644 --- a/gpu_display/src/gpu_display_stub.rs +++ b/gpu_display/src/gpu_display_stub.rs @@ -45,12 +45,12 @@ struct Surface { } impl Surface { - fn create(width: u32, height: u32) -> Result<Surface, GpuDisplayError> { - Ok(Surface { + fn create(width: u32, height: u32) -> Surface { + Surface { width, height, buffer: None, - }) + } } /// Gets the buffer at buffer_index, allocating it if necessary. @@ -103,14 +103,14 @@ impl SurfacesHelper { } } - fn create_surface(&mut self, width: u32, height: u32) -> Result<u32, GpuDisplayError> { - let new_surface = Surface::create(width, height)?; + fn create_surface(&mut self, width: u32, height: u32) -> u32 { + let new_surface = Surface::create(width, height); let new_surface_id = self.next_surface_id; self.surfaces.insert(new_surface_id, new_surface); self.next_surface_id = SurfaceId::new(self.next_surface_id.get() + 1).unwrap(); - Ok(new_surface_id.get()) + new_surface_id.get() } fn get_surface(&mut self, surface_id: u32) -> Option<&mut Surface> { @@ -157,7 +157,7 @@ impl DisplayT for DisplayStub { if parent_surface_id.is_some() { return Err(GpuDisplayError::Unsupported); } - self.surfaces.create_surface(width, height) + Ok(self.surfaces.create_surface(width, height)) } fn release_surface(&mut self, surface_id: u32) { diff --git a/gpu_display/src/gpu_display_wl.rs b/gpu_display/src/gpu_display_wl.rs index d6365ac99..d242bbbd1 100644 --- a/gpu_display/src/gpu_display_wl.rs +++ b/gpu_display/src/gpu_display_wl.rs @@ -209,7 +209,7 @@ impl DisplayT for DisplayWl { let buffer_shm = SharedMemory::named("GpuDisplaySurface", buffer_size as u64) .map_err(GpuDisplayError::CreateShm)?; let buffer_mem = MemoryMappingBuilder::new(buffer_size) - .from_descriptor(&buffer_shm) + .from_shared_memory(&buffer_shm) .build() .unwrap(); diff --git a/gpu_display/src/gpu_display_x.rs b/gpu_display/src/gpu_display_x.rs index 5d6a9b609..16bdd1096 100644 --- a/gpu_display/src/gpu_display_x.rs +++ b/gpu_display/src/gpu_display_x.rs @@ -243,7 +243,7 @@ impl Surface { visual: *mut xlib::Visual, width: u32, height: u32, - ) -> Result<Surface, GpuDisplayError> { + ) -> Surface { let keycode_translator = KeycodeTranslator::new(KeycodeTypes::XkbScancode); unsafe { let depth = xlib::XDefaultDepthOfScreen(screen.as_ptr()) as u32; @@ -305,7 +305,7 @@ impl Surface { // Flush everything so that the window is visible immediately. display.flush(); - Ok(Surface { + Surface { display, visual, depth, @@ -320,7 +320,7 @@ impl Surface { buffer_completion_type, delete_window_atom, close_requested: false, - }) + } } } @@ -367,8 +367,8 @@ impl Surface { // The touch event *must* be first per the Linux input subsystem's guidance. let events = &[ virtio_input_event::touch(pressed), - virtio_input_event::absolute_x(max(0, button_event.x) as u32), - virtio_input_event::absolute_y(max(0, button_event.y) as u32), + virtio_input_event::absolute_x(max(0, button_event.x)), + virtio_input_event::absolute_y(max(0, button_event.y)), ]; self.dispatch_to_event_devices(events, EventDeviceKind::Touchscreen); } @@ -377,8 +377,8 @@ impl Surface { if motion.state & xlib::Button1Mask != 0 { let events = &[ virtio_input_event::touch(true), - virtio_input_event::absolute_x(max(0, motion.x) as u32), - virtio_input_event::absolute_y(max(0, motion.y) as u32), + virtio_input_event::absolute_x(max(0, motion.x)), + virtio_input_event::absolute_y(max(0, motion.y)), ]; self.dispatch_to_event_devices(events, EventDeviceKind::Touchscreen); } @@ -730,7 +730,7 @@ impl DisplayT for DisplayX { self.visual, width, height, - )?; + ); let new_surface_id = self.next_id; self.surfaces.insert(new_surface_id, new_surface); self.next_id = ObjectId::new(self.next_id.get() + 1).unwrap(); diff --git a/gpu_display/src/lib.rs b/gpu_display/src/lib.rs index 0ffae8792..1f8ea0f13 100644 --- a/gpu_display/src/lib.rs +++ b/gpu_display/src/lib.rs @@ -15,6 +15,7 @@ mod gpu_display_stub; mod gpu_display_wl; #[cfg(feature = "x")] mod gpu_display_x; +#[cfg(feature = "x")] mod keycode_converter; pub use event_device::{EventDevice, EventDeviceKind}; diff --git a/hypervisor/Cargo.toml b/hypervisor/Cargo.toml index aef0ce4a8..70c0ec41c 100644 --- a/hypervisor/Cargo.toml +++ b/hypervisor/Cargo.toml @@ -12,7 +12,7 @@ enumn = { path = "../enumn" } kvm = { path = "../kvm" } kvm_sys = { path = "../kvm_sys" } libc = "*" -msg_socket = { path = "../msg_socket" } +serde = { version = "1", features = [ "derive" ] } sync = { path = "../sync" } base = { path = "../base" } vm_memory = { path = "../vm_memory" } diff --git a/hypervisor/src/kvm/mod.rs b/hypervisor/src/kvm/mod.rs index 24955c287..78708beca 100644 --- a/hypervisor/src/kvm/mod.rs +++ b/hypervisor/src/kvm/mod.rs @@ -16,9 +16,11 @@ use std::cell::RefCell; use std::cmp::{min, Reverse}; use std::collections::{BTreeMap, BinaryHeap}; use std::convert::TryFrom; +use std::ffi::CString; use std::mem::{size_of, ManuallyDrop}; -use std::os::raw::{c_char, c_int, c_ulong, c_void}; -use std::os::unix::io::AsRawFd; +use std::os::raw::{c_int, c_ulong, c_void}; +use std::os::unix::{io::AsRawFd, prelude::OsStrExt}; +use std::path::{Path, PathBuf}; use std::ptr::copy_nonoverlapping; use std::sync::atomic::AtomicU64; use std::sync::Arc; @@ -30,8 +32,8 @@ use libc::{ use base::{ block_signal, errno_result, error, ioctl, ioctl_with_mut_ref, ioctl_with_ref, ioctl_with_val, pagesize, signal, unblock_signal, AsRawDescriptor, Error, Event, FromRawDescriptor, - MappedRegion, MemoryMapping, MemoryMappingBuilder, MmapError, Protection, RawDescriptor, - Result, SafeDescriptor, + MappedRegion, MemoryMapping, MemoryMappingBuilder, MemoryMappingBuilderUnix, MmapError, + Protection, RawDescriptor, Result, SafeDescriptor, }; use data_model::vec_with_array_field; use kvm_sys::*; @@ -94,11 +96,10 @@ pub struct Kvm { type KvmCap = kvm::Cap; impl Kvm { - /// Opens `/dev/kvm/` and returns a Kvm object on success. - pub fn new() -> Result<Kvm> { - // Open calls are safe because we give a constant nul-terminated string and verify the - // result. - let ret = unsafe { open("/dev/kvm\0".as_ptr() as *const c_char, O_RDWR | O_CLOEXEC) }; + pub fn new_with_path(device_path: &Path) -> Result<Kvm> { + // Open calls are safe because we give a nul-terminated string and verify the result. + let c_path = CString::new(device_path.as_os_str().as_bytes()).unwrap(); + let ret = unsafe { open(c_path.as_ptr(), O_RDWR | O_CLOEXEC) }; if ret < 0 { return errno_result(); } @@ -108,6 +109,11 @@ impl Kvm { }) } + /// Opens `/dev/kvm/` and returns a Kvm object on success. + pub fn new() -> Result<Kvm> { + Kvm::new_with_path(&PathBuf::from("/dev/kvm")) + } + /// Gets the size of the mmap required to use vcpu's `kvm_run` structure. pub fn get_vcpu_mmap_size(&self) -> Result<usize> { // Safe because we know that our file is a KVM fd and we verify the return result. @@ -166,7 +172,7 @@ impl KvmVm { } // Safe because we verify that ret is valid and we own the fd. let vm_descriptor = unsafe { SafeDescriptor::from_raw_descriptor(ret) }; - guest_mem.with_regions(|index, guest_addr, size, host_addr, _| { + guest_mem.with_regions(|index, guest_addr, size, host_addr, _, _| { unsafe { // Safe because the guest regions are guaranteed not to overlap. set_user_memory_region( @@ -448,7 +454,11 @@ impl Vm for KvmVm { read_only: bool, log_dirty_pages: bool, ) -> Result<MemSlot> { - let size = mem.size() as u64; + let pgsz = pagesize() as u64; + // KVM require to set the user memory region with page size aligned size. Safe to extend + // the mem.size() to be page size aligned because the mmap will round up the size to be + // page size aligned if it is not. + let size = (mem.size() as u64 + pgsz - 1) / pgsz * pgsz; let end_addr = guest_addr .checked_add(size) .ok_or_else(|| Error::new(EOVERFLOW))?; @@ -691,11 +701,15 @@ impl Vcpu for KvmVcpu { // AcqRel ordering is sufficient to ensure only one thread gets to set its fingerprint to // this Vcpu and subsequent `run` calls will see the fingerprint. - if self.vcpu_run_handle_fingerprint.compare_and_swap( - 0, - vcpu_run_handle.fingerprint().as_u64(), - std::sync::atomic::Ordering::AcqRel, - ) != 0 + if self + .vcpu_run_handle_fingerprint + .compare_exchange( + 0, + vcpu_run_handle.fingerprint().as_u64(), + std::sync::atomic::Ordering::AcqRel, + std::sync::atomic::Ordering::Acquire, + ) + .is_err() { return Err(Error::new(EBUSY)); } @@ -867,7 +881,7 @@ impl Vcpu for KvmVcpu { // The pointer is page aligned so casting to a different type is well defined, hence the clippy // allow attribute. fn run(&self, run_handle: &VcpuRunHandle) -> Result<VcpuExit> { - // Acquire is used to ensure this check is ordered after the `compare_and_swap` in `run`. + // Acquire is used to ensure this check is ordered after the `compare_exchange` in `run`. if self .vcpu_run_handle_fingerprint .load(std::sync::atomic::Ordering::Acquire) diff --git a/hypervisor/src/kvm/x86_64.rs b/hypervisor/src/kvm/x86_64.rs index b43a02d78..6cfd394a6 100644 --- a/hypervisor/src/kvm/x86_64.rs +++ b/hypervisor/src/kvm/x86_64.rs @@ -20,7 +20,6 @@ use crate::{ ClockState, CpuId, CpuIdEntry, DebugRegs, DescriptorTable, DeviceKind, Fpu, HypervisorX86_64, IoapicRedirectionTableEntry, IoapicState, IrqSourceChip, LapicState, PicSelect, PicState, PitChannelState, PitState, Register, Regs, Segment, Sregs, VcpuX86_64, VmCap, VmX86_64, - NUM_IOAPIC_PINS, }; type KvmCpuId = kvm::CpuId; @@ -280,12 +279,12 @@ impl KvmVm { } /// Enable support for split-irqchip. - pub fn enable_split_irqchip(&self) -> Result<()> { + pub fn enable_split_irqchip(&self, ioapic_pins: usize) -> Result<()> { let mut cap = kvm_enable_cap { cap: KVM_CAP_SPLIT_IRQCHIP, ..Default::default() }; - cap.args[0] = NUM_IOAPIC_PINS as u64; + cap.args[0] = ioapic_pins as u64; // safe becuase we allocated the struct and we know the kernel will read // exactly the size of the struct let ret = unsafe { ioctl_with_ref(self, KVM_ENABLE_CAP(), &cap) }; diff --git a/hypervisor/src/lib.rs b/hypervisor/src/lib.rs index cb534114d..ac3642b79 100644 --- a/hypervisor/src/lib.rs +++ b/hypervisor/src/lib.rs @@ -13,8 +13,9 @@ pub mod x86_64; use std::os::raw::c_int; use std::os::unix::io::AsRawFd; -use base::{Event, MappedRegion, Protection, RawDescriptor, Result, SafeDescriptor}; -use msg_socket::MsgOnSocket; +use serde::{Deserialize, Serialize}; + +use base::{Event, MappedRegion, Protection, Result, SafeDescriptor}; use vm_memory::{GuestAddress, GuestMemory}; #[cfg(any(target_arch = "arm", target_arch = "aarch64"))] @@ -281,7 +282,7 @@ pub trait Vcpu: downcast_rs::DowncastSync { downcast_rs::impl_downcast!(sync Vcpu); /// An address either in programmable I/O space or in memory mapped I/O space. -#[derive(Copy, Clone, Debug, MsgOnSocket, PartialEq, Eq, std::hash::Hash)] +#[derive(Copy, Clone, Debug, Serialize, Deserialize, PartialEq, Eq, std::hash::Hash)] pub enum IoEventAddress { Pio(u64), Mmio(u64), diff --git a/hypervisor/src/x86_64.rs b/hypervisor/src/x86_64.rs index ae4cb4433..b880689b2 100644 --- a/hypervisor/src/x86_64.rs +++ b/hypervisor/src/x86_64.rs @@ -2,10 +2,12 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -use base::{error, RawDescriptor, Result}; +use serde::{Deserialize, Serialize}; + +use base::{error, Result}; use bit_field::*; use downcast_rs::impl_downcast; -use msg_socket::MsgOnSocket; + use vm_memory::GuestAddress; use crate::{Hypervisor, IrqRoute, IrqSource, IrqSourceChip, Vcpu, Vm}; @@ -566,7 +568,7 @@ pub struct DebugRegs { } /// State of one VCPU register. Currently used for MSRs and XCRs. -#[derive(Debug, Default, Copy, Clone, MsgOnSocket)] +#[derive(Debug, Default, Copy, Clone, Serialize, Deserialize)] pub struct Register { pub id: u32, pub value: u64, diff --git a/integration_tests/guest_under_test/Dockerfile b/integration_tests/guest_under_test/Dockerfile index aa0523df5..d0813b257 100644 --- a/integration_tests/guest_under_test/Dockerfile +++ b/integration_tests/guest_under_test/Dockerfile @@ -1,8 +1,8 @@ # Copyright 2020 The Chromium OS Authors. All rights reserved. # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. - -FROM amd64/alpine:3.12 +ARG ARCH +FROM ${ARCH}/alpine:3.12 RUN apk add --no-cache pciutils diff --git a/integration_tests/guest_under_test/Makefile b/integration_tests/guest_under_test/Makefile index fa34edb4f..50ca767bd 100644 --- a/integration_tests/guest_under_test/Makefile +++ b/integration_tests/guest_under_test/Makefile @@ -8,24 +8,40 @@ # target/guest_under_test/bzImage # target/guest_under_test/rootfs +ARCH ?= $(shell arch) +ifeq ($(ARCH), x86_64) + KERNEL_ARCH=x86 + KERNEL_CONFIG=arch/x86/configs/chromiumos-container-vm-x86_64_defconfig + KERNEL_BINARY=bzImage + DOCKER_ARCH=amd64 + CROSS_COMPILE= + RUSTFLAGS= +else ifeq ($(ARCH), aarch64) + KERNEL_ARCH=arm64 + KERNEL_CONFIG=arch/arm64/configs/chromiumos-container-vm-arm64_defconfig + KERNEL_BINARY=Image + DOCKER_ARCH=arm64v8 + CROSS_COMPILE=aarch64-linux-gnu- + RUSTFLAGS="-Clinker=aarch64-linux-gnu-ld" +else + $(error Only x86_64 or aarch64 are supported) +endif + +# Build against the musl toolchain, which will produce a statically linked, +# portable binary that we can run on the alpine linux guest without needing +# libc at runtime +RUST_TARGET ?= $(ARCH)-unknown-linux-musl + # We are building everything in target/guest_under_test CARGO_TARGET ?= $(shell cargo metadata --no-deps --format-version 1 | \ jq -r ".target_directory") -TARGET ?= $(CARGO_TARGET)/guest_under_test +TARGET ?= $(CARGO_TARGET)/guest_under_test/$(ARCH) $(shell mkdir -p $(TARGET)) -# Currently only x86_64 is tested and supported. -ARCH = $(shell arch) - # Parameteters for building the kernel locally KERNEL_REPO ?= https://chromium.googlesource.com/chromiumos/third_party/kernel KERNEL_BRANCH ?= chromeos-4.19 -# Build against the musl toolchain, which will produce a statically linked, -# portable binary that we can run on the alpine linux guest without needing -# libc at runtime -RUST_TARGET ?= $(ARCH)-unknown-linux-musl - ################################################################################ # Main targets @@ -42,8 +58,8 @@ dockerfile := $(shell pwd)/Dockerfile # Build rootfs from Dockerfile and export into squashfs $(TARGET)/rootfs: $(TARGET)/rootfs-build/delegate # Build image from Dockerfile - cd $(TARGET)/rootfs-build && docker build -t crosvm_integration_test . \ - -f $(dockerfile) + docker build -t crosvm_integration_test $(TARGET)/rootfs-build \ + -f $(dockerfile) --build-arg ARCH=$(DOCKER_ARCH) # Create container and export into squashfs, and don't forget to clean up # the container afterwards. @@ -55,21 +71,19 @@ $(TARGET)/rootfs: $(TARGET)/rootfs-build/delegate # Build and copy delegate binary into rootfs build directory $(TARGET)/rootfs-build/delegate: delegate.rs rustup target add $(RUST_TARGET) - rustc --edition=2018 delegate.rs --out-dir $(@D) --target $(RUST_TARGET) + rustc --edition=2018 delegate.rs --out-dir $(@D) \ + $(RUSTFLAGS) --target $(RUST_TARGET) ################################################################################ # Build kernel -ifeq ($(ARCH), x86_64) - KERNEL_CONFIG ?= arch/x86/configs/chromiumos-container-vm-x86_64_defconfig -else - $(error Only x86_64 is supported) -endif - $(TARGET)/bzImage: $(TARGET)/kernel-source $(TARGET)/kernel-build cd $(TARGET)/kernel-source && \ - yes "" | make O=$(TARGET)/kernel-build -j$(shell nproc) bzImage - cp $(TARGET)/kernel-build/arch/x86/boot/bzImage $@ + yes "" | make \ + O=$(TARGET)/kernel-build \ + ARCH=$(KERNEL_ARCH) CROSS_COMPILE=$(CROSS_COMPILE) \ + -j$(shell nproc) $(KERNEL_BINARY) + cp $(TARGET)/kernel-build/arch/${KERNEL_ARCH}/boot/$(KERNEL_BINARY) $@ $(TARGET)/kernel-build: $(TARGET)/kernel-source mkdir -p $@ diff --git a/integration_tests/guest_under_test/PREBUILT_VERSION b/integration_tests/guest_under_test/PREBUILT_VERSION index 26dc4f9ed..75d30fb53 100644 --- a/integration_tests/guest_under_test/PREBUILT_VERSION +++ b/integration_tests/guest_under_test/PREBUILT_VERSION @@ -1 +1 @@ -r0000 +r0001 diff --git a/integration_tests/guest_under_test/upload_prebuilts.sh b/integration_tests/guest_under_test/upload_prebuilts.sh index 4b362e2d1..1383badc5 100755 --- a/integration_tests/guest_under_test/upload_prebuilts.sh +++ b/integration_tests/guest_under_test/upload_prebuilts.sh @@ -14,31 +14,36 @@ set -e cd "${0%/*}" readonly PREBUILT_VERSION="$(cat ./PREBUILT_VERSION)" - -# Cloud storage files readonly GS_BUCKET="gs://chromeos-localmirror/distfiles" readonly GS_PREFIX="${GS_BUCKET}/crosvm-testing" -readonly REMOTE_BZIMAGE="${GS_PREFIX}-bzimage-$(arch)-${PREBUILT_VERSION}" -readonly REMOTE_ROOTFS="${GS_PREFIX}-rootfs-$(arch)-${PREBUILT_VERSION}" - -# Local files -CARGO_TARGET=$(cargo metadata --no-deps --format-version 1 | - jq -r ".target_directory") -LOCAL_BZIMAGE=${CARGO_TARGET}/guest_under_test/bzImage -LOCAL_ROOTFS=${CARGO_TARGET}/guest_under_test/rootfs function prebuilts_exist_error() { echo "Prebuilts of version ${PREBUILT_VERSION} already exist. See README.md" exit 1 } -echo "Checking if prebuilts already exist." -gsutil stat "${REMOTE_BZIMAGE}" && prebuilts_exist_error -gsutil stat "${REMOTE_ROOTFS}" && prebuilts_exist_error +function upload() { + local arch=$1 + local remote_bzimage="${GS_PREFIX}-bzimage-${arch}-${PREBUILT_VERSION}" + local remote_rootfs="${GS_PREFIX}-rootfs-${arch}-${PREBUILT_VERSION}" + + # Local files + local cargo_target=$(cargo metadata --no-deps --format-version 1 | + jq -r ".target_directory") + local local_bzimage=${cargo_target}/guest_under_test/${arch}/bzImage + local local_rootfs=${cargo_target}/guest_under_test/${arch}/rootfs -echo "Building rootfs and kernel." -make "${LOCAL_BZIMAGE}" "${LOCAL_ROOTFS}" + echo "Checking if prebuilts already exist." + gsutil stat "${remote_bzimage}" && prebuilts_exist_error + gsutil stat "${remote_rootfs}" && prebuilts_exist_error + + echo "Building rootfs and kernel." + make ARCH=${arch} "${local_bzimage}" "${local_rootfs}" + + echo "Uploading files." + gsutil cp -n -a public-read "${local_bzimage}" "${remote_bzimage}" + gsutil cp -n -a public-read "${local_rootfs}" "${remote_rootfs}" +} -echo "Uploading files." -gsutil cp -n -a public-read "${LOCAL_BZIMAGE}" "${REMOTE_BZIMAGE}" -gsutil cp -n -a public-read "${LOCAL_ROOTFS}" "${REMOTE_ROOTFS}" +upload x86_64 +upload aarch64 diff --git a/integration_tests/run b/integration_tests/run new file mode 100755 index 000000000..1d2cd15ce --- /dev/null +++ b/integration_tests/run @@ -0,0 +1,11 @@ +#!/bin/bash +# Copyright 2021 The Chromium OS Authors. All rights reserved. +# Use of this source code is governed by a BSD-style license that can be +# found in the LICENSE file + +# We require the crosvm binary to build before running the integration tests. +# There is an RFC for cargo to allow for this kind of dependency: +# https://github.com/rust-lang/cargo/issues/9096 +cd $(dirname $0) +(cd .. && cargo build $@) +cargo test $@ diff --git a/integration_tests/tests/boot.rs b/integration_tests/tests/boot.rs index bebd11935..877610c2c 100644 --- a/integration_tests/tests/boot.rs +++ b/integration_tests/tests/boot.rs @@ -2,11 +2,20 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. mod fixture; -use crosvm::Config; -use fixture::{TestVm, TestVmOptions}; +use fixture::TestVm; #[test] fn boot_test_vm() { - let mut vm = TestVm::new(Config::default(), TestVmOptions::default()).unwrap(); + let mut vm = TestVm::new(&[], false).unwrap(); + assert_eq!(vm.exec_in_guest("echo 42").unwrap().trim(), "42"); +} + +#[test] +fn boot_test_suspend_resume() { + // There is no easy way for us to check if the VM is actually suspended. But at + // least exercise the code-path. + let mut vm = TestVm::new(&[], false).unwrap(); + vm.suspend().unwrap(); + vm.resume().unwrap(); assert_eq!(vm.exec_in_guest("echo 42").unwrap().trim(), "42"); } diff --git a/integration_tests/tests/fixture.rs b/integration_tests/tests/fixture.rs index 57867eb5a..42a1604f6 100644 --- a/integration_tests/tests/fixture.rs +++ b/integration_tests/tests/fixture.rs @@ -2,9 +2,7 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -use std::env; use std::ffi::CString; -use std::fs::File; use std::io::{self, BufRead, BufReader, Write}; use std::path::{Path, PathBuf}; use std::process::Command; @@ -12,11 +10,11 @@ use std::sync::mpsc::sync_channel; use std::sync::Once; use std::thread; use std::time::Duration; +use std::{env, process::Child}; +use std::{fs::File, process::Stdio}; use anyhow::{anyhow, Result}; -use arch::{set_default_serial_parameters, SerialHardware, SerialParameters, SerialType}; use base::syslog; -use crosvm::{platform, Config, DiskOption, Executable}; use tempfile::TempDir; const PREBUILT_URL: &str = "https://storage.googleapis.com/chromeos-localmirror/distfiles"; @@ -30,7 +28,7 @@ const ARCH: &str = "aarch64"; /// Timeout for communicating with the VM. If we do not hear back, panic so we /// do not block the tests. -const VM_COMMUNICATION_TIMEOUT: Duration = Duration::from_millis(1000); +const VM_COMMUNICATION_TIMEOUT: Duration = Duration::from_secs(10); fn prebuilt_version() -> &'static str { include_str!("../guest_under_test/PREBUILT_VERSION").trim() @@ -76,6 +74,22 @@ fn rootfs_path() -> PathBuf { } } +/// The crosvm binary is expected to be alongside to the integration tests +/// binary. Alternatively in the parent directory (cargo will put the +/// test binary in target/debug/deps/ but the crosvm binary in target/debug). +fn find_crosvm_binary() -> PathBuf { + let exe_dir = env::current_exe().unwrap().parent().unwrap().to_path_buf(); + let first = exe_dir.join("crosvm"); + if first.exists() { + return first; + } + let second = exe_dir.parent().unwrap().join("crosvm"); + if second.exists() { + return second; + } + panic!("Cannot find ./crosvm or ../crosvm alongside test binary."); +} + /// Safe wrapper for libc::mkfifo fn mkfifo(path: &Path) -> io::Result<()> { let cpath = CString::new(path.to_str().unwrap()).unwrap(); @@ -124,9 +138,18 @@ fn download_file(url: &str, destination: &Path) -> Result<()> { } } -#[derive(Default)] -pub struct TestVmOptions { - pub debug: bool, +fn crosvm_command(command: &str, args: &[&str]) -> Result<()> { + println!("$ crosvm {} {:?}", command, &args.join(" ")); + let status = Command::new(find_crosvm_binary()) + .arg(command) + .args(args) + .status()?; + + if !status.success() { + Err(anyhow!("Command failed with exit code {}", status)) + } else { + Ok(()) + } } /// Test fixture to spin up a VM running a guest that can be communicated with. @@ -139,8 +162,9 @@ pub struct TestVm { test_dir: TempDir, from_guest_reader: BufReader<File>, to_guest: File, - vm_thread: Option<thread::JoinHandle<()>>, - options: TestVmOptions, + control_socket_path: PathBuf, + process: Child, + debug: bool, } impl TestVm { @@ -188,85 +212,35 @@ impl TestVm { // delegate binary. // - ttyS1: Serial device attached to the named pipes. fn configure_serial_devices( - config: &mut Config, + command: &mut Command, from_guest_pipe: &Path, to_guest_pipe: &Path, - debug: bool, - ) -> Result<()> { - for ((_, index), _) in &config.serial_parameters { - if *index == 1 || *index == 2 { - return Err(anyhow!("Do not specify serial device 1 or 2.")); - } - } - - config.serial_parameters.insert( - (SerialHardware::Serial, 1), - SerialParameters { - type_: if debug { - SerialType::Stdout - } else { - SerialType::Sink - }, - hardware: SerialHardware::Serial, - path: None, - input: None, - num: 1, - console: true, - earlycon: false, - stdin: false, - }, + ) { + command.args(&["--serial", "type=syslog"]); + + // Setup channel for communication with the delegate. + let serial_params = format!( + "type=file,path={},input={},num=2", + from_guest_pipe.display(), + to_guest_pipe.display() ); - config.serial_parameters.insert( - (SerialHardware::Serial, 2), - SerialParameters { - type_: SerialType::File, - hardware: SerialHardware::Serial, - path: Some(PathBuf::from(from_guest_pipe)), - input: Some(PathBuf::from(to_guest_pipe.clone())), - num: 2, - console: false, - earlycon: false, - stdin: false, - }, - ); - set_default_serial_parameters(&mut config.serial_parameters); - return Ok(()); + command.args(&["--serial", &serial_params]); } /// Configures the VM kernel and rootfs to load from the guest_under_test assets. - fn configure_kernel(config: &mut Config) -> Result<()> { - for param in &config.params { - if param.starts_with("root") || param.starts_with("init") { - return Err(anyhow!("Do not set the root or init parameters.")); - } - } - config.executable_path = Some(Executable::Kernel(kernel_path())); - config.params.push("root=/dev/vda ro".to_string()); - config.params.push("init=/bin/delegate".to_string()); - config.disks.insert( - 0, - DiskOption { - id: None, - path: rootfs_path(), - read_only: true, - sparse: true, - block_size: 512, - }, - ); - - return Ok(()); + fn configure_kernel(command: &mut Command) { + command + .args(&["--root", rootfs_path().to_str().unwrap()]) + .args(&["--params", "init=/bin/delegate"]) + .arg(kernel_path()); } /// Instanciate a new crosvm instance. The first call will trigger the download of prebuilt /// files if necessary. - pub fn new(mut config: Config, options: TestVmOptions) -> Result<TestVm> { + pub fn new(additional_arguments: &[&str], debug: bool) -> Result<TestVm> { static PREP_ONCE: Once = Once::new(); PREP_ONCE.call_once(|| TestVm::initialize_once()); - // TODO(b/173233134): Running sandboxed tests is going to require a lot of configuration - // on the host. - config.sandbox = false; - // Create two named pipes to communicate with the guest. let test_dir = TempDir::new()?; let from_guest_pipe = test_dir.path().join("from_guest"); @@ -274,18 +248,22 @@ impl TestVm { mkfifo(&from_guest_pipe)?; mkfifo(&to_guest_pipe)?; - TestVm::configure_serial_devices( - &mut config, - &from_guest_pipe, - &to_guest_pipe, - options.debug, - )?; - TestVm::configure_kernel(&mut config)?; + let control_socket_path = test_dir.path().join("control"); + + let mut command = Command::new(find_crosvm_binary()); + command.args(&["run", "--disable-sandbox"]); + TestVm::configure_serial_devices(&mut command, &from_guest_pipe, &to_guest_pipe); + command.args(&["--socket", &control_socket_path.to_str().unwrap()]); + command.args(additional_arguments); - // Run VM in a separate thread. - let vm_thread = thread::spawn(move || { - platform::run_config(config).expect("Cannot run VM."); - }); + TestVm::configure_kernel(&mut command); + + println!("$ {:?}", command); + if !debug { + command.stdout(Stdio::null()); + command.stderr(Stdio::null()); + } + let process = command.spawn()?; // Open pipes. Panic if we cannot connect after a timeout. let (to_guest, from_guest) = panic_on_timeout( @@ -303,8 +281,9 @@ impl TestVm { test_dir, from_guest_reader, to_guest: to_guest?, - vm_thread: Some(vm_thread), - options, + control_socket_path, + process, + debug, }) } @@ -329,25 +308,28 @@ impl TestVm { output.push_str(&line); } let trimmed = output.trim(); - if self.options.debug { + if self.debug { println!("<- {:?}", trimmed); } Ok(trimmed.to_string()) } + + pub fn stop(&self) -> Result<()> { + crosvm_command("stop", &[self.control_socket_path.to_str().unwrap()]) + } + + pub fn suspend(&self) -> Result<()> { + crosvm_command("suspend", &[self.control_socket_path.to_str().unwrap()]) + } + + pub fn resume(&self) -> Result<()> { + crosvm_command("resume", &[self.control_socket_path.to_str().unwrap()]) + } } impl Drop for TestVm { fn drop(&mut self) { - if let Some(handle) = self.vm_thread.take() { - // Run exit command to shut down the VM. - writeln!(&mut self.to_guest, "exit").expect("Cannot send exit command."); - // Wait for the VM to exit, but don't wait forever. - panic_on_timeout( - move || { - handle.join().expect("Cannot join VM thread."); - }, - VM_COMMUNICATION_TIMEOUT, - ); - } + self.stop().unwrap(); + self.process.wait().unwrap(); } } diff --git a/io_uring/Cargo.toml b/io_uring/Cargo.toml index cbe5aa1e6..3b7cc7c36 100644 --- a/io_uring/Cargo.toml +++ b/io_uring/Cargo.toml @@ -5,13 +5,12 @@ authors = ["The Chromium OS Authors"] edition = "2018" [dependencies] -data_model = { path = "../data_model" } +data_model = { path = "../data_model" } # provided by ebuild libc = "*" -syscall_defines = { path = "../syscall_defines" } -sync = { path = "../sync" } -sys_util = { path = "../sys_util" } +sync = { path = "../sync" } # provided by ebuild +sys_util = { path = "../sys_util" } # provided by ebuild [dev-dependencies] -tempfile = { path = "../tempfile" } +tempfile = { path = "../tempfile" } # provided by ebuild [workspace] diff --git a/io_uring/src/syscalls.rs b/io_uring/src/syscalls.rs index e3c9f8630..f3a3b41a0 100644 --- a/io_uring/src/syscalls.rs +++ b/io_uring/src/syscalls.rs @@ -6,8 +6,7 @@ use std::io::Error; use std::os::unix::io::RawFd; use std::ptr::null_mut; -use libc::{c_int, c_long, c_void}; -use syscall_defines::linux::LinuxSyscall::*; +use libc::{c_int, c_long, c_void, syscall, SYS_io_uring_enter, SYS_io_uring_setup}; use crate::bindings::*; @@ -15,7 +14,7 @@ use crate::bindings::*; pub type Result<T> = std::result::Result<T, c_int>; pub unsafe fn io_uring_setup(num_entries: usize, params: &io_uring_params) -> Result<RawFd> { - let ret = libc::syscall( + let ret = syscall( SYS_io_uring_setup as c_long, num_entries as c_int, params as *const _, @@ -27,7 +26,7 @@ pub unsafe fn io_uring_setup(num_entries: usize, params: &io_uring_params) -> Re } pub unsafe fn io_uring_enter(fd: RawFd, to_submit: u64, to_wait: u64, flags: u32) -> Result<()> { - let ret = libc::syscall( + let ret = syscall( SYS_io_uring_enter as c_long, fd, to_submit as c_int, diff --git a/io_uring/src/uring.rs b/io_uring/src/uring.rs index 6586a71b4..d67447379 100644 --- a/io_uring/src/uring.rs +++ b/io_uring/src/uring.rs @@ -1338,15 +1338,7 @@ mod tests { } mem::drop(c); - // Now add NOPs to wake up any threads blocked on the syscall. - for i in 0..NUM_THREADS { - uring.add_nop((num_entries * 3 + i) as UserData).unwrap(); - } - uring.submit().unwrap(); - - for t in threads { - t.join().unwrap(); - } + // Let the OS clean up the still-waiting threads after the test run. } #[test] @@ -1431,7 +1423,9 @@ mod tests { ); } + // TODO(b/183722981): Fix and re-enable test #[test] + #[ignore] fn multi_thread_submit_and_complete() { const NUM_SUBMITTERS: usize = 7; const NUM_COMPLETERS: usize = 3; diff --git a/kvm/Cargo.toml b/kvm/Cargo.toml index d2e94cad7..82cb31b6e 100644 --- a/kvm/Cargo.toml +++ b/kvm/Cargo.toml @@ -8,7 +8,6 @@ edition = "2018" data_model = { path = "../data_model" } kvm_sys = { path = "../kvm_sys" } libc = "*" -msg_socket = { path = "../msg_socket" } base = { path = "../base" } sync = { path = "../sync" } vm_memory = { path = "../vm_memory" } diff --git a/kvm/src/lib.rs b/kvm/src/lib.rs index 357020af2..8945fb275 100644 --- a/kvm/src/lib.rs +++ b/kvm/src/lib.rs @@ -9,10 +9,13 @@ mod cap; use std::cell::RefCell; use std::cmp::{min, Ordering}; use std::collections::{BTreeMap, BinaryHeap}; +use std::ffi::CString; use std::fs::File; use std::mem::size_of; use std::ops::{Deref, DerefMut}; use std::os::raw::*; +use std::os::unix::prelude::OsStrExt; +use std::path::{Path, PathBuf}; use std::ptr::copy_nonoverlapping; use std::sync::Arc; use sync::Mutex; @@ -34,7 +37,6 @@ use base::{ ioctl_with_val, pagesize, signal, unblock_signal, warn, Error, Event, IoctlNr, MappedRegion, MemoryMapping, MemoryMappingBuilder, MmapError, Result, SIGRTMIN, }; -use msg_socket::MsgOnSocket; use vm_memory::{GuestAddress, GuestMemory}; pub use crate::cap::*; @@ -94,9 +96,14 @@ pub struct Kvm { impl Kvm { /// Opens `/dev/kvm/` and returns a Kvm object on success. pub fn new() -> Result<Kvm> { - // Open calls are safe because we give a constant nul-terminated string and verify the - // result. - let ret = unsafe { open("/dev/kvm\0".as_ptr() as *const c_char, O_RDWR | O_CLOEXEC) }; + Kvm::new_with_path(&PathBuf::from("/dev/kvm")) + } + + /// Opens a KVM device at `device_path` and returns a Kvm object on success. + pub fn new_with_path(device_path: &Path) -> Result<Kvm> { + // Open calls are safe because we give a nul-terminated string and verify the result. + let c_path = CString::new(device_path.as_os_str().as_bytes()).unwrap(); + let ret = unsafe { open(c_path.as_ptr(), O_RDWR | O_CLOEXEC) }; if ret < 0 { return errno_result(); } @@ -200,7 +207,7 @@ impl AsRawDescriptor for Kvm { } /// An address either in programmable I/O space or in memory mapped I/O space. -#[derive(Copy, Clone, Debug, MsgOnSocket)] +#[derive(Copy, Clone, Debug)] pub enum IoeventAddress { Pio(u64), Mmio(u64), @@ -271,7 +278,7 @@ impl Vm { if ret >= 0 { // Safe because we verify the value of ret and we are the owners of the fd. let vm_file = unsafe { File::from_raw_descriptor(ret) }; - guest_mem.with_regions(|index, guest_addr, size, host_addr, _| { + guest_mem.with_regions(|index, guest_addr, size, host_addr, _, _| { unsafe { // Safe because the guest regions are guaranteed not to overlap. set_user_memory_region( @@ -962,7 +969,7 @@ impl Vcpu { let vcpu = unsafe { File::from_raw_descriptor(vcpu_fd) }; let run_mmap = MemoryMappingBuilder::new(run_mmap_size) - .from_descriptor(&vcpu) + .from_file(&vcpu) .build() .map_err(|_| Error::new(ENOSPC))?; diff --git a/kvm/tests/dirty_log.rs b/kvm/tests/dirty_log.rs index 9fb5e5919..fb848df31 100644 --- a/kvm/tests/dirty_log.rs +++ b/kvm/tests/dirty_log.rs @@ -21,7 +21,7 @@ fn test_run() { let guest_mem = GuestMemory::new(&[]).unwrap(); let mem = SharedMemory::anon(mem_size).expect("failed to create shared memory"); let mmap = MemoryMappingBuilder::new(mem_size as usize) - .from_descriptor(&mem) + .from_shared_memory(&mem) .build() .expect("failed to create memory mapping"); @@ -48,7 +48,7 @@ fn test_run() { GuestAddress(0), Box::new( MemoryMappingBuilder::new(mem_size as usize) - .from_descriptor(&mem) + .from_shared_memory(&mem) .build() .expect("failed to create memory mapping"), ), diff --git a/kvm/tests/read_only_memory.rs b/kvm/tests/read_only_memory.rs index a2e5105df..2a2111063 100644 --- a/kvm/tests/read_only_memory.rs +++ b/kvm/tests/read_only_memory.rs @@ -23,7 +23,7 @@ fn test_run() { let guest_mem = GuestMemory::new(&[]).unwrap(); let mem = SharedMemory::anon(mem_size).expect("failed to create shared memory"); let mmap = MemoryMappingBuilder::new(mem_size as usize) - .from_descriptor(&mem) + .from_shared_memory(&mem) .build() .expect("failed to create memory mapping"); @@ -50,7 +50,7 @@ fn test_run() { GuestAddress(0), Box::new( MemoryMappingBuilder::new(mem_size as usize) - .from_descriptor(&mem) + .from_shared_memory(&mem) .build() .expect("failed to create memory mapping"), ), @@ -63,7 +63,7 @@ fn test_run() { // from it. let mem_ro = SharedMemory::anon(0x1000).expect("failed to create shared memory"); let mmap_ro = MemoryMappingBuilder::new(0x1000) - .from_descriptor(&mem_ro) + .from_shared_memory(&mem_ro) .build() .expect("failed to create memory mapping"); mmap_ro @@ -73,7 +73,7 @@ fn test_run() { GuestAddress(vcpu_sregs.es.base), Box::new( MemoryMappingBuilder::new(0x1000) - .from_descriptor(&mem_ro) + .from_shared_memory(&mem_ro) .build() .expect("failed to create memory mapping"), ), diff --git a/libcrosvm_control/Cargo.toml b/libcrosvm_control/Cargo.toml new file mode 100644 index 000000000..2ed68b8a9 --- /dev/null +++ b/libcrosvm_control/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "libcrosvm_control" +version = "0.1.0" +authors = ["The Chromium OS Authors"] +edition = "2018" + +[lib] +crate-type = ["cdylib"] + +[dependencies] +base = { path = "../base" } +vm_control = { path = "../vm_control" } +libc = "0.2.65" diff --git a/libcrosvm_control/src/lib.rs b/libcrosvm_control/src/lib.rs new file mode 100644 index 000000000..973d64e0a --- /dev/null +++ b/libcrosvm_control/src/lib.rs @@ -0,0 +1,358 @@ +// Copyright 2021 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. +// +// Provides parts of crosvm as a library to communicate with running crosvm instances. +// Usually you would need to invoke crosvm with subcommands and you'd get the result on +// stdout. +use std::convert::{TryFrom, TryInto}; +use std::ffi::CStr; +use std::panic::catch_unwind; +use std::path::{Path, PathBuf}; + +use libc::{c_char, ssize_t}; + +use vm_control::{ + client::*, BalloonControlCommand, BalloonStats, DiskControlCommand, UsbControlAttachedDevice, + UsbControlResult, VmRequest, VmResponse, +}; + +fn validate_socket_path(socket_path: *const c_char) -> Option<PathBuf> { + if !socket_path.is_null() { + let socket_path = unsafe { CStr::from_ptr(socket_path) }; + Some(PathBuf::from(socket_path.to_str().ok()?)) + } else { + None + } +} + +/// Stops the crosvm instance whose control socket is listening on `socket_path`. +/// +/// The function returns true on success or false if an error occured. +#[no_mangle] +pub extern "C" fn crosvm_client_stop_vm(socket_path: *const c_char) -> bool { + catch_unwind(|| { + if let Some(socket_path) = validate_socket_path(socket_path) { + vms_request(&VmRequest::Exit, &socket_path).is_ok() + } else { + false + } + }) + .unwrap_or(false) +} + +/// Suspends the crosvm instance whose control socket is listening on `socket_path`. +/// +/// The function returns true on success or false if an error occured. +#[no_mangle] +pub extern "C" fn crosvm_client_suspend_vm(socket_path: *const c_char) -> bool { + catch_unwind(|| { + if let Some(socket_path) = validate_socket_path(socket_path) { + vms_request(&VmRequest::Suspend, &socket_path).is_ok() + } else { + false + } + }) + .unwrap_or(false) +} + +/// Resumes the crosvm instance whose control socket is listening on `socket_path`. +/// +/// The function returns true on success or false if an error occured. +#[no_mangle] +pub extern "C" fn crosvm_client_resume_vm(socket_path: *const c_char) -> bool { + catch_unwind(|| { + if let Some(socket_path) = validate_socket_path(socket_path) { + vms_request(&VmRequest::Resume, &socket_path).is_ok() + } else { + false + } + }) + .unwrap_or(false) +} + +/// Adjusts the balloon size of the crosvm instance whose control socket is +/// listening on `socket_path`. +/// +/// The function returns true on success or false if an error occured. +#[no_mangle] +pub extern "C" fn crosvm_client_balloon_vms(socket_path: *const c_char, num_bytes: u64) -> bool { + catch_unwind(|| { + if let Some(socket_path) = validate_socket_path(socket_path) { + let command = BalloonControlCommand::Adjust { num_bytes }; + vms_request(&VmRequest::BalloonCommand(command), &socket_path).is_ok() + } else { + false + } + }) + .unwrap_or(false) +} + +/// Represents an individual attached USB device. +#[repr(C)] +pub struct UsbDeviceEntry { + /// Internal port index used for identifying this individual device. + port: u8, + /// USB vendor ID + vendor_id: u16, + /// USB product ID + product_id: u16, +} + +impl From<&UsbControlAttachedDevice> for UsbDeviceEntry { + fn from(other: &UsbControlAttachedDevice) -> Self { + Self { + port: other.port, + vendor_id: other.vendor_id, + product_id: other.product_id, + } + } +} + +/// Returns all USB devices passed through the crosvm instance whose control socket is listening on `socket_path`. +/// +/// The function returns the amount of entries written. +/// # Arguments +/// +/// * `socket_path` - Path to the crosvm control socket +/// * `entries` - Pointer to an array of `UsbDeviceEntry` where the details about the attached +/// devices will be written to +/// * `entries_length` - Amount of entries in the array specified by `entries` +/// +/// Crosvm supports passing through up to 255 devices, so pasing an array with 255 entries will +/// guarantee to return all entries. +#[no_mangle] +pub extern "C" fn crosvm_client_usb_list( + socket_path: *const c_char, + entries: *mut UsbDeviceEntry, + entries_length: ssize_t, +) -> ssize_t { + catch_unwind(|| { + if let Some(socket_path) = validate_socket_path(socket_path) { + if let Ok(UsbControlResult::Devices(res)) = do_usb_list(&socket_path) { + let mut i = 0; + for entry in res.iter().filter(|x| x.valid()) { + if i >= entries_length { + break; + } + unsafe { + *entries.offset(i) = entry.into(); + i += 1; + } + } + i + } else { + -1 + } + } else { + -1 + } + }) + .unwrap_or(-1) +} + +/// Attaches an USB device to crosvm instance whose control socket is listening on `socket_path`. +/// +/// The function returns the amount of entries written. +/// # Arguments +/// +/// * `socket_path` - Path to the crosvm control socket +/// * `bus` - USB device bus ID +/// * `addr` - USB device address +/// * `vid` - USB device vendor ID +/// * `pid` - USB device product ID +/// * `dev_path` - Path to the USB device (Most likely `/dev/bus/usb/<bus>/<addr>`). +/// * `out_port` - (optional) internal port will be written here if provided. +/// +/// The function returns true on success or false if an error occured. +#[no_mangle] +pub extern "C" fn crosvm_client_usb_attach( + socket_path: *const c_char, + bus: u8, + addr: u8, + vid: u16, + pid: u16, + dev_path: *const c_char, + out_port: *mut u8, +) -> bool { + catch_unwind(|| { + if let Some(socket_path) = validate_socket_path(socket_path) { + if dev_path.is_null() { + return false; + } + let dev_path = Path::new(unsafe { CStr::from_ptr(dev_path) }.to_str().unwrap_or("")); + + if let Ok(UsbControlResult::Ok { port }) = + do_usb_attach(&socket_path, bus, addr, vid, pid, dev_path) + { + if !out_port.is_null() { + unsafe { *out_port = port }; + } + true + } else { + false + } + } else { + false + } + }) + .unwrap_or(false) +} + +/// Detaches an USB device from crosvm instance whose control socket is listening on `socket_path`. +/// `port` determines device to be detached. +/// +/// The function returns true on success or false if an error occured. +#[no_mangle] +pub extern "C" fn crosvm_client_usb_detach(socket_path: *const c_char, port: u8) -> bool { + catch_unwind(|| { + if let Some(socket_path) = validate_socket_path(socket_path) { + do_usb_detach(&socket_path, port).is_ok() + } else { + false + } + }) + .unwrap_or(false) +} + +/// Modifies the battery status of crosvm instance whose control socket is listening on +/// `socket_path`. +/// +/// The function returns true on success or false if an error occured. +#[no_mangle] +pub extern "C" fn crosvm_client_modify_battery( + socket_path: *const c_char, + battery_type: *const c_char, + property: *const c_char, + target: *const c_char, +) -> bool { + catch_unwind(|| { + if let Some(socket_path) = validate_socket_path(socket_path) { + if battery_type.is_null() || property.is_null() || target.is_null() { + return false; + } + let battery_type = unsafe { CStr::from_ptr(battery_type) }; + let property = unsafe { CStr::from_ptr(property) }; + let target = unsafe { CStr::from_ptr(target) }; + + do_modify_battery( + &socket_path, + &battery_type.to_str().unwrap(), + &property.to_str().unwrap(), + &target.to_str().unwrap(), + ) + .is_ok() + } else { + false + } + }) + .unwrap_or(false) +} + +/// Resizes the disk of the crosvm instance whose control socket is listening on `socket_path`. +/// +/// The function returns true on success or false if an error occured. +#[no_mangle] +pub extern "C" fn crosvm_client_resize_disk( + socket_path: *const c_char, + disk_index: u64, + new_size: u64, +) -> bool { + catch_unwind(|| { + if let Some(socket_path) = validate_socket_path(socket_path) { + if let Ok(disk_index) = usize::try_from(disk_index) { + let request = VmRequest::DiskCommand { + disk_index, + command: DiskControlCommand::Resize { new_size }, + }; + vms_request(&request, &socket_path).is_ok() + } else { + false + } + } else { + false + } + }) + .unwrap_or(false) +} + +/// Similar to internally used `BalloonStats` but using i64 instead of +/// Option<u64>. `None` (or values bigger than i64::max) will be encoded as -1. +#[repr(C)] +pub struct BalloonStatsFfi { + swap_in: i64, + swap_out: i64, + major_faults: i64, + minor_faults: i64, + free_memory: i64, + total_memory: i64, + available_memory: i64, + disk_caches: i64, + hugetlb_allocations: i64, + hugetlb_failures: i64, +} + +impl From<&BalloonStats> for BalloonStatsFfi { + fn from(other: &BalloonStats) -> Self { + let convert = + |x: Option<u64>| -> i64 { x.map(|y| y.try_into().ok()).flatten().unwrap_or(-1) }; + Self { + swap_in: convert(other.swap_in), + swap_out: convert(other.swap_out), + major_faults: convert(other.major_faults), + minor_faults: convert(other.minor_faults), + free_memory: convert(other.free_memory), + total_memory: convert(other.total_memory), + available_memory: convert(other.available_memory), + disk_caches: convert(other.disk_caches), + hugetlb_allocations: convert(other.hugetlb_allocations), + hugetlb_failures: convert(other.hugetlb_failures), + } + } +} + +/// Returns balloon stats of the crosvm instance whose control socket is listening on `socket_path`. +/// +/// The parameters `stats` and `actual` are optional and will only be written to if they are +/// non-null. +/// +/// The function returns true on success or false if an error occured. +/// +/// # Note +/// +/// Entries in `BalloonStatsFfi` that are not available will be set to `-1`. +#[no_mangle] +pub extern "C" fn crosvm_client_balloon_stats( + socket_path: *const c_char, + stats: *mut BalloonStatsFfi, + actual: *mut u64, +) -> bool { + catch_unwind(|| { + if let Some(socket_path) = validate_socket_path(socket_path) { + let request = &VmRequest::BalloonCommand(BalloonControlCommand::Stats {}); + if let Ok(VmResponse::BalloonStats { + stats: ref balloon_stats, + balloon_actual, + }) = handle_request(request, &socket_path) + { + if !stats.is_null() { + unsafe { + *stats = balloon_stats.into(); + } + } + + if !actual.is_null() { + unsafe { + *actual = balloon_actual; + } + } + true + } else { + false + } + } else { + false + } + }) + .unwrap_or(false) +} diff --git a/linux_input_sys/src/lib.rs b/linux_input_sys/src/lib.rs index 1c9e801da..b48e9370e 100644 --- a/linux_input_sys/src/lib.rs +++ b/linux_input_sys/src/lib.rs @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -use data_model::{DataInit, Le16, Le32}; +use data_model::{DataInit, Le16, SLe32}; use std::mem::size_of; const EV_SYN: u16 = 0x00; @@ -37,7 +37,7 @@ pub struct input_event { pub timestamp_fields: [u64; 2], pub type_: u16, pub code: u16, - pub value: u32, + pub value: i32, } // Safe because it only has data and has no implicit padding. unsafe impl DataInit for input_event {} @@ -64,7 +64,7 @@ impl InputEventDecoder for input_event { virtio_input_event { type_: Le16::from(e.type_), code: Le16::from(e.code), - value: Le32::from(e.value), + value: SLe32::from(e.value), } } } @@ -74,7 +74,7 @@ impl InputEventDecoder for input_event { pub struct virtio_input_event { pub type_: Le16, pub code: Le16, - pub value: Le32, + pub value: SLe32, } // Safe because it only has data and has no implicit padding. @@ -97,46 +97,46 @@ impl virtio_input_event { virtio_input_event { type_: Le16::from(EV_SYN), code: Le16::from(SYN_REPORT), - value: Le32::from(0), + value: SLe32::from(0), } } #[inline] - pub fn absolute(code: u16, value: u32) -> virtio_input_event { + pub fn absolute(code: u16, value: i32) -> virtio_input_event { virtio_input_event { type_: Le16::from(EV_ABS), code: Le16::from(code), - value: Le32::from(value), + value: SLe32::from(value), } } #[inline] - pub fn multitouch_tracking_id(id: u32) -> virtio_input_event { + pub fn multitouch_tracking_id(id: i32) -> virtio_input_event { Self::absolute(ABS_MT_TRACKING_ID, id) } #[inline] - pub fn multitouch_slot(slot: u32) -> virtio_input_event { + pub fn multitouch_slot(slot: i32) -> virtio_input_event { Self::absolute(ABS_MT_SLOT, slot) } #[inline] - pub fn multitouch_absolute_x(x: u32) -> virtio_input_event { + pub fn multitouch_absolute_x(x: i32) -> virtio_input_event { Self::absolute(ABS_MT_POSITION_X, x) } #[inline] - pub fn multitouch_absolute_y(y: u32) -> virtio_input_event { + pub fn multitouch_absolute_y(y: i32) -> virtio_input_event { Self::absolute(ABS_MT_POSITION_Y, y) } #[inline] - pub fn absolute_x(x: u32) -> virtio_input_event { + pub fn absolute_x(x: i32) -> virtio_input_event { Self::absolute(ABS_X, x) } #[inline] - pub fn absolute_y(y: u32) -> virtio_input_event { + pub fn absolute_y(y: i32) -> virtio_input_event { Self::absolute(ABS_Y, y) } @@ -155,7 +155,7 @@ impl virtio_input_event { virtio_input_event { type_: Le16::from(EV_KEY), code: Le16::from(code), - value: Le32::from(if pressed { 1 } else { 0 }), + value: SLe32::from(if pressed { 1 } else { 0 }), } } } diff --git a/msg_socket/Cargo.toml b/msg_socket/Cargo.toml deleted file mode 100644 index 418b31132..000000000 --- a/msg_socket/Cargo.toml +++ /dev/null @@ -1,14 +0,0 @@ -[package] -name = "msg_socket" -version = "0.1.0" -authors = ["The Chromium OS Authors"] -edition = "2018" - -[dependencies] -cros_async = { path = "../cros_async" } -data_model = { path = "../data_model" } -futures = "*" -libc = "*" -msg_on_socket_derive = { path = "msg_on_socket_derive" } -base = { path = "../base" } -sync = { path = "../sync" } diff --git a/msg_socket/msg_on_socket_derive/Cargo.toml b/msg_socket/msg_on_socket_derive/Cargo.toml deleted file mode 100644 index 206c03a78..000000000 --- a/msg_socket/msg_on_socket_derive/Cargo.toml +++ /dev/null @@ -1,15 +0,0 @@ -[package] -name = "msg_on_socket_derive" -version = "0.1.0" -authors = ["The Chromium OS Authors"] -edition = "2018" - -[dependencies] -base = "*" -proc-macro2 = "^1" -quote = "^1" -syn = "^1" - -[lib] -proc-macro = true -path = "msg_on_socket_derive.rs" diff --git a/msg_socket/msg_on_socket_derive/msg_on_socket_derive.rs b/msg_socket/msg_on_socket_derive/msg_on_socket_derive.rs deleted file mode 100644 index 98b354900..000000000 --- a/msg_socket/msg_on_socket_derive/msg_on_socket_derive.rs +++ /dev/null @@ -1,911 +0,0 @@ -// Copyright 2018 The Chromium OS Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#![recursion_limit = "256"] -extern crate proc_macro; - -use std::vec::Vec; - -use proc_macro2::{Span, TokenStream}; -use quote::{format_ident, quote}; -use syn::{ - parse_macro_input, Data, DataEnum, DataStruct, DeriveInput, Fields, Ident, Index, Member, Meta, - NestedMeta, Type, -}; - -/// The function that derives the recursive implementation for struct or enum. -#[proc_macro_derive(MsgOnSocket, attributes(msg_on_socket))] -pub fn msg_on_socket_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream { - let input = parse_macro_input!(input as DeriveInput); - let impl_for_input = socket_msg_impl(input); - impl_for_input.into() -} - -fn socket_msg_impl(input: DeriveInput) -> TokenStream { - if !input.generics.params.is_empty() { - return quote! { - compile_error!("derive(SocketMsg) does not support generic parameters"); - }; - } - match input.data { - Data::Struct(ds) => { - if is_named_struct(&ds) { - impl_for_named_struct(input.ident, ds) - } else { - impl_for_tuple_struct(input.ident, ds) - } - } - Data::Enum(de) => impl_for_enum(input.ident, de), - _ => quote! { - compile_error!("derive(SocketMsg) only support struct and enum"); - }, - } -} - -fn is_named_struct(ds: &DataStruct) -> bool { - matches!(&ds.fields, Fields::Named(_f)) -} - -/************************** Named Struct Impls ********************************************/ - -struct StructField { - member: Member, - ty: Type, - skipped: bool, -} - -fn impl_for_named_struct(name: Ident, ds: DataStruct) -> TokenStream { - let fields = get_struct_fields(ds); - let uses_fd_impl = define_uses_fd_for_struct(&fields); - let buffer_sizes_impls = define_buffer_size_for_struct(&fields); - - let read_buffer = define_read_buffer_for_struct(&name, &fields); - let write_buffer = define_write_buffer_for_struct(&name, &fields); - quote! { - impl msg_socket::MsgOnSocket for #name { - #uses_fd_impl - #buffer_sizes_impls - #read_buffer - #write_buffer - } - } -} - -// Flatten struct fields. -fn get_struct_fields(ds: DataStruct) -> Vec<StructField> { - let fields = match ds.fields { - Fields::Named(fields_named) => fields_named.named, - _ => { - panic!("Struct must have named fields"); - } - }; - let mut vec = Vec::new(); - for field in fields { - let member = match field.ident { - Some(ident) => Member::Named(ident), - None => panic!("Unknown Error."), - }; - let ty = field.ty; - let mut skipped = false; - for attr in field - .attrs - .iter() - .filter(|attr| attr.path.is_ident("msg_on_socket")) - { - match attr.parse_meta().unwrap() { - Meta::List(meta) => { - for nested in meta.nested { - match nested { - NestedMeta::Meta(Meta::Path(ref meta_path)) - if meta_path.is_ident("skip") => - { - skipped = true; - } - _ => panic!("unrecognized attribute meta `{}`", quote! { #nested }), - } - } - } - _ => panic!("unrecognized attribute `{}`", quote! { #attr }), - } - } - vec.push(StructField { - member, - ty, - skipped, - }); - } - vec -} - -fn define_uses_fd_for_struct(fields: &[StructField]) -> TokenStream { - let field_types: Vec<_> = fields - .iter() - .filter(|f| !f.skipped) - .map(|f| &f.ty) - .collect(); - - if field_types.is_empty() { - return quote!(); - } - - quote! { - fn uses_descriptor() -> bool { - #(<#field_types>::uses_descriptor())||* - } - } -} - -fn define_buffer_size_for_struct(fields: &[StructField]) -> TokenStream { - let (msg_size, fd_count) = get_fields_buffer_size_sum(fields); - quote! { - fn msg_size(&self) -> usize { - #msg_size - } - fn descriptor_count(&self) -> usize { - #fd_count - } - } -} - -fn define_read_buffer_for_struct(_name: &Ident, fields: &[StructField]) -> TokenStream { - let mut read_fields = Vec::new(); - let mut init_fields = Vec::new(); - for field in fields { - let ident = match &field.member { - Member::Named(ident) => ident, - Member::Unnamed(_) => unreachable!(), - }; - let name = ident.clone(); - if field.skipped { - let ty = &field.ty; - init_fields.push(quote! { - #name: <#ty>::default() - }); - continue; - } - let read_field = read_from_buffer_and_move_offset(&ident, &field.ty); - read_fields.push(read_field); - init_fields.push(quote!(#name)); - } - quote! { - unsafe fn read_from_buffer( - buffer: &[u8], - fds: &[RawDescriptor], - ) -> msg_socket::MsgResult<(Self, usize)> { - let mut __offset = 0usize; - let mut __fd_offset = 0usize; - #(#read_fields)* - Ok(( - Self { - #(#init_fields),* - }, - __fd_offset - )) - } - } -} - -fn define_write_buffer_for_struct(_name: &Ident, fields: &[StructField]) -> TokenStream { - let mut write_fields = Vec::new(); - for field in fields { - if field.skipped { - continue; - } - let ident = match &field.member { - Member::Named(ident) => ident, - Member::Unnamed(_) => unreachable!(), - }; - let write_field = write_to_buffer_and_move_offset(&ident); - write_fields.push(write_field); - } - quote! { - fn write_to_buffer( - &self, - buffer: &mut [u8], - fds: &mut [RawDescriptor], - ) -> msg_socket::MsgResult<usize> { - let mut __offset = 0usize; - let mut __fd_offset = 0usize; - #(#write_fields)* - Ok(__fd_offset) - } - } -} - -/************************** Enum Impls ********************************************/ -fn impl_for_enum(name: Ident, de: DataEnum) -> TokenStream { - let uses_fd_impl = define_uses_fd_for_enum(&de); - let buffer_sizes_impls = define_buffer_size_for_enum(&name, &de); - let read_buffer = define_read_buffer_for_enum(&name, &de); - let write_buffer = define_write_buffer_for_enum(&name, &de); - quote! { - impl msg_socket::MsgOnSocket for #name { - #uses_fd_impl - #buffer_sizes_impls - #read_buffer - #write_buffer - } - } -} - -fn define_uses_fd_for_enum(de: &DataEnum) -> TokenStream { - let mut variant_field_types = Vec::new(); - for variant in &de.variants { - for variant_field_ty in variant.fields.iter().map(|f| &f.ty) { - variant_field_types.push(variant_field_ty); - } - } - - if variant_field_types.is_empty() { - return quote!(); - } - - quote! { - fn uses_descriptor() -> bool { - #(<#variant_field_types>::uses_descriptor())||* - } - } -} - -fn define_buffer_size_for_enum(name: &Ident, de: &DataEnum) -> TokenStream { - let mut msg_size_match_variants = Vec::new(); - let mut fd_count_match_variants = Vec::new(); - - for variant in &de.variants { - let variant_name = &variant.ident; - match &variant.fields { - Fields::Named(fields) => { - let mut tmp_names = Vec::new(); - for field in &fields.named { - tmp_names.push(field.ident.clone().unwrap()); - } - - let v = quote! { - #name::#variant_name { #(#tmp_names),* } => #(#tmp_names.msg_size())+*, - }; - msg_size_match_variants.push(v); - - let v = quote! { - #name::#variant_name { #(#tmp_names),* } => #(#tmp_names.descriptor_count())+*, - }; - fd_count_match_variants.push(v); - } - Fields::Unnamed(fields) => { - let mut tmp_names = Vec::new(); - for idx in 0..fields.unnamed.len() { - let tmp_name = format!("enum_field{}", idx); - let tmp_name = Ident::new(&tmp_name, Span::call_site()); - tmp_names.push(tmp_name.clone()); - } - - let v = quote! { - #name::#variant_name(#(#tmp_names),*) => #(#tmp_names.msg_size())+*, - }; - msg_size_match_variants.push(v); - - let v = quote! { - #name::#variant_name(#(#tmp_names),*) => #(#tmp_names.descriptor_count())+*, - }; - fd_count_match_variants.push(v); - } - Fields::Unit => { - let v = quote! { - #name::#variant_name => 0, - }; - msg_size_match_variants.push(v.clone()); - fd_count_match_variants.push(v); - } - } - } - - quote! { - fn msg_size(&self) -> usize { - 1 + match self { - #(#msg_size_match_variants)* - } - } - fn descriptor_count(&self) -> usize { - match self { - #(#fd_count_match_variants)* - } - } - } -} - -fn define_read_buffer_for_enum(name: &Ident, de: &DataEnum) -> TokenStream { - let mut match_variants = Vec::new(); - let de = de.clone(); - for (idx, variant) in de.variants.iter().enumerate() { - let idx = idx as u8; - let variant_name = &variant.ident; - match &variant.fields { - Fields::Named(fields) => { - let mut tmp_names = Vec::new(); - let mut read_tmps = Vec::new(); - for f in &fields.named { - tmp_names.push(f.ident.clone()); - let read_tmp = - read_from_buffer_and_move_offset(f.ident.as_ref().unwrap(), &f.ty); - read_tmps.push(read_tmp); - } - let v = quote! { - #idx => { - let mut __offset = 1usize; - let mut __fd_offset = 0usize; - #(#read_tmps)* - Ok((#name::#variant_name { #(#tmp_names),* }, __fd_offset)) - } - }; - match_variants.push(v); - } - Fields::Unnamed(fields) => { - let mut tmp_names = Vec::new(); - let mut read_tmps = Vec::new(); - for (idx, field) in fields.unnamed.iter().enumerate() { - let tmp_name = format_ident!("enum_field{}", idx); - tmp_names.push(tmp_name.clone()); - let read_tmp = read_from_buffer_and_move_offset(&tmp_name, &field.ty); - read_tmps.push(read_tmp); - } - - let v = quote! { - #idx => { - let mut __offset = 1usize; - let mut __fd_offset = 0usize; - #(#read_tmps)* - Ok((#name::#variant_name( #(#tmp_names),*), __fd_offset)) - } - }; - match_variants.push(v); - } - Fields::Unit => { - let v = quote! { - #idx => Ok((#name::#variant_name, 0)), - }; - match_variants.push(v); - } - } - } - quote! { - unsafe fn read_from_buffer( - buffer: &[u8], - fds: &[RawDescriptor], - ) -> msg_socket::MsgResult<(Self, usize)> { - let v = buffer.get(0).ok_or(msg_socket::MsgError::WrongMsgBufferSize)?; - match v { - #(#match_variants)* - _ => Err(msg_socket::MsgError::InvalidType), - } - } - } -} - -fn define_write_buffer_for_enum(name: &Ident, de: &DataEnum) -> TokenStream { - let mut match_variants = Vec::new(); - let de = de.clone(); - for (idx, variant) in de.variants.iter().enumerate() { - let idx = idx as u8; - let variant_name = &variant.ident; - match &variant.fields { - Fields::Named(fields) => { - let mut tmp_names = Vec::new(); - let mut write_tmps = Vec::new(); - for f in &fields.named { - tmp_names.push(f.ident.clone().unwrap()); - let write_tmp = - enum_write_to_buffer_and_move_offset(&f.ident.as_ref().unwrap()); - write_tmps.push(write_tmp); - } - - let v = quote! { - #name::#variant_name { #(#tmp_names),* } => { - buffer[0] = #idx; - let mut __offset = 1usize; - let mut __fd_offset = 0usize; - #(#write_tmps)* - Ok(__fd_offset) - } - }; - match_variants.push(v); - } - Fields::Unnamed(fields) => { - let mut tmp_names = Vec::new(); - let mut write_tmps = Vec::new(); - for idx in 0..fields.unnamed.len() { - let tmp_name = format_ident!("enum_field{}", idx); - tmp_names.push(tmp_name.clone()); - let write_tmp = enum_write_to_buffer_and_move_offset(&tmp_name); - write_tmps.push(write_tmp); - } - - let v = quote! { - #name::#variant_name(#(#tmp_names),*) => { - buffer[0] = #idx; - let mut __offset = 1usize; - let mut __fd_offset = 0usize; - #(#write_tmps)* - Ok(__fd_offset) - } - }; - match_variants.push(v); - } - Fields::Unit => { - let v = quote! { - #name::#variant_name => { - buffer[0] = #idx; - Ok(0) - } - }; - match_variants.push(v); - } - } - } - - quote! { - fn write_to_buffer( - &self, - buffer: &mut [u8], - fds: &mut [RawDescriptor], - ) -> msg_socket::MsgResult<usize> { - if buffer.is_empty() { - return Err(msg_socket::MsgError::WrongMsgBufferSize) - } - match self { - #(#match_variants)* - } - } - } -} - -fn enum_write_to_buffer_and_move_offset(name: &Ident) -> TokenStream { - quote! { - let o = #name.write_to_buffer(&mut buffer[__offset..], &mut fds[__fd_offset..])?; - __offset += #name.msg_size(); - __fd_offset += o; - } -} - -/************************** Tuple Impls ********************************************/ -fn impl_for_tuple_struct(name: Ident, ds: DataStruct) -> TokenStream { - let fields = get_tuple_fields(ds); - - let uses_fd_impl = define_uses_fd_for_tuples(&fields); - let buffer_sizes_impls = define_buffer_size_for_struct(&fields); - let read_buffer = define_read_buffer_for_tuples(&name, &fields); - let write_buffer = define_write_buffer_for_tuples(&name, &fields); - quote! { - impl msg_socket::MsgOnSocket for #name { - #uses_fd_impl - #buffer_sizes_impls - #read_buffer - #write_buffer - } - } -} - -fn get_tuple_fields(ds: DataStruct) -> Vec<StructField> { - let mut field_idents = Vec::new(); - let fields = match ds.fields { - Fields::Unnamed(fields_unnamed) => fields_unnamed.unnamed, - _ => { - panic!("Tuple struct must have unnamed fields."); - } - }; - for (idx, field) in fields.iter().enumerate() { - let member = Member::Unnamed(Index::from(idx)); - let ty = field.ty.clone(); - field_idents.push(StructField { - member, - ty, - skipped: false, - }); - } - field_idents -} - -fn define_uses_fd_for_tuples(fields: &[StructField]) -> TokenStream { - if fields.is_empty() { - return quote!(); - } - - let field_types = fields.iter().map(|f| &f.ty); - quote! { - fn uses_descriptor() -> bool { - #(<#field_types>::uses_descriptor())||* - } - } -} - -fn define_read_buffer_for_tuples(name: &Ident, fields: &[StructField]) -> TokenStream { - let mut read_fields = Vec::new(); - let mut init_fields = Vec::new(); - for (idx, field) in fields.iter().enumerate() { - let tmp_name = format!("tuple_tmp{}", idx); - let tmp_name = Ident::new(&tmp_name, Span::call_site()); - let read_field = read_from_buffer_and_move_offset(&tmp_name, &field.ty); - read_fields.push(read_field); - init_fields.push(quote!(#tmp_name)); - } - - quote! { - unsafe fn read_from_buffer( - buffer: &[u8], - fds: &[RawDescriptor], - ) -> msg_socket::MsgResult<(Self, usize)> { - let mut __offset = 0usize; - let mut __fd_offset = 0usize; - #(#read_fields)* - Ok(( - #name ( - #(#init_fields),* - ), - __fd_offset - )) - } - } -} - -fn define_write_buffer_for_tuples(name: &Ident, fields: &[StructField]) -> TokenStream { - let mut write_fields = Vec::new(); - let mut tmp_names = Vec::new(); - for idx in 0..fields.len() { - let tmp_name = format_ident!("tuple_tmp{}", idx); - let write_field = enum_write_to_buffer_and_move_offset(&tmp_name); - write_fields.push(write_field); - tmp_names.push(tmp_name); - } - quote! { - fn write_to_buffer( - &self, - buffer: &mut [u8], - fds: &mut [RawDescriptor], - ) -> msg_socket::MsgResult<usize> { - let mut __offset = 0usize; - let mut __fd_offset = 0usize; - let #name( #(#tmp_names),* ) = self; - #(#write_fields)* - Ok(__fd_offset) - } - } -} -/************************** Helpers ********************************************/ -fn get_fields_buffer_size_sum(fields: &[StructField]) -> (TokenStream, TokenStream) { - let fields: Vec<_> = fields - .iter() - .filter(|f| !f.skipped) - .map(|f| &f.member) - .collect(); - if !fields.is_empty() { - ( - quote! { - #( self.#fields.msg_size() as usize )+* - }, - quote! { - #( self.#fields.descriptor_count() as usize )+* - }, - ) - } else { - (quote!(0), quote!(0)) - } -} - -fn read_from_buffer_and_move_offset(name: &Ident, ty: &Type) -> TokenStream { - quote! { - let t = <#ty>::read_from_buffer(&buffer[__offset..], &fds[__fd_offset..])?; - __offset += t.0.msg_size(); - __fd_offset += t.1; - let #name = t.0; - } -} - -fn write_to_buffer_and_move_offset(name: &Ident) -> TokenStream { - quote! { - let o = self.#name.write_to_buffer(&mut buffer[__offset..], &mut fds[__fd_offset..])?; - __offset += self.#name.msg_size(); - __fd_offset += o; - } -} - -#[cfg(test)] -mod tests { - use crate::socket_msg_impl; - use quote::quote; - use syn::{parse_quote, DeriveInput}; - - #[test] - fn end_to_end_struct_test() { - let input: DeriveInput = parse_quote! { - struct MyMsg { - a: u8, - b: RawDescriptor, - c: u32, - } - }; - - let expected = quote! { - impl msg_socket::MsgOnSocket for MyMsg { - fn uses_descriptor() -> bool { - <u8>::uses_descriptor() - || <RawDescriptor>::uses_descriptor() - || <u32>::uses_descriptor() - } - fn msg_size(&self) -> usize { - self.a.msg_size() as usize - + self.b.msg_size() as usize - + self.c.msg_size() as usize - } - fn descriptor_count(&self) -> usize { - self.a.descriptor_count() as usize - + self.b.descriptor_count() as usize - + self.c.descriptor_count() as usize - } - unsafe fn read_from_buffer( - buffer: &[u8], - fds: &[RawDescriptor], - ) -> msg_socket::MsgResult<(Self, usize)> { - let mut __offset = 0usize; - let mut __fd_offset = 0usize; - let t = <u8>::read_from_buffer(&buffer[__offset..], &fds[__fd_offset..])?; - __offset += t.0.msg_size(); - __fd_offset += t.1; - let a = t.0; - let t = <RawDescriptor>::read_from_buffer( - &buffer[__offset..], &fds[__fd_offset..])?; - __offset += t.0.msg_size(); - __fd_offset += t.1; - let b = t.0; - let t = <u32>::read_from_buffer(&buffer[__offset..], &fds[__fd_offset..])?; - __offset += t.0.msg_size(); - __fd_offset += t.1; - let c = t.0; - Ok((Self { a, b, c }, __fd_offset)) - } - fn write_to_buffer( - &self, - buffer: &mut [u8], - fds: &mut [RawDescriptor], - ) -> msg_socket::MsgResult<usize> { - let mut __offset = 0usize; - let mut __fd_offset = 0usize; - let o = self - .a - .write_to_buffer(&mut buffer[__offset..], &mut fds[__fd_offset..])?; - __offset += self.a.msg_size(); - __fd_offset += o; - let o = self - .b - .write_to_buffer(&mut buffer[__offset..], &mut fds[__fd_offset..])?; - __offset += self.b.msg_size(); - __fd_offset += o; - let o = self - .c - .write_to_buffer(&mut buffer[__offset..], &mut fds[__fd_offset..])?; - __offset += self.c.msg_size(); - __fd_offset += o; - Ok(__fd_offset) - } - } - - }; - - assert_eq!(socket_msg_impl(input).to_string(), expected.to_string()); - } - - #[test] - fn end_to_end_tuple_struct_test() { - let input: DeriveInput = parse_quote! { - struct MyMsg(u8, u32, File); - }; - - let expected = quote! { - impl msg_socket::MsgOnSocket for MyMsg { - fn uses_descriptor() -> bool { - <u8>::uses_descriptor() || <u32>::uses_descriptor() || <File>::uses_descriptor() - } - fn msg_size(&self) -> usize { - self.0.msg_size() as usize - + self.1.msg_size() as usize + self.2.msg_size() as usize - } - fn descriptor_count(&self) -> usize { - self.0.descriptor_count() as usize - + self.1.descriptor_count() as usize - + self.2.descriptor_count() as usize - } - unsafe fn read_from_buffer( - buffer: &[u8], - fds: &[RawDescriptor], - ) -> msg_socket::MsgResult<(Self, usize)> { - let mut __offset = 0usize; - let mut __fd_offset = 0usize; - let t = <u8>::read_from_buffer(&buffer[__offset..], &fds[__fd_offset..])?; - __offset += t.0.msg_size(); - __fd_offset += t.1; - let tuple_tmp0 = t.0; - let t = <u32>::read_from_buffer(&buffer[__offset..], &fds[__fd_offset..])?; - __offset += t.0.msg_size(); - __fd_offset += t.1; - let tuple_tmp1 = t.0; - let t = <File>::read_from_buffer(&buffer[__offset..], &fds[__fd_offset..])?; - __offset += t.0.msg_size(); - __fd_offset += t.1; - let tuple_tmp2 = t.0; - Ok((MyMsg(tuple_tmp0, tuple_tmp1, tuple_tmp2), __fd_offset)) - } - fn write_to_buffer( - &self, - buffer: &mut [u8], - fds: &mut [RawDescriptor], - ) -> msg_socket::MsgResult<usize> { - let mut __offset = 0usize; - let mut __fd_offset = 0usize; - let MyMsg(tuple_tmp0, tuple_tmp1, tuple_tmp2) = self; - let o = tuple_tmp0.write_to_buffer(&mut buffer[__offset..], &mut fds[__fd_offset..])?; - __offset += tuple_tmp0.msg_size(); - __fd_offset += o; - let o = tuple_tmp1.write_to_buffer(&mut buffer[__offset..], &mut fds[__fd_offset..])?; - __offset += tuple_tmp1.msg_size(); - __fd_offset += o; - let o = tuple_tmp2.write_to_buffer(&mut buffer[__offset..], &mut fds[__fd_offset..])?; - __offset += tuple_tmp2.msg_size(); - __fd_offset += o; - Ok(__fd_offset) - } - } - }; - - assert_eq!(socket_msg_impl(input).to_string(), expected.to_string()); - } - - #[test] - fn end_to_end_enum_test() { - let input: DeriveInput = parse_quote! { - enum MyMsg { - A(u8), - B, - C { - f0: u8, - f1: RawDescriptor, - }, - } - }; - - let expected = quote! { - impl msg_socket::MsgOnSocket for MyMsg { - fn uses_descriptor() -> bool { - <u8>::uses_descriptor() - || <u8>::uses_descriptor() - || <RawDescriptor>::uses_descriptor() - } - fn msg_size(&self) -> usize { - 1 + match self { - MyMsg::A(enum_field0) => enum_field0.msg_size(), - MyMsg::B => 0, - MyMsg::C { f0, f1 } => f0.msg_size() + f1.msg_size(), - } - } - fn descriptor_count(&self) -> usize { - match self { - MyMsg::A(enum_field0) => enum_field0.descriptor_count(), - MyMsg::B => 0, - MyMsg::C { f0, f1 } => f0.descriptor_count() + f1.descriptor_count(), - } - } - unsafe fn read_from_buffer( - buffer: &[u8], - fds: &[RawDescriptor], - ) -> msg_socket::MsgResult<(Self, usize)> { - let v = buffer - .get(0) - .ok_or(msg_socket::MsgError::WrongMsgBufferSize)?; - match v { - 0u8 => { - let mut __offset = 1usize; - let mut __fd_offset = 0usize; - let t = <u8>::read_from_buffer(&buffer[__offset..], &fds[__fd_offset..])?; - __offset += t.0.msg_size(); - __fd_offset += t.1; - let enum_field0 = t.0; - Ok((MyMsg::A(enum_field0), __fd_offset)) - } - 1u8 => Ok((MyMsg::B, 0)), - 2u8 => { - let mut __offset = 1usize; - let mut __fd_offset = 0usize; - let t = <u8>::read_from_buffer(&buffer[__offset..], &fds[__fd_offset..])?; - __offset += t.0.msg_size(); - __fd_offset += t.1; - let f0 = t.0; - let t = <RawDescriptor>::read_from_buffer(&buffer[__offset..], &fds[__fd_offset..])?; - __offset += t.0.msg_size(); - __fd_offset += t.1; - let f1 = t.0; - Ok((MyMsg::C { f0, f1 }, __fd_offset)) - } - _ => Err(msg_socket::MsgError::InvalidType), - } - } - fn write_to_buffer( - &self, - buffer: &mut [u8], - fds: &mut [RawDescriptor], - ) -> msg_socket::MsgResult<usize> { - if buffer.is_empty() { - return Err(msg_socket::MsgError::WrongMsgBufferSize) - } - match self { - MyMsg::A(enum_field0) => { - buffer[0] = 0u8; - let mut __offset = 1usize; - let mut __fd_offset = 0usize; - let o = enum_field0 - .write_to_buffer(&mut buffer[__offset..], &mut fds[__fd_offset..])?; - __offset += enum_field0.msg_size(); - __fd_offset += o; - Ok(__fd_offset) - } - MyMsg::B => { - buffer[0] = 1u8; - Ok(0) - } - MyMsg::C { f0, f1 } => { - buffer[0] = 2u8; - let mut __offset = 1usize; - let mut __fd_offset = 0usize; - let o = f0.write_to_buffer(&mut buffer[__offset..], &mut fds[__fd_offset..])?; - __offset += f0.msg_size(); - __fd_offset += o; - let o = f1.write_to_buffer(&mut buffer[__offset..], &mut fds[__fd_offset..])?; - __offset += f1.msg_size(); - __fd_offset += o; - Ok(__fd_offset) - } - } - } - } - }; - - assert_eq!(socket_msg_impl(input).to_string(), expected.to_string()); - } - - #[test] - fn end_to_end_struct_skip_test() { - let input: DeriveInput = parse_quote! { - struct MyMsg { - #[msg_on_socket(skip)] - a: u8, - } - }; - - let expected = quote! { - impl msg_socket::MsgOnSocket for MyMsg { - fn msg_size(&self) -> usize { - 0 - } - fn descriptor_count(&self) -> usize { - 0 - } - unsafe fn read_from_buffer( - buffer: &[u8], - fds: &[RawDescriptor], - ) -> msg_socket::MsgResult<(Self, usize)> { - let mut __offset = 0usize; - let mut __fd_offset = 0usize; - Ok((Self { a: <u8>::default() }, __fd_offset)) - } - fn write_to_buffer( - &self, - buffer: &mut [u8], - fds: &mut [RawDescriptor], - ) -> msg_socket::MsgResult<usize> { - let mut __offset = 0usize; - let mut __fd_offset = 0usize; - Ok(__fd_offset) - } - } - - }; - - assert_eq!(socket_msg_impl(input).to_string(), expected.to_string()); - } -} diff --git a/msg_socket/src/lib.rs b/msg_socket/src/lib.rs deleted file mode 100644 index d6f90773b..000000000 --- a/msg_socket/src/lib.rs +++ /dev/null @@ -1,232 +0,0 @@ -// Copyright 2018 The Chromium OS Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -mod msg_on_socket; -mod serializable_descriptors; - -use std::io::{IoSlice, Result}; -use std::marker::PhantomData; - -use base::{ - handle_eintr, net::UnixSeqpacket, AsRawDescriptor, Error as SysError, RawDescriptor, ScmSocket, - UnsyncMarker, -}; -use cros_async::{Executor, IoSourceExt}; - -pub use crate::msg_on_socket::*; -pub use msg_on_socket_derive::*; - -/// Create a pair of socket. Request is send in one direction while response is in the other -/// direction. -pub fn pair<Request: MsgOnSocket, Response: MsgOnSocket>( -) -> Result<(MsgSocket<Request, Response>, MsgSocket<Response, Request>)> { - let (sock1, sock2) = UnixSeqpacket::pair()?; - let requester = MsgSocket::new(sock1); - let responder = MsgSocket::new(sock2); - Ok((requester, responder)) -} - -/// Bidirection sock that support both send and recv. -pub struct MsgSocket<I: MsgOnSocket, O: MsgOnSocket> { - sock: UnixSeqpacket, - _i: PhantomData<I>, - _o: PhantomData<O>, - _unsync_marker: UnsyncMarker, -} - -impl<I: MsgOnSocket, O: MsgOnSocket> MsgSocket<I, O> { - // Create a new MsgSocket. - pub fn new(s: UnixSeqpacket) -> MsgSocket<I, O> { - MsgSocket { - sock: s, - _i: PhantomData, - _o: PhantomData, - _unsync_marker: PhantomData, - } - } - - // Creates an async receiver that implements `futures::Stream`. - pub fn async_receiver(&self, ex: &Executor) -> MsgResult<AsyncReceiver<I, O>> { - AsyncReceiver::new(self, ex) - } -} - -/// One direction socket that only supports sending. -pub struct Sender<M: MsgOnSocket> { - sock: UnixSeqpacket, - _m: PhantomData<M>, -} - -impl<M: MsgOnSocket> Sender<M> { - /// Create a new sender sock. - pub fn new(s: UnixSeqpacket) -> Sender<M> { - Sender { - sock: s, - _m: PhantomData, - } - } -} - -/// One direction socket that only supports receiving. -pub struct Receiver<M: MsgOnSocket> { - sock: UnixSeqpacket, - _m: PhantomData<M>, -} - -impl<M: MsgOnSocket> Receiver<M> { - /// Create a new receiver sock. - pub fn new(s: UnixSeqpacket) -> Receiver<M> { - Receiver { - sock: s, - _m: PhantomData, - } - } -} - -impl<I: MsgOnSocket, O: MsgOnSocket> AsRef<UnixSeqpacket> for MsgSocket<I, O> { - fn as_ref(&self) -> &UnixSeqpacket { - &self.sock - } -} - -impl<I: MsgOnSocket, O: MsgOnSocket> AsRawDescriptor for MsgSocket<I, O> { - fn as_raw_descriptor(&self) -> RawDescriptor { - self.sock.as_raw_descriptor() - } -} - -impl<I: MsgOnSocket, O: MsgOnSocket> AsRawDescriptor for &MsgSocket<I, O> { - fn as_raw_descriptor(&self) -> RawDescriptor { - self.sock.as_raw_descriptor() - } -} - -impl<M: MsgOnSocket> AsRef<UnixSeqpacket> for Sender<M> { - fn as_ref(&self) -> &UnixSeqpacket { - &self.sock - } -} - -impl<M: MsgOnSocket> AsRawDescriptor for Sender<M> { - fn as_raw_descriptor(&self) -> RawDescriptor { - self.sock.as_raw_descriptor() - } -} - -impl<M: MsgOnSocket> AsRef<UnixSeqpacket> for Receiver<M> { - fn as_ref(&self) -> &UnixSeqpacket { - &self.sock - } -} - -impl<M: MsgOnSocket> AsRawDescriptor for Receiver<M> { - fn as_raw_descriptor(&self) -> RawDescriptor { - self.sock.as_raw_descriptor() - } -} - -/// Types that could send a message. -pub trait MsgSender: AsRef<UnixSeqpacket> { - type M: MsgOnSocket; - fn send(&self, msg: &Self::M) -> MsgResult<()> { - let msg_size = msg.msg_size(); - let descriptor_size = msg.descriptor_count(); - let mut msg_buffer: Vec<u8> = vec![0; msg_size]; - let mut descriptor_buffer: Vec<RawDescriptor> = vec![0; descriptor_size]; - - let descriptor_size = msg.write_to_buffer(&mut msg_buffer, &mut descriptor_buffer)?; - let sock: &UnixSeqpacket = self.as_ref(); - if descriptor_size == 0 { - handle_eintr!(sock.send(&msg_buffer)) - .map_err(|e| MsgError::Send(SysError::new(e.raw_os_error().unwrap_or(0))))?; - } else { - let ioslice = IoSlice::new(&msg_buffer[..]); - sock.send_with_fds(&[ioslice], &descriptor_buffer[0..descriptor_size]) - .map_err(MsgError::Send)?; - } - Ok(()) - } -} - -/// Types that could receive a message. -pub trait MsgReceiver: AsRef<UnixSeqpacket> { - type M: MsgOnSocket; - fn recv(&self) -> MsgResult<Self::M> { - let sock: &UnixSeqpacket = self.as_ref(); - - let (msg_buffer, descriptor_buffer) = { - if Self::M::uses_descriptor() { - sock.recv_as_vec_with_fds() - .map_err(|e| MsgError::Recv(SysError::new(e.raw_os_error().unwrap_or(0))))? - } else { - ( - sock.recv_as_vec().map_err(|e| { - MsgError::Recv(SysError::new(e.raw_os_error().unwrap_or(0))) - })?, - vec![], - ) - } - }; - - if msg_buffer.is_empty() && Self::M::fixed_size() != Some(0) { - return Err(MsgError::RecvZero); - } - - if let Some(fixed_size) = Self::M::fixed_size() { - if fixed_size != msg_buffer.len() { - return Err(MsgError::BadRecvSize { - expected: fixed_size, - actual: msg_buffer.len(), - }); - } - } - - // Safe because fd buffer is read from socket. - let (v, read_descriptor_size) = - unsafe { Self::M::read_from_buffer(&msg_buffer, &descriptor_buffer)? }; - if descriptor_buffer.len() != read_descriptor_size { - return Err(MsgError::NotExpectDescriptor); - } - Ok(v) - } -} - -impl<I: MsgOnSocket, O: MsgOnSocket> MsgSender for MsgSocket<I, O> { - type M = I; -} -impl<I: MsgOnSocket, O: MsgOnSocket> MsgReceiver for MsgSocket<I, O> { - type M = O; -} - -impl<I: MsgOnSocket> MsgSender for Sender<I> { - type M = I; -} -impl<O: MsgOnSocket> MsgReceiver for Receiver<O> { - type M = O; -} - -/// Asynchronous adaptor for `MsgSocket`. -pub struct AsyncReceiver<'m, I: MsgOnSocket, O: MsgOnSocket> { - // This weirdness is because we can't directly implement IntoAsync for &MsgSocket because there - // is no AsRawFd impl for references. - inner: &'m MsgSocket<I, O>, - sock: Box<dyn IoSourceExt<&'m UnixSeqpacket> + 'm>, -} - -impl<'m, I: MsgOnSocket, O: MsgOnSocket> AsyncReceiver<'m, I, O> { - fn new(msg_socket: &'m MsgSocket<I, O>, ex: &Executor) -> MsgResult<Self> { - let sock = ex - .async_from(&msg_socket.sock) - .map_err(MsgError::CreateAsync)?; - Ok(AsyncReceiver { - inner: msg_socket, - sock, - }) - } - - pub async fn next(&mut self) -> MsgResult<O> { - self.sock.wait_readable().await.unwrap(); - self.inner.recv() - } -} diff --git a/msg_socket/src/msg_on_socket.rs b/msg_socket/src/msg_on_socket.rs deleted file mode 100644 index 5db521772..000000000 --- a/msg_socket/src/msg_on_socket.rs +++ /dev/null @@ -1,438 +0,0 @@ -// Copyright 2018 The Chromium OS Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -mod slice; -mod tuple; - -use std::fmt::{self, Display}; -use std::mem::{size_of, transmute_copy, MaybeUninit}; -use std::result; -use std::sync::Arc; - -use base::{Error as SysError, RawDescriptor}; -use data_model::*; -use slice::{slice_read_helper, slice_write_helper}; - -#[derive(Debug)] -/// An error during transaction or serialization/deserialization. -pub enum MsgError { - /// Error while creating an async socket. - CreateAsync(cros_async::AsyncError), - /// Error while sending a request or response. - Send(SysError), - /// Error while receiving a request or response. - Recv(SysError), - /// The type of a received request or response is unknown. - InvalidType, - /// There was not the expected amount of data when receiving a message. The inner - /// value is how much data is expected and how much data was actually received. - BadRecvSize { expected: usize, actual: usize }, - /// There was no data received when the socket `recv`-ed. - RecvZero, - /// There was no associated file descriptor received for a request that expected it. - ExpectDescriptor, - /// There was some associated file descriptor received but not used when deserialize. - NotExpectDescriptor, - /// Failed to set flags on the file descriptor. - SettingDescriptorFlags(SysError), - /// Trying to serialize/deserialize, but fd buffer size is too small. This typically happens - /// when max_fd_count() returns a value that is too small. - WrongDescriptorBufferSize, - /// Trying to serialize/deserialize, but msg buffer size is too small. This typically happens - /// when msg_size() returns a value that is too small. - WrongMsgBufferSize, -} - -pub type MsgResult<T> = result::Result<T, MsgError>; - -impl Display for MsgError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - use self::MsgError::*; - - match self { - CreateAsync(e) => write!(f, "failed to create an async socket: {}", e), - Send(e) => write!(f, "failed to send request or response: {}", e), - Recv(e) => write!(f, "failed to receive request or response: {}", e), - InvalidType => write!(f, "invalid type"), - BadRecvSize { expected, actual } => write!( - f, - "wrong amount of data received; expected {} bytes; got {} bytes", - expected, actual - ), - RecvZero => write!(f, "received zero data"), - ExpectDescriptor => write!(f, "missing associated file descriptor for request"), - NotExpectDescriptor => write!(f, "unexpected file descriptor is unused"), - SettingDescriptorFlags(e) => { - write!(f, "failed setting flags on the message descriptor: {}", e) - } - WrongDescriptorBufferSize => write!(f, "descriptor buffer size too small"), - WrongMsgBufferSize => write!(f, "msg buffer size too small"), - } - } -} - -/// A msg that could be serialized to and deserialize from array in little endian. -/// -/// For structs, we always have fixed size of bytes and fixed count of fds. -/// For enums, the size needed might be different for each variant. -/// -/// e.g. -/// ``` -/// use base::RawDescriptor; -/// enum Message { -/// VariantA(u8), -/// VariantB(u32, RawDescriptor), -/// VariantC, -/// } -/// ``` -/// -/// For variant A, we need 1 byte to store its inner value. -/// For variant B, we need 4 bytes and 1 RawDescriptor to store its inner value. -/// For variant C, we need 0 bytes to store its inner value. -/// When we serialize Message to (buffer, fd_buffer), we always use fixed number of bytes in -/// the buffer. Unused buffer bytes will be padded with zero. -/// However, for fd_buffer, we could not do the same thing. Otherwise, we are essentially sending -/// fd 0 through the socket. -/// Thus, read/write functions always the return correct count of fds in this variant. There will be -/// no padding in fd_buffer. -pub trait MsgOnSocket: Sized { - // `true` if this structure can potentially serialize descriptors. - fn uses_descriptor() -> bool { - false - } - - // Returns `Some(size)` if this structure always has a fixed size. - fn fixed_size() -> Option<usize> { - None - } - - /// Size of message in bytes. - fn msg_size(&self) -> usize { - Self::fixed_size().unwrap() - } - - /// Number of FDs in this message. This must be overridden if `uses_descriptor()` returns true. - fn descriptor_count(&self) -> usize { - assert!(!Self::uses_descriptor()); - 0 - } - /// Returns (self, fd read count). - /// This function is safe only when: - /// 0. fds contains valid fds, received from socket, serialized by Self::write_to_buffer. - /// 1. For enum, fds contains correct fd layout of the particular variant. - /// 2. write_to_buffer is implemented correctly(put valid fds into the buffer, has no padding, - /// return correct count). - unsafe fn read_from_buffer(buffer: &[u8], fds: &[RawDescriptor]) -> MsgResult<(Self, usize)>; - - /// Serialize self to buffers. - fn write_to_buffer(&self, buffer: &mut [u8], fds: &mut [RawDescriptor]) -> MsgResult<usize>; -} - -impl MsgOnSocket for SysError { - fn fixed_size() -> Option<usize> { - Some(size_of::<u32>()) - } - unsafe fn read_from_buffer(buffer: &[u8], fds: &[RawDescriptor]) -> MsgResult<(Self, usize)> { - let (v, size) = u32::read_from_buffer(buffer, fds)?; - Ok((SysError::new(v as i32), size)) - } - fn write_to_buffer(&self, buffer: &mut [u8], fds: &mut [RawDescriptor]) -> MsgResult<usize> { - let v = self.errno() as u32; - v.write_to_buffer(buffer, fds) - } -} - -impl<T: MsgOnSocket> MsgOnSocket for Option<T> { - fn uses_descriptor() -> bool { - T::uses_descriptor() - } - - fn msg_size(&self) -> usize { - match self { - Some(v) => v.msg_size() + 1, - None => 1, - } - } - - fn descriptor_count(&self) -> usize { - match self { - Some(v) => v.descriptor_count(), - None => 0, - } - } - - unsafe fn read_from_buffer(buffer: &[u8], fds: &[RawDescriptor]) -> MsgResult<(Self, usize)> { - match buffer[0] { - 0 => Ok((None, 0)), - 1 => { - let (inner, len) = T::read_from_buffer(&buffer[1..], fds)?; - Ok((Some(inner), len)) - } - _ => Err(MsgError::InvalidType), - } - } - - fn write_to_buffer(&self, buffer: &mut [u8], fds: &mut [RawDescriptor]) -> MsgResult<usize> { - match self { - None => { - buffer[0] = 0; - Ok(0) - } - Some(inner) => { - buffer[0] = 1; - inner.write_to_buffer(&mut buffer[1..], fds) - } - } - } -} - -impl<T: MsgOnSocket> MsgOnSocket for Arc<T> { - fn uses_descriptor() -> bool { - T::uses_descriptor() - } - - fn msg_size(&self) -> usize { - (**self).msg_size() - } - - fn descriptor_count(&self) -> usize { - (**self).descriptor_count() - } - - unsafe fn read_from_buffer(buffer: &[u8], fds: &[RawDescriptor]) -> MsgResult<(Self, usize)> { - T::read_from_buffer(buffer, fds).map(|(v, count)| (Arc::new(v), count)) - } - - fn write_to_buffer(&self, buffer: &mut [u8], fds: &mut [RawDescriptor]) -> MsgResult<usize> { - (**self).write_to_buffer(buffer, fds) - } -} - -impl MsgOnSocket for () { - fn fixed_size() -> Option<usize> { - Some(0) - } - - unsafe fn read_from_buffer(_buffer: &[u8], _fds: &[RawDescriptor]) -> MsgResult<(Self, usize)> { - Ok(((), 0)) - } - - fn write_to_buffer(&self, _buffer: &mut [u8], _fds: &mut [RawDescriptor]) -> MsgResult<usize> { - Ok(0) - } -} - -// usize could be different sizes on different targets. We always use u64. -impl MsgOnSocket for usize { - fn msg_size(&self) -> usize { - size_of::<u64>() - } - unsafe fn read_from_buffer(buffer: &[u8], _fds: &[RawDescriptor]) -> MsgResult<(Self, usize)> { - if buffer.len() < size_of::<u64>() { - return Err(MsgError::WrongMsgBufferSize); - } - let t = u64::from_le_bytes(slice_to_array(buffer)); - Ok((t as usize, 0)) - } - - fn write_to_buffer(&self, buffer: &mut [u8], _fds: &mut [RawDescriptor]) -> MsgResult<usize> { - if buffer.len() < size_of::<u64>() { - return Err(MsgError::WrongMsgBufferSize); - } - let t: Le64 = (*self as u64).into(); - buffer[0..self.msg_size()].copy_from_slice(t.as_slice()); - Ok(0) - } -} - -// Encode bool as a u8 of value 0 or 1 -impl MsgOnSocket for bool { - fn msg_size(&self) -> usize { - size_of::<u8>() - } - unsafe fn read_from_buffer(buffer: &[u8], _fds: &[RawDescriptor]) -> MsgResult<(Self, usize)> { - if buffer.len() < size_of::<u8>() { - return Err(MsgError::WrongMsgBufferSize); - } - let t: u8 = buffer[0]; - match t { - 0 => Ok((false, 0)), - 1 => Ok((true, 0)), - _ => Err(MsgError::InvalidType), - } - } - fn write_to_buffer(&self, buffer: &mut [u8], _fds: &mut [RawDescriptor]) -> MsgResult<usize> { - if buffer.len() < size_of::<u8>() { - return Err(MsgError::WrongMsgBufferSize); - } - buffer[0] = *self as u8; - Ok(0) - } -} - -macro_rules! le_impl { - ($type:ident, $native_type:ident) => { - impl MsgOnSocket for $type { - fn fixed_size() -> Option<usize> { - Some(size_of::<$native_type>()) - } - - unsafe fn read_from_buffer( - buffer: &[u8], - _fds: &[RawDescriptor], - ) -> MsgResult<(Self, usize)> { - if buffer.len() < size_of::<$native_type>() { - return Err(MsgError::WrongMsgBufferSize); - } - let t = $native_type::from_le_bytes(slice_to_array(buffer)); - Ok((t.into(), 0)) - } - - fn write_to_buffer( - &self, - buffer: &mut [u8], - _fds: &mut [RawDescriptor], - ) -> MsgResult<usize> { - if buffer.len() < size_of::<$native_type>() { - return Err(MsgError::WrongMsgBufferSize); - } - let t: $native_type = self.clone().into(); - buffer[0..self.msg_size()].copy_from_slice(&t.to_le_bytes()); - Ok(0) - } - } - }; -} - -le_impl!(u8, u8); -le_impl!(u16, u16); -le_impl!(u32, u32); -le_impl!(u64, u64); - -le_impl!(Le16, u16); -le_impl!(Le32, u32); -le_impl!(Le64, u64); - -fn simple_read<T: MsgOnSocket>(buffer: &[u8], offset: &mut usize) -> MsgResult<T> { - assert!(!T::uses_descriptor()); - // Safety for T::read_from_buffer depends on the given FDs being valid, but we pass no FDs. - let (v, _) = unsafe { T::read_from_buffer(&buffer[*offset..], &[])? }; - *offset += v.msg_size(); - Ok(v) -} - -fn simple_write<T: MsgOnSocket>(v: T, buffer: &mut [u8], offset: &mut usize) -> MsgResult<()> { - assert!(!T::uses_descriptor()); - v.write_to_buffer(&mut buffer[*offset..], &mut [])?; - *offset += v.msg_size(); - Ok(()) -} - -// Converts a slice into an array of fixed size inferred from by the return value. Panics if the -// slice is too small, but will tolerate slices that are too large. -fn slice_to_array<T, O>(s: &[T]) -> O -where - T: Copy, - O: Default + AsMut<[T]>, -{ - let mut o = O::default(); - let o_slice = o.as_mut(); - let len = o_slice.len(); - o_slice.copy_from_slice(&s[..len]); - o -} - -macro_rules! array_impls { - ($N:expr, $t: ident $($ts:ident)*) - => { - impl<T: MsgOnSocket + Clone> MsgOnSocket for [T; $N] { - fn uses_descriptor() -> bool { - T::uses_descriptor() - } - - fn fixed_size() -> Option<usize> { - Some(T::fixed_size()? * $N) - } - - fn msg_size(&self) -> usize { - match T::fixed_size() { - Some(s) => s * $N, - None => self.iter().map(|i| i.msg_size()).sum::<usize>() + size_of::<u64>() * $N - } - } - - fn descriptor_count(&self) -> usize { - if T::uses_descriptor() { - self.iter().map(|i| i.descriptor_count()).sum() - } else { - 0 - } - } - - unsafe fn read_from_buffer(buffer: &[u8], fds: &[RawDescriptor]) - -> MsgResult<(Self, usize)> { - // Taken from the canonical example of initializing an array, the `assume_init` can - // be assumed safe because the array elements (`MaybeUninit<T>` in this case) - // themselves don't require initializing. - let mut msgs: [MaybeUninit<T>; $N] = MaybeUninit::uninit().assume_init(); - - let fd_count = slice_read_helper(buffer, fds, &mut msgs)?; - - // Also taken from the canonical example, we initialized every member of the array - // in the first loop of this function, so it is safe to `transmute_copy` the array - // of `MaybeUninit` data to plain data. Although `transmute`, which checks the - // types' sizes, would have been preferred in this code, the compiler complains with - // "cannot transmute between types of different sizes, or dependently-sized types." - // Because this function operates on generic data, the type is "dependently-sized" - // and so the compiler will not check that the size of the input and output match. - // See this issue for details: https://github.com/rust-lang/rust/issues/61956 - Ok((transmute_copy::<_, [T; $N]>(&msgs), fd_count)) - } - - fn write_to_buffer( - &self, - buffer: &mut [u8], - fds: &mut [RawDescriptor], - ) -> MsgResult<usize> { - slice_write_helper(self, buffer, fds) - } - } - #[cfg(test)] - mod $t { - use super::MsgOnSocket; - - #[test] - fn read_write_option_array() { - type ArrayType = [Option<u32>; $N]; - let array = [Some($N); $N]; - let mut buffer = vec![0; array.msg_size()]; - array.write_to_buffer(&mut buffer, &mut []).unwrap(); - let read_array = unsafe { ArrayType::read_from_buffer(&buffer, &[]) }.unwrap().0; - - assert!(array.iter().eq(read_array.iter())); - } - - #[test] - fn read_write_fixed() { - type ArrayType = [u32; $N]; - let mut buffer = vec![0; <ArrayType>::fixed_size().unwrap()]; - let array = [$N as u32; $N]; - array.write_to_buffer(&mut buffer, &mut []).unwrap(); - let read_array = unsafe { ArrayType::read_from_buffer(&buffer, &[]) }.unwrap().0; - - assert!(array.iter().eq(read_array.iter())); - } - } - array_impls!(($N - 1), $($ts)*); - }; - {$N:expr, } => {}; -} - -array_impls! { - 64, tmp1 tmp2 tmp3 tmp4 tmp5 tmp6 tmp7 tmp8 tmp9 tmp10 tmp11 tmp12 tmp13 tmp14 tmp15 tmp16 - tmp17 tmp18 tmp19 tmp20 tmp21 tmp22 tmp23 tmp24 tmp25 tmp26 tmp27 tmp28 tmp29 tmp30 tmp31 - tmp32 tmp33 tmp34 tmp35 tmp36 tmp37 tmp38 tmp39 tmp40 tmp41 tmp42 tmp43 tmp44 tmp45 tmp46 - tmp47 tmp48 tmp49 tmp50 tmp51 tmp52 tmp53 tmp54 tmp55 tmp56 tmp57 tmp58 tmp59 tmp60 tmp61 - tmp62 tmp63 tmp64 -} diff --git a/msg_socket/src/msg_on_socket/slice.rs b/msg_socket/src/msg_on_socket/slice.rs deleted file mode 100644 index 05df5e142..000000000 --- a/msg_socket/src/msg_on_socket/slice.rs +++ /dev/null @@ -1,184 +0,0 @@ -// Copyright 2020 The Chromium OS Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -use base::RawDescriptor; -use std::mem::{size_of, ManuallyDrop, MaybeUninit}; -use std::ptr::drop_in_place; - -use crate::{MsgOnSocket, MsgResult}; - -use super::{simple_read, simple_write}; - -/// Helper used by the types that read a slice of homegenously typed data. -/// -/// # Safety -/// This function has the same safety requirements as `T::read_from_buffer`, with the additional -/// requirements that the `msgs` are only used on success of this function -pub unsafe fn slice_read_helper<T: MsgOnSocket>( - buffer: &[u8], - fds: &[RawDescriptor], - msgs: &mut [MaybeUninit<T>], -) -> MsgResult<usize> { - let mut offset = 0usize; - let mut fd_offset = 0usize; - - // In case of an error, we need to keep track of how many elements got initialized. - // In order to perform the necessary drops, the below loop is executed in a closure - // to capture errors without returning. - let mut last_index = 0; - let res = (|| { - for msg in &mut msgs[..] { - let element_size = match T::fixed_size() { - Some(s) => s, - None => simple_read::<u64>(buffer, &mut offset)? as usize, - }; - // Assuming the unsafe caller gave valid FDs, this call should be safe. - let (m, fd_size) = T::read_from_buffer(&buffer[offset..], &fds[fd_offset..])?; - *msg = MaybeUninit::new(m); - offset += element_size; - fd_offset += fd_size; - last_index += 1; - } - Ok(()) - })(); - - // Because `MaybeUninit` will not automatically call drops, we have to drop the - // partially initialized array manually in the case of an error. - if let Err(e) = res { - for msg in &mut msgs[..last_index] { - // The call to `as_mut_ptr()` turns the `MaybeUninit` element of the array - // into a pointer, which can be used with `drop_in_place` to call the - // destructor without moving the element, which is impossible. This is safe - // because `last_index` prevents this loop from traversing into the - // uninitialized parts of the array. - drop_in_place(msg.as_mut_ptr()); - } - return Err(e); - } - - Ok(fd_offset) -} - -/// Helper used by the types that write a slice of homegenously typed data. -pub fn slice_write_helper<T: MsgOnSocket>( - msgs: &[T], - buffer: &mut [u8], - fds: &mut [RawDescriptor], -) -> MsgResult<usize> { - let mut offset = 0usize; - let mut fd_offset = 0usize; - for msg in msgs { - let element_size = match T::fixed_size() { - Some(s) => s, - None => { - let element_size = msg.msg_size(); - simple_write(element_size as u64, buffer, &mut offset)?; - element_size as usize - } - }; - let fd_size = msg.write_to_buffer(&mut buffer[offset..], &mut fds[fd_offset..])?; - offset += element_size; - fd_offset += fd_size; - } - - Ok(fd_offset) -} - -impl<T: MsgOnSocket> MsgOnSocket for Vec<T> { - fn uses_descriptor() -> bool { - T::uses_descriptor() - } - - fn fixed_size() -> Option<usize> { - None - } - - fn msg_size(&self) -> usize { - let vec_size = match T::fixed_size() { - Some(s) => s * self.len(), - None => self.iter().map(|i| i.msg_size() + size_of::<u64>()).sum(), - }; - size_of::<u64>() + vec_size - } - - fn descriptor_count(&self) -> usize { - if T::uses_descriptor() { - self.iter().map(|i| i.descriptor_count()).sum() - } else { - 0 - } - } - - unsafe fn read_from_buffer(buffer: &[u8], fds: &[RawDescriptor]) -> MsgResult<(Self, usize)> { - let mut offset = 0; - let len = simple_read::<u64>(buffer, &mut offset)? as usize; - let mut msgs: Vec<MaybeUninit<T>> = Vec::with_capacity(len); - msgs.set_len(len); - let fd_count = slice_read_helper(&buffer[offset..], fds, &mut msgs)?; - let mut msgs = ManuallyDrop::new(msgs); - Ok(( - Vec::from_raw_parts(msgs.as_mut_ptr() as *mut T, msgs.len(), msgs.capacity()), - fd_count, - )) - } - - fn write_to_buffer(&self, buffer: &mut [u8], fds: &mut [RawDescriptor]) -> MsgResult<usize> { - let mut offset = 0; - simple_write(self.len() as u64, buffer, &mut offset)?; - slice_write_helper(self, &mut buffer[offset..], fds) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn read_write_1_fixed() { - let vec = vec![1u32]; - let mut buffer = vec![0; vec.msg_size()]; - vec.write_to_buffer(&mut buffer, &mut []).unwrap(); - let read_vec = unsafe { <Vec<u32>>::read_from_buffer(&buffer, &[]) } - .unwrap() - .0; - - assert_eq!(vec, read_vec); - } - - #[test] - fn read_write_8_fixed() { - let vec = vec![1u16, 1, 3, 5, 8, 13, 21, 34]; - let mut buffer = vec![0; vec.msg_size()]; - vec.write_to_buffer(&mut buffer, &mut []).unwrap(); - let read_vec = unsafe { <Vec<u16>>::read_from_buffer(&buffer, &[]) } - .unwrap() - .0; - assert_eq!(vec, read_vec); - } - - #[test] - fn read_write_1() { - let vec = vec![Some(1u64)]; - let mut buffer = vec![0; vec.msg_size()]; - println!("{:?}", vec.msg_size()); - vec.write_to_buffer(&mut buffer, &mut []).unwrap(); - let read_vec = unsafe { <Vec<_>>::read_from_buffer(&buffer, &[]) } - .unwrap() - .0; - - assert_eq!(vec, read_vec); - } - - #[test] - fn read_write_4() { - let vec = vec![Some(12u16), Some(0), None, None]; - let mut buffer = vec![0; vec.msg_size()]; - vec.write_to_buffer(&mut buffer, &mut []).unwrap(); - let read_vec = unsafe { <Vec<_>>::read_from_buffer(&buffer, &[]) } - .unwrap() - .0; - - assert_eq!(vec, read_vec); - } -} diff --git a/msg_socket/src/msg_on_socket/tuple.rs b/msg_socket/src/msg_on_socket/tuple.rs deleted file mode 100644 index 90784bfef..000000000 --- a/msg_socket/src/msg_on_socket/tuple.rs +++ /dev/null @@ -1,205 +0,0 @@ -// Copyright 2020 The Chromium OS Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -use base::RawDescriptor; -use std::mem::size_of; - -use crate::{MsgOnSocket, MsgResult}; - -use super::{simple_read, simple_write}; - -// Returns the size of one part of a tuple. -fn tuple_size_helper<T: MsgOnSocket>(v: &T) -> usize { - T::fixed_size().unwrap_or_else(|| v.msg_size() + size_of::<u64>()) -} - -unsafe fn tuple_read_helper<T: MsgOnSocket>( - buffer: &[u8], - fds: &[RawDescriptor], - buffer_index: &mut usize, - fd_index: &mut usize, -) -> MsgResult<T> { - let end = match T::fixed_size() { - Some(_) => buffer.len(), - None => { - let len = simple_read::<u64>(buffer, buffer_index)? as usize; - *buffer_index + len - } - }; - let (v, fd_read) = T::read_from_buffer(&buffer[*buffer_index..end], &fds[*fd_index..])?; - *buffer_index += v.msg_size(); - *fd_index += fd_read; - Ok(v) -} - -fn tuple_write_helper<T: MsgOnSocket>( - v: &T, - buffer: &mut [u8], - fds: &mut [RawDescriptor], - buffer_index: &mut usize, - fd_index: &mut usize, -) -> MsgResult<()> { - let end = match T::fixed_size() { - Some(_) => buffer.len(), - None => { - let len = v.msg_size(); - simple_write(len as u64, buffer, buffer_index)?; - *buffer_index + len - } - }; - let fd_written = v.write_to_buffer(&mut buffer[*buffer_index..end], &mut fds[*fd_index..])?; - *buffer_index += v.msg_size(); - *fd_index += fd_written; - Ok(()) -} - -macro_rules! tuple_impls { - () => {}; - ($t: ident) => { - #[allow(unused_variables, non_snake_case)] - impl<$t: MsgOnSocket> MsgOnSocket for ($t,) { - fn uses_descriptor() -> bool { - $t::uses_descriptor() - } - - fn descriptor_count(&self) -> usize { - self.0.descriptor_count() - } - - fn fixed_size() -> Option<usize> { - $t::fixed_size() - } - - fn msg_size(&self) -> usize { - self.0.msg_size() - } - - unsafe fn read_from_buffer(buffer: &[u8], fds: &[RawDescriptor]) -> MsgResult<(Self, usize)> { - let (t, s) = $t::read_from_buffer(buffer, fds)?; - Ok(((t,), s)) - } - - fn write_to_buffer( - &self, - buffer: &mut [u8], - fds: &mut [RawDescriptor], - ) -> MsgResult<usize> { - self.0.write_to_buffer(buffer, fds) - } - } - }; - ($t: ident, $($ts:ident),*) => { - #[allow(unused_variables, non_snake_case)] - impl<$t: MsgOnSocket $(, $ts: MsgOnSocket)*> MsgOnSocket for ($t$(, $ts)*) { - fn uses_descriptor() -> bool { - $t::uses_descriptor() $(|| $ts::uses_descriptor())* - } - - fn descriptor_count(&self) -> usize { - if Self::uses_descriptor() { - return 0; - } - let ($t $(,$ts)*) = self; - $t.descriptor_count() $(+ $ts.descriptor_count())* - } - - fn fixed_size() -> Option<usize> { - // Returns None if any element is not fixed size. - Some($t::fixed_size()? $(+ $ts::fixed_size()?)*) - } - - fn msg_size(&self) -> usize { - if let Some(size) = Self::fixed_size() { - return size - } - - let ($t $(,$ts)*) = self; - tuple_size_helper($t) $(+ tuple_size_helper($ts))* - } - - unsafe fn read_from_buffer(buffer: &[u8], fds: &[RawDescriptor]) -> MsgResult<(Self, usize)> { - let mut buffer_index = 0; - let mut fd_index = 0; - Ok(( - ( - tuple_read_helper(buffer, fds, &mut buffer_index, &mut fd_index)?, - $({ - // Dummy let used to trigger the correct number of iterations. - let $ts = (); - tuple_read_helper(buffer, fds, &mut buffer_index, &mut fd_index)? - },)* - ), - fd_index - )) - } - - fn write_to_buffer( - &self, - buffer: &mut [u8], - fds: &mut [RawDescriptor], - ) -> MsgResult<usize> { - let mut buffer_index = 0; - let mut fd_index = 0; - let ($t $(,$ts)*) = self; - tuple_write_helper($t, buffer, fds, &mut buffer_index, &mut fd_index)?; - $( - tuple_write_helper($ts, buffer, fds, &mut buffer_index, &mut fd_index)?; - )* - Ok(fd_index) - } - } - tuple_impls!{ $($ts),* } - } -} - -// Imlpement tuple for up to 8 elements. -tuple_impls! { A, B, C, D, E, F, G, H } - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn read_write_1_fixed() { - let tuple = (1,); - let mut buffer = vec![0; tuple.msg_size()]; - tuple.write_to_buffer(&mut buffer, &mut []).unwrap(); - let read_tuple = unsafe { <(u32,)>::read_from_buffer(&buffer, &[]) } - .unwrap() - .0; - - assert_eq!(tuple, read_tuple); - } - - #[test] - fn read_write_8_fixed() { - let tuple = (1u32, 2u8, 3u16, 4u64, 5u32, 6u16, 7u8, 8u8); - let mut buffer = vec![0; tuple.msg_size()]; - tuple.write_to_buffer(&mut buffer, &mut []).unwrap(); - let read_tuple = unsafe { <_>::read_from_buffer(&buffer, &[]) }.unwrap().0; - - assert_eq!(tuple, read_tuple); - } - - #[test] - fn read_write_1() { - let tuple = (Some(1u64),); - let mut buffer = vec![0; tuple.msg_size()]; - tuple.write_to_buffer(&mut buffer, &mut []).unwrap(); - let read_tuple = unsafe { <_>::read_from_buffer(&buffer, &[]) }.unwrap().0; - - assert_eq!(tuple, read_tuple); - } - - #[test] - fn read_write_4() { - let tuple = (Some(12u16), Some(false), None::<u8>, None::<u64>); - let mut buffer = vec![0; tuple.msg_size()]; - println!("{:?}", tuple.msg_size()); - tuple.write_to_buffer(&mut buffer, &mut []).unwrap(); - let read_tuple = unsafe { <_>::read_from_buffer(&buffer, &[]) }.unwrap().0; - - assert_eq!(tuple, read_tuple); - } -} diff --git a/msg_socket/src/serializable_descriptors.rs b/msg_socket/src/serializable_descriptors.rs deleted file mode 100644 index 21ce84b7e..000000000 --- a/msg_socket/src/serializable_descriptors.rs +++ /dev/null @@ -1,85 +0,0 @@ -// Copyright 2020 The Chromium OS Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -use crate::msg_on_socket::{MsgError, MsgOnSocket, MsgResult}; -use base::{AsRawDescriptor, Event, FromRawDescriptor, RawDescriptor}; -use std::fs::File; -use std::net::{TcpListener, TcpStream, UdpSocket}; -use std::os::unix::io::{AsRawFd, FromRawFd, RawFd}; -use std::os::unix::net::{UnixDatagram, UnixListener, UnixStream}; - -macro_rules! rawdescriptor_impl { - ($type:ident) => { - impl MsgOnSocket for $type { - fn uses_descriptor() -> bool { - true - } - fn msg_size(&self) -> usize { - 0 - } - fn descriptor_count(&self) -> usize { - 1 - } - unsafe fn read_from_buffer( - _buffer: &[u8], - descriptors: &[RawDescriptor], - ) -> MsgResult<(Self, usize)> { - if descriptors.len() < 1 { - return Err(MsgError::ExpectDescriptor); - } - Ok(($type::from_raw_descriptor(descriptors[0]), 1)) - } - fn write_to_buffer( - &self, - _buffer: &mut [u8], - descriptors: &mut [RawDescriptor], - ) -> MsgResult<usize> { - if descriptors.is_empty() { - return Err(MsgError::WrongDescriptorBufferSize); - } - descriptors[0] = self.as_raw_descriptor(); - Ok(1) - } - } - }; -} - -rawdescriptor_impl!(Event); -rawdescriptor_impl!(File); - -macro_rules! rawfd_impl { - ($type:ident) => { - impl MsgOnSocket for $type { - fn uses_descriptor() -> bool { - true - } - fn msg_size(&self) -> usize { - 0 - } - fn descriptor_count(&self) -> usize { - 1 - } - unsafe fn read_from_buffer(_buffer: &[u8], fds: &[RawFd]) -> MsgResult<(Self, usize)> { - if fds.len() < 1 { - return Err(MsgError::ExpectDescriptor); - } - Ok(($type::from_raw_fd(fds[0]), 1)) - } - fn write_to_buffer(&self, _buffer: &mut [u8], fds: &mut [RawFd]) -> MsgResult<usize> { - if fds.is_empty() { - return Err(MsgError::WrongDescriptorBufferSize); - } - fds[0] = self.as_raw_fd(); - Ok(1) - } - } - }; -} - -rawfd_impl!(UnixStream); -rawfd_impl!(TcpStream); -rawfd_impl!(TcpListener); -rawfd_impl!(UdpSocket); -rawfd_impl!(UnixListener); -rawfd_impl!(UnixDatagram); diff --git a/msg_socket/tests/enum.rs b/msg_socket/tests/enum.rs deleted file mode 100644 index 7f5998bdc..000000000 --- a/msg_socket/tests/enum.rs +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright 2019 The Chromium OS Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -use base::{Event, RawDescriptor}; - -use msg_socket::*; - -#[derive(MsgOnSocket)] -struct DummyRequest {} - -#[derive(MsgOnSocket)] -enum Response { - A(u8), - B, - C(u32, Event), - D([u8; 4]), - E { f0: u8, f1: u32 }, -} - -#[test] -fn sock_send_recv_enum() { - let (req, res) = pair::<DummyRequest, Response>().unwrap(); - let e0 = Event::new().unwrap(); - let e1 = e0.try_clone().unwrap(); - res.send(&Response::C(0xf0f0, e0)).unwrap(); - let r = req.recv().unwrap(); - match r { - Response::C(v, efd) => { - assert_eq!(v, 0xf0f0); - efd.write(0x0f0f).unwrap(); - } - _ => panic!("wrong type"), - }; - assert_eq!(e1.read().unwrap(), 0x0f0f); - - res.send(&Response::B).unwrap(); - match req.recv().unwrap() { - Response::B => {} - _ => panic!("Wrong enum type"), - }; - - res.send(&Response::A(0x3)).unwrap(); - match req.recv().unwrap() { - Response::A(v) => assert_eq!(v, 0x3), - _ => panic!("Wrong enum type"), - }; - - res.send(&Response::D([0, 1, 2, 3])).unwrap(); - match req.recv().unwrap() { - Response::D(v) => assert_eq!(v, [0, 1, 2, 3]), - _ => panic!("Wrong enum type"), - }; - - res.send(&Response::E { - f0: 0x12, - f1: 0x0f0f, - }) - .unwrap(); - match req.recv().unwrap() { - Response::E { f0, f1 } => { - assert_eq!(f0, 0x12); - assert_eq!(f1, 0x0f0f); - } - _ => panic!("Wrong enum type"), - }; -} diff --git a/msg_socket/tests/struct.rs b/msg_socket/tests/struct.rs deleted file mode 100644 index 8e3c93f4a..000000000 --- a/msg_socket/tests/struct.rs +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright 2019 The Chromium OS Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -use base::{Event, RawDescriptor}; - -use msg_socket::*; - -#[derive(MsgOnSocket)] -struct Request { - field0: u8, - field1: Event, - field2: u32, - field3: bool, -} - -#[derive(MsgOnSocket)] -struct DummyResponse {} - -#[test] -fn sock_send_recv_struct() { - let (req, res) = pair::<Request, DummyResponse>().unwrap(); - let e0 = Event::new().unwrap(); - let e1 = e0.try_clone().unwrap(); - req.send(&Request { - field0: 2, - field1: e0, - field2: 0xf0f0, - field3: true, - }) - .unwrap(); - let r = res.recv().unwrap(); - assert_eq!(r.field0, 2); - assert_eq!(r.field2, 0xf0f0); - assert_eq!(r.field3, true); - r.field1.write(0x0f0f).unwrap(); - assert_eq!(e1.read().unwrap(), 0x0f0f); -} diff --git a/msg_socket/tests/tuple.rs b/msg_socket/tests/tuple.rs deleted file mode 100644 index ae008135f..000000000 --- a/msg_socket/tests/tuple.rs +++ /dev/null @@ -1,22 +0,0 @@ -// Copyright 2019 The Chromium OS Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -use base::{Event, RawDescriptor}; -use msg_socket::*; - -#[derive(MsgOnSocket)] -struct Message(u8, u16, Event); - -#[test] -fn sock_send_recv_tuple() { - let (req, res) = pair::<Message, Message>().unwrap(); - let e0 = Event::new().unwrap(); - let e1 = e0.try_clone().unwrap(); - req.send(&Message(1, 0x12, e0)).unwrap(); - let r = res.recv().unwrap(); - assert_eq!(r.0, 1); - assert_eq!(r.1, 0x12); - r.2.write(0x0f0f).unwrap(); - assert_eq!(e1.read().unwrap(), 0x0f0f); -} diff --git a/msg_socket/tests/unit.rs b/msg_socket/tests/unit.rs deleted file mode 100644 index 9855752f8..000000000 --- a/msg_socket/tests/unit.rs +++ /dev/null @@ -1,12 +0,0 @@ -// Copyright 2018 The Chromium OS Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -use msg_socket::*; - -#[test] -fn sock_send_recv_unit() { - let (req, res) = pair::<(), ()>().unwrap(); - req.send(&()).unwrap(); - let _ = res.recv().unwrap(); -} diff --git a/resources/Cargo.toml b/resources/Cargo.toml index 041da9f3c..ca1d4b882 100644 --- a/resources/Cargo.toml +++ b/resources/Cargo.toml @@ -6,5 +6,5 @@ edition = "2018" [dependencies] libc = "*" -msg_socket = { path = "../msg_socket" } base = { path = "../base" } +serde = { version = "1", features = ["derive"] } diff --git a/resources/src/lib.rs b/resources/src/lib.rs index 0a2518199..0983ffaac 100644 --- a/resources/src/lib.rs +++ b/resources/src/lib.rs @@ -4,14 +4,10 @@ //! Manages system resources that can be allocated to VMs and their devices. -extern crate base; -extern crate libc; -extern crate msg_socket; - -use base::RawDescriptor; -use msg_socket::MsgOnSocket; use std::fmt::Display; +use serde::{Deserialize, Serialize}; + pub use crate::address_allocator::AddressAllocator; pub use crate::system_allocator::{MmioType, SystemAllocator}; @@ -19,7 +15,7 @@ mod address_allocator; mod system_allocator; /// Used to tag SystemAllocator allocations. -#[derive(Debug, Eq, PartialEq, Hash, MsgOnSocket, Copy, Clone)] +#[derive(Debug, Eq, PartialEq, Hash, Copy, Clone, Serialize, Deserialize)] pub enum Alloc { /// An anonymous resource allocation. /// Should only be instantiated through `SystemAllocator::get_anon_alloc()`. @@ -11,10 +11,9 @@ from typing import List, Dict from ci.test_runner import Requirements, main # A list of all crates and their requirements +# See ci/test_runner.py for documentation on the requirements CRATE_REQUIREMENTS: Dict[str, List[Requirements]] = { "aarch64": [Requirements.AARCH64], - "crosvm": [Requirements.DISABLED], - "aarch64": [Requirements.AARCH64], "acpi_tables": [], "arch": [], "assertions": [], @@ -22,6 +21,7 @@ CRATE_REQUIREMENTS: Dict[str, List[Requirements]] = { "bit_field": [], "bit_field_derive": [], "cros_async": [Requirements.DISABLED], + "crosvm": [Requirements.DO_NOT_RUN], "crosvm_plugin": [Requirements.X86_64], "data_model": [], "devices": [ @@ -29,20 +29,24 @@ CRATE_REQUIREMENTS: Dict[str, List[Requirements]] = { Requirements.PRIVILEGED, Requirements.X86_64, ], - "disk": [Requirements.DISABLED], + "disk": [Requirements.PRIVILEGED], "enumn": [], "fuse": [], "fuzz": [Requirements.DISABLED], "gpu_display": [], "hypervisor": [Requirements.PRIVILEGED, Requirements.X86_64], - "io_uring": [Requirements.DISABLED], + "integration_tests": [Requirements.PRIVILEGED, Requirements.X86_64], + "io_uring": [ + Requirements.SEPARATE_WORKSPACE, + Requirements.PRIVILEGED, + Requirements.SINGLE_THREADED, + ], "kernel_cmdline": [], "kernel_loader": [Requirements.PRIVILEGED], "kvm_sys": [Requirements.PRIVILEGED], "kvm": [Requirements.PRIVILEGED, Requirements.X86_64], + "libcrosvm_control": [], "linux_input_sys": [], - "msg_socket": [Requirements.PRIVILEGED], - "msg_on_socket_derive": [], "net_sys": [], "net_util": [Requirements.PRIVILEGED], "power_monitor": [], @@ -50,11 +54,10 @@ CRATE_REQUIREMENTS: Dict[str, List[Requirements]] = { "qcow_utils": [], "rand_ish": [], "resources": [], - "rutabaga_gfx": [Requirements.CROS_BUILD, Requirements.X86_64], + "rutabaga_gfx": [Requirements.CROS_BUILD, Requirements.PRIVILEGED], "sync": [], "sys_util": [Requirements.SINGLE_THREADED, Requirements.PRIVILEGED], "poll_token_derive": [], - "syscall_defines": [], "tempfile": [], "tpm2-sys": [], "tpm2": [], @@ -64,7 +67,7 @@ CRATE_REQUIREMENTS: Dict[str, List[Requirements]] = { "vhost": [Requirements.PRIVILEGED], "virtio_sys": [], "vm_control": [], - "vm_memory": [Requirements.DISABLED], + "vm_memory": [Requirements.PRIVILEGED], "x86_64": [Requirements.X86_64, Requirements.PRIVILEGED], } diff --git a/rutabaga_gfx/src/cross_domain/cross_domain.rs b/rutabaga_gfx/src/cross_domain/cross_domain.rs index 5ab738715..a7b3ac999 100644 --- a/rutabaga_gfx/src/cross_domain/cross_domain.rs +++ b/rutabaga_gfx/src/cross_domain/cross_domain.rs @@ -138,6 +138,7 @@ impl RutabagaContext for CrossDomainContext { &mut self, resource_id: u32, resource_create_blob: ResourceCreateBlob, + handle: Option<RutabagaHandle>, ) -> RutabagaResult<RutabagaResource> { let reqs = self .requirements_blobs @@ -152,7 +153,11 @@ impl RutabagaContext for CrossDomainContext { // create blob function, which says "the actual allocation is done via // VIRTIO_GPU_CMD_SUBMIT_3D." However, atomic resource creation is easiest for the // cross-domain use case, so whatever. - let handle = self.gralloc.lock().allocate_memory(*reqs)?; + let hnd = match handle { + Some(handle) => handle, + None => self.gralloc.lock().allocate_memory(*reqs)?, + }; + let info_3d = Resource3DInfo { width: reqs.info.width, height: reqs.info.height, @@ -164,7 +169,7 @@ impl RutabagaContext for CrossDomainContext { Ok(RutabagaResource { resource_id, - handle: Some(Arc::new(handle)), + handle: Some(Arc::new(hnd)), blob: true, blob_mem: resource_create_blob.blob_mem, blob_flags: resource_create_blob.blob_flags, @@ -251,7 +256,7 @@ impl RutabagaContext for CrossDomainContext { impl RutabagaComponent for CrossDomain { fn get_capset_info(&self, _capset_id: u32) -> (u32, u32) { - return (0 as u32, size_of::<CrossDomainCapabilities>() as u32); + return (0u32, size_of::<CrossDomainCapabilities>() as u32); } fn get_capset(&self, _capset_id: u32, _version: u32) -> Vec<u8> { diff --git a/rutabaga_gfx/src/generated/virgl_renderer_bindings.rs b/rutabaga_gfx/src/generated/virgl_renderer_bindings.rs index 3f03ad75e..08b8f7a3e 100644 --- a/rutabaga_gfx/src/generated/virgl_renderer_bindings.rs +++ b/rutabaga_gfx/src/generated/virgl_renderer_bindings.rs @@ -11,6 +11,8 @@ pub const VIRGL_RENDERER_USE_GLX: u32 = 4; pub const VIRGL_RENDERER_USE_SURFACELESS: u32 = 8; pub const VIRGL_RENDERER_USE_GLES: u32 = 16; pub const VIRGL_RENDERER_USE_EXTERNAL_BLOB: u32 = 32; +pub const VIRGL_RENDERER_VENUS: u32 = 64; +pub const VIRGL_RENDERER_NO_VIRGL: u32 = 128; pub const VIRGL_RES_BIND_DEPTH_STENCIL: u32 = 1; pub const VIRGL_RES_BIND_RENDER_TARGET: u32 = 2; pub const VIRGL_RES_BIND_SAMPLER_VIEW: u32 = 8; diff --git a/rutabaga_gfx/src/gfxstream.rs b/rutabaga_gfx/src/gfxstream.rs index 09a6a17a8..81d3b8a20 100644 --- a/rutabaga_gfx/src/gfxstream.rs +++ b/rutabaga_gfx/src/gfxstream.rs @@ -221,6 +221,7 @@ impl Gfxstream { Ok(Box::new(Gfxstream { fence_state })) } + #[allow(clippy::unnecessary_wraps)] fn map_info(&self, _resource_id: u32) -> RutabagaResult<u32> { Ok(RUTABAGA_MAP_CACHE_WC) } @@ -410,7 +411,7 @@ impl RutabagaComponent for Gfxstream { _ctx_id: u32, resource_id: u32, resource_create_blob: ResourceCreateBlob, - _iovecs: Vec<RutabagaIovec>, + _iovec_opt: Option<Vec<RutabagaIovec>>, ) -> RutabagaResult<RutabagaResource> { unsafe { stream_renderer_resource_create_v2(resource_id, resource_create_blob.blob_id); diff --git a/rutabaga_gfx/src/lib.rs b/rutabaga_gfx/src/lib.rs index 40c131ba3..dc99fa159 100644 --- a/rutabaga_gfx/src/lib.rs +++ b/rutabaga_gfx/src/lib.rs @@ -10,6 +10,7 @@ mod generated; mod gfxstream; #[macro_use] mod macros; +#[cfg(any(feature = "gfxstream", feature = "virgl_renderer"))] mod renderer_utils; mod rutabaga_2d; mod rutabaga_core; diff --git a/rutabaga_gfx/src/renderer_utils.rs b/rutabaga_gfx/src/renderer_utils.rs index 9b3726437..94a271a72 100644 --- a/rutabaga_gfx/src/renderer_utils.rs +++ b/rutabaga_gfx/src/renderer_utils.rs @@ -30,7 +30,7 @@ pub struct VirglBox { * -o vsnprintf.rs */ -#[allow(dead_code, non_snake_case, non_camel_case_types)] +#[allow(non_snake_case, non_camel_case_types)] #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] extern "C" { pub fn vsnprintf( diff --git a/rutabaga_gfx/src/rutabaga_core.rs b/rutabaga_gfx/src/rutabaga_core.rs index 9fef434c3..23b30aca9 100644 --- a/rutabaga_gfx/src/rutabaga_core.rs +++ b/rutabaga_gfx/src/rutabaga_core.rs @@ -132,7 +132,7 @@ pub trait RutabagaComponent { _ctx_id: u32, _resource_id: u32, _resource_create_blob: ResourceCreateBlob, - _backing_iovecs: Vec<RutabagaIovec>, + _iovec_opt: Option<Vec<RutabagaIovec>>, ) -> RutabagaResult<RutabagaResource> { Err(RutabagaError::Unsupported) } @@ -166,6 +166,7 @@ pub trait RutabagaContext { &mut self, _resource_id: u32, _resource_create_blob: ResourceCreateBlob, + _handle: Option<RutabagaHandle>, ) -> RutabagaResult<RutabagaResource> { Err(RutabagaError::Unsupported) } @@ -442,13 +443,15 @@ impl Rutabaga { } /// Creates a blob resource with the `ctx_id` and `resource_create_blob` metadata. - /// Associates `iovecs` with the resource, if there are any. + /// Associates `iovecs` with the resource, if there are any. Associates externally + /// created `handle` with the resource, if there is any. pub fn resource_create_blob( &mut self, ctx_id: u32, resource_id: u32, resource_create_blob: ResourceCreateBlob, - iovecs: Vec<RutabagaIovec>, + iovecs: Option<Vec<RutabagaIovec>>, + handle: Option<RutabagaHandle>, ) -> RutabagaResult<()> { if self.resources.contains_key(&resource_id) { return Err(RutabagaError::InvalidResourceId); @@ -463,7 +466,8 @@ impl Rutabaga { .get_mut(&ctx_id) .ok_or(RutabagaError::InvalidContextId)?; - if let Ok(resource) = ctx.context_create_blob(resource_id, resource_create_blob) { + if let Ok(resource) = ctx.context_create_blob(resource_id, resource_create_blob, handle) + { self.resources.insert(resource_id, resource); return Ok(()); } @@ -501,8 +505,18 @@ impl Rutabaga { .get(&resource_id) .ok_or(RutabagaError::InvalidResourceId)?; - let map_info = resource.map_info.ok_or(RutabagaError::SpecViolation)?; - Ok(map_info) + resource.map_info.ok_or(RutabagaError::SpecViolation) + } + + /// Returns the `vulkan_info` of the blob resource, which consists of the physical device + /// index and memory index associated with the resource. + pub fn vulkan_info(&self, resource_id: u32) -> RutabagaResult<VulkanInfo> { + let resource = self + .resources + .get(&resource_id) + .ok_or(RutabagaError::InvalidResourceId)?; + + resource.vulkan_info.ok_or(RutabagaError::Unsupported) } /// Returns the 3D info associated with the resource, if any. @@ -512,8 +526,7 @@ impl Rutabaga { .get(&resource_id) .ok_or(RutabagaError::InvalidResourceId)?; - let info_3d = resource.info_3d.ok_or(RutabagaError::Unsupported)?; - Ok(info_3d) + resource.info_3d.ok_or(RutabagaError::Unsupported) } /// Exports a blob resource. See virtio-gpu spec for blob flag use flags. diff --git a/rutabaga_gfx/src/rutabaga_gralloc/gralloc.rs b/rutabaga_gfx/src/rutabaga_gralloc/gralloc.rs index 4222ba945..9b3f25f03 100644 --- a/rutabaga_gfx/src/rutabaga_gralloc/gralloc.rs +++ b/rutabaga_gfx/src/rutabaga_gralloc/gralloc.rs @@ -7,7 +7,7 @@ use std::collections::BTreeMap as Map; -use base::round_up_to_page_size; +use base::{round_up_to_page_size, MappedRegion}; use crate::rutabaga_gralloc::formats::*; use crate::rutabaga_gralloc::system_gralloc::SystemGralloc; @@ -89,6 +89,26 @@ impl RutabagaGrallocFlags { } } + /// Sets the SW write flag's presence. + #[inline(always)] + pub fn use_sw_write(self, e: bool) -> RutabagaGrallocFlags { + if e { + RutabagaGrallocFlags(self.0 | RUTABAGA_GRALLOC_USE_SW_WRITE_OFTEN) + } else { + RutabagaGrallocFlags(self.0 & !RUTABAGA_GRALLOC_USE_SW_WRITE_OFTEN) + } + } + + /// Sets the SW read flag's presence. + #[inline(always)] + pub fn use_sw_read(self, e: bool) -> RutabagaGrallocFlags { + if e { + RutabagaGrallocFlags(self.0 | RUTABAGA_GRALLOC_USE_SW_READ_OFTEN) + } else { + RutabagaGrallocFlags(self.0 & !RUTABAGA_GRALLOC_USE_SW_READ_OFTEN) + } + } + /// Returns true if the texturing flag is set. #[inline(always)] pub fn uses_texturing(self) -> bool { @@ -167,6 +187,17 @@ pub trait Gralloc { /// Implementations must allocate memory given the requirements and return a RutabagaHandle /// upon success. fn allocate_memory(&mut self, reqs: ImageMemoryRequirements) -> RutabagaResult<RutabagaHandle>; + + /// Implementations must import the given `handle` and return a mapping, suitable for use with + /// KVM and other hypervisors. This is optional and only works with the Vulkano backend. + fn import_and_map( + &mut self, + _handle: RutabagaHandle, + _vulkan_info: VulkanInfo, + _size: u64, + ) -> RutabagaResult<Box<dyn MappedRegion>> { + Err(RutabagaError::Unsupported) + } } /// Enumeration of possible allocation backends. @@ -293,6 +324,22 @@ impl RutabagaGralloc { gralloc.allocate_memory(reqs) } + + /// Imports the `handle` using the given `vulkan_info`. Returns a mapping using Vulkano upon + /// success. Should not be used with minigbm or system gralloc backends. + pub fn import_and_map( + &mut self, + handle: RutabagaHandle, + vulkan_info: VulkanInfo, + size: u64, + ) -> RutabagaResult<Box<dyn MappedRegion>> { + let gralloc = self + .grallocs + .get_mut(&GrallocBackend::Vulkano) + .ok_or(RutabagaError::Unsupported)?; + + gralloc.import_and_map(handle, vulkan_info, size) + } } #[cfg(test)] @@ -363,4 +410,44 @@ mod tests { // Reallocate with same requirements let _handle2 = gralloc.allocate_memory(reqs).unwrap(); } + + #[test] + fn export_and_map() { + let gralloc_result = RutabagaGralloc::new(); + if gralloc_result.is_err() { + return; + } + + let mut gralloc = gralloc_result.unwrap(); + + let info = ImageAllocationInfo { + width: 512, + height: 1024, + drm_format: DrmFormat::new(b'X', b'R', b'2', b'4'), + flags: RutabagaGrallocFlags::empty() + .use_linear(true) + .use_sw_write(true) + .use_sw_read(true), + }; + + let mut reqs = gralloc.get_image_memory_requirements(info).unwrap(); + + // Anything else can use the mmap(..) system call. + if reqs.vulkan_info.is_none() { + return; + } + + let handle = gralloc.allocate_memory(reqs).unwrap(); + let vulkan_info = reqs.vulkan_info.take().unwrap(); + + let mapping = gralloc + .import_and_map(handle, vulkan_info, reqs.size) + .unwrap(); + + let addr = mapping.as_ptr(); + let size = mapping.size(); + + assert_eq!(size as u64, reqs.size); + assert_ne!(addr as *const u8, std::ptr::null()); + } } diff --git a/rutabaga_gfx/src/rutabaga_gralloc/minigbm.rs b/rutabaga_gfx/src/rutabaga_gralloc/minigbm.rs index bdf62e753..c20438de3 100644 --- a/rutabaga_gfx/src/rutabaga_gralloc/minigbm.rs +++ b/rutabaga_gfx/src/rutabaga_gralloc/minigbm.rs @@ -143,7 +143,7 @@ impl Gralloc for MinigbmDevice { return Err(RutabagaError::SpecViolation); } - let dmabuf = gbm_buffer.export()?; + let dmabuf = gbm_buffer.export()?.into(); return Ok(RutabagaHandle { os_handle: dmabuf, handle_type: RUTABAGA_MEM_HANDLE_TYPE_DMABUF, @@ -165,7 +165,7 @@ impl Gralloc for MinigbmDevice { } let gbm_buffer = MinigbmBuffer(bo, self.clone()); - let dmabuf = gbm_buffer.export()?; + let dmabuf = gbm_buffer.export()?.into(); Ok(RutabagaHandle { os_handle: dmabuf, handle_type: RUTABAGA_MEM_HANDLE_TYPE_DMABUF, diff --git a/rutabaga_gfx/src/rutabaga_gralloc/vulkano_gralloc.rs b/rutabaga_gfx/src/rutabaga_gralloc/vulkano_gralloc.rs index b74908a08..7c66d210b 100644 --- a/rutabaga_gfx/src/rutabaga_gralloc/vulkano_gralloc.rs +++ b/rutabaga_gfx/src/rutabaga_gralloc/vulkano_gralloc.rs @@ -9,22 +9,27 @@ #![cfg(feature = "vulkano")] +use std::collections::BTreeMap as Map; +use std::convert::TryInto; use std::iter::Empty; use std::sync::Arc; +use base::MappedRegion; + use crate::rutabaga_gralloc::gralloc::{Gralloc, ImageAllocationInfo, ImageMemoryRequirements}; use crate::rutabaga_utils::*; use vulkano::device::{Device, DeviceCreationError, DeviceExtensions}; -use vulkano::image::{sys, ImageCreationError, ImageDimensions, ImageUsage}; +use vulkano::image::{sys, ImageCreateFlags, ImageCreationError, ImageDimensions, ImageUsage}; use vulkano::instance::{ Instance, InstanceCreationError, InstanceExtensions, MemoryType, PhysicalDevice, + PhysicalDeviceType, }; use vulkano::memory::{ - DedicatedAlloc, DeviceMemoryAllocError, DeviceMemoryBuilder, ExternalMemoryHandleType, - MemoryRequirements, + DedicatedAlloc, DeviceMemoryAllocError, DeviceMemoryBuilder, DeviceMemoryMapping, + ExternalMemoryHandleType, MemoryRequirements, }; use vulkano::memory::pool::AllocFromRequirementsFilter; @@ -32,7 +37,32 @@ use vulkano::sync::Sharing; /// A gralloc implementation capable of allocation `VkDeviceMemory`. pub struct VulkanoGralloc { - device: Arc<Device>, + devices: Map<PhysicalDeviceType, Arc<Device>>, + has_integrated_gpu: bool, +} + +struct VulkanoMapping { + mapping: DeviceMemoryMapping, + size: usize, +} + +impl VulkanoMapping { + pub fn new(mapping: DeviceMemoryMapping, size: usize) -> VulkanoMapping { + VulkanoMapping { mapping, size } + } +} + +unsafe impl MappedRegion for VulkanoMapping { + /// Used for passing this region for hypervisor memory mappings. We trust crosvm to use this + /// safely. + fn as_ptr(&self) -> *mut u8 { + unsafe { self.mapping.as_ptr() } + } + + /// Returns the size of the memory region in bytes. + fn size(&self) -> usize { + self.size + } } impl VulkanoGralloc { @@ -42,39 +72,52 @@ impl VulkanoGralloc { // explanation of VK initialization. let instance = Instance::new(None, &InstanceExtensions::none(), None)?; - // We should really check for integrated GPU versus dGPU. - let physical = PhysicalDevice::enumerate(&instance) - .next() - .ok_or(RutabagaError::Unsupported)?; + let mut devices: Map<PhysicalDeviceType, Arc<Device>> = Default::default(); + let mut has_integrated_gpu = false; + + for physical in PhysicalDevice::enumerate(&instance) { + let queue_family = physical + .queue_families() + .find(|&q| { + // We take the first queue family that supports graphics. + q.supports_graphics() + }) + .ok_or(RutabagaError::Unsupported)?; + + let supported_extensions = DeviceExtensions::supported_by_device(physical); + + let desired_extensions = DeviceExtensions { + khr_dedicated_allocation: true, + khr_get_memory_requirements2: true, + khr_external_memory: true, + khr_external_memory_fd: true, + ext_external_memory_dmabuf: true, + ..DeviceExtensions::none() + }; - let queue_family = physical - .queue_families() - .find(|&q| { - // We take the first queue family that supports graphics. - q.supports_graphics() - }) - .ok_or(RutabagaError::Unsupported)?; + let intersection = supported_extensions.intersection(&desired_extensions); - let supported_extensions = DeviceExtensions::supported_by_device(physical); - let desired_extensions = DeviceExtensions { - khr_dedicated_allocation: true, - khr_get_memory_requirements2: true, - khr_external_memory: true, - khr_external_memory_fd: true, - ext_external_memory_dmabuf: true, - ..DeviceExtensions::none() - }; + let (device, mut _queues) = Device::new( + physical, + physical.supported_features(), + &intersection, + [(queue_family, 0.5)].iter().cloned(), + )?; - let intersection = supported_extensions.intersection(&desired_extensions); + if device.physical_device().ty() == PhysicalDeviceType::IntegratedGpu { + has_integrated_gpu = true + } - let (device, mut _queues) = Device::new( - physical, - physical.supported_features(), - &intersection, - [(queue_family, 0.5)].iter().cloned(), - )?; + // If we have two devices of the same type (two integrated GPUs), the old value is + // dropped. Vulkano is verbose enough such that a keener selection algorithm may + // be used, but the need for such complexity does not seem to exist now. + devices.insert(device.physical_device().ty(), device); + } - Ok(Box::new(VulkanoGralloc { device })) + Ok(Box::new(VulkanoGralloc { + devices, + has_integrated_gpu, + })) } // This function is used safely in this module because gralloc does not: @@ -88,6 +131,16 @@ impl VulkanoGralloc { &mut self, info: ImageAllocationInfo, ) -> RutabagaResult<(sys::UnsafeImage, MemoryRequirements)> { + let device = if self.has_integrated_gpu { + self.devices + .get(&PhysicalDeviceType::IntegratedGpu) + .ok_or(RutabagaError::Unsupported)? + } else { + self.devices + .get(&PhysicalDeviceType::DiscreteGpu) + .ok_or(RutabagaError::Unsupported)? + }; + let usage = match info.flags.uses_rendering() { true => ImageUsage { color_attachment: true, @@ -111,14 +164,14 @@ impl VulkanoGralloc { let vulkan_format = info.drm_format.vulkan_format()?; let (unsafe_image, memory_requirements) = sys::UnsafeImage::new( - self.device.clone(), + device.clone(), usage, vulkan_format, + ImageCreateFlags::none(), ImageDimensions::Dim2d { width: info.width, height: info.height, array_layers: 1, - cubemap_compatible: false, }, 1, /* number of samples */ 1, /* mipmap count */ @@ -133,11 +186,23 @@ impl VulkanoGralloc { impl Gralloc for VulkanoGralloc { fn supports_external_gpu_memory(&self) -> bool { - self.device.loaded_extensions().khr_external_memory + for device in self.devices.values() { + if !device.loaded_extensions().khr_external_memory { + return false; + } + } + + true } fn supports_dmabuf(&self) -> bool { - self.device.loaded_extensions().ext_external_memory_dmabuf + for device in self.devices.values() { + if !device.loaded_extensions().ext_external_memory_dmabuf { + return false; + } + } + + true } fn get_image_memory_requirements( @@ -148,6 +213,16 @@ impl Gralloc for VulkanoGralloc { let (unsafe_image, memory_requirements) = unsafe { self.create_image(info)? }; + let device = if self.has_integrated_gpu { + self.devices + .get(&PhysicalDeviceType::IntegratedGpu) + .ok_or(RutabagaError::Unsupported)? + } else { + self.devices + .get(&PhysicalDeviceType::DiscreteGpu) + .ok_or(RutabagaError::Unsupported)? + }; + let planar_layout = info.drm_format.planar_layout()?; // Safe because we created the image with the linear bit set and verified the format is @@ -188,13 +263,11 @@ impl Gralloc for VulkanoGralloc { AllocFromRequirementsFilter::Allowed }; - let first_loop = self - .device + let first_loop = device .physical_device() .memory_types() .map(|t| (t, AllocFromRequirementsFilter::Preferred)); - let second_loop = self - .device + let second_loop = device .physical_device() .memory_types() .map(|t| (t, AllocFromRequirementsFilter::Allowed)); @@ -219,7 +292,7 @@ impl Gralloc for VulkanoGralloc { reqs.vulkan_info = Some(VulkanInfo { memory_idx: memory_type.id() as u32, - physical_device_idx: self.device.physical_device().index() as u32, + physical_device_idx: device.physical_device().index() as u32, }); Ok(reqs) @@ -227,15 +300,26 @@ impl Gralloc for VulkanoGralloc { fn allocate_memory(&mut self, reqs: ImageMemoryRequirements) -> RutabagaResult<RutabagaHandle> { let (unsafe_image, memory_requirements) = unsafe { self.create_image(reqs.info)? }; + let vulkan_info = reqs.vulkan_info.ok_or(RutabagaError::SpecViolation)?; - let memory_type = self - .device + + let device = if self.has_integrated_gpu { + self.devices + .get(&PhysicalDeviceType::IntegratedGpu) + .ok_or(RutabagaError::Unsupported)? + } else { + self.devices + .get(&PhysicalDeviceType::DiscreteGpu) + .ok_or(RutabagaError::Unsupported)? + }; + + let memory_type = device .physical_device() .memory_type_by_id(vulkan_info.memory_idx) .ok_or(RutabagaError::SpecViolation)?; let (handle_type, rutabaga_type) = - match self.device.loaded_extensions().ext_external_memory_dmabuf { + match device.loaded_extensions().ext_external_memory_dmabuf { true => ( ExternalMemoryHandleType { dma_buf: true, @@ -252,7 +336,7 @@ impl Gralloc for VulkanoGralloc { ), }; - let dedicated = match self.device.loaded_extensions().khr_dedicated_allocation { + let dedicated = match device.loaded_extensions().khr_dedicated_allocation { true => { if memory_requirements.prefer_dedicated { DedicatedAlloc::Image(&unsafe_image) @@ -264,18 +348,55 @@ impl Gralloc for VulkanoGralloc { }; let device_memory = - DeviceMemoryBuilder::new(self.device.clone(), memory_type, reqs.size as usize) + DeviceMemoryBuilder::new(device.clone(), memory_type.id(), reqs.size as usize) .dedicated_info(dedicated) .export_info(handle_type) .build()?; - let file = device_memory.export_fd(handle_type)?; + let descriptor = device_memory.export_fd(handle_type)?.into(); Ok(RutabagaHandle { - os_handle: file, + os_handle: descriptor, handle_type: rutabaga_type, }) } + + /// Implementations must map the memory associated with the `resource_id` upon success. + fn import_and_map( + &mut self, + handle: RutabagaHandle, + vulkan_info: VulkanInfo, + size: u64, + ) -> RutabagaResult<Box<dyn MappedRegion>> { + let device = self + .devices + .values() + .find(|device| { + device.physical_device().index() as u32 == vulkan_info.physical_device_idx + }) + .ok_or(RutabagaError::Unsupported)?; + + let handle_type = match handle.handle_type { + RUTABAGA_MEM_HANDLE_TYPE_DMABUF => ExternalMemoryHandleType { + dma_buf: true, + ..ExternalMemoryHandleType::none() + }, + RUTABAGA_MEM_HANDLE_TYPE_OPAQUE_FD => ExternalMemoryHandleType { + opaque_fd: true, + ..ExternalMemoryHandleType::none() + }, + _ => return Err(RutabagaError::Unsupported), + }; + + let valid_size: usize = size.try_into()?; + let device_memory = + DeviceMemoryBuilder::new(device.clone(), vulkan_info.memory_idx, valid_size) + .import_info(handle.os_handle.into(), handle_type) + .build()?; + let mapping = DeviceMemoryMapping::new(device.clone(), device_memory.clone(), 0, size, 0)?; + + Ok(Box::new(VulkanoMapping::new(mapping, valid_size))) + } } // Vulkano should really define an universal type that wraps all these errors, say diff --git a/rutabaga_gfx/src/rutabaga_utils.rs b/rutabaga_gfx/src/rutabaga_utils.rs index 943f36ccd..06f6fe7d6 100644 --- a/rutabaga_gfx/src/rutabaga_utils.rs +++ b/rutabaga_gfx/src/rutabaga_utils.rs @@ -5,13 +5,13 @@ //! rutabaga_utils: Utility enums, structs, and implementations needed by the rest of the crate. use std::fmt::{self, Display}; -use std::fs::File; use std::io::Error as IoError; +use std::num::TryFromIntError; use std::os::raw::c_void; use std::path::PathBuf; use std::str::Utf8Error; -use base::{Error as SysError, ExternalMappingError}; +use base::{Error as SysError, ExternalMappingError, SafeDescriptor}; use data_model::VolatileMemoryError; #[cfg(feature = "vulkano")] @@ -155,6 +155,8 @@ pub enum RutabagaError { SpecViolation, /// System error returned as a result of rutabaga library operation. SysError(SysError), + /// An attempted integer conversion failed. + TryFromIntError(TryFromIntError), /// The command is unsupported. Unsupported, /// Utf8 error. @@ -210,6 +212,7 @@ impl Display for RutabagaError { ComponentError(ret) => write!(f, "rutabaga component failed with error {}", ret), SpecViolation => write!(f, "violation of the rutabaga spec"), SysError(e) => write!(f, "rutabaga received a system error: {}", e), + TryFromIntError(e) => write!(f, "int conversion failed: {}", e), Unsupported => write!(f, "feature or function unsupported"), Utf8Error(e) => write!(f, "an utf8 error occured: {}", e), VolatileMemoryError(e) => write!(f, "noticed a volatile memory error {}", e), @@ -239,6 +242,12 @@ impl From<SysError> for RutabagaError { } } +impl From<TryFromIntError> for RutabagaError { + fn from(e: TryFromIntError) -> RutabagaError { + RutabagaError::TryFromIntError(e) + } +} + impl From<Utf8Error> for RutabagaError { fn from(e: Utf8Error) -> RutabagaError { RutabagaError::Utf8Error(e) @@ -262,6 +271,8 @@ const VIRGLRENDERER_USE_GLX: u32 = 1 << 2; const VIRGLRENDERER_USE_SURFACELESS: u32 = 1 << 3; const VIRGLRENDERER_USE_GLES: u32 = 1 << 4; const VIRGLRENDERER_USE_EXTERNAL_BLOB: u32 = 1 << 5; +const VIRGLRENDERER_VENUS: u32 = 1 << 6; +const VIRGLRENDERER_NO_VIRGL: u32 = 1 << 7; /// virglrenderer flag struct. #[derive(Copy, Clone)] @@ -270,6 +281,8 @@ pub struct VirglRendererFlags(u32); impl Default for VirglRendererFlags { fn default() -> VirglRendererFlags { VirglRendererFlags::new() + .use_virgl(true) + .use_venus(false) .use_egl(true) .use_surfaceless(true) .use_gles(true) @@ -296,6 +309,16 @@ impl VirglRendererFlags { } } + /// Enable virgl support + pub fn use_virgl(self, v: bool) -> VirglRendererFlags { + self.set_flag(VIRGLRENDERER_NO_VIRGL, !v) + } + + /// Enable venus support + pub fn use_venus(self, v: bool) -> VirglRendererFlags { + self.set_flag(VIRGLRENDERER_VENUS, v) + } + /// Use EGL for context creation. pub fn use_egl(self, v: bool) -> VirglRendererFlags { self.set_flag(VIRGLRENDERER_USE_EGL, v) @@ -377,7 +400,7 @@ impl GfxstreamFlags { } /// Support using Vulkan. - pub fn support_vulkan(self, v: bool) -> GfxstreamFlags { + pub fn use_vulkan(self, v: bool) -> GfxstreamFlags { self.set_flag(GFXSTREAM_RENDERER_FLAGS_NO_VK_BIT, !v) } @@ -463,7 +486,7 @@ pub const RUTABAGE_FENCE_HANDLE_TYPE_OPAQUE_WIN32: u32 = 0x0006; /// Handle to OS-specific memory or synchronization objects. pub struct RutabagaHandle { - pub os_handle: File, + pub os_handle: SafeDescriptor, pub handle_type: u32, } diff --git a/rutabaga_gfx/src/virgl_renderer.rs b/rutabaga_gfx/src/virgl_renderer.rs index ffc35a230..cdea147a3 100644 --- a/rutabaga_gfx/src/virgl_renderer.rs +++ b/rutabaga_gfx/src/virgl_renderer.rs @@ -9,7 +9,6 @@ use std::cell::RefCell; use std::ffi::CString; -use std::fs::File; use std::mem::{size_of, transmute}; use std::os::raw::{c_char, c_void}; use std::ptr::null_mut; @@ -19,7 +18,7 @@ use std::sync::Arc; use base::{ warn, Error as SysError, ExternalMapping, ExternalMappingError, ExternalMappingResult, - FromRawDescriptor, + FromRawDescriptor, SafeDescriptor, }; use crate::generated::virgl_renderer_bindings::*; @@ -181,7 +180,10 @@ impl VirglRenderer { // Initialize it only once and use the non-send/non-sync Renderer struct to keep things tied // to whichever thread called this function first. static INIT_ONCE: AtomicBool = AtomicBool::new(false); - if INIT_ONCE.compare_and_swap(false, true, Ordering::Acquire) { + if INIT_ONCE + .compare_exchange(false, true, Ordering::Acquire, Ordering::Acquire) + .is_err() + { return Err(RutabagaError::AlreadyInUse); } @@ -266,7 +268,7 @@ impl VirglRenderer { return Err(RutabagaError::Unsupported); } - let dmabuf = unsafe { File::from_raw_descriptor(fd) }; + let dmabuf = unsafe { SafeDescriptor::from_raw_descriptor(fd) }; Ok(Arc::new(RutabagaHandle { os_handle: dmabuf, handle_type: RUTABAGA_MEM_HANDLE_TYPE_DMABUF, @@ -478,10 +480,17 @@ impl RutabagaComponent for VirglRenderer { ctx_id: u32, resource_id: u32, resource_create_blob: ResourceCreateBlob, - mut iovecs: Vec<RutabagaIovec>, + mut iovec_opt: Option<Vec<RutabagaIovec>>, ) -> RutabagaResult<RutabagaResource> { #[cfg(feature = "virgl_renderer_next")] { + let mut iovec_ptr = null_mut(); + let mut num_iovecs = 0; + if let Some(ref mut iovecs) = iovec_opt { + iovec_ptr = iovecs.as_mut_ptr(); + num_iovecs = iovecs.len(); + } + let resource_create_args = virgl_renderer_resource_create_blob_args { res_handle: resource_id, ctx_id, @@ -489,17 +498,13 @@ impl RutabagaComponent for VirglRenderer { blob_flags: resource_create_blob.blob_flags, blob_id: resource_create_blob.blob_id, size: resource_create_blob.size, - iovecs: iovecs.as_mut_ptr() as *const iovec, - num_iovs: iovecs.len() as u32, + iovecs: iovec_ptr as *const iovec, + num_iovs: num_iovecs as u32, }; + let ret = unsafe { virgl_renderer_resource_create_blob(&resource_create_args) }; ret_to_res(ret)?; - let iovec_opt = match resource_create_blob.blob_mem { - RUTABAGA_BLOB_MEM_GUEST => Some(iovecs), - _ => None, - }; - Ok(RutabagaResource { resource_id, handle: self.export_blob(resource_id).ok(), @@ -536,7 +541,7 @@ impl RutabagaComponent for VirglRenderer { // Safe because the FD was just returned by a successful virglrenderer call so it must // be valid and owned by us. - let fence = unsafe { File::from_raw_descriptor(fd) }; + let fence = unsafe { SafeDescriptor::from_raw_descriptor(fd) }; Ok(RutabagaHandle { os_handle: fence, handle_type: RUTABAGA_FENCE_HANDLE_TYPE_SYNC_FD, diff --git a/seccomp/aarch64/9p_device.policy b/seccomp/aarch64/9p_device.policy index 27c79083f..85344028c 100644 --- a/seccomp/aarch64/9p_device.policy +++ b/seccomp/aarch64/9p_device.policy @@ -6,7 +6,6 @@ openat: 1 @include /usr/share/policy/crosvm/common_device.policy -fcntl: 1 pread64: 1 pwrite64: 1 statx: 1 @@ -23,6 +22,8 @@ unlinkat: 1 socket: arg0 == AF_UNIX utimensat: 1 ftruncate: 1 -fchown: arg1 == 0xffffffff && arg2 == 0xffffffff +fchmod: 1 +fchown: 1 fstatfs: 1 newfstatat: 1 +prctl: arg0 == PR_SET_NAME diff --git a/seccomp/aarch64/balloon_device.policy b/seccomp/aarch64/balloon_device.policy index e1ca95339..57e21cce4 100644 --- a/seccomp/aarch64/balloon_device.policy +++ b/seccomp/aarch64/balloon_device.policy @@ -4,5 +4,5 @@ @include /usr/share/policy/crosvm/common_device.policy -fcntl: 1 openat: return ENOENT +prctl: arg0 == PR_SET_NAME diff --git a/seccomp/aarch64/battery.policy b/seccomp/aarch64/battery.policy index a4fb9fcb6..f26af9caa 100644 --- a/seccomp/aarch64/battery.policy +++ b/seccomp/aarch64/battery.policy @@ -3,3 +3,4 @@ # found in the LICENSE file. @include /usr/share/policy/crosvm/common_device.policy +prctl: arg0 == PR_SET_NAME diff --git a/seccomp/aarch64/block_device.policy b/seccomp/aarch64/block_device.policy index 7697a7e6e..64d5ca5ff 100644 --- a/seccomp/aarch64/block_device.policy +++ b/seccomp/aarch64/block_device.policy @@ -5,7 +5,6 @@ @include /usr/share/policy/crosvm/common_device.policy fallocate: 1 -fcntl: 1 fdatasync: 1 fstat: 1 fsync: 1 @@ -18,3 +17,4 @@ statx: 1 timerfd_create: 1 timerfd_gettime: 1 timerfd_settime: 1 +prctl: arg0 == PR_SET_NAME diff --git a/seccomp/aarch64/common_device.policy b/seccomp/aarch64/common_device.policy index 841e52d09..349afe9aa 100644 --- a/seccomp/aarch64/common_device.policy +++ b/seccomp/aarch64/common_device.policy @@ -25,9 +25,9 @@ mprotect: arg2 in ~PROT_EXEC mremap: 1 munmap: 1 nanosleep: 1 +clock_nanosleep: 1 pipe2: 1 ppoll: 1 -prctl: arg0 == PR_SET_NAME read: 1 readv: 1 recvfrom: 1 @@ -43,3 +43,4 @@ set_robust_list: 1 sigaltstack: 1 write: 1 writev: 1 +fcntl: 1 diff --git a/seccomp/aarch64/cras_audio_device.policy b/seccomp/aarch64/cras_audio_device.policy index 19419fd4f..60797f978 100644 --- a/seccomp/aarch64/cras_audio_device.policy +++ b/seccomp/aarch64/cras_audio_device.policy @@ -11,3 +11,4 @@ sched_setscheduler: 1 socketpair: arg0 == AF_UNIX clock_gettime: 1 openat: return ENOENT +prctl: arg0 == PR_SET_NAME diff --git a/seccomp/aarch64/fs_device.policy b/seccomp/aarch64/fs_device.policy index 1d8bbd9f7..828003e0c 100644 --- a/seccomp/aarch64/fs_device.policy +++ b/seccomp/aarch64/fs_device.policy @@ -48,3 +48,6 @@ symlinkat: 1 umask: 1 unlinkat: 1 utimensat: 1 +prctl: arg0 == PR_SET_NAME || arg0 == PR_SET_SECUREBITS || arg0 == PR_GET_SECUREBITS +capget: 1 +capset: 1 diff --git a/seccomp/aarch64/gpu_device.policy b/seccomp/aarch64/gpu_device.policy index bd1f6481d..4ceac5c46 100644 --- a/seccomp/aarch64/gpu_device.policy +++ b/seccomp/aarch64/gpu_device.policy @@ -23,6 +23,7 @@ madvise: arg2 == MADV_DONTNEED || arg2 == MADV_DONTDUMP || arg2 == MADV_REMOVE mremap: 1 munmap: 1 nanosleep: 1 +clock_nanosleep: 1 pipe2: 1 ppoll: 1 prctl: arg0 == PR_SET_NAME || arg0 == PR_GET_NAME @@ -58,8 +59,8 @@ newfstatat: 1 getdents64: 1 sysinfo: 1 -# 0x6400 == DRM_IOCTL_BASE, 0x8000 = KBASE_IOCTL_TYPE (mali) -ioctl: arg1 & 0x6400 || arg1 & 0x8000 +# 0x6400 == DRM_IOCTL_BASE, 0x8000 = KBASE_IOCTL_TYPE (mali), 0x40086200 = DMA_BUF_IOCTL_SYNC, 0x40087543 == UDMABUF_CREATE_LIST +ioctl: arg1 & 0x6400 || arg1 & 0x8000 || arg1 == 0x40086200 || arg1 == 0x40087543 ## mmap/mprotect differ from the common_device.policy mmap: arg2 == PROT_READ|PROT_WRITE || arg2 == PROT_NONE || arg2 == PROT_READ|PROT_EXEC || arg2 == PROT_WRITE || arg2 == PROT_READ diff --git a/seccomp/aarch64/input_device.policy b/seccomp/aarch64/input_device.policy index 07d3b5f74..728580dc6 100644 --- a/seccomp/aarch64/input_device.policy +++ b/seccomp/aarch64/input_device.policy @@ -5,6 +5,6 @@ @include /usr/share/policy/crosvm/common_device.policy ioctl: 1 -fcntl: 1 getsockname: 1 openat: return ENOENT +prctl: arg0 == PR_SET_NAME diff --git a/seccomp/aarch64/net_device.policy b/seccomp/aarch64/net_device.policy index a1c2eeff4..b77bdfb82 100644 --- a/seccomp/aarch64/net_device.policy +++ b/seccomp/aarch64/net_device.policy @@ -7,3 +7,5 @@ # TUNSETOFFLOAD ioctl: arg1 == 0x400454d0 openat: return ENOENT + +prctl: arg0 == PR_SET_NAME diff --git a/seccomp/aarch64/null_audio_device.policy b/seccomp/aarch64/null_audio_device.policy index 7a88fe297..b55aa1e94 100644 --- a/seccomp/aarch64/null_audio_device.policy +++ b/seccomp/aarch64/null_audio_device.policy @@ -9,3 +9,4 @@ prlimit64: 1 setrlimit: 1 clock_gettime: 1 openat: return ENOENT +prctl: arg0 == PR_SET_NAME diff --git a/seccomp/aarch64/pmem_device.policy b/seccomp/aarch64/pmem_device.policy index 77719a997..cbdf83a23 100644 --- a/seccomp/aarch64/pmem_device.policy +++ b/seccomp/aarch64/pmem_device.policy @@ -7,3 +7,4 @@ fdatasync: 1 fsync: 1 openat: return ENOENT +prctl: arg0 == PR_SET_NAME diff --git a/seccomp/aarch64/rng_device.policy b/seccomp/aarch64/rng_device.policy index fa86280a4..57e21cce4 100644 --- a/seccomp/aarch64/rng_device.policy +++ b/seccomp/aarch64/rng_device.policy @@ -5,3 +5,4 @@ @include /usr/share/policy/crosvm/common_device.policy openat: return ENOENT +prctl: arg0 == PR_SET_NAME diff --git a/seccomp/aarch64/serial.policy b/seccomp/aarch64/serial.policy index 8d23c0f90..3a76beee4 100644 --- a/seccomp/aarch64/serial.policy +++ b/seccomp/aarch64/serial.policy @@ -7,3 +7,4 @@ connect: 1 bind: 1 openat: return ENOENT +prctl: arg0 == PR_SET_NAME diff --git a/seccomp/aarch64/tpm_device.policy b/seccomp/aarch64/tpm_device.policy index a39d61c6a..98d32b6b1 100644 --- a/seccomp/aarch64/tpm_device.policy +++ b/seccomp/aarch64/tpm_device.policy @@ -25,6 +25,7 @@ mprotect: arg2 in ~PROT_EXEC mremap: 1 munmap: 1 nanosleep: 1 +clock_nanosleep: 1 pipe2: 1 ppoll: 1 prctl: arg0 == PR_SET_NAME diff --git a/seccomp/aarch64/vhost_net_device.policy b/seccomp/aarch64/vhost_net_device.policy index 4de1967b3..44247e097 100644 --- a/seccomp/aarch64/vhost_net_device.policy +++ b/seccomp/aarch64/vhost_net_device.policy @@ -22,3 +22,5 @@ # arg1 == VHOST_NET_SET_BACKEND ioctl: arg1 == 0x8008af00 || arg1 == 0x4008af00 || arg1 == 0x0000af01 || arg1 == 0x0000af02 || arg1 == 0x4008af03 || arg1 == 0x4008af04 || arg1 == 0x4004af07 || arg1 == 0x4008af10 || arg1 == 0x4028af11 || arg1 == 0x4008af12 || arg1 == 0xc008af12 || arg1 == 0x4008af20 || arg1 == 0x4008af21 || arg1 == 0x4008af22 || arg1 == 0x4008af30 openat: return ENOENT + +prctl: arg0 == PR_SET_NAME diff --git a/seccomp/aarch64/vhost_vsock_device.policy b/seccomp/aarch64/vhost_vsock_device.policy index 82b66502d..a2774e16b 100644 --- a/seccomp/aarch64/vhost_vsock_device.policy +++ b/seccomp/aarch64/vhost_vsock_device.policy @@ -23,3 +23,5 @@ # arg1 == VHOST_VSOCK_SET_RUNNING ioctl: arg1 == 0x8008af00 || arg1 == 0x4008af00 || arg1 == 0x0000af01 || arg1 == 0x0000af02 || arg1 == 0x4008af03 || arg1 == 0x4008af04 || arg1 == 0x4004af07 || arg1 == 0x4008af10 || arg1 == 0x4028af11 || arg1 == 0x4008af12 || arg1 == 0xc008af12 || arg1 == 0x4008af20 || arg1 == 0x4008af21 || arg1 == 0x4008af22 || arg1 == 0x4008af60 || arg1 == 0x4004af61 openat: return ENOENT + +prctl: arg0 == PR_SET_NAME diff --git a/seccomp/aarch64/vios_audio_device.policy b/seccomp/aarch64/vios_audio_device.policy index df54139f5..d425ab279 100644 --- a/seccomp/aarch64/vios_audio_device.policy +++ b/seccomp/aarch64/vios_audio_device.policy @@ -5,9 +5,9 @@ @include /usr/share/policy/crosvm/common_device.policy clock_gettime: 1 -clock_nanosleep: 1 lseek: 1 openat: return ENOENT prlimit64: 1 sched_setscheduler: 1 setrlimit: 1 +prctl: arg0 == PR_SET_NAME diff --git a/seccomp/aarch64/wl_device.policy b/seccomp/aarch64/wl_device.policy index 864aefb88..cd6804637 100644 --- a/seccomp/aarch64/wl_device.policy +++ b/seccomp/aarch64/wl_device.policy @@ -16,5 +16,5 @@ memfd_create: arg1 == 3 ftruncate: 1 # Used to determine shm size after recvmsg with fd lseek: 1 -# Allow F_GETFL only -fcntl: arg1 == 3 + +prctl: arg0 == PR_SET_NAME diff --git a/seccomp/aarch64/xhci.policy b/seccomp/aarch64/xhci.policy index d69514090..684ae0d2f 100644 --- a/seccomp/aarch64/xhci.policy +++ b/seccomp/aarch64/xhci.policy @@ -16,7 +16,6 @@ getsockname: 1 openat: 1 setsockopt: 1 bind: 1 -fcntl: 1 socket: arg0 == AF_NETLINK uname: 1 # The following ioctls are: @@ -37,3 +36,4 @@ ioctl: arg1 == 0xc0105500 || arg1 == 0x802c550a || arg1 == 0x8004551a || arg1 == fstat: 1 getrandom: 1 lseek: 1 +prctl: arg0 == PR_SET_NAME diff --git a/seccomp/arm/9p_device.policy b/seccomp/arm/9p_device.policy index 95d0b320d..a7b877b27 100644 --- a/seccomp/arm/9p_device.policy +++ b/seccomp/arm/9p_device.policy @@ -4,7 +4,6 @@ @include /usr/share/policy/crosvm/common_device.policy -fcntl64: 1 pread64: 1 pwrite64: 1 stat64: 1 @@ -24,7 +23,10 @@ linkat: 1 unlinkat: 1 socket: arg0 == AF_UNIX utimensat: 1 +utimensat_time64: 1 ftruncate64: 1 -fchown: arg1 == 0xffffffff && arg2 == 0xffffffff +fchmod: 1 +fchown: 1 fstatfs64: 1 fstatat64: 1 +prctl: arg0 == PR_SET_NAME diff --git a/seccomp/arm/balloon_device.policy b/seccomp/arm/balloon_device.policy index 868ae3110..e0e444270 100644 --- a/seccomp/arm/balloon_device.policy +++ b/seccomp/arm/balloon_device.policy @@ -4,6 +4,6 @@ @include /usr/share/policy/crosvm/common_device.policy -fcntl64: 1 open: return ENOENT openat: return ENOENT +prctl: arg0 == PR_SET_NAME diff --git a/seccomp/arm/battery.policy b/seccomp/arm/battery.policy index a4fb9fcb6..f26af9caa 100644 --- a/seccomp/arm/battery.policy +++ b/seccomp/arm/battery.policy @@ -3,3 +3,4 @@ # found in the LICENSE file. @include /usr/share/policy/crosvm/common_device.policy +prctl: arg0 == PR_SET_NAME diff --git a/seccomp/arm/block_device.policy b/seccomp/arm/block_device.policy index 785af4582..75b769f43 100644 --- a/seccomp/arm/block_device.policy +++ b/seccomp/arm/block_device.policy @@ -5,7 +5,6 @@ @include /usr/share/policy/crosvm/common_device.policy fallocate: 1 -fcntl64: 1 fdatasync: 1 fstat64: 1 fsync: 1 @@ -20,4 +19,7 @@ pwritev: 1 statx: 1 timerfd_create: 1 timerfd_gettime: 1 +timerfd_gettime64: 1 timerfd_settime: 1 +timerfd_settime64: 1 +prctl: arg0 == PR_SET_NAME diff --git a/seccomp/arm/common_device.policy b/seccomp/arm/common_device.policy index cbbfd7d43..165bfda6e 100644 --- a/seccomp/arm/common_device.policy +++ b/seccomp/arm/common_device.policy @@ -3,6 +3,8 @@ # found in the LICENSE file. brk: 1 +clock_gettime: 1 +clock_gettime64: 1 clone: arg0 & CLONE_THREAD close: 1 dup2: 1 @@ -14,6 +16,7 @@ eventfd2: 1 exit: 1 exit_group: 1 futex: 1 +futex_time64: 1 getpid: 1 gettid: 1 gettimeofday: 1 @@ -26,15 +29,18 @@ mprotect: arg2 in ~PROT_EXEC mremap: 1 munmap: 1 nanosleep: 1 +clock_nanosleep: 1 +clock_nanosleep_time64: 1 pipe2: 1 poll: 1 ppoll: 1 -prctl: arg0 == PR_SET_NAME +ppoll_time64: 1 read: 1 readv: 1 recv: 1 recvfrom: 1 recvmsg: 1 +recvmmsg_time64: 1 restart_syscall: 1 rt_sigaction: 1 rt_sigprocmask: 1 @@ -46,3 +52,4 @@ set_robust_list: 1 sigaltstack: 1 write: 1 writev: 1 +fcntl64: 1 diff --git a/seccomp/arm/cras_audio_device.policy b/seccomp/arm/cras_audio_device.policy index 505208b15..20bf60e1f 100644 --- a/seccomp/arm/cras_audio_device.policy +++ b/seccomp/arm/cras_audio_device.policy @@ -11,4 +11,4 @@ prlimit64: 1 setrlimit: 1 sched_setscheduler: 1 socketpair: arg0 == AF_UNIX -clock_gettime: 1 +prctl: arg0 == PR_SET_NAME diff --git a/seccomp/arm/fs_device.policy b/seccomp/arm/fs_device.policy index 02dff2916..e84fd08fc 100644 --- a/seccomp/arm/fs_device.policy +++ b/seccomp/arm/fs_device.policy @@ -50,4 +50,8 @@ statx: 1 symlinkat: 1 umask: 1 unlinkat: 1 -utimensat: 1
\ No newline at end of file +utimensat: 1 +utimensat_time64: 1 +prctl: arg0 == PR_SET_NAME || arg0 == PR_SET_SECUREBITS || arg0 == PR_GET_SECUREBITS +capget: 1 +capset: 1 diff --git a/seccomp/arm/gpu_device.policy b/seccomp/arm/gpu_device.policy index 1bdea6d0d..ec5a5b481 100644 --- a/seccomp/arm/gpu_device.policy +++ b/seccomp/arm/gpu_device.policy @@ -16,6 +16,7 @@ eventfd2: 1 exit: 1 exit_group: 1 futex: 1 +futex_time64: 1 getpid: 1 gettimeofday: 1 kill: 1 @@ -23,15 +24,19 @@ madvise: arg2 == MADV_DONTNEED || arg2 == MADV_DONTDUMP || arg2 == MADV_REMOVE mremap: 1 munmap: 1 nanosleep: 1 +clock_nanosleep: 1 +clock_nanosleep_time64: 1 pipe2: 1 poll: 1 ppoll: 1 +ppoll_time64: 1 prctl: arg0 == PR_SET_NAME || arg0 == PR_GET_NAME read: 1 readv: 1 recv: 1 recvfrom: 1 recvmsg: 1 +recvmmsg_time64: 1 restart_syscall: 1 rt_sigaction: 1 rt_sigprocmask: 1 @@ -60,8 +65,8 @@ getdents: 1 getdents64: 1 sysinfo: 1 -# 0x6400 == DRM_IOCTL_BASE, 0x8000 = KBASE_IOCTL_TYPE (mali) -ioctl: arg1 & 0x6400 || arg1 & 0x8000 +# 0x6400 == DRM_IOCTL_BASE, 0x8000 = KBASE_IOCTL_TYPE (mali), 0x40086200 = DMA_BUF_IOCTL_SYNC, 0x40087543 == UDMABUF_CREATE_LIST +ioctl: arg1 & 0x6400 || arg1 & 0x8000 || arg1 == 0x40086200 || arg1 == 0x40087543 # Used for sharing memory with wayland. arg1 == MFD_CLOEXEC|MFD_ALLOW_SEALING memfd_create: arg1 == 3 @@ -81,6 +86,7 @@ gettid: 1 fcntl64: 1 tgkill: 1 clock_gettime: 1 +clock_gettime64: 1 # Rules specific to Mesa. uname: 1 diff --git a/seccomp/arm/input_device.policy b/seccomp/arm/input_device.policy index d32c31222..bb6985315 100644 --- a/seccomp/arm/input_device.policy +++ b/seccomp/arm/input_device.policy @@ -5,7 +5,7 @@ @include /usr/share/policy/crosvm/common_device.policy ioctl: 1 -fcntl: 1 getsockname: 1 open: return ENOENT openat: return ENOENT +prctl: arg0 == PR_SET_NAME diff --git a/seccomp/arm/net_device.policy b/seccomp/arm/net_device.policy index cf0584ce6..4cd4815ca 100644 --- a/seccomp/arm/net_device.policy +++ b/seccomp/arm/net_device.policy @@ -8,3 +8,4 @@ ioctl: arg1 == 0x400454d0 open: return ENOENT openat: return ENOENT +prctl: arg0 == PR_SET_NAME diff --git a/seccomp/arm/null_audio_device.policy b/seccomp/arm/null_audio_device.policy index f89397b9e..c87441c38 100644 --- a/seccomp/arm/null_audio_device.policy +++ b/seccomp/arm/null_audio_device.policy @@ -9,4 +9,4 @@ open: return ENOENT openat: return ENOENT prlimit64: 1 setrlimit: 1 -clock_gettime: 1 +prctl: arg0 == PR_SET_NAME diff --git a/seccomp/arm/pmem_device.policy b/seccomp/arm/pmem_device.policy index 12a3b04f8..e7321f747 100644 --- a/seccomp/arm/pmem_device.policy +++ b/seccomp/arm/pmem_device.policy @@ -8,3 +8,4 @@ fdatasync: 1 fsync: 1 open: return ENOENT openat: return ENOENT +prctl: arg0 == PR_SET_NAME diff --git a/seccomp/arm/rng_device.policy b/seccomp/arm/rng_device.policy index 0c7d2583a..e0e444270 100644 --- a/seccomp/arm/rng_device.policy +++ b/seccomp/arm/rng_device.policy @@ -6,3 +6,4 @@ open: return ENOENT openat: return ENOENT +prctl: arg0 == PR_SET_NAME diff --git a/seccomp/arm/serial.policy b/seccomp/arm/serial.policy index f0456e931..1d8140d44 100644 --- a/seccomp/arm/serial.policy +++ b/seccomp/arm/serial.policy @@ -8,3 +8,4 @@ connect: 1 bind: 1 open: return ENOENT openat: return ENOENT +prctl: arg0 == PR_SET_NAME diff --git a/seccomp/arm/tpm_device.policy b/seccomp/arm/tpm_device.policy index d17f67cd1..5653f7794 100644 --- a/seccomp/arm/tpm_device.policy +++ b/seccomp/arm/tpm_device.policy @@ -4,6 +4,8 @@ # common policy brk: 1 +clock_gettime: 1 +clock_gettime64: 1 clone: arg0 & CLONE_THREAD close: 1 dup2: 1 @@ -15,6 +17,7 @@ eventfd2: 1 exit: 1 exit_group: 1 futex: 1 +futex_time64: 1 getpid: 1 getrandom: 1 gettimeofday: 1 @@ -25,14 +28,18 @@ mprotect: arg2 in ~PROT_EXEC mremap: 1 munmap: 1 nanosleep: 1 +clock_nanosleep: 1 +clock_nanosleep_time64: 1 pipe2: 1 poll: 1 ppoll: 1 +ppoll_time64: 1 prctl: arg0 == PR_SET_NAME read: 1 recv: 1 recvfrom: 1 recvmsg: 1 +recvmmsg_time64: 1 restart_syscall: 1 rt_sigaction: 1 rt_sigprocmask: 1 diff --git a/seccomp/arm/vhost_net_device.policy b/seccomp/arm/vhost_net_device.policy index 4571a938e..fbee418b2 100644 --- a/seccomp/arm/vhost_net_device.policy +++ b/seccomp/arm/vhost_net_device.policy @@ -23,3 +23,4 @@ ioctl: arg1 == 0x8008af00 || arg1 == 0x4008af00 || arg1 == 0x0000af01 || arg1 == 0x0000af02 || arg1 == 0x4008af03 || arg1 == 0x4008af04 || arg1 == 0x4004af07 || arg1 == 0x4008af10 || arg1 == 0x4028af11 || arg1 == 0x4008af12 || arg1 == 0xc008af12 || arg1 == 0x4008af20 || arg1 == 0x4008af21 || arg1 == 0x4008af22 || arg1 == 0x4008af30 open: return ENOENT openat: return ENOENT +prctl: arg0 == PR_SET_NAME diff --git a/seccomp/arm/vhost_vsock_device.policy b/seccomp/arm/vhost_vsock_device.policy index c6a984c51..a793afc9e 100644 --- a/seccomp/arm/vhost_vsock_device.policy +++ b/seccomp/arm/vhost_vsock_device.policy @@ -24,3 +24,4 @@ ioctl: arg1 == 0x8008af00 || arg1 == 0x4008af00 || arg1 == 0x0000af01 || arg1 == 0x0000af02 || arg1 == 0x4008af03 || arg1 == 0x4008af04 || arg1 == 0x4004af07 || arg1 == 0x4008af10 || arg1 == 0x4028af11 || arg1 == 0x4008af12 || arg1 == 0xc008af12 || arg1 == 0x4008af20 || arg1 == 0x4008af21 || arg1 == 0x4008af22 || arg1 == 0x4008af60 || arg1 == 0x4004af61 open: return ENOENT openat: return ENOENT +prctl: arg0 == PR_SET_NAME diff --git a/seccomp/arm/video_device.policy b/seccomp/arm/video_device.policy index 784cc7cd4..5c5a4a5a1 100644 --- a/seccomp/arm/video_device.policy +++ b/seccomp/arm/video_device.policy @@ -6,12 +6,12 @@ # Syscalls specific to video devices. clock_getres: 1 -clock_gettime: 1 +clock_getres_time64: 1 connect: 1 -fcntl64: arg1 == F_GETFL || arg1 == F_SETFL || arg1 == F_DUPFD_CLOEXEC || arg1 == F_GETFD || arg1 == F_SETFD getegid32: 1 geteuid32: 1 getgid32: 1 +getrandom: 1 getresgid32: 1 getresuid32: 1 getsockname: 1 @@ -24,3 +24,4 @@ send: 1 setpriority: 1 socket: arg0 == AF_UNIX stat64: 1 +prctl: arg0 == PR_SET_NAME diff --git a/seccomp/arm/vios_audio_device.policy b/seccomp/arm/vios_audio_device.policy index ad27b0e36..3a1fb0811 100644 --- a/seccomp/arm/vios_audio_device.policy +++ b/seccomp/arm/vios_audio_device.policy @@ -4,11 +4,10 @@ @include /usr/share/policy/crosvm/common_device.policy -clock_gettime: 1 -clock_nanosleep: 1 lseek: 1 open: return ENOENT openat: return ENOENT prlimit64: 1 sched_setscheduler: 1 setrlimit: 1 +prctl: arg0 == PR_SET_NAME diff --git a/seccomp/arm/wl_device.policy b/seccomp/arm/wl_device.policy index 0b84c4b42..0a3de4f7f 100644 --- a/seccomp/arm/wl_device.policy +++ b/seccomp/arm/wl_device.policy @@ -17,5 +17,5 @@ memfd_create: arg1 == 3 ftruncate64: 1 # Used to determine shm size after recvmsg with fd _llseek: 1 -# Allow F_GETFL only -fcntl64: arg1 == 3 + +prctl: arg0 == PR_SET_NAME diff --git a/seccomp/arm/xhci.policy b/seccomp/arm/xhci.policy index 6c51ddf8d..ca1a73dfc 100644 --- a/seccomp/arm/xhci.policy +++ b/seccomp/arm/xhci.policy @@ -5,20 +5,17 @@ @include /usr/share/policy/crosvm/common_device.policy stat64: 1 -fcntl64: 1 lstat64: 1 readlink: 1 readlinkat: 1 getdents64: 1 name_to_handle_at: 1 access: 1 -clock_gettime: 1 timerfd_create: 1 getsockname: 1 pipe: 1 setsockopt: 1 bind: 1 -fcntl: 1 socket: arg0 == AF_NETLINK stat: 1 statx: 1 @@ -44,3 +41,4 @@ getdents: 1 _llseek: 1 open: return ENOENT openat: 1 +prctl: arg0 == PR_SET_NAME diff --git a/seccomp/x86_64/9p_device.policy b/seccomp/x86_64/9p_device.policy index 6f14c0af6..48eb0415d 100644 --- a/seccomp/x86_64/9p_device.policy +++ b/seccomp/x86_64/9p_device.policy @@ -7,7 +7,6 @@ openat: 1 @include /usr/share/policy/crosvm/common_device.policy -fcntl: 1 pwrite64: 1 stat: 1 statx: 1 @@ -25,6 +24,8 @@ fsync: 1 fdatasync: 1 utimensat: 1 ftruncate: 1 -fchown: arg1 == 0xffffffff && arg2 == 0xffffffff +fchmod: 1 +fchown: 1 fstatfs: 1 newfstatat: 1 +prctl: arg0 == PR_SET_NAME diff --git a/seccomp/x86_64/balloon_device.policy b/seccomp/x86_64/balloon_device.policy index 49cf785ab..f717ad476 100644 --- a/seccomp/x86_64/balloon_device.policy +++ b/seccomp/x86_64/balloon_device.policy @@ -4,6 +4,6 @@ @include /usr/share/policy/crosvm/common_device.policy -fcntl: 1 open: return ENOENT openat: return ENOENT +prctl: arg0 == PR_SET_NAME diff --git a/seccomp/x86_64/battery.policy b/seccomp/x86_64/battery.policy index f6bbe2889..ce6d41271 100644 --- a/seccomp/x86_64/battery.policy +++ b/seccomp/x86_64/battery.policy @@ -7,7 +7,6 @@ # Syscalls used by power_monitor's powerd implementation. clock_getres: 1 connect: 1 -fcntl: 1 getcwd: 1 getegid: 1 geteuid: 1 @@ -19,3 +18,4 @@ openat: 1 readlink: 1 socket: arg0 == AF_UNIX tgkill: 1 +prctl: arg0 == PR_SET_NAME diff --git a/seccomp/x86_64/block_device.policy b/seccomp/x86_64/block_device.policy index f1130d90d..8f68c9be5 100644 --- a/seccomp/x86_64/block_device.policy +++ b/seccomp/x86_64/block_device.policy @@ -5,7 +5,6 @@ @include /usr/share/policy/crosvm/common_device.policy fallocate: 1 -fcntl: 1 fdatasync: 1 fstat: 1 fsync: 1 @@ -21,3 +20,4 @@ statx: 1 timerfd_create: 1 timerfd_gettime: 1 timerfd_settime: 1 +prctl: arg0 == PR_SET_NAME diff --git a/seccomp/x86_64/common_device.policy b/seccomp/x86_64/common_device.policy index bf8dd1581..49d452051 100644 --- a/seccomp/x86_64/common_device.policy +++ b/seccomp/x86_64/common_device.policy @@ -27,10 +27,10 @@ mprotect: arg2 in ~PROT_EXEC mremap: 1 munmap: 1 nanosleep: 1 +clock_nanosleep: 1 pipe2: 1 poll: 1 ppoll: 1 -prctl: arg0 == PR_SET_NAME read: 1 readv: 1 recvfrom: 1 @@ -46,3 +46,4 @@ set_robust_list: 1 sigaltstack: 1 write: 1 writev: 1 +fcntl: 1 diff --git a/seccomp/x86_64/cras_audio_device.policy b/seccomp/x86_64/cras_audio_device.policy index 505208b15..bbaffb080 100644 --- a/seccomp/x86_64/cras_audio_device.policy +++ b/seccomp/x86_64/cras_audio_device.policy @@ -12,3 +12,4 @@ setrlimit: 1 sched_setscheduler: 1 socketpair: arg0 == AF_UNIX clock_gettime: 1 +prctl: arg0 == PR_SET_NAME diff --git a/seccomp/x86_64/fs_device.policy b/seccomp/x86_64/fs_device.policy index dea28aef1..bd03307a2 100644 --- a/seccomp/x86_64/fs_device.policy +++ b/seccomp/x86_64/fs_device.policy @@ -50,4 +50,7 @@ symlinkat: 1 statx: 1 umask: 1 unlinkat: 1 -utimensat: 1
\ No newline at end of file +utimensat: 1 +prctl: arg0 == PR_SET_NAME || arg0 == PR_SET_SECUREBITS || arg0 == PR_GET_SECUREBITS +capget: 1 +capset: 1 diff --git a/seccomp/x86_64/gpu_device.policy b/seccomp/x86_64/gpu_device.policy index 7f167d728..2b4f4b2bb 100644 --- a/seccomp/x86_64/gpu_device.policy +++ b/seccomp/x86_64/gpu_device.policy @@ -25,6 +25,7 @@ madvise: arg2 == MADV_DONTNEED || arg2 == MADV_DONTDUMP || arg2 == MADV_REMOVE mremap: 1 munmap: 1 nanosleep: 1 +clock_nanosleep: 1 pipe2: 1 poll: 1 ppoll: 1 @@ -53,10 +54,12 @@ fstat: 1 # Used to set of size new memfd. ftruncate: 1 getdents: 1 +getdents64: 1 geteuid: 1 getrandom: 1 getuid: 1 -ioctl: arg1 == FIONBIO || arg1 == FIOCLEX || arg1 == 0x40086200 || arg1 & 0x6400 +# 0x40086200 = DMA_BUF_IOCTL_SYNC, 0x6400 == DRM_IOCTL_BASE, 0x40087543 == UDMABUF_CREATE_LIST +ioctl: arg1 == FIONBIO || arg1 == FIOCLEX || arg1 == 0x40086200 || arg1 & 0x6400 || arg1 == 0x40087543 lseek: 1 lstat: 1 # Used for sharing memory with wayland. Also internally by Intel anv. diff --git a/seccomp/x86_64/input_device.policy b/seccomp/x86_64/input_device.policy index d32c31222..bb6985315 100644 --- a/seccomp/x86_64/input_device.policy +++ b/seccomp/x86_64/input_device.policy @@ -5,7 +5,7 @@ @include /usr/share/policy/crosvm/common_device.policy ioctl: 1 -fcntl: 1 getsockname: 1 open: return ENOENT openat: return ENOENT +prctl: arg0 == PR_SET_NAME diff --git a/seccomp/x86_64/net_device.policy b/seccomp/x86_64/net_device.policy index 5d6535a94..b8c9d41ca 100644 --- a/seccomp/x86_64/net_device.policy +++ b/seccomp/x86_64/net_device.policy @@ -8,3 +8,4 @@ ioctl: arg1 == 0x400454d0 open: return ENOENT openat: return ENOENT +prctl: arg0 == PR_SET_NAME diff --git a/seccomp/x86_64/null_audio_device.policy b/seccomp/x86_64/null_audio_device.policy index f118d88de..5c360c945 100644 --- a/seccomp/x86_64/null_audio_device.policy +++ b/seccomp/x86_64/null_audio_device.policy @@ -9,3 +9,5 @@ open: return ENOENT openat: return ENOENT prlimit64: 1 setrlimit: 1 +sched_setscheduler: 1 +prctl: arg0 == PR_SET_NAME diff --git a/seccomp/x86_64/pmem_device.policy b/seccomp/x86_64/pmem_device.policy index 12a3b04f8..e7321f747 100644 --- a/seccomp/x86_64/pmem_device.policy +++ b/seccomp/x86_64/pmem_device.policy @@ -8,3 +8,4 @@ fdatasync: 1 fsync: 1 open: return ENOENT openat: return ENOENT +prctl: arg0 == PR_SET_NAME diff --git a/seccomp/x86_64/rng_device.policy b/seccomp/x86_64/rng_device.policy index c6681634f..f717ad476 100644 --- a/seccomp/x86_64/rng_device.policy +++ b/seccomp/x86_64/rng_device.policy @@ -6,3 +6,4 @@ open: return ENOENT openat: return ENOENT +prctl: arg0 == PR_SET_NAME diff --git a/seccomp/x86_64/serial.policy b/seccomp/x86_64/serial.policy index f0456e931..1d8140d44 100644 --- a/seccomp/x86_64/serial.policy +++ b/seccomp/x86_64/serial.policy @@ -8,3 +8,4 @@ connect: 1 bind: 1 open: return ENOENT openat: return ENOENT +prctl: arg0 == PR_SET_NAME diff --git a/seccomp/x86_64/tpm_device.policy b/seccomp/x86_64/tpm_device.policy index 50536f8aa..bfd64a838 100644 --- a/seccomp/x86_64/tpm_device.policy +++ b/seccomp/x86_64/tpm_device.policy @@ -25,6 +25,7 @@ mprotect: arg2 in ~PROT_EXEC mremap: 1 munmap: 1 nanosleep: 1 +clock_nanosleep: 1 pipe2: 1 poll: 1 ppoll: 1 diff --git a/seccomp/x86_64/vfio_device.policy b/seccomp/x86_64/vfio_device.policy index aa28d1ad4..bf7a00d12 100644 --- a/seccomp/x86_64/vfio_device.policy +++ b/seccomp/x86_64/vfio_device.policy @@ -10,3 +10,4 @@ openat: return ENOENT readlink: 1 pread64: 1 pwrite64: 1 +prctl: arg0 == PR_SET_NAME diff --git a/seccomp/x86_64/vhost_net_device.policy b/seccomp/x86_64/vhost_net_device.policy index c9182e6ec..55a3d188d 100644 --- a/seccomp/x86_64/vhost_net_device.policy +++ b/seccomp/x86_64/vhost_net_device.policy @@ -23,3 +23,4 @@ ioctl: arg1 == 0x8008af00 || arg1 == 0x4008af00 || arg1 == 0x0000af01 || arg1 == 0x0000af02 || arg1 == 0x4008af03 || arg1 == 0x4008af04 || arg1 == 0x4004af07 || arg1 == 0x4008af10 || arg1 == 0x4028af11 || arg1 == 0x4008af12 || arg1 == 0xc008af12 || arg1 == 0x4008af20 || arg1 == 0x4008af21 || arg1 == 0x4008af22 || arg1 == 0x4008af30 open: return ENOENT openat: return ENOENT +prctl: arg0 == PR_SET_NAME diff --git a/seccomp/x86_64/vhost_vsock_device.policy b/seccomp/x86_64/vhost_vsock_device.policy index 69fca47fb..e558c4e79 100644 --- a/seccomp/x86_64/vhost_vsock_device.policy +++ b/seccomp/x86_64/vhost_vsock_device.policy @@ -25,3 +25,4 @@ ioctl: arg1 == 0x8008af00 || arg1 == 0x4008af00 || arg1 == 0x0000af01 || arg1 == connect: 1 open: return ENOENT openat: return ENOENT +prctl: arg0 == PR_SET_NAME diff --git a/seccomp/x86_64/video_device.policy b/seccomp/x86_64/video_device.policy index e43900ae1..4c54d9d17 100644 --- a/seccomp/x86_64/video_device.policy +++ b/seccomp/x86_64/video_device.policy @@ -7,8 +7,8 @@ # Syscalls specific to video devices. clock_getres: 1 connect: 1 -fcntl: arg1 == F_GETFL || arg1 == F_SETFL || arg1 == F_DUPFD_CLOEXEC || arg1 == F_GETFD || arg1 == F_SETFD getdents: 1 +getdents64: 1 getegid: 1 geteuid: 1 getgid: 1 @@ -38,3 +38,5 @@ uname: 1 # Required by mesa on AMD GPU sysinfo: 1 + +prctl: arg0 == PR_SET_NAME diff --git a/seccomp/x86_64/vios_audio_device.policy b/seccomp/x86_64/vios_audio_device.policy index ad27b0e36..a3b7f1961 100644 --- a/seccomp/x86_64/vios_audio_device.policy +++ b/seccomp/x86_64/vios_audio_device.policy @@ -5,10 +5,10 @@ @include /usr/share/policy/crosvm/common_device.policy clock_gettime: 1 -clock_nanosleep: 1 lseek: 1 open: return ENOENT openat: return ENOENT prlimit64: 1 sched_setscheduler: 1 setrlimit: 1 +prctl: arg0 == PR_SET_NAME diff --git a/seccomp/x86_64/wl_device.policy b/seccomp/x86_64/wl_device.policy index f79b08aab..f2cda7f24 100644 --- a/seccomp/x86_64/wl_device.policy +++ b/seccomp/x86_64/wl_device.policy @@ -15,7 +15,6 @@ memfd_create: arg1 == 3 ftruncate: 1 # Used to determine shm size after recvmsg with fd lseek: 1 -# Allow F_GETFL only -fcntl: arg1 == 3 open: return ENOENT openat: return ENOENT +prctl: arg0 == PR_SET_NAME diff --git a/seccomp/x86_64/xhci.policy b/seccomp/x86_64/xhci.policy index a548d9ea0..9ef376698 100644 --- a/seccomp/x86_64/xhci.policy +++ b/seccomp/x86_64/xhci.policy @@ -14,7 +14,6 @@ getsockname: 1 pipe: 1 setsockopt: 1 bind: 1 -fcntl: 1 open: return ENOENT openat: 1 socket: arg0 == AF_NETLINK @@ -39,4 +38,6 @@ ioctl: arg1 == 0xc0185500 || arg1 == 0x41045508 || arg1 == 0x8004550f || arg1 == fstat: 1 getrandom: 1 getdents: 1 +getdents64: 1 lseek: 1 +prctl: arg0 == PR_SET_NAME diff --git a/src/argument.rs b/src/argument.rs index 525506df7..64be18227 100644 --- a/src/argument.rs +++ b/src/argument.rs @@ -364,7 +364,7 @@ where /// Prints command line usage information to stdout. /// /// Usage information is printed according to the help fields in `args` with a leading usage line. -/// The usage line is of the format "`program_name` [ARGUMENTS] `required_arg`". +/// The usage line is of the format "`program_name` \[ARGUMENTS\] `required_arg`". pub fn print_help(program_name: &str, required_arg: &str, args: &[Argument]) { println!( "Usage: {} {}{}\n", diff --git a/src/crosvm.rs b/src/crosvm.rs index 1a44031bf..7a4318ab3 100644 --- a/src/crosvm.rs +++ b/src/crosvm.rs @@ -29,6 +29,9 @@ use devices::ProtectionType; use libc::{getegid, geteuid}; use vm_control::BatteryType; +static KVM_PATH: &str = "/dev/kvm"; +static VHOST_VSOCK_PATH: &str = "/dev/vhost-vsock"; +static VHOST_NET_PATH: &str = "/dev/vhost-net"; static SECCOMP_POLICY_DIR: &str = "/usr/share/policy/crosvm"; /// Indicates the location and kind of executable kernel for a VM. @@ -55,6 +58,15 @@ pub struct DiskOption { pub id: Option<[u8; DISK_ID_LEN]>, } +pub struct VhostUserOption { + pub socket: PathBuf, +} + +pub struct VhostUserFsOption { + pub socket: PathBuf, + pub tag: String, +} + /// A bind mount for directories in the plugin process. pub struct BindMount { pub src: PathBuf, @@ -69,6 +81,13 @@ pub struct GidMap { pub count: u32, } +/// Direct IO forwarding options +#[cfg(feature = "direct")] +pub struct DirectIoOption { + pub path: PathBuf, + pub ranges: Vec<(u64, u64)>, +} + pub const DEFAULT_TOUCH_DEVICE_HEIGHT: u32 = 1024; pub const DEFAULT_TOUCH_DEVICE_WIDTH: u32 = 1280; @@ -174,11 +193,16 @@ impl Default for SharedDir { /// Aggregate of all configurable options for a running VM. pub struct Config { + pub kvm_device_path: PathBuf, + pub vhost_vsock_device_path: PathBuf, + pub vhost_net_device_path: PathBuf, pub vcpu_count: Option<usize>, pub rt_cpus: Vec<usize>, pub vcpu_affinity: Option<VcpuAffinity>, pub no_smt: bool, pub memory: Option<u64>, + pub hugepages: bool, + pub memory_file: Option<PathBuf>, pub executable_path: Option<Executable>, pub android_fstab: Option<PathBuf>, pub initrd_path: Option<PathBuf>, @@ -230,16 +254,31 @@ pub struct Config { #[cfg(all(target_arch = "x86_64", feature = "gdb"))] pub gdb: Option<u32>, pub balloon_bias: i64, + pub vhost_user_blk: Vec<VhostUserOption>, + pub vhost_user_fs: Vec<VhostUserFsOption>, + pub vhost_user_net: Vec<VhostUserOption>, + #[cfg(feature = "direct")] + pub direct_pmio: Option<DirectIoOption>, + #[cfg(feature = "direct")] + pub direct_level_irq: Vec<u32>, + #[cfg(feature = "direct")] + pub direct_edge_irq: Vec<u32>, + pub dmi_path: Option<PathBuf>, } impl Default for Config { fn default() -> Config { Config { + kvm_device_path: PathBuf::from(KVM_PATH), + vhost_vsock_device_path: PathBuf::from(VHOST_VSOCK_PATH), + vhost_net_device_path: PathBuf::from(VHOST_NET_PATH), vcpu_count: None, rt_cpus: Vec::new(), vcpu_affinity: None, no_smt: false, memory: None, + hugepages: false, + memory_file: None, executable_path: None, android_fstab: None, initrd_path: None, @@ -291,6 +330,16 @@ impl Default for Config { #[cfg(all(target_arch = "x86_64", feature = "gdb"))] gdb: None, balloon_bias: 0, + vhost_user_blk: Vec::new(), + vhost_user_fs: Vec::new(), + vhost_user_net: Vec::new(), + #[cfg(feature = "direct")] + direct_pmio: None, + #[cfg(feature = "direct")] + direct_level_irq: Vec::new(), + #[cfg(feature = "direct")] + direct_edge_irq: Vec::new(), + dmi_path: None, } } } diff --git a/src/gdb.rs b/src/gdb.rs index 65626f594..6a18f5564 100644 --- a/src/gdb.rs +++ b/src/gdb.rs @@ -6,12 +6,11 @@ use std::net::TcpListener; use std::sync::mpsc; use std::time::Duration; -use base::{error, info}; -use msg_socket::{MsgReceiver, MsgSender}; +use base::{error, info, Tube, TubeError}; + use sync::Mutex; use vm_control::{ - VcpuControl, VcpuDebug, VcpuDebugStatus, VcpuDebugStatusMessage, VmControlRequestSocket, - VmRequest, VmResponse, + VcpuControl, VcpuDebug, VcpuDebugStatus, VcpuDebugStatusMessage, VmRequest, VmResponse, }; use vm_memory::GuestAddress; @@ -82,15 +81,15 @@ enum Error { VcpuResponse(mpsc::RecvTimeoutError), /// Failed to send a VM request. #[error("failed to send a VM request: {0}")] - VmRequest(msg_socket::MsgError), + VmRequest(TubeError), /// Failed to receive a VM request. #[error("failed to receive a VM response: {0}")] - VmResponse(msg_socket::MsgError), + VmResponse(TubeError), } type GdbResult<T> = std::result::Result<T, Error>; pub struct GdbStub { - vm_socket: Mutex<VmControlRequestSocket>, + vm_tube: Mutex<Tube>, vcpu_com: Vec<mpsc::Sender<VcpuControl>>, from_vcpu: mpsc::Receiver<VcpuDebugStatusMessage>, @@ -99,12 +98,12 @@ pub struct GdbStub { impl GdbStub { pub fn new( - vm_socket: VmControlRequestSocket, + vm_tube: Tube, vcpu_com: Vec<mpsc::Sender<VcpuControl>>, from_vcpu: mpsc::Receiver<VcpuDebugStatusMessage>, ) -> Self { GdbStub { - vm_socket: Mutex::new(vm_socket), + vm_tube: Mutex::new(vm_tube), vcpu_com, from_vcpu, hw_breakpoints: Default::default(), @@ -122,9 +121,9 @@ impl GdbStub { } fn vm_request(&self, request: VmRequest) -> GdbResult<()> { - let vm_socket = self.vm_socket.lock(); - vm_socket.send(&request).map_err(Error::VmRequest)?; - match vm_socket.recv() { + let vm_tube = self.vm_tube.lock(); + vm_tube.send(&request).map_err(Error::VmRequest)?; + match vm_tube.recv() { Ok(VmResponse::Ok) => Ok(()), Ok(r) => Err(Error::UnexpectedVmResponse(r)), Err(e) => Err(Error::VmResponse(e)), diff --git a/src/linux.rs b/src/linux.rs index b89ce6df5..5efa2bb1f 100644 --- a/src/linux.rs +++ b/src/linux.rs @@ -32,7 +32,11 @@ use libc::{self, c_int, gid_t, uid_t}; use acpi_tables::sdt::SDT; -use base::net::{UnixSeqpacket, UnixSeqpacketListener, UnlinkUnixSeqpacketListener}; +use base::net::{UnixSeqpacketListener, UnlinkUnixSeqpacketListener}; +use base::*; +use devices::virtio::vhost::user::{ + Block as VhostUserBlock, Error as VhostUserError, Fs as VhostUserFs, Net as VhostUserNet, +}; #[cfg(feature = "gpu")] use devices::virtio::EventDevice; use devices::virtio::{self, Console, VirtioDevice}; @@ -45,38 +49,20 @@ use devices::{ use hypervisor::kvm::{Kvm, KvmVcpu, KvmVm}; use hypervisor::{HypervisorCap, Vcpu, VcpuExit, VcpuRunHandle, Vm, VmCap}; use minijail::{self, Minijail}; -use msg_socket::{MsgError, MsgReceiver, MsgSender, MsgSocket}; use net_util::{Error as NetError, MacAddress, Tap}; use remain::sorted; use resources::{Alloc, MmioType, SystemAllocator}; use rutabaga_gfx::RutabagaGralloc; use sync::Mutex; - -use base::{ - self, block_signal, clear_signal, drop_capabilities, error, flock, get_blocked_signals, - get_group_id, get_user_id, getegid, geteuid, info, register_rt_signal_handler, - set_cpu_affinity, set_rt_prio_limit, set_rt_round_robin, signal, validate_raw_descriptor, warn, - AsRawDescriptor, Event, EventType, ExternalMapping, FlockOperation, FromRawDescriptor, - Killable, MemoryMappingArena, PollToken, Protection, RawDescriptor, ScopedEvent, SignalFd, - Terminal, Timer, WaitContext, SIGRTMIN, -}; -use vm_control::{ - BalloonControlCommand, BalloonControlRequestSocket, BalloonControlResponseSocket, - BalloonControlResult, BalloonStats, DiskControlCommand, DiskControlRequestSocket, - DiskControlResponseSocket, DiskControlResult, FsMappingRequest, FsMappingRequestSocket, - FsMappingResponseSocket, IrqSetup, UsbControlSocket, VcpuControl, VmControlResponseSocket, - VmIrqRequest, VmIrqRequestSocket, VmIrqResponse, VmIrqResponseSocket, - VmMemoryControlRequestSocket, VmMemoryControlResponseSocket, VmMemoryRequest, VmMemoryResponse, - VmMsyncRequest, VmMsyncRequestSocket, VmMsyncResponse, VmMsyncResponseSocket, VmResponse, - VmRunMode, -}; -#[cfg(all(target_arch = "x86_64", feature = "gdb"))] -use vm_control::{VcpuDebug, VcpuDebugStatus, VcpuDebugStatusMessage, VmRequest}; -use vm_memory::{GuestAddress, GuestMemory}; +use vm_control::*; +use vm_memory::{GuestAddress, GuestMemory, MemoryPolicy}; #[cfg(all(target_arch = "x86_64", feature = "gdb"))] use crate::gdb::{gdb_thread, GdbStub}; -use crate::{Config, DiskOption, Executable, SharedDir, SharedDirKind, TouchDeviceOption}; +use crate::{ + Config, DiskOption, Executable, SharedDir, SharedDirKind, TouchDeviceOption, VhostUserFsOption, + VhostUserOption, +}; use arch::{ self, LinuxArch, RunnableLinuxVm, SerialHardware, SerialParameters, VcpuAffinity, VirtioDeviceStub, VmComponents, VmImage, @@ -115,20 +101,28 @@ pub enum Error { #[cfg(feature = "audio")] CreateAc97(devices::PciDeviceError), CreateConsole(arch::serial::Error), + CreateControlServer(io::Error), CreateDiskError(disk::Error), CreateEvent(base::Error), CreateGrallocError(rutabaga_gfx::RutabagaError), + CreateKvm(base::Error), CreateSignalFd(base::SignalFdError), CreateSocket(io::Error), CreateTapDevice(NetError), CreateTimer(base::Error), CreateTpmStorage(PathBuf, io::Error), + CreateTube(TubeError), CreateUsbProvider(devices::usb::host_backend::error::Error), CreateVcpu(base::Error), CreateVfioDevice(devices::vfio::VfioError), + CreateVm(base::Error), CreateWaitContext(base::Error), DeviceJail(minijail::Error), DevicePivotRoot(minijail::Error), + #[cfg(feature = "direct")] + DirectIo(io::Error), + #[cfg(feature = "direct")] + DirectIrq(devices::DirectIrqError), Disk(PathBuf, io::Error), DiskImageLock(base::Error), DropCapabilities(base::Error), @@ -139,6 +133,7 @@ pub enum Error { GuestCachedTooLarge(std::num::TryFromIntError), GuestFreeMissing(), GuestFreeTooLarge(std::num::TryFromIntError), + GuestMemoryLayout(<Arch as LinuxArch>::Error), #[cfg(all(target_arch = "x86_64", feature = "gdb"))] HandleDebugCommand(<Arch as LinuxArch>::Error), InputDeviceNew(virtio::InputError), @@ -189,6 +184,10 @@ pub enum Error { Timer(base::Error), ValidateRawDescriptor(base::Error), VhostNetDeviceNew(virtio::vhost::Error), + VhostUserBlockDeviceNew(VhostUserError), + VhostUserFsDeviceNew(VhostUserError), + VhostUserNetDeviceNew(VhostUserError), + VhostUserNetWithNetArgs, VhostVsockDeviceNew(virtio::vhost::Error), VirtioPciDev(base::Error), WaitContextAdd(base::Error), @@ -222,9 +221,11 @@ impl Display for Error { #[cfg(feature = "audio")] CreateAc97(e) => write!(f, "failed to create ac97 device: {}", e), CreateConsole(e) => write!(f, "failed to create console device: {}", e), + CreateControlServer(e) => write!(f, "failed to create control server: {}", e), CreateDiskError(e) => write!(f, "failed to create virtual disk: {}", e), CreateEvent(e) => write!(f, "failed to create event: {}", e), CreateGrallocError(e) => write!(f, "failed to create gralloc: {}", e), + CreateKvm(e) => write!(f, "failed to create kvm: {}", e), CreateSignalFd(e) => write!(f, "failed to create signalfd: {}", e), CreateSocket(e) => write!(f, "failed to create socket: {}", e), CreateTapDevice(e) => write!(f, "failed to create tap device: {}", e), @@ -232,12 +233,18 @@ impl Display for Error { CreateTpmStorage(p, e) => { write!(f, "failed to create tpm storage dir {}: {}", p.display(), e) } + CreateTube(e) => write!(f, "failed to create tube: {}", e), CreateUsbProvider(e) => write!(f, "failed to create usb provider: {}", e), CreateVcpu(e) => write!(f, "failed to create vcpu: {}", e), CreateVfioDevice(e) => write!(f, "Failed to create vfio device {}", e), + CreateVm(e) => write!(f, "failed to create vm: {}", e), CreateWaitContext(e) => write!(f, "failed to create wait context: {}", e), DeviceJail(e) => write!(f, "failed to jail device: {}", e), DevicePivotRoot(e) => write!(f, "failed to pivot root device: {}", e), + #[cfg(feature = "direct")] + DirectIo(e) => write!(f, "failed to open direct io device: {}", e), + #[cfg(feature = "direct")] + DirectIrq(e) => write!(f, "failed to enable interrupt forwarding: {}", e), Disk(p, e) => write!(f, "failed to load disk image {}: {}", p.display(), e), DiskImageLock(e) => write!(f, "failed to lock disk image: {}", e), DropCapabilities(e) => write!(f, "failed to drop process capabilities: {}", e), @@ -248,6 +255,7 @@ impl Display for Error { GuestCachedTooLarge(e) => write!(f, "guest cached is too large: {}", e), GuestFreeMissing() => write!(f, "guest free is missing from balloon stats"), GuestFreeTooLarge(e) => write!(f, "guest free is too large: {}", e), + GuestMemoryLayout(e) => write!(f, "failed to create guest memory layout: {}", e), #[cfg(all(target_arch = "x86_64", feature = "gdb"))] HandleDebugCommand(e) => write!(f, "failed to handle a gdb command: {}", e), InputDeviceNew(e) => write!(f, "failed to set up input device: {}", e), @@ -309,6 +317,15 @@ impl Display for Error { Timer(e) => write!(f, "failed to read timer fd: {}", e), ValidateRawDescriptor(e) => write!(f, "failed to validate raw descriptor: {}", e), VhostNetDeviceNew(e) => write!(f, "failed to set up vhost networking: {}", e), + VhostUserBlockDeviceNew(e) => { + write!(f, "failed to set up vhost-user block device: {}", e) + } + VhostUserFsDeviceNew(e) => write!(f, "failed to set up vhost-user fs device: {}", e), + VhostUserNetDeviceNew(e) => write!(f, "failed to set up vhost-user net device: {}", e), + VhostUserNetWithNetArgs => write!( + f, + "vhost-user-net cannot be used with any of --host_ip, --netmask or --mac" + ), VhostVsockDeviceNew(e) => write!(f, "failed to set up virtual socket device: {}", e), VirtioPciDev(e) => write!(f, "failed to create virtio pci dev: {}", e), WaitContextAdd(e) => write!(f, "failed to add descriptor to wait context: {}", e), @@ -330,28 +347,24 @@ impl std::error::Error for Error {} type Result<T> = std::result::Result<T, Error>; -enum TaggedControlSocket { - Fs(FsMappingResponseSocket), - Vm(VmControlResponseSocket), - VmMemory(VmMemoryControlResponseSocket), - VmIrq(VmIrqResponseSocket), - VmMsync(VmMsyncResponseSocket), +enum TaggedControlTube { + Fs(Tube), + Vm(Tube), + VmMemory(Tube), + VmIrq(Tube), + VmMsync(Tube), } -impl AsRef<UnixSeqpacket> for TaggedControlSocket { - fn as_ref(&self) -> &UnixSeqpacket { - use self::TaggedControlSocket::*; +impl AsRef<Tube> for TaggedControlTube { + fn as_ref(&self) -> &Tube { + use self::TaggedControlTube::*; match &self { - Fs(ref socket) => socket.as_ref(), - Vm(ref socket) => socket.as_ref(), - VmMemory(ref socket) => socket.as_ref(), - VmIrq(ref socket) => socket.as_ref(), - VmMsync(ref socket) => socket.as_ref(), + Fs(tube) | Vm(tube) | VmMemory(tube) | VmIrq(tube) | VmMsync(tube) => tube, } } } -impl AsRawDescriptor for TaggedControlSocket { +impl AsRawDescriptor for TaggedControlTube { fn as_raw_descriptor(&self) -> RawDescriptor { self.as_ref().as_raw_descriptor() } @@ -477,11 +490,7 @@ fn simple_jail(cfg: &Config, policy: &str) -> Result<Option<Minijail>> { type DeviceResult<T = VirtioDeviceStub> = std::result::Result<T, Error>; -fn create_block_device( - cfg: &Config, - disk: &DiskOption, - disk_device_socket: DiskControlResponseSocket, -) -> DeviceResult { +fn create_block_device(cfg: &Config, disk: &DiskOption, disk_device_tube: Tube) -> DeviceResult { // Special case '/proc/self/fd/*' paths. The FD is already open, just use it. let raw_image: File = if disk.path.parent() == Some(Path::new("/proc/self/fd")) { // Safe because we will validate |raw_fd|. @@ -510,7 +519,8 @@ fn create_block_device( disk.read_only, disk.sparse, disk.block_size, - Some(disk_device_socket), + disk.id, + Some(disk_device_tube), ) .map_err(Error::BlockDeviceNew)?, ) as Box<dyn VirtioDevice> @@ -524,7 +534,7 @@ fn create_block_device( disk.sparse, disk.block_size, disk.id, - Some(disk_device_socket), + Some(disk_device_tube), ) .map_err(Error::BlockDeviceNew)?, ) as Box<dyn VirtioDevice> @@ -536,6 +546,32 @@ fn create_block_device( }) } +fn create_vhost_user_block_device(cfg: &Config, opt: &VhostUserOption) -> DeviceResult { + let dev = VhostUserBlock::new(virtio::base_features(cfg.protected_vm), &opt.socket) + .map_err(Error::VhostUserBlockDeviceNew)?; + + Ok(VirtioDeviceStub { + dev: Box::new(dev), + // no sandbox here because virtqueue handling is exported to a different process. + jail: None, + }) +} + +fn create_vhost_user_fs_device(cfg: &Config, option: &VhostUserFsOption) -> DeviceResult { + let dev = VhostUserFs::new( + virtio::base_features(cfg.protected_vm), + &option.socket, + &option.tag, + ) + .map_err(Error::VhostUserFsDeviceNew)?; + + Ok(VirtioDeviceStub { + dev: Box::new(dev), + // no sandbox here because virtqueue handling is exported to a different process. + jail: None, + }) +} + fn create_rng_device(cfg: &Config) -> DeviceResult { let dev = virtio::Rng::new(virtio::base_features(cfg.protected_vm)).map_err(Error::RngDeviceNew)?; @@ -548,7 +584,6 @@ fn create_rng_device(cfg: &Config) -> DeviceResult { #[cfg(feature = "tpm")] fn create_tpm_device(cfg: &Config) -> DeviceResult { - use base::chown; use std::ffi::CString; use std::fs; use std::process; @@ -724,8 +759,8 @@ fn create_vinput_device(cfg: &Config, dev_path: &Path) -> DeviceResult { }) } -fn create_balloon_device(cfg: &Config, socket: BalloonControlResponseSocket) -> DeviceResult { - let dev = virtio::Balloon::new(virtio::base_features(cfg.protected_vm), socket) +fn create_balloon_device(cfg: &Config, tube: Tube) -> DeviceResult { + let dev = virtio::Balloon::new(virtio::base_features(cfg.protected_vm), tube) .map_err(Error::BalloonDeviceNew)?; Ok(VirtioDeviceStub { @@ -775,6 +810,7 @@ fn create_net_device( let features = virtio::base_features(cfg.protected_vm); let dev = if cfg.vhost_net { let dev = virtio::vhost::Net::<Tap, vhost::Net<Tap>>::new( + &cfg.vhost_net_device_path, features, host_ip, netmask, @@ -801,16 +837,28 @@ fn create_net_device( }) } +fn create_vhost_user_net_device(cfg: &Config, opt: &VhostUserOption) -> DeviceResult { + let dev = VhostUserNet::new(virtio::base_features(cfg.protected_vm), &opt.socket) + .map_err(Error::VhostUserNetDeviceNew)?; + + Ok(VirtioDeviceStub { + dev: Box::new(dev), + // no sandbox here because virtqueue handling is exported to a different process. + jail: None, + }) +} + #[cfg(feature = "gpu")] fn create_gpu_device( cfg: &Config, exit_evt: &Event, - gpu_device_socket: VmMemoryControlRequestSocket, - gpu_sockets: Vec<virtio::resource_bridge::ResourceResponseSocket>, + gpu_device_tube: Tube, + resource_bridges: Vec<Tube>, wayland_socket_path: Option<&PathBuf>, x_display: Option<String>, event_devices: Vec<EventDevice>, map_request: Arc<Mutex<Option<ExternalMapping>>>, + mem: &GuestMemory, ) -> DeviceResult { let jailed_wayland_path = Path::new("/wayland-0"); @@ -832,9 +880,9 @@ fn create_gpu_device( let dev = virtio::Gpu::new( exit_evt.try_clone().map_err(Error::CloneEvent)?, - Some(gpu_device_socket), + Some(gpu_device_tube), NonZeroU8::new(1).unwrap(), // number of scanouts - gpu_sockets, + resource_bridges, display_backends, cfg.gpu_parameters.as_ref().unwrap(), event_devices, @@ -842,6 +890,7 @@ fn create_gpu_device( cfg.sandbox, virtio::base_features(cfg.protected_vm), cfg.wayland_socket_paths.clone(), + mem.clone(), ); let jail = match simple_jail(&cfg, "gpu_device")? { @@ -902,6 +951,12 @@ fn create_gpu_device( jail.mount_bind(pvr_sync_path, pvr_sync_path, true)?; } + // If the udmabuf driver exists on the host, bind mount it in. + let udmabuf_path = Path::new("/dev/udmabuf"); + if udmabuf_path.exists() { + jail.mount_bind(udmabuf_path, udmabuf_path, true)?; + } + // Libraries that are required when mesa drivers are dynamically loaded. let lib_dirs = &[ "/usr/lib", @@ -957,8 +1012,8 @@ fn create_gpu_device( fn create_wayland_device( cfg: &Config, - socket: VmMemoryControlRequestSocket, - resource_bridge: Option<virtio::resource_bridge::ResourceRequestSocket>, + control_tube: Tube, + resource_bridge: Option<Tube>, ) -> DeviceResult { let wayland_socket_dirs = cfg .wayland_socket_paths @@ -971,7 +1026,7 @@ fn create_wayland_device( let dev = virtio::Wl::new( features, cfg.wayland_socket_paths.clone(), - socket, + control_tube, resource_bridge, ) .map_err(Error::WaylandDeviceNew)?; @@ -1012,7 +1067,7 @@ fn create_wayland_device( fn create_video_device( cfg: &Config, typ: devices::virtio::VideoDeviceType, - resource_bridge: virtio::resource_bridge::ResourceRequestSocket, + resource_bridge: Tube, ) -> DeviceResult { let jail = match simple_jail(&cfg, "video_device")? { Some(mut jail) => { @@ -1075,20 +1130,18 @@ fn create_video_device( #[cfg(any(feature = "video-decoder", feature = "video-encoder"))] fn register_video_device( devs: &mut Vec<VirtioDeviceStub>, - resource_bridges: &mut Vec<virtio::resource_bridge::ResourceResponseSocket>, + video_tube: Tube, cfg: &Config, typ: devices::virtio::VideoDeviceType, ) -> std::result::Result<(), Error> { - let (video_socket, gpu_socket) = - virtio::resource_bridge::pair().map_err(Error::CreateSocket)?; - resource_bridges.push(gpu_socket); - devs.push(create_video_device(cfg, typ, video_socket)?); + devs.push(create_video_device(cfg, typ, video_tube)?); Ok(()) } fn create_vhost_vsock_device(cfg: &Config, cid: u64, mem: &GuestMemory) -> DeviceResult { let features = virtio::base_features(cfg.protected_vm); - let dev = virtio::vhost::Vsock::new(features, cid, mem).map_err(Error::VhostVsockDeviceNew)?; + let dev = virtio::vhost::Vsock::new(&cfg.vhost_vsock_device_path, features, cid, mem) + .map_err(Error::VhostVsockDeviceNew)?; Ok(VirtioDeviceStub { dev: Box::new(dev), @@ -1103,7 +1156,7 @@ fn create_fs_device( src: &Path, tag: &str, fs_cfg: virtio::fs::passthrough::Config, - device_socket: FsMappingRequestSocket, + device_tube: Tube, ) -> DeviceResult { let max_open_files = get_max_open_files()?; let j = if cfg.sandbox { @@ -1129,7 +1182,7 @@ fn create_fs_device( // TODO(chirantan): Use more than one worker once the kernel driver has been fixed to not panic // when num_queues > 1. let dev = - virtio::fs::Fs::new(features, tag, 1, fs_cfg, device_socket).map_err(Error::FsDeviceNew)?; + virtio::fs::Fs::new(features, tag, 1, fs_cfg, device_tube).map_err(Error::FsDeviceNew)?; Ok(VirtioDeviceStub { dev: Box::new(dev), @@ -1186,13 +1239,19 @@ fn create_pmem_device( resources: &mut SystemAllocator, disk: &DiskOption, index: usize, - pmem_device_socket: VmMsyncRequestSocket, + pmem_device_tube: Tube, ) -> DeviceResult { - let fd = OpenOptions::new() - .read(true) - .write(!disk.read_only) - .open(&disk.path) - .map_err(|e| Error::Disk(disk.path.to_path_buf(), e))?; + // Special case '/proc/self/fd/*' paths. The FD is already open, just use it. + let fd: File = if disk.path.parent() == Some(Path::new("/proc/self/fd")) { + // Safe because we will validate |raw_fd|. + unsafe { File::from_raw_descriptor(raw_descriptor_from_path(&disk.path)?) } + } else { + OpenOptions::new() + .read(true) + .write(!disk.read_only) + .open(&disk.path) + .map_err(|e| Error::Disk(disk.path.to_path_buf(), e))? + }; let arena_size = { let metadata = @@ -1259,7 +1318,7 @@ fn create_pmem_device( GuestAddress(mapping_address), slot, arena_size, - Some(pmem_device_socket), + Some(pmem_device_tube), ) .map_err(Error::PmemDeviceNew)?; @@ -1304,7 +1363,7 @@ fn create_console_device(cfg: &Config, param: &SerialParameters) -> DeviceResult }) } -// gpu_device_socket is not used when GPU support is disabled. +// gpu_device_tube is not used when GPU support is disabled. #[cfg_attr(not(feature = "gpu"), allow(unused_variables))] fn create_virtio_devices( cfg: &Config, @@ -1312,13 +1371,13 @@ fn create_virtio_devices( vm: &mut impl Vm, resources: &mut SystemAllocator, _exit_evt: &Event, - wayland_device_socket: VmMemoryControlRequestSocket, - gpu_device_socket: VmMemoryControlRequestSocket, - balloon_device_socket: BalloonControlResponseSocket, - disk_device_sockets: &mut Vec<DiskControlResponseSocket>, - pmem_device_sockets: &mut Vec<VmMsyncRequestSocket>, + wayland_device_tube: Tube, + gpu_device_tube: Tube, + balloon_device_tube: Tube, + disk_device_tubes: &mut Vec<Tube>, + pmem_device_tubes: &mut Vec<Tube>, map_request: Arc<Mutex<Option<ExternalMapping>>>, - fs_device_sockets: &mut Vec<FsMappingRequestSocket>, + fs_device_tubes: &mut Vec<Tube>, ) -> DeviceResult<Vec<VirtioDeviceStub>> { let mut devs = Vec::new(); @@ -1332,19 +1391,23 @@ fn create_virtio_devices( } for disk in &cfg.disks { - let disk_device_socket = disk_device_sockets.remove(0); - devs.push(create_block_device(cfg, disk, disk_device_socket)?); + let disk_device_tube = disk_device_tubes.remove(0); + devs.push(create_block_device(cfg, disk, disk_device_tube)?); + } + + for blk in &cfg.vhost_user_blk { + devs.push(create_vhost_user_block_device(cfg, blk)?); } for (index, pmem_disk) in cfg.pmem_devices.iter().enumerate() { - let pmem_device_socket = pmem_device_sockets.remove(0); + let pmem_device_tube = pmem_device_tubes.remove(0); devs.push(create_pmem_device( cfg, vm, resources, pmem_disk, index, - pmem_device_socket, + pmem_device_tube, )?); } @@ -1385,7 +1448,7 @@ fn create_virtio_devices( devs.push(create_vinput_device(cfg, dev_path)?); } - devs.push(create_balloon_device(cfg, balloon_device_socket)?); + devs.push(create_balloon_device(cfg, balloon_device_tube)?); // We checked above that if the IP is defined, then the netmask is, too. for tap_fd in &cfg.tap_fd { @@ -1395,21 +1458,27 @@ fn create_virtio_devices( if let (Some(host_ip), Some(netmask), Some(mac_address)) = (cfg.host_ip, cfg.netmask, cfg.mac_address) { + if !cfg.vhost_user_net.is_empty() { + return Err(Error::VhostUserNetWithNetArgs); + } devs.push(create_net_device(cfg, host_ip, netmask, mac_address, mem)?); } + for net in &cfg.vhost_user_net { + devs.push(create_vhost_user_net_device(cfg, net)?); + } + #[cfg_attr(not(feature = "gpu"), allow(unused_mut))] - let mut resource_bridges = Vec::<virtio::resource_bridge::ResourceResponseSocket>::new(); + let mut resource_bridges = Vec::<Tube>::new(); if !cfg.wayland_socket_paths.is_empty() { #[cfg_attr(not(feature = "gpu"), allow(unused_mut))] - let mut wl_resource_bridge = None::<virtio::resource_bridge::ResourceRequestSocket>; + let mut wl_resource_bridge = None::<Tube>; #[cfg(feature = "gpu")] { if cfg.gpu_parameters.is_some() { - let (wl_socket, gpu_socket) = - virtio::resource_bridge::pair().map_err(Error::CreateSocket)?; + let (wl_socket, gpu_socket) = Tube::pair().map_err(Error::CreateTube)?; resource_bridges.push(gpu_socket); wl_resource_bridge = Some(wl_socket); } @@ -1417,34 +1486,28 @@ fn create_virtio_devices( devs.push(create_wayland_device( cfg, - wayland_device_socket, + wayland_device_tube, wl_resource_bridge, )?); } #[cfg(feature = "video-decoder")] - { - if cfg.video_dec { - register_video_device( - &mut devs, - &mut resource_bridges, - cfg, - devices::virtio::VideoDeviceType::Decoder, - )?; - } - } + let video_dec_tube = if cfg.video_dec { + let (video_tube, gpu_tube) = Tube::pair().map_err(Error::CreateTube)?; + resource_bridges.push(gpu_tube); + Some(video_tube) + } else { + None + }; #[cfg(feature = "video-encoder")] - { - if cfg.video_enc { - register_video_device( - &mut devs, - &mut resource_bridges, - cfg, - devices::virtio::VideoDeviceType::Encoder, - )?; - } - } + let video_enc_tube = if cfg.video_enc { + let (video_tube, gpu_tube) = Tube::pair().map_err(Error::CreateTube)?; + resource_bridges.push(gpu_tube); + Some(video_tube) + } else { + None + }; #[cfg(feature = "gpu")] { @@ -1488,21 +1551,50 @@ fn create_virtio_devices( devs.push(create_gpu_device( cfg, _exit_evt, - gpu_device_socket, + gpu_device_tube, resource_bridges, // Use the unnamed socket for GPU display screens. cfg.wayland_socket_paths.get(""), cfg.x_display.clone(), event_devices, map_request, + mem, )?); } } + #[cfg(feature = "video-decoder")] + { + if let Some(video_dec_tube) = video_dec_tube { + register_video_device( + &mut devs, + video_dec_tube, + cfg, + devices::virtio::VideoDeviceType::Decoder, + )?; + } + } + + #[cfg(feature = "video-encoder")] + { + if let Some(video_enc_tube) = video_enc_tube { + register_video_device( + &mut devs, + video_enc_tube, + cfg, + devices::virtio::VideoDeviceType::Encoder, + )?; + } + } + if let Some(cid) = cfg.cid { devs.push(create_vhost_vsock_device(cfg, cid, mem)?); } + for vhost_user_fs in &cfg.vhost_user_fs { + devs.push(create_vhost_user_fs_device(cfg, &vhost_user_fs)?); + } + for shared_dir in &cfg.shared_dirs { let SharedDir { src, @@ -1516,16 +1608,8 @@ fn create_virtio_devices( let dev = match kind { SharedDirKind::FS => { - let device_socket = fs_device_sockets.remove(0); - create_fs_device( - cfg, - uid_map, - gid_map, - src, - tag, - fs_cfg.clone(), - device_socket, - )? + let device_tube = fs_device_tubes.remove(0); + create_fs_device(cfg, uid_map, gid_map, src, tag, fs_cfg.clone(), device_tube)? } SharedDirKind::P9 => create_9p_device(cfg, uid_map, gid_map, src, tag, p9_cfg.clone())?, }; @@ -1541,13 +1625,13 @@ fn create_devices( vm: &mut impl Vm, resources: &mut SystemAllocator, exit_evt: &Event, - control_sockets: &mut Vec<TaggedControlSocket>, - wayland_device_socket: VmMemoryControlRequestSocket, - gpu_device_socket: VmMemoryControlRequestSocket, - balloon_device_socket: BalloonControlResponseSocket, - disk_device_sockets: &mut Vec<DiskControlResponseSocket>, - pmem_device_sockets: &mut Vec<VmMsyncRequestSocket>, - fs_device_sockets: &mut Vec<FsMappingRequestSocket>, + control_tubes: &mut Vec<TaggedControlTube>, + wayland_device_tube: Tube, + gpu_device_tube: Tube, + balloon_device_tube: Tube, + disk_device_tubes: &mut Vec<Tube>, + pmem_device_tubes: &mut Vec<Tube>, + fs_device_tubes: &mut Vec<Tube>, usb_provider: HostBackendDeviceProvider, map_request: Arc<Mutex<Option<ExternalMapping>>>, ) -> DeviceResult<Vec<(Box<dyn PciDevice>, Option<Minijail>)>> { @@ -1557,22 +1641,21 @@ fn create_devices( vm, resources, exit_evt, - wayland_device_socket, - gpu_device_socket, - balloon_device_socket, - disk_device_sockets, - pmem_device_sockets, + wayland_device_tube, + gpu_device_tube, + balloon_device_tube, + disk_device_tubes, + pmem_device_tubes, map_request, - fs_device_sockets, + fs_device_tubes, )?; let mut pci_devices = Vec::new(); for stub in stubs { - let (msi_host_socket, msi_device_socket) = - msg_socket::pair::<VmIrqResponse, VmIrqRequest>().map_err(Error::CreateSocket)?; - control_sockets.push(TaggedControlSocket::VmIrq(msi_host_socket)); - let dev = VirtioPciDevice::new(mem.clone(), stub.dev, msi_device_socket) + let (msi_host_tube, msi_device_tube) = Tube::pair().map_err(Error::CreateTube)?; + control_tubes.push(TaggedControlTube::VmIrq(msi_host_tube)); + let dev = VirtioPciDevice::new(mem.clone(), stub.dev, msi_device_tube) .map_err(Error::VirtioPciDev)?; let dev = Box::new(dev) as Box<dyn PciDevice>; pci_devices.push((dev, stub.jail)); @@ -1596,26 +1679,25 @@ fn create_devices( for vfio_path in &cfg.vfio { // create MSI, MSI-X, and Mem request sockets for each vfio device - let (vfio_host_socket_msi, vfio_device_socket_msi) = - msg_socket::pair::<VmIrqResponse, VmIrqRequest>().map_err(Error::CreateSocket)?; - control_sockets.push(TaggedControlSocket::VmIrq(vfio_host_socket_msi)); + let (vfio_host_tube_msi, vfio_device_tube_msi) = + Tube::pair().map_err(Error::CreateTube)?; + control_tubes.push(TaggedControlTube::VmIrq(vfio_host_tube_msi)); - let (vfio_host_socket_msix, vfio_device_socket_msix) = - msg_socket::pair::<VmIrqResponse, VmIrqRequest>().map_err(Error::CreateSocket)?; - control_sockets.push(TaggedControlSocket::VmIrq(vfio_host_socket_msix)); + let (vfio_host_tube_msix, vfio_device_tube_msix) = + Tube::pair().map_err(Error::CreateTube)?; + control_tubes.push(TaggedControlTube::VmIrq(vfio_host_tube_msix)); - let (vfio_host_socket_mem, vfio_device_socket_mem) = - msg_socket::pair::<VmMemoryResponse, VmMemoryRequest>() - .map_err(Error::CreateSocket)?; - control_sockets.push(TaggedControlSocket::VmMemory(vfio_host_socket_mem)); + let (vfio_host_tube_mem, vfio_device_tube_mem) = + Tube::pair().map_err(Error::CreateTube)?; + control_tubes.push(TaggedControlTube::VmMemory(vfio_host_tube_mem)); let vfiodevice = VfioDevice::new(vfio_path.as_path(), vm, mem, vfio_container.clone()) .map_err(Error::CreateVfioDevice)?; let mut vfiopcidevice = Box::new(VfioPciDevice::new( vfiodevice, - vfio_device_socket_msi, - vfio_device_socket_msix, - vfio_device_socket_mem, + vfio_device_tube_msi, + vfio_device_tube_msix, + vfio_device_tube_mem, )); // early reservation for pass-through PCI devices. if vfiopcidevice.allocate_address(resources).is_err() { @@ -1713,7 +1795,7 @@ impl IntoUnixStream for UnixStream { fn setup_vcpu_signal_handler<T: Vcpu>(use_hypervisor_signals: bool) -> Result<()> { if use_hypervisor_signals { unsafe { - extern "C" fn handle_signal() {} + extern "C" fn handle_signal(_: c_int) {} // Our signal handler does nothing and is trivially async signal safe. register_rt_signal_handler(SIGRTMIN() + 0, handle_signal) .map_err(Error::RegisterSignalHandler)?; @@ -1721,7 +1803,7 @@ fn setup_vcpu_signal_handler<T: Vcpu>(use_hypervisor_signals: bool) -> Result<() block_signal(SIGRTMIN() + 0).map_err(Error::BlockSignal)?; } else { unsafe { - extern "C" fn handle_signal<T: Vcpu>() { + extern "C" fn handle_signal<T: Vcpu>(_: c_int) { T::set_local_immediate_exit(true); } register_rt_signal_handler(SIGRTMIN() + 0, handle_signal::<T>) @@ -1818,7 +1900,7 @@ fn handle_debug_msg<V>( vcpu: &V, guest_mem: &GuestMemory, d: VcpuDebug, - reply_channel: &mpsc::Sender<VcpuDebugStatusMessage>, + reply_tube: &mpsc::Sender<VcpuDebugStatusMessage>, ) -> Result<()> where V: VcpuArch + 'static, @@ -1831,13 +1913,13 @@ where Arch::debug_read_registers(vcpu as &V).map_err(Error::HandleDebugCommand)?, ), }; - reply_channel + reply_tube .send(msg) .map_err(|e| Error::SendDebugStatus(Box::new(e))) } VcpuDebug::WriteRegs(regs) => { Arch::debug_write_registers(vcpu as &V, ®s).map_err(Error::HandleDebugCommand)?; - reply_channel + reply_tube .send(VcpuDebugStatusMessage { cpu: cpu_id as usize, msg: VcpuDebugStatus::CommandComplete, @@ -1852,14 +1934,14 @@ where .unwrap_or(Vec::new()), ), }; - reply_channel + reply_tube .send(msg) .map_err(|e| Error::SendDebugStatus(Box::new(e))) } VcpuDebug::WriteMem(vaddr, buf) => { Arch::debug_write_memory(vcpu as &V, guest_mem, vaddr, &buf) .map_err(Error::HandleDebugCommand)?; - reply_channel + reply_tube .send(VcpuDebugStatusMessage { cpu: cpu_id as usize, msg: VcpuDebugStatus::CommandComplete, @@ -1868,7 +1950,7 @@ where } VcpuDebug::EnableSinglestep => { Arch::debug_enable_singlestep(vcpu as &V).map_err(Error::HandleDebugCommand)?; - reply_channel + reply_tube .send(VcpuDebugStatusMessage { cpu: cpu_id as usize, msg: VcpuDebugStatus::CommandComplete, @@ -1878,7 +1960,7 @@ where VcpuDebug::SetHwBreakPoint(addrs) => { Arch::debug_set_hw_breakpoints(vcpu as &V, &addrs) .map_err(Error::HandleDebugCommand)?; - reply_channel + reply_tube .send(VcpuDebugStatusMessage { cpu: cpu_id as usize, msg: VcpuDebugStatus::CommandComplete, @@ -1903,9 +1985,9 @@ fn run_vcpu<V>( mmio_bus: devices::Bus, exit_evt: Event, requires_pvclock_ctrl: bool, - from_main_channel: mpsc::Receiver<VcpuControl>, + from_main_tube: mpsc::Receiver<VcpuControl>, use_hypervisor_signals: bool, - #[cfg(all(target_arch = "x86_64", feature = "gdb"))] to_gdb_channel: Option< + #[cfg(all(target_arch = "x86_64", feature = "gdb"))] to_gdb_tube: Option< mpsc::Sender<VcpuDebugStatusMessage>, >, ) -> Result<JoinHandle<()>> @@ -1946,7 +2028,7 @@ where let mut run_mode = VmRunMode::Running; #[cfg(all(target_arch = "x86_64", feature = "gdb"))] - if to_gdb_channel.is_some() { + if to_gdb_tube.is_some() { // Wait until a GDB client attaches run_mode = VmRunMode::Breakpoint; } @@ -1960,7 +2042,7 @@ where if interrupted_by_signal || run_mode != VmRunMode::Running { 'state_loop: loop { // Tries to get a pending message without blocking first. - let msg = match from_main_channel.try_recv() { + let msg = match from_main_tube.try_recv() { Ok(m) => m, Err(mpsc::TryRecvError::Empty) if run_mode == VmRunMode::Running => { // If the VM is running and no message is pending, the state won't @@ -1969,23 +2051,23 @@ where } Err(mpsc::TryRecvError::Empty) => { // If the VM is not running, wait until a message is ready. - match from_main_channel.recv() { + match from_main_tube.recv() { Ok(m) => m, Err(mpsc::RecvError) => { - error!("Failed to read from main channel in vcpu"); + error!("Failed to read from main tube in vcpu"); break 'vcpu_loop; } } } Err(mpsc::TryRecvError::Disconnected) => { - error!("Failed to read from main channel in vcpu"); + error!("Failed to read from main tube in vcpu"); break 'vcpu_loop; } }; // Collect all pending messages. let mut messages = vec![msg]; - messages.append(&mut from_main_channel.try_iter().collect()); + messages.append(&mut from_main_tube.try_iter().collect()); for msg in messages { match msg { @@ -2015,7 +2097,7 @@ where } #[cfg(all(target_arch = "x86_64", feature = "gdb"))] VcpuControl::Debug(d) => { - match &to_gdb_channel { + match &to_gdb_tube { Some(ref ch) => { if let Err(e) = handle_debug_msg( cpu_id, &vcpu, &guest_mem, d, &ch, @@ -2112,7 +2194,7 @@ where cpu: cpu_id as usize, msg: VcpuDebugStatus::HitBreakPoint, }; - if let Some(ref ch) = to_gdb_channel { + if let Some(ref ch) = to_gdb_tube { if let Err(e) = ch.send(msg) { error!("failed to notify breakpoint to GDB thread: {}", e); break; @@ -2183,16 +2265,10 @@ fn file_to_i64<P: AsRef<Path>>(path: P, nth: usize) -> io::Result<i64> { .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "empty file")) } -fn create_kvm(mem: GuestMemory) -> base::Result<KvmVm> { - let kvm = Kvm::new()?; - let vm = KvmVm::new(&kvm, mem)?; - Ok(vm) -} - fn create_kvm_kernel_irq_chip( vm: &KvmVm, vcpu_count: usize, - _ioapic_device_socket: VmIrqRequestSocket, + _ioapic_device_tube: Tube, ) -> base::Result<impl IrqChipArch> { let irq_chip = KvmKernelIrqChip::new(vm.try_clone()?, vcpu_count)?; Ok(irq_chip) @@ -2202,13 +2278,27 @@ fn create_kvm_kernel_irq_chip( fn create_kvm_split_irq_chip( vm: &KvmVm, vcpu_count: usize, - ioapic_device_socket: VmIrqRequestSocket, + ioapic_device_tube: Tube, ) -> base::Result<impl IrqChipArch> { - let irq_chip = KvmSplitIrqChip::new(vm.try_clone()?, vcpu_count, ioapic_device_socket)?; + let irq_chip = + KvmSplitIrqChip::new(vm.try_clone()?, vcpu_count, ioapic_device_tube, Some(120))?; Ok(irq_chip) } pub fn run_config(cfg: Config) -> Result<()> { + let components = setup_vm_components(&cfg)?; + + let guest_mem_layout = + Arch::guest_memory_layout(&components).map_err(Error::GuestMemoryLayout)?; + let guest_mem = GuestMemory::new(&guest_mem_layout).unwrap(); + let mut mem_policy = MemoryPolicy::empty(); + if components.hugepages { + mem_policy |= MemoryPolicy::USE_HUGEPAGES; + } + guest_mem.set_memory_policy(mem_policy); + let kvm = Kvm::new_with_path(&cfg.kvm_device_path).map_err(Error::CreateKvm)?; + let vm = KvmVm::new(&kvm, guest_mem).map_err(Error::CreateVm)?; + if cfg.split_irqchip { #[cfg(any(target_arch = "arm", target_arch = "aarch64"))] { @@ -2217,39 +2307,14 @@ pub fn run_config(cfg: Config) -> Result<()> { #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] { - run_vm::<_, KvmVcpu, _, _, _>(cfg, create_kvm, create_kvm_split_irq_chip) + run_vm::<KvmVcpu, _, _, _>(cfg, components, vm, create_kvm_split_irq_chip) } } else { - run_vm::<_, KvmVcpu, _, _, _>(cfg, create_kvm, create_kvm_kernel_irq_chip) + run_vm::<KvmVcpu, _, _, _>(cfg, components, vm, create_kvm_kernel_irq_chip) } } -fn run_vm<V, Vcpu, I, FV, FI>(cfg: Config, create_vm: FV, create_irq_chip: FI) -> Result<()> -where - V: VmArch + 'static, - Vcpu: VcpuArch + 'static, - I: IrqChipArch + 'static, - FV: FnOnce(GuestMemory) -> base::Result<V>, - FI: FnOnce( - &V, - usize, // vcpu_count - VmIrqRequestSocket, // ioapic_device_socket - ) -> base::Result<I>, -{ - if cfg.sandbox { - // Printing something to the syslog before entering minijail so that libc's syslogger has a - // chance to open files necessary for its operation, like `/etc/localtime`. After jailing, - // access to those files will not be possible. - info!("crosvm entering multiprocess mode"); - } - - let (usb_control_socket, usb_provider) = - HostBackendDeviceProvider::new().map_err(Error::CreateUsbProvider)?; - // Masking signals is inherently dangerous, since this can persist across clones/execs. Do this - // before any jailed devices have been spawned, so that we can catch any of them that fail very - // quickly. - let sigchld_fd = SignalFd::new(libc::SIGCHLD).map_err(Error::CreateSignalFd)?; - +fn setup_vm_components(cfg: &Config) -> Result<VmComponents> { let initrd_image = if let Some(initrd_path) = &cfg.initrd_path { Some(File::open(initrd_path).map_err(|e| Error::OpenInitrd(initrd_path.clone(), e))?) } else { @@ -2266,19 +2331,7 @@ where _ => panic!("Did not receive a bios or kernel, should be impossible."), }; - let mut control_sockets = Vec::new(); - #[cfg(all(target_arch = "x86_64", feature = "gdb"))] - let gdb_socket = if let Some(port) = cfg.gdb { - // GDB needs a control socket to interrupt vcpus. - let (gdb_host_socket, gdb_control_socket) = - msg_socket::pair::<VmResponse, VmRequest>().map_err(Error::CreateSocket)?; - control_sockets.push(TaggedControlSocket::Vm(gdb_host_socket)); - Some((port, gdb_control_socket)) - } else { - None - }; - - let components = VmComponents { + Ok(VmComponents { memory_size: cfg .memory .unwrap_or(256) @@ -2287,6 +2340,7 @@ where vcpu_count: cfg.vcpu_count.unwrap_or(1), vcpu_affinity: cfg.vcpu_affinity.clone(), no_smt: cfg.no_smt, + hugepages: cfg.hugepages, vm_image, android_fstab: cfg .android_fstab @@ -2305,52 +2359,86 @@ where rt_cpus: cfg.rt_cpus.clone(), protected_vm: cfg.protected_vm, #[cfg(all(target_arch = "x86_64", feature = "gdb"))] - gdb: gdb_socket, - }; + gdb: None, + dmi_path: cfg.dmi_path.clone(), + }) +} + +fn run_vm<Vcpu, V, I, FI>( + cfg: Config, + #[allow(unused_mut)] mut components: VmComponents, + vm: V, + create_irq_chip: FI, +) -> Result<()> +where + Vcpu: VcpuArch + 'static, + V: VmArch + 'static, + I: IrqChipArch + 'static, + FI: FnOnce( + &V, + usize, // vcpu_count + Tube, // ioapic_device_tube + ) -> base::Result<I>, +{ + if cfg.sandbox { + // Printing something to the syslog before entering minijail so that libc's syslogger has a + // chance to open files necessary for its operation, like `/etc/localtime`. After jailing, + // access to those files will not be possible. + info!("crosvm entering multiprocess mode"); + } + + let (usb_control_tube, usb_provider) = + HostBackendDeviceProvider::new().map_err(Error::CreateUsbProvider)?; + // Masking signals is inherently dangerous, since this can persist across clones/execs. Do this + // before any jailed devices have been spawned, so that we can catch any of them that fail very + // quickly. + let sigchld_fd = SignalFd::new(libc::SIGCHLD).map_err(Error::CreateSignalFd)?; let control_server_socket = match &cfg.socket_path { Some(path) => Some(UnlinkUnixSeqpacketListener( - UnixSeqpacketListener::bind(path).map_err(Error::CreateSocket)?, + UnixSeqpacketListener::bind(path).map_err(Error::CreateControlServer)?, )), None => None, }; - let (wayland_host_socket, wayland_device_socket) = - msg_socket::pair::<VmMemoryResponse, VmMemoryRequest>().map_err(Error::CreateSocket)?; - control_sockets.push(TaggedControlSocket::VmMemory(wayland_host_socket)); + let mut control_tubes = Vec::new(); + + #[cfg(all(target_arch = "x86_64", feature = "gdb"))] + if let Some(port) = cfg.gdb { + // GDB needs a control socket to interrupt vcpus. + let (gdb_host_tube, gdb_control_tube) = Tube::pair().map_err(Error::CreateTube)?; + control_tubes.push(TaggedControlTube::Vm(gdb_host_tube)); + components.gdb = Some((port, gdb_control_tube)); + } + + let (wayland_host_tube, wayland_device_tube) = Tube::pair().map_err(Error::CreateTube)?; + control_tubes.push(TaggedControlTube::VmMemory(wayland_host_tube)); // Balloon gets a special socket so balloon requests can be forwarded from the main process. - let (balloon_host_socket, balloon_device_socket) = - msg_socket::pair::<BalloonControlCommand, BalloonControlResult>() - .map_err(Error::CreateSocket)?; + let (balloon_host_tube, balloon_device_tube) = Tube::pair().map_err(Error::CreateTube)?; // Create one control socket per disk. - let mut disk_device_sockets = Vec::new(); - let mut disk_host_sockets = Vec::new(); + let mut disk_device_tubes = Vec::new(); + let mut disk_host_tubes = Vec::new(); let disk_count = cfg.disks.len(); for _ in 0..disk_count { - let (disk_host_socket, disk_device_socket) = - msg_socket::pair::<DiskControlCommand, DiskControlResult>() - .map_err(Error::CreateSocket)?; - disk_host_sockets.push(disk_host_socket); - disk_device_sockets.push(disk_device_socket); + let (disk_host_tub, disk_device_tube) = Tube::pair().map_err(Error::CreateTube)?; + disk_host_tubes.push(disk_host_tub); + disk_device_tubes.push(disk_device_tube); } - let mut pmem_device_sockets = Vec::new(); + let mut pmem_device_tubes = Vec::new(); let pmem_count = cfg.pmem_devices.len(); for _ in 0..pmem_count { - let (pmem_host_socket, pmem_device_socket) = - msg_socket::pair::<VmMsyncResponse, VmMsyncRequest>().map_err(Error::CreateSocket)?; - pmem_device_sockets.push(pmem_device_socket); - control_sockets.push(TaggedControlSocket::VmMsync(pmem_host_socket)); + let (pmem_host_tube, pmem_device_tube) = Tube::pair().map_err(Error::CreateTube)?; + pmem_device_tubes.push(pmem_device_tube); + control_tubes.push(TaggedControlTube::VmMsync(pmem_host_tube)); } - let (gpu_host_socket, gpu_device_socket) = - msg_socket::pair::<VmMemoryResponse, VmMemoryRequest>().map_err(Error::CreateSocket)?; - control_sockets.push(TaggedControlSocket::VmMemory(gpu_host_socket)); + let (gpu_host_tube, gpu_device_tube) = Tube::pair().map_err(Error::CreateTube)?; + control_tubes.push(TaggedControlTube::VmMemory(gpu_host_tube)); - let (ioapic_host_socket, ioapic_device_socket) = - msg_socket::pair::<VmIrqResponse, VmIrqRequest>().map_err(Error::CreateSocket)?; - control_sockets.push(TaggedControlSocket::VmIrq(ioapic_host_socket)); + let (ioapic_host_tube, ioapic_device_tube) = Tube::pair().map_err(Error::CreateTube)?; + control_tubes.push(TaggedControlTube::VmIrq(ioapic_host_tube)); let battery = if cfg.battery_type.is_some() { let jail = match simple_jail(&cfg, "battery")? { @@ -2390,19 +2478,20 @@ where .iter() .filter(|sd| sd.kind == SharedDirKind::FS) .count(); - let mut fs_device_sockets = Vec::with_capacity(fs_count); + let mut fs_device_tubes = Vec::with_capacity(fs_count); for _ in 0..fs_count { - let (fs_host_socket, fs_device_socket) = - msg_socket::pair::<VmResponse, FsMappingRequest>().map_err(Error::CreateSocket)?; - control_sockets.push(TaggedControlSocket::Fs(fs_host_socket)); - fs_device_sockets.push(fs_device_socket); + let (fs_host_tube, fs_device_tube) = Tube::pair().map_err(Error::CreateTube)?; + control_tubes.push(TaggedControlTube::Fs(fs_host_tube)); + fs_device_tubes.push(fs_device_tube); } - let linux: RunnableLinuxVm<_, Vcpu, _> = Arch::build_vm( + #[cfg_attr(not(feature = "direct"), allow(unused_mut))] + let mut linux: RunnableLinuxVm<_, Vcpu, _> = Arch::build_vm( components, &cfg.serial_parameters, simple_jail(&cfg, "serial")?, battery, + vm, |mem, vm, sys_allocator, exit_evt| { create_devices( &cfg, @@ -2410,29 +2499,75 @@ where vm, sys_allocator, exit_evt, - &mut control_sockets, - wayland_device_socket, - gpu_device_socket, - balloon_device_socket, - &mut disk_device_sockets, - &mut pmem_device_sockets, - &mut fs_device_sockets, + &mut control_tubes, + wayland_device_tube, + gpu_device_tube, + balloon_device_tube, + &mut disk_device_tubes, + &mut pmem_device_tubes, + &mut fs_device_tubes, usb_provider, Arc::clone(&map_request), ) }, - create_vm, - |vm, vcpu_count| create_irq_chip(vm, vcpu_count, ioapic_device_socket), + |vm, vcpu_count| create_irq_chip(vm, vcpu_count, ioapic_device_tube), ) .map_err(Error::BuildVm)?; + #[cfg(feature = "direct")] + if let Some(pmio) = &cfg.direct_pmio { + let direct_io = + Arc::new(devices::DirectIo::new(&pmio.path, false).map_err(Error::DirectIo)?); + for range in pmio.ranges.iter() { + linux + .io_bus + .insert_sync(direct_io.clone(), range.0, range.1) + .unwrap(); + } + }; + + #[cfg(feature = "direct")] + let mut irqs = Vec::new(); + + #[cfg(feature = "direct")] + for irq in &cfg.direct_level_irq { + if !linux.resources.reserve_irq(*irq) { + warn!("irq {} already reserved.", irq); + } + let trigger = Event::new().map_err(Error::CreateEvent)?; + let resample = Event::new().map_err(Error::CreateEvent)?; + linux + .irq_chip + .register_irq_event(*irq, &trigger, Some(&resample)) + .unwrap(); + let direct_irq = + devices::DirectIrq::new(trigger, Some(resample)).map_err(Error::DirectIrq)?; + direct_irq.irq_enable(*irq).map_err(Error::DirectIrq)?; + irqs.push(direct_irq); + } + + #[cfg(feature = "direct")] + for irq in &cfg.direct_edge_irq { + if !linux.resources.reserve_irq(*irq) { + warn!("irq {} already reserved.", irq); + } + let trigger = Event::new().map_err(Error::CreateEvent)?; + linux + .irq_chip + .register_irq_event(*irq, &trigger, None) + .unwrap(); + let direct_irq = devices::DirectIrq::new(trigger, None).map_err(Error::DirectIrq)?; + direct_irq.irq_enable(*irq).map_err(Error::DirectIrq)?; + irqs.push(direct_irq); + } + run_control( linux, control_server_socket, - control_sockets, - balloon_host_socket, - &disk_host_sockets, - usb_control_socket, + control_tubes, + balloon_host_tube, + &disk_host_tubes, + usb_control_tube, sigchld_fd, cfg.sandbox, Arc::clone(&map_request), @@ -2441,8 +2576,8 @@ where ) } -/// Signals all running VCPUs to vmexit, sends VmRunMode message to each VCPU channel, and tells -/// `irq_chip` to stop blocking halted VCPUs. The channel message is set first because both the +/// Signals all running VCPUs to vmexit, sends VmRunMode message to each VCPU tube, and tells +/// `irq_chip` to stop blocking halted VCPUs. The tube message is set first because both the /// signal and the irq_chip kick could cause the VCPU thread to continue through the VCPU run /// loop. fn kick_all_vcpus( @@ -2450,8 +2585,8 @@ fn kick_all_vcpus( irq_chip: &impl IrqChip, run_mode: &VmRunMode, ) { - for (handle, channel) in vcpu_handles { - if let Err(e) = channel.send(VcpuControl::RunState(run_mode.clone())) { + for (handle, tube) in vcpu_handles { + if let Err(e) = tube.send(VcpuControl::RunState(run_mode.clone())) { error!("failed to send VmRunMode: {}", e); } let _ = handle.kill(SIGRTMIN() + 0); @@ -2626,10 +2761,10 @@ impl BalloonPolicy { fn run_control<V: VmArch + 'static, Vcpu: VcpuArch + 'static, I: IrqChipArch + 'static>( mut linux: RunnableLinuxVm<V, Vcpu, I>, control_server_socket: Option<UnlinkUnixSeqpacketListener>, - mut control_sockets: Vec<TaggedControlSocket>, - balloon_host_socket: BalloonControlRequestSocket, - disk_host_sockets: &[DiskControlRequestSocket], - usb_control_socket: UsbControlSocket, + mut control_tubes: Vec<TaggedControlTube>, + balloon_host_tube: Tube, + disk_host_tubes: &[Tube], + usb_control_tube: Tube, sigchld_fd: SignalFd, sandbox: bool, map_request: Arc<Mutex<Option<ExternalMapping>>>, @@ -2664,7 +2799,7 @@ fn run_control<V: VmArch + 'static, Vcpu: VcpuArch + 'static, I: IrqChipArch + ' .add(socket_server, Token::VmControlServer) .map_err(Error::WaitContextAdd)?; } - for (index, socket) in control_sockets.iter().enumerate() { + for (index, socket) in control_tubes.iter().enumerate() { wait_ctx .add(socket.as_ref(), Token::VmControl { index }) .map_err(Error::WaitContextAdd)?; @@ -2696,7 +2831,7 @@ fn run_control<V: VmArch + 'static, Vcpu: VcpuArch + 'static, I: IrqChipArch + ' // Listen for balloon statistics from the guest so we can balance. wait_ctx - .add(&balloon_host_socket, Token::BalloonResult) + .add(&balloon_host_tube, Token::BalloonResult) .map_err(Error::WaitContextAdd)?; Some(BalloonPolicy::new( linux.vm.get_memory().memory_size() as i64, @@ -2766,13 +2901,13 @@ fn run_control<V: VmArch + 'static, Vcpu: VcpuArch + 'static, I: IrqChipArch + ' #[cfg(all(target_arch = "x86_64", feature = "gdb"))] // Spawn GDB thread. - if let Some((gdb_port_num, gdb_control_socket)) = linux.gdb.take() { + if let Some((gdb_port_num, gdb_control_tube)) = linux.gdb.take() { let to_vcpu_channels = vcpu_handles .iter() .map(|(_handle, channel)| channel.clone()) .collect(); let target = GdbStub::new( - gdb_control_socket, + gdb_control_tube, to_vcpu_channels, from_vcpu_channel.unwrap(), // Must succeed to unwrap() ); @@ -2834,12 +2969,12 @@ fn run_control<V: VmArch + 'static, Vcpu: VcpuArch + 'static, I: IrqChipArch + ' Token::BalanceMemory => { balancemem_timer.wait().map_err(Error::Timer)?; let command = BalloonControlCommand::Stats {}; - if let Err(e) = balloon_host_socket.send(&command) { + if let Err(e) = balloon_host_tube.send(&command) { warn!("failed to send stats request to balloon device: {}", e); } } Token::BalloonResult => { - match balloon_host_socket.recv() { + match balloon_host_tube.recv() { Ok(BalloonControlResult::Stats { stats, balloon_actual: balloon_actual_u, @@ -2860,7 +2995,7 @@ fn run_control<V: VmArch + 'static, Vcpu: VcpuArch + 'static, I: IrqChipArch + ' let target = max((balloon_actual_u as i64) + delta, 0) as u64; let command = BalloonControlCommand::Adjust { num_bytes: target }; - if let Err(e) = balloon_host_socket.send(&command) { + if let Err(e) = balloon_host_tube.send(&command) { warn!( "failed to send memory value to balloon device: {}", e @@ -2883,31 +3018,30 @@ fn run_control<V: VmArch + 'static, Vcpu: VcpuArch + 'static, I: IrqChipArch + ' .add( &socket, Token::VmControl { - index: control_sockets.len(), + index: control_tubes.len(), }, ) .map_err(Error::WaitContextAdd)?; - control_sockets - .push(TaggedControlSocket::Vm(MsgSocket::new(socket))); + control_tubes.push(TaggedControlTube::Vm(Tube::new(socket))); } Err(e) => error!("failed to accept socket: {}", e), } } } Token::VmControl { index } => { - if let Some(socket) = control_sockets.get(index) { + if let Some(socket) = control_tubes.get(index) { match socket { - TaggedControlSocket::Vm(socket) => match socket.recv() { + TaggedControlTube::Vm(tube) => match tube.recv::<VmRequest>() { Ok(request) => { let mut run_mode_opt = None; let response = request.execute( &mut run_mode_opt, - &balloon_host_socket, - disk_host_sockets, - &usb_control_socket, + &balloon_host_tube, + disk_host_tubes, + &usb_control_tube, &mut linux.bat_control, ); - if let Err(e) = socket.send(&response) { + if let Err(e) = tube.send(&response) { error!("failed to send VmResponse: {}", e); } if let Some(run_mode) = run_mode_opt { @@ -2930,34 +3064,36 @@ fn run_control<V: VmArch + 'static, Vcpu: VcpuArch + 'static, I: IrqChipArch + ' } } Err(e) => { - if let MsgError::RecvZero = e { + if let TubeError::Disconnected = e { vm_control_indices_to_remove.push(index); } else { error!("failed to recv VmRequest: {}", e); } } }, - TaggedControlSocket::VmMemory(socket) => match socket.recv() { - Ok(request) => { - let response = request.execute( - &mut linux.vm, - &mut linux.resources, - Arc::clone(&map_request), - &mut gralloc, - ); - if let Err(e) = socket.send(&response) { - error!("failed to send VmMemoryControlResponse: {}", e); + TaggedControlTube::VmMemory(tube) => { + match tube.recv::<VmMemoryRequest>() { + Ok(request) => { + let response = request.execute( + &mut linux.vm, + &mut linux.resources, + Arc::clone(&map_request), + &mut gralloc, + ); + if let Err(e) = tube.send(&response) { + error!("failed to send VmMemoryControlResponse: {}", e); + } } - } - Err(e) => { - if let MsgError::RecvZero = e { - vm_control_indices_to_remove.push(index); - } else { - error!("failed to recv VmMemoryControlRequest: {}", e); + Err(e) => { + if let TubeError::Disconnected = e { + vm_control_indices_to_remove.push(index); + } else { + error!("failed to recv VmMemoryControlRequest: {}", e); + } } } - }, - TaggedControlSocket::VmIrq(socket) => match socket.recv() { + } + TaggedControlTube::VmIrq(tube) => match tube.recv::<VmIrqRequest>() { Ok(request) => { let response = { let irq_chip = &mut linux.irq_chip; @@ -2990,43 +3126,45 @@ fn run_control<V: VmArch + 'static, Vcpu: VcpuArch + 'static, I: IrqChipArch + ' &mut linux.resources, ) }; - if let Err(e) = socket.send(&response) { + if let Err(e) = tube.send(&response) { error!("failed to send VmIrqResponse: {}", e); } } Err(e) => { - if let MsgError::RecvZero = e { + if let TubeError::Disconnected = e { vm_control_indices_to_remove.push(index); } else { error!("failed to recv VmIrqRequest: {}", e); } } }, - TaggedControlSocket::VmMsync(socket) => match socket.recv() { - Ok(request) => { - let response = request.execute(&mut linux.vm); - if let Err(e) = socket.send(&response) { - error!("failed to send VmMsyncResponse: {}", e); + TaggedControlTube::VmMsync(tube) => { + match tube.recv::<VmMsyncRequest>() { + Ok(request) => { + let response = request.execute(&mut linux.vm); + if let Err(e) = tube.send(&response) { + error!("failed to send VmMsyncResponse: {}", e); + } } - } - Err(e) => { - if let MsgError::BadRecvSize { actual: 0, .. } = e { - vm_control_indices_to_remove.push(index); - } else { - error!("failed to recv VmMsyncRequest: {}", e); + Err(e) => { + if let TubeError::Disconnected = e { + vm_control_indices_to_remove.push(index); + } else { + error!("failed to recv VmMsyncRequest: {}", e); + } } } - }, - TaggedControlSocket::Fs(socket) => match socket.recv() { + } + TaggedControlTube::Fs(tube) => match tube.recv::<FsMappingRequest>() { Ok(request) => { let response = request.execute(&mut linux.vm, &mut linux.resources); - if let Err(e) = socket.send(&response) { + if let Err(e) = tube.send(&response) { error!("failed to send VmResponse: {}", e); } } Err(e) => { - if let MsgError::BadRecvSize { actual: 0, .. } = e { + if let TubeError::Disconnected = e { vm_control_indices_to_remove.push(index); } else { error!("failed to recv VmResponse: {}", e); @@ -3050,15 +3188,14 @@ fn run_control<V: VmArch + 'static, Vcpu: VcpuArch + 'static, I: IrqChipArch + ' Token::VmControlServer => {} Token::VmControl { index } => { // It's possible more data is readable and buffered while the socket is hungup, - // so don't delete the socket from the poll context until we're sure all the + // so don't delete the tube from the poll context until we're sure all the // data is read. - match control_sockets + if control_tubes .get(index) - .map(|s| s.as_ref().get_readable_bytes()) + .map(|s| !s.as_ref().is_packet_ready()) + .unwrap_or(false) { - Some(Ok(0)) | Some(Err(_)) => vm_control_indices_to_remove.push(index), - Some(Ok(x)) => info!("control index {} has {} bytes readable", index, x), - _ => {} + vm_control_indices_to_remove.push(index); } } } @@ -3077,7 +3214,7 @@ fn run_control<V: VmArch + 'static, Vcpu: VcpuArch + 'static, I: IrqChipArch + ' // now belongs to a different socket, the control loop will start to interact with // sockets that might not be ready to use. This can cause incorrect hangup detection or // blocking on a socket that will never be ready. See also: crbug.com/1019986 - if let Some(socket) = control_sockets.get(index) { + if let Some(socket) = control_tubes.get(index) { wait_ctx.delete(socket).map_err(Error::WaitContextDelete)?; } @@ -3085,10 +3222,10 @@ fn run_control<V: VmArch + 'static, Vcpu: VcpuArch + 'static, I: IrqChipArch + ' // `swap_remove`. After this line, the socket at `index` is not the one from // `vm_control_indices_to_remove`. Because of this socket's change in index, we need to // use `wait_ctx.modify` to change the associated index in its `Token::VmControl`. - control_sockets.swap_remove(index); - if let Some(socket) = control_sockets.get(index) { + control_tubes.swap_remove(index); + if let Some(tube) = control_tubes.get(index) { wait_ctx - .modify(socket, EventType::Read, Token::VmControl { index }) + .modify(tube, EventType::Read, Token::VmControl { index }) .map_err(Error::WaitContextAdd)?; } } diff --git a/src/main.rs b/src/main.rs index 4fedde676..e86845a65 100644 --- a/src/main.rs +++ b/src/main.rs @@ -8,10 +8,8 @@ pub mod panic_hook; use std::collections::BTreeMap; use std::default::Default; -use std::fmt; use std::fs::{File, OpenOptions}; use std::io::{BufRead, BufReader}; -use std::num::ParseIntError; use std::path::{Path, PathBuf}; use std::str::FromStr; use std::string::String; @@ -22,15 +20,13 @@ use arch::{ set_default_serial_parameters, Pstore, SerialHardware, SerialParameters, SerialType, VcpuAffinity, }; -use base::{ - debug, error, getpid, info, kill_process_group, net::UnixSeqpacket, reap_child, syslog, - validate_raw_descriptor, warn, FromRawDescriptor, IntoRawDescriptor, RawDescriptor, - SafeDescriptor, -}; +use base::{debug, error, getpid, info, kill_process_group, reap_child, syslog, warn}; +#[cfg(feature = "direct")] +use crosvm::DirectIoOption; use crosvm::{ argument::{self, print_help, set_arguments, Argument}, platform, BindMount, Config, DiskOption, Executable, GidMap, SharedDir, TouchDeviceOption, - DISK_ID_LEN, + VhostUserFsOption, VhostUserOption, DISK_ID_LEN, }; #[cfg(feature = "gpu")] use devices::virtio::gpu::{GpuMode, GpuParameters}; @@ -38,11 +34,12 @@ use devices::ProtectionType; #[cfg(feature = "audio")] use devices::{Ac97Backend, Ac97Parameters}; use disk::QcowFile; -use msg_socket::{MsgReceiver, MsgSender, MsgSocket}; use vm_control::{ - BalloonControlCommand, BatControlCommand, BatControlResult, BatteryType, DiskControlCommand, - MaybeOwnedDescriptor, UsbControlCommand, UsbControlResult, VmControlRequestSocket, VmRequest, - VmResponse, USB_CONTROL_MAX_PORTS, + client::{ + do_modify_battery, do_usb_attach, do_usb_detach, do_usb_list, handle_request, vms_request, + ModifyUsbError, ModifyUsbResult, + }, + BalloonControlCommand, BatteryType, DiskControlCommand, UsbControlResult, VmRequest, }; fn executable_is_plugin(executable: &Option<Executable>) -> bool { @@ -175,12 +172,12 @@ fn parse_gpu_options(s: Option<&str>) -> argument::Result<GpuParameters> { for (k, v) in opts { match k { // Deprecated: Specifying --gpu=<mode> Not great as the mode can be set multiple - // times if the user specifies several modes (--gpu=2d,3d,gfxstream) + // times if the user specifies several modes (--gpu=2d,virglrenderer,gfxstream) "2d" | "2D" => { gpu_params.mode = GpuMode::Mode2D; } - "3d" | "3D" => { - gpu_params.mode = GpuMode::Mode3D; + "3d" | "3D" | "virglrenderer" => { + gpu_params.mode = GpuMode::ModeVirglRenderer; } #[cfg(feature = "gfxstream")] "gfxstream" => { @@ -191,8 +188,8 @@ fn parse_gpu_options(s: Option<&str>) -> argument::Result<GpuParameters> { "2d" | "2D" => { gpu_params.mode = GpuMode::Mode2D; } - "3d" | "3D" => { - gpu_params.mode = GpuMode::Mode3D; + "3d" | "3D" | "virglrenderer" => { + gpu_params.mode = GpuMode::ModeVirglRenderer; } #[cfg(feature = "gfxstream")] "gfxstream" => { @@ -202,7 +199,7 @@ fn parse_gpu_options(s: Option<&str>) -> argument::Result<GpuParameters> { return Err(argument::Error::InvalidValue { value: v.to_string(), expected: String::from( - "gpu parameter 'backend' should be one of (2d|3d|gfxstream)", + "gpu parameter 'backend' should be one of (2d|virglrenderer|gfxstream)", ), }); } @@ -303,15 +300,17 @@ fn parse_gpu_options(s: Option<&str>) -> argument::Result<GpuParameters> { } } } - #[cfg(feature = "gfxstream")] "vulkan" => { - vulkan_specified = true; + #[cfg(feature = "gfxstream")] + { + vulkan_specified = true; + } match v { "true" | "" => { - gpu_params.gfxstream_support_vulkan = true; + gpu_params.use_vulkan = true; } "false" => { - gpu_params.gfxstream_support_vulkan = false; + gpu_params.use_vulkan = false; } _ => { return Err(argument::Error::InvalidValue { @@ -345,6 +344,20 @@ fn parse_gpu_options(s: Option<&str>) -> argument::Result<GpuParameters> { } "cache-path" => gpu_params.cache_path = Some(v.to_string()), "cache-size" => gpu_params.cache_size = Some(v.to_string()), + "udmabuf" => match v { + "true" | "" => { + gpu_params.udmabuf = true; + } + "false" => { + gpu_params.udmabuf = false; + } + _ => { + return Err(argument::Error::InvalidValue { + value: v.to_string(), + expected: String::from("gpu parameter 'udmabuf' should be a boolean"), + }); + } + }, "" => {} _ => { return Err(argument::Error::UnknownArgument(format!( @@ -358,12 +371,16 @@ fn parse_gpu_options(s: Option<&str>) -> argument::Result<GpuParameters> { #[cfg(feature = "gfxstream")] { - if vulkan_specified || syncfd_specified || angle_specified { + if !vulkan_specified && gpu_params.mode == GpuMode::ModeGfxstream { + gpu_params.use_vulkan = true; + } + + if syncfd_specified || angle_specified { match gpu_params.mode { GpuMode::ModeGfxstream => {} _ => { return Err(argument::Error::UnknownArgument( - "gpu parameter vulkan and syncfd are only supported for gfxstream backend" + "gpu parameter syncfd and angle are only supported for gfxstream backend" .to_string(), )); } @@ -398,7 +415,15 @@ fn parse_ac97_options(s: &str) -> argument::Result<Ac97Parameters> { argument::Error::Syntax(format!("invalid capture option: {}", e)) })?; } - #[cfg(target_os = "linux")] + "client_type" => { + ac97_params + .set_client_type(v) + .map_err(|e| argument::Error::InvalidValue { + value: v.to_string(), + expected: e.to_string(), + })?; + } + #[cfg(any(target_os = "linux", target_os = "android"))] "server" => { ac97_params.vios_server_path = Some( @@ -418,7 +443,7 @@ fn parse_ac97_options(s: &str) -> argument::Result<Ac97Parameters> { } // server is required for and exclusive to vios backend - #[cfg(target_os = "linux")] + #[cfg(any(target_os = "linux", target_os = "android"))] match ac97_params.backend { Ac97Backend::VIOS => { if ac97_params.vios_server_path.is_none() { @@ -658,6 +683,64 @@ fn parse_battery_options(s: Option<&str>) -> argument::Result<BatteryType> { Ok(battery_type) } +#[cfg(feature = "direct")] +fn parse_direct_io_options(s: Option<&str>) -> argument::Result<DirectIoOption> { + let s = s.ok_or(argument::Error::ExpectedValue(String::from( + "expected path@range[,range] value", + )))?; + let parts: Vec<&str> = s.splitn(2, '@').collect(); + if parts.len() != 2 { + return Err(argument::Error::InvalidValue { + value: s.to_string(), + expected: String::from("missing port range, use /path@X-Y,Z,.. syntax"), + }); + } + let path = PathBuf::from(parts[0]); + if !path.exists() { + return Err(argument::Error::InvalidValue { + value: parts[0].to_owned(), + expected: String::from("the path does not exist"), + }); + }; + let ranges: argument::Result<Vec<(u64, u64)>> = parts[1] + .split(',') + .map(|frag| frag.split('-')) + .map(|mut range| { + let base = range + .next() + .map(|v| v.parse::<u64>()) + .map_or(Ok(None), |r| r.map(Some)); + let last = range + .next() + .map(|v| v.parse::<u64>()) + .map_or(Ok(None), |r| r.map(Some)); + (base, last) + }) + .map(|range| match range { + (Ok(Some(base)), Ok(None)) => Ok((base, 1)), + (Ok(Some(base)), Ok(Some(last))) => { + Ok((base, last.saturating_sub(base).saturating_add(1))) + } + (Err(e), _) => Err(argument::Error::InvalidValue { + value: e.to_string(), + expected: String::from("invalid base range value"), + }), + (_, Err(e)) => Err(argument::Error::InvalidValue { + value: e.to_string(), + expected: String::from("invalid last range value"), + }), + _ => Err(argument::Error::InvalidValue { + value: s.to_owned(), + expected: String::from("invalid range format"), + }), + }) + .collect(); + Ok(DirectIoOption { + path, + ranges: ranges?, + }) +} + fn set_argument(cfg: &mut Config, name: &str, value: Option<&str>) -> argument::Result<()> { match name { "" => { @@ -676,6 +759,39 @@ fn set_argument(cfg: &mut Config, name: &str, value: Option<&str>) -> argument:: } cfg.executable_path = Some(Executable::Kernel(kernel_path)); } + "kvm-device" => { + let kvm_device_path = PathBuf::from(value.unwrap()); + if !kvm_device_path.exists() { + return Err(argument::Error::InvalidValue { + value: value.unwrap().to_owned(), + expected: String::from("this kvm device path does not exist"), + }); + } + + cfg.kvm_device_path = kvm_device_path; + } + "vhost-vsock-device" => { + let vhost_vsock_device_path = PathBuf::from(value.unwrap()); + if !vhost_vsock_device_path.exists() { + return Err(argument::Error::InvalidValue { + value: value.unwrap().to_owned(), + expected: String::from("this vhost-vsock device path does not exist"), + }); + } + + cfg.vhost_vsock_device_path = vhost_vsock_device_path; + } + "vhost-net-device" => { + let vhost_net_device_path = PathBuf::from(value.unwrap()); + if !vhost_net_device_path.exists() { + return Err(argument::Error::InvalidValue { + value: value.unwrap().to_owned(), + expected: String::from("this vhost-vsock device path does not exist"), + }); + } + + cfg.vhost_net_device_path = vhost_net_device_path; + } "android-fstab" => { if cfg.android_fstab.is_some() && !cfg.android_fstab.as_ref().unwrap().as_os_str().is_empty() @@ -750,6 +866,9 @@ fn set_argument(cfg: &mut Config, name: &str, value: Option<&str>) -> argument:: })?, ) } + "hugepages" => { + cfg.hugepages = true; + } #[cfg(feature = "audio")] "ac97" => { let ac97_params = parse_ac97_options(value.unwrap())?; @@ -1544,6 +1663,94 @@ fn set_argument(cfg: &mut Config, name: &str, value: Option<&str>) -> argument:: * 1024 * 1024; // cfg.balloon_bias is in bytes. } + "vhost-user-blk" => cfg.vhost_user_blk.push(VhostUserOption { + socket: PathBuf::from(value.unwrap()), + }), + "vhost-user-net" => cfg.vhost_user_net.push(VhostUserOption { + socket: PathBuf::from(value.unwrap()), + }), + "vhost-user-fs" => { + // (socket:tag) + let param = value.unwrap(); + let mut components = param.split(':'); + let socket = + PathBuf::from( + components + .next() + .ok_or_else(|| argument::Error::InvalidValue { + value: param.to_owned(), + expected: String::from("missing socket path for `vhost-user-fs`"), + })?, + ); + let tag = components + .next() + .ok_or_else(|| argument::Error::InvalidValue { + value: param.to_owned(), + expected: String::from("missing tag for `vhost-user-fs`"), + })? + .to_owned(); + cfg.vhost_user_fs.push(VhostUserFsOption { socket, tag }); + } + #[cfg(feature = "direct")] + "direct-pmio" => { + if cfg.direct_pmio.is_some() { + return Err(argument::Error::TooManyArguments( + "`direct_pmio` already given".to_owned(), + )); + } + cfg.direct_pmio = Some(parse_direct_io_options(value)?); + } + #[cfg(feature = "direct")] + "direct-level-irq" => { + cfg.direct_level_irq + .push( + value + .unwrap() + .parse() + .map_err(|_| argument::Error::InvalidValue { + value: value.unwrap().to_owned(), + expected: String::from( + "this value for `direct-level-irq` must be an unsigned integer", + ), + })?, + ); + } + #[cfg(feature = "direct")] + "direct-edge-irq" => { + cfg.direct_edge_irq + .push( + value + .unwrap() + .parse() + .map_err(|_| argument::Error::InvalidValue { + value: value.unwrap().to_owned(), + expected: String::from( + "this value for `direct-edge-irq` must be an unsigned integer", + ), + })?, + ); + } + "dmi" => { + if cfg.dmi_path.is_some() { + return Err(argument::Error::TooManyArguments( + "`dmi` already given".to_owned(), + )); + } + let dmi_path = PathBuf::from(value.unwrap()); + if !dmi_path.exists() { + return Err(argument::Error::InvalidValue { + value: value.unwrap().to_owned(), + expected: String::from("the dmi path does not exist"), + }); + } + if !dmi_path.is_dir() { + return Err(argument::Error::InvalidValue { + value: value.unwrap().to_owned(), + expected: String::from("the dmi path should be directory"), + }); + } + cfg.dmi_path = Some(dmi_path); + } "help" => return Err(argument::Error::PrintHelp), _ => unreachable!(), } @@ -1603,6 +1810,9 @@ fn validate_arguments(cfg: &mut Config) -> std::result::Result<(), argument::Err fn run_vm(args: std::env::Args) -> std::result::Result<(), ()> { let arguments = &[Argument::positional("KERNEL", "bzImage of kernel to run"), + Argument::value("kvm-device", "PATH", "Path to the KVM device. (default /dev/kvm)"), + Argument::value("vhost-vsock-device", "PATH", "Path to the vhost-vsock device. (default /dev/vhost-vsock)"), + Argument::value("vhost-net-device", "PATH", "Path to the vhost-net device. (default /dev/vhost-net)"), Argument::value("android-fstab", "PATH", "Path to Android fstab"), Argument::short_value('i', "initrd", "PATH", "Initial ramdisk to load."), Argument::short_value('p', @@ -1618,6 +1828,7 @@ fn run_vm(args: std::env::Args) -> std::result::Result<(), ()> { "mem", "N", "Amount of guest memory in MiB. (default: 256)"), + Argument::flag("hugepages", "Advise the kernel to use Huge Pages for guest memory mappings."), Argument::short_value('r', "root", "PATH[,key=value[,key=value[,...]]", @@ -1644,13 +1855,14 @@ fn run_vm(args: std::env::Args) -> std::result::Result<(), ()> { Argument::value("net-vq-pairs", "N", "virtio net virtual queue paris. (default: 1)"), #[cfg(feature = "audio")] Argument::value("ac97", - "[backend=BACKEND,capture=true,capture_effect=EFFECT,shm-fd=FD,client-fd=FD,server-fd=FD]", + "[backend=BACKEND,capture=true,capture_effect=EFFECT,client_type=TYPE,shm-fd=FD,client-fd=FD,server-fd=FD]", "Comma separated key=value pairs for setting up Ac97 devices. Can be given more than once . Possible key values: backend=(null, cras, vios) - Where to route the audio device. If not provided, backend will default to null. `null` for /dev/null, cras for CRAS server and vios for VioS server. capture - Enable audio capture capture_effects - | separated effects to be enabled for recording. The only supported effect value now is EchoCancellation or aec. + client_type - Set specific client type for cras backend. server - The to the VIOS server (unix socket)."), Argument::value("serial", "type=TYPE,[hardware=HW,num=NUM,path=PATH,input=PATH,console,earlycon,stdin]", @@ -1712,15 +1924,15 @@ writeback=BOOL - Indicates whether the VM can use writeback caching (default: fa "[width=INT,height=INT]", "(EXPERIMENTAL) Comma separated key=value pairs for setting up a virtio-gpu device Possible key values: - backend=(2d|3d|gfxstream) - Which backend to use for virtio-gpu (determining rendering protocol) + backend=(2d|virglrenderer|gfxstream) - Which backend to use for virtio-gpu (determining rendering protocol) width=INT - The width of the virtual display connected to the virtio-gpu. height=INT - The height of the virtual display connected to the virtio-gpu. - egl[=true|=false] - If the virtio-gpu backend should use a EGL context for rendering. - glx[=true|=false] - If the virtio-gpu backend should use a GLX context for rendering. - surfaceless[=true|=false] - If the virtio-gpu backend should use a surfaceless context for rendering. - angle[=true|=false] - If the guest is using ANGLE (OpenGL on Vulkan) as its native OpenGL driver. + egl[=true|=false] - If the backend should use a EGL context for rendering. + glx[=true|=false] - If the backend should use a GLX context for rendering. + surfaceless[=true|=false] - If the backend should use a surfaceless context for rendering. + angle[=true|=false] - If the gfxstream backend should use ANGLE (OpenGL on Vulkan) as its native OpenGL driver. syncfd[=true|=false] - If the gfxstream backend should support EGL_ANDROID_native_fence_sync - vulkan[=true|=false] - If the gfxstream backend should support vulkan + vulkan[=true|=false] - If the backend should support vulkan "), #[cfg(feature = "tpm")] Argument::flag("software-tpm", "enable a software emulated trusted platform module device"), @@ -1749,6 +1961,17 @@ writeback=BOOL - Indicates whether the VM can use writeback caching (default: fa "), Argument::value("gdb", "PORT", "(EXPERIMENTAL) gdb on the given port"), Argument::value("balloon_bias_mib", "N", "Amount to bias balance of memory between host and guest as the balloon inflates, in MiB."), + Argument::value("vhost-user-blk", "SOCKET_PATH", "Path to a socket for vhost-user block"), + Argument::value("vhost-user-net", "SOCKET_PATH", "Path to a socket for vhost-user net"), + Argument::value("vhost-user-fs", "SOCKET_PATH:TAG", + "Path to a socket path for vhost-user fs, and tag for the shared dir"), + #[cfg(feature = "direct")] + Argument::value("direct-pmio", "PATH@RANGE[,RANGE[,...]]", "Path and ranges for direct port I/O access"), + #[cfg(feature = "direct")] + Argument::value("direct-level-irq", "irq", "Enable interrupt passthrough"), + #[cfg(feature = "direct")] + Argument::value("direct-edge-irq", "irq", "Enable interrupt passthrough"), + Argument::value("dmi", "DIR", "Directory with smbios_entry_point/DMI files"), Argument::short_flag('h', "help", "Print help message.")]; let mut cfg = Config::default(); @@ -1792,76 +2015,37 @@ writeback=BOOL - Indicates whether the VM can use writeback caching (default: fa } } -fn handle_request( - request: &VmRequest, - args: std::env::Args, -) -> std::result::Result<VmResponse, ()> { - let mut return_result = Err(()); - for socket_path in args { - match UnixSeqpacket::connect(&socket_path) { - Ok(s) => { - let socket: VmControlRequestSocket = MsgSocket::new(s); - if let Err(e) = socket.send(request) { - error!( - "failed to send request to socket at '{}': {}", - socket_path, e - ); - return_result = Err(()); - continue; - } - match socket.recv() { - Ok(response) => return_result = Ok(response), - Err(e) => { - error!( - "failed to send request to socket at2 '{}': {}", - socket_path, e - ); - return_result = Err(()); - continue; - } - } - } - Err(e) => { - error!("failed to connect to socket at '{}': {}", socket_path, e); - return_result = Err(()); - } - } - } - - return_result -} - -fn vms_request(request: &VmRequest, args: std::env::Args) -> std::result::Result<(), ()> { - let response = handle_request(request, args)?; - info!("request response was {}", response); - Ok(()) -} - -fn stop_vms(args: std::env::Args) -> std::result::Result<(), ()> { +fn stop_vms(mut args: std::env::Args) -> std::result::Result<(), ()> { if args.len() == 0 { print_help("crosvm stop", "VM_SOCKET...", &[]); println!("Stops the crosvm instance listening on each `VM_SOCKET` given."); return Err(()); } - vms_request(&VmRequest::Exit, args) + let socket_path = &args.next().unwrap(); + let socket_path = Path::new(&socket_path); + vms_request(&VmRequest::Exit, socket_path) } -fn suspend_vms(args: std::env::Args) -> std::result::Result<(), ()> { +fn suspend_vms(mut args: std::env::Args) -> std::result::Result<(), ()> { if args.len() == 0 { print_help("crosvm suspend", "VM_SOCKET...", &[]); println!("Suspends the crosvm instance listening on each `VM_SOCKET` given."); return Err(()); } - vms_request(&VmRequest::Suspend, args) + let socket_path = &args.next().unwrap(); + let socket_path = Path::new(&socket_path); + vms_request(&VmRequest::Suspend, socket_path) } -fn resume_vms(args: std::env::Args) -> std::result::Result<(), ()> { +fn resume_vms(mut args: std::env::Args) -> std::result::Result<(), ()> { if args.len() == 0 { print_help("crosvm resume", "VM_SOCKET...", &[]); println!("Resumes the crosvm instance listening on each `VM_SOCKET` given."); return Err(()); } - vms_request(&VmRequest::Resume, args) + let socket_path = &args.next().unwrap(); + let socket_path = Path::new(&socket_path); + vms_request(&VmRequest::Resume, socket_path) } fn balloon_vms(mut args: std::env::Args) -> std::result::Result<(), ()> { @@ -1879,10 +2063,12 @@ fn balloon_vms(mut args: std::env::Args) -> std::result::Result<(), ()> { }; let command = BalloonControlCommand::Adjust { num_bytes }; - vms_request(&VmRequest::BalloonCommand(command), args) + let socket_path = &args.next().unwrap(); + let socket_path = Path::new(&socket_path); + vms_request(&VmRequest::BalloonCommand(command), socket_path) } -fn balloon_stats(args: std::env::Args) -> std::result::Result<(), ()> { +fn balloon_stats(mut args: std::env::Args) -> std::result::Result<(), ()> { if args.len() != 1 { print_help("crosvm balloon_stats", "VM_SOCKET", &[]); println!("Prints virtio balloon statistics for a `VM_SOCKET`."); @@ -1890,7 +2076,9 @@ fn balloon_stats(args: std::env::Args) -> std::result::Result<(), ()> { } let command = BalloonControlCommand::Stats {}; let request = &VmRequest::BalloonCommand(command); - let response = handle_request(request, args)?; + let socket_path = &args.next().unwrap(); + let socket_path = Path::new(&socket_path); + let response = handle_request(request, socket_path)?; println!("{}", response); Ok(()) } @@ -2013,47 +2201,11 @@ fn disk_cmd(mut args: std::env::Args) -> std::result::Result<(), ()> { } }; - vms_request(&request, args) -} - -enum ModifyUsbError { - ArgMissing(&'static str), - ArgParse(&'static str, String), - ArgParseInt(&'static str, String, ParseIntError), - FailedDescriptorValidate(base::Error), - PathDoesNotExist(PathBuf), - SocketFailed, - UnexpectedResponse(VmResponse), - UnknownCommand(String), - UsbControl(UsbControlResult), -} - -impl fmt::Display for ModifyUsbError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - use self::ModifyUsbError::*; - - match self { - ArgMissing(a) => write!(f, "argument missing: {}", a), - ArgParse(name, value) => { - write!(f, "failed to parse argument {} value `{}`", name, value) - } - ArgParseInt(name, value, e) => write!( - f, - "failed to parse integer argument {} value `{}`: {}", - name, value, e - ), - FailedDescriptorValidate(e) => write!(f, "failed to validate file descriptor: {}", e), - PathDoesNotExist(p) => write!(f, "path `{}` does not exist", p.display()), - SocketFailed => write!(f, "socket failed"), - UnexpectedResponse(r) => write!(f, "unexpected response: {}", r), - UnknownCommand(c) => write!(f, "unknown command: `{}`", c), - UsbControl(e) => write!(f, "{}", e), - } - } + let socket_path = &args.next().unwrap(); + let socket_path = Path::new(&socket_path); + vms_request(&request, socket_path) } -type ModifyUsbResult<T> = std::result::Result<T, ModifyUsbError>; - fn parse_bus_id_addr(v: &str) -> ModifyUsbResult<(u8, u8, u16, u16)> { debug!("parse_bus_id_addr: {}", v); let mut ids = v.split(':'); @@ -2078,27 +2230,6 @@ fn parse_bus_id_addr(v: &str) -> ModifyUsbResult<(u8, u8, u16, u16)> { } } -fn raw_descriptor_from_path(path: &Path) -> ModifyUsbResult<RawDescriptor> { - if !path.exists() { - return Err(ModifyUsbError::PathDoesNotExist(path.to_owned())); - } - let raw_descriptor = path - .file_name() - .and_then(|fd_osstr| fd_osstr.to_str()) - .map_or( - Err(ModifyUsbError::ArgParse( - "USB_DEVICE_PATH", - path.to_string_lossy().into_owned(), - )), - |fd_str| { - fd_str.parse::<libc::c_int>().map_err(|e| { - ModifyUsbError::ArgParseInt("USB_DEVICE_PATH", fd_str.to_owned(), e) - }) - }, - )?; - validate_raw_descriptor(raw_descriptor).map_err(ModifyUsbError::FailedDescriptorValidate) -} - fn usb_attach(mut args: std::env::Args) -> ModifyUsbResult<UsbControlResult> { let val = args .next() @@ -2108,39 +2239,13 @@ fn usb_attach(mut args: std::env::Args) -> ModifyUsbResult<UsbControlResult> { args.next() .ok_or(ModifyUsbError::ArgMissing("usb device path"))?, ); - let usb_file: Option<File> = if dev_path == Path::new("-") { - None - } else if dev_path.parent() == Some(Path::new("/proc/self/fd")) { - // Special case '/proc/self/fd/*' paths. The FD is already open, just use it. - // Safe because we will validate |raw_fd|. - Some(unsafe { File::from_raw_descriptor(raw_descriptor_from_path(&dev_path)?) }) - } else { - Some( - OpenOptions::new() - .read(true) - .write(true) - .open(&dev_path) - .map_err(|_| ModifyUsbError::UsbControl(UsbControlResult::FailedToOpenDevice))?, - ) - }; - let request = VmRequest::UsbCommand(UsbControlCommand::AttachDevice { - bus, - addr, - vid, - pid, - // Safe because we are transferring ownership to the rawdescriptor - descriptor: usb_file.map(|file| { - MaybeOwnedDescriptor::Owned(unsafe { - SafeDescriptor::from_raw_descriptor(file.into_raw_descriptor()) - }) - }), - }); - let response = handle_request(&request, args).map_err(|_| ModifyUsbError::SocketFailed)?; - match response { - VmResponse::UsbResponse(usb_resp) => Ok(usb_resp), - r => Err(ModifyUsbError::UnexpectedResponse(r)), - } + let socket_path = args + .next() + .ok_or(ModifyUsbError::ArgMissing("control socket path"))?; + let socket_path = Path::new(&socket_path); + + do_usb_attach(&socket_path, bus, addr, vid, pid, &dev_path) } fn usb_detach(mut args: std::env::Args) -> ModifyUsbResult<UsbControlResult> { @@ -2150,25 +2255,19 @@ fn usb_detach(mut args: std::env::Args) -> ModifyUsbResult<UsbControlResult> { p.parse::<u8>() .map_err(|e| ModifyUsbError::ArgParseInt("PORT", p.to_owned(), e)) })?; - let request = VmRequest::UsbCommand(UsbControlCommand::DetachDevice { port }); - let response = handle_request(&request, args).map_err(|_| ModifyUsbError::SocketFailed)?; - match response { - VmResponse::UsbResponse(usb_resp) => Ok(usb_resp), - r => Err(ModifyUsbError::UnexpectedResponse(r)), - } + let socket_path = args + .next() + .ok_or(ModifyUsbError::ArgMissing("control socket path"))?; + let socket_path = Path::new(&socket_path); + do_usb_detach(&socket_path, port) } -fn usb_list(args: std::env::Args) -> ModifyUsbResult<UsbControlResult> { - let mut ports: [u8; USB_CONTROL_MAX_PORTS] = Default::default(); - for (index, port) in ports.iter_mut().enumerate() { - *port = index as u8 - } - let request = VmRequest::UsbCommand(UsbControlCommand::ListDevice { ports }); - let response = handle_request(&request, args).map_err(|_| ModifyUsbError::SocketFailed)?; - match response { - VmResponse::UsbResponse(usb_resp) => Ok(usb_resp), - r => Err(ModifyUsbError::UnexpectedResponse(r)), - } +fn usb_list(mut args: std::env::Args) -> ModifyUsbResult<UsbControlResult> { + let socket_path = args + .next() + .ok_or(ModifyUsbError::ArgMissing("control socket path"))?; + let socket_path = Path::new(&socket_path); + do_usb_list(&socket_path) } fn modify_usb(mut args: std::env::Args) -> std::result::Result<(), ()> { @@ -2179,7 +2278,7 @@ fn modify_usb(mut args: std::env::Args) -> std::result::Result<(), ()> { } // This unwrap will not panic because of the above length check. - let command = args.next().unwrap(); + let command = &args.next().unwrap(); let result = match command.as_ref() { "attach" => usb_attach(args), "detach" => usb_detach(args), @@ -2199,16 +2298,22 @@ fn modify_usb(mut args: std::env::Args) -> std::result::Result<(), ()> { } fn print_usage() { - print_help("crosvm", "[stop|run]", &[]); + print_help("crosvm", "[command]", &[]); println!("Commands:"); - println!(" stop - Stops crosvm instances via their control sockets."); - println!(" run - Start a new crosvm instance."); + println!(" balloon - Set balloon size of the crosvm instance."); + println!(" balloon_stats - Prints virtio balloon statistics."); + println!(" battery - Modify battery."); println!(" create_qcow2 - Create a new qcow2 disk image file."); println!(" disk - Manage attached virtual disk devices."); + println!(" resume - Resumes the crosvm instance."); + println!(" run - Start a new crosvm instance."); + println!(" stop - Stops crosvm instances via their control sockets."); + println!(" suspend - Suspends the crosvm instance."); println!(" usb - Manage attached virtual USB devices."); println!(" version - Show package version."); } +#[allow(clippy::unnecessary_wraps)] fn pkg_version() -> std::result::Result<(), ()> { const VERSION: Option<&'static str> = option_env!("CARGO_PKG_VERSION"); const PKG_VERSION: Option<&'static str> = option_env!("PKG_VERSION"); @@ -2221,20 +2326,6 @@ fn pkg_version() -> std::result::Result<(), ()> { Ok(()) } -enum ModifyBatError { - BatControlErr(BatControlResult), -} - -impl fmt::Display for ModifyBatError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - use self::ModifyBatError::*; - - match self { - BatControlErr(e) => write!(f, "{}", e), - } - } -} - fn modify_battery(mut args: std::env::Args) -> std::result::Result<(), ()> { if args.len() < 4 { print_help("crosvm battery BATTERY_TYPE ", @@ -2247,27 +2338,10 @@ fn modify_battery(mut args: std::env::Args) -> std::result::Result<(), ()> { let property = args.next().unwrap(); let target = args.next().unwrap(); - let response = match battery_type.parse::<BatteryType>() { - Ok(type_) => match BatControlCommand::new(property, target) { - Ok(cmd) => { - let request = VmRequest::BatCommand(type_, cmd); - Ok(handle_request(&request, args)?) - } - Err(e) => Err(ModifyBatError::BatControlErr(e)), - }, - Err(e) => Err(ModifyBatError::BatControlErr(e)), - }; + let socket_path = args.next().unwrap(); + let socket_path = Path::new(&socket_path); - match response { - Ok(response) => { - println!("{}", response); - Ok(()) - } - Err(e) => { - println!("error {}", e); - Err(()) - } - } + do_modify_battery(&socket_path, &*battery_type, &*property, &*target) } fn crosvm_main() -> std::result::Result<(), ()> { @@ -2444,6 +2518,17 @@ mod tests { #[cfg(feature = "audio")] #[test] + fn parse_ac97_client_type() { + parse_ac97_options("backend=cras,capture=true,client_type=crosvm") + .expect("parse should have succeded"); + parse_ac97_options("backend=cras,capture=true,client_type=arcvm") + .expect("parse should have succeded"); + parse_ac97_options("backend=cras,capture=true,client_type=none") + .expect_err("parse should have failed"); + } + + #[cfg(feature = "audio")] + #[test] fn parse_ac97_vios_valid() { parse_ac97_options("backend=vios,server=/path/to/server") .expect("parse should have succeded"); @@ -2717,41 +2802,54 @@ mod tests { validate_arguments(&mut config).unwrap(); assert_eq!( config.virtio_switches.unwrap(), - PathBuf::from("/dev/switches-test")); + PathBuf::from("/dev/switches-test") + ); } - #[cfg(all(feature = "gpu", feature = "gfxstream"))] + #[cfg(feature = "gpu")] + #[test] + fn parse_gpu_options_default_vulkan_support() { + assert!( + !parse_gpu_options(Some("backend=virglrenderer")) + .unwrap() + .use_vulkan + ); + + #[cfg(feature = "gfxstream")] + assert!( + parse_gpu_options(Some("backend=gfxstream")) + .unwrap() + .use_vulkan + ); + } + + #[cfg(feature = "gpu")] #[test] - fn parse_gpu_options_gfxstream_with_vulkan_specified() { + fn parse_gpu_options_with_vulkan_specified() { + assert!(parse_gpu_options(Some("vulkan=true")).unwrap().use_vulkan); assert!( - parse_gpu_options(Some("backend=gfxstream,vulkan=true")) + parse_gpu_options(Some("backend=virglrenderer,vulkan=true")) .unwrap() - .gfxstream_support_vulkan + .use_vulkan ); assert!( - parse_gpu_options(Some("vulkan=true,backend=gfxstream")) + parse_gpu_options(Some("vulkan=true,backend=virglrenderer")) .unwrap() - .gfxstream_support_vulkan + .use_vulkan ); + assert!(!parse_gpu_options(Some("vulkan=false")).unwrap().use_vulkan); assert!( - !parse_gpu_options(Some("backend=gfxstream,vulkan=false")) + !parse_gpu_options(Some("backend=virglrenderer,vulkan=false")) .unwrap() - .gfxstream_support_vulkan + .use_vulkan ); assert!( - !parse_gpu_options(Some("vulkan=false,backend=gfxstream")) + !parse_gpu_options(Some("vulkan=false,backend=virglrenderer")) .unwrap() - .gfxstream_support_vulkan + .use_vulkan ); - assert!(parse_gpu_options(Some("backend=gfxstream,vulkan=invalid_value")).is_err()); - assert!(parse_gpu_options(Some("vulkan=invalid_value,backend=gfxstream")).is_err()); - } - - #[cfg(all(feature = "gpu", feature = "gfxstream"))] - #[test] - fn parse_gpu_options_not_gfxstream_with_vulkan_specified() { - assert!(parse_gpu_options(Some("backend=3d,vulkan=true")).is_err()); - assert!(parse_gpu_options(Some("vulkan=true,backend=3d")).is_err()); + assert!(parse_gpu_options(Some("backend=virglrenderer,vulkan=invalid_value")).is_err()); + assert!(parse_gpu_options(Some("vulkan=invalid_value,backend=virglrenderer")).is_err()); } #[cfg(all(feature = "gpu", feature = "gfxstream"))] @@ -2784,8 +2882,8 @@ mod tests { #[cfg(all(feature = "gpu", feature = "gfxstream"))] #[test] fn parse_gpu_options_not_gfxstream_with_syncfd_specified() { - assert!(parse_gpu_options(Some("backend=3d,syncfd=true")).is_err()); - assert!(parse_gpu_options(Some("syncfd=true,backend=3d")).is_err()); + assert!(parse_gpu_options(Some("backend=virglrenderer,syncfd=true")).is_err()); + assert!(parse_gpu_options(Some("syncfd=true,backend=virglrenderer")).is_err()); } #[test] diff --git a/src/plugin/mod.rs b/src/plugin/mod.rs index f900bf599..8236425c0 100644 --- a/src/plugin/mod.rs +++ b/src/plugin/mod.rs @@ -34,7 +34,7 @@ use base::{ use kvm::{Cap, Datamatch, IoeventAddress, Kvm, Vcpu, VcpuExit, Vm}; use minijail::{self, Minijail}; use net_util::{Error as TapError, Tap, TapT}; -use vm_memory::GuestMemory; +use vm_memory::{GuestMemory, MemoryPolicy}; use self::process::*; use self::vcpu::*; @@ -418,7 +418,7 @@ pub fn run_vcpus( if use_kvm_signals { unsafe { - extern "C" fn handle_signal() {} + extern "C" fn handle_signal(_: c_int) {} // Our signal handler does nothing and is trivially async signal safe. // We need to install this signal handler even though we do block // the signal below, to ensure that this signal will interrupt @@ -430,7 +430,7 @@ pub fn run_vcpus( block_signal(SIGRTMIN() + 0).expect("failed to block signal"); } else { unsafe { - extern "C" fn handle_signal() { + extern "C" fn handle_signal(_: c_int) { Vcpu::set_local_immediate_exit(true); } register_rt_signal_handler(SIGRTMIN() + 0, handle_signal) @@ -691,7 +691,12 @@ pub fn run_config(cfg: Config) -> Result<()> { }; let vcpu_count = cfg.vcpu_count.unwrap_or(1) as u32; let mem = GuestMemory::new(&[]).unwrap(); - let kvm = Kvm::new().map_err(Error::CreateKvm)?; + let mut mem_policy = MemoryPolicy::empty(); + if cfg.hugepages { + mem_policy |= MemoryPolicy::USE_HUGEPAGES; + } + mem.set_memory_policy(mem_policy); + let kvm = Kvm::new_with_path(&cfg.kvm_device_path).map_err(Error::CreateKvm)?; let mut vm = Vm::new(&kvm, mem).map_err(Error::CreateVm)?; vm.create_irq_chip().map_err(Error::CreateIrqChip)?; vm.create_pit().map_err(Error::CreatePIT)?; diff --git a/src/plugin/process.rs b/src/plugin/process.rs index 7c61c0908..c34b708bb 100644 --- a/src/plugin/process.rs +++ b/src/plugin/process.rs @@ -365,7 +365,7 @@ impl Process { _ => {} } let mem = MemoryMappingBuilder::new(length as usize) - .from_descriptor(&shm) + .from_shared_memory(&shm) .offset(offset) .build() .map_err(mmap_to_sys_err)?; diff --git a/sys_util/Cargo.toml b/sys_util/Cargo.toml index 99b994cdb..7edffd9f5 100644 --- a/sys_util/Cargo.toml +++ b/sys_util/Cargo.toml @@ -9,8 +9,9 @@ include = ["src/**/*", "Cargo.toml"] data_model = { path = "../data_model" } # provided by ebuild libc = "*" poll_token_derive = { version = "*", path = "poll_token_derive" } +serde = { version = "1", features = [ "derive" ] } +serde_json = "1" sync = { path = "../sync" } # provided by ebuild -syscall_defines = { path = "../syscall_defines" } # provided by ebuild tempfile = { path = "../tempfile" } # provided by ebuild [target.'cfg(target_os = "android")'.dependencies] diff --git a/sys_util/src/descriptor.rs b/sys_util/src/descriptor.rs index 5ee075a42..325b2b450 100644 --- a/sys_util/src/descriptor.rs +++ b/sys_util/src/descriptor.rs @@ -2,6 +2,7 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. +use std::convert::TryFrom; use std::fs::File; use std::io::{Stderr, Stdin, Stdout}; use std::mem; @@ -10,6 +11,8 @@ use std::ops::Drop; use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd}; use std::os::unix::net::{UnixDatagram, UnixStream}; +use serde::{Deserialize, Serialize}; + use crate::net::UnlinkUnixSeqpacketListener; use crate::{errno_result, PollToken, Result}; @@ -33,9 +36,24 @@ pub trait FromRawDescriptor { unsafe fn from_raw_descriptor(descriptor: RawDescriptor) -> Self; } +/// Clones `fd`, returning a new file descriptor that refers to the same open file description as +/// `fd`. The cloned fd will have the `FD_CLOEXEC` flag set but will not share any other file +/// descriptor flags with `fd`. +pub fn clone_fd(fd: &dyn AsRawFd) -> Result<RawFd> { + // Safe because this doesn't modify any memory and we check the return value. + let ret = unsafe { libc::fcntl(fd.as_raw_fd(), libc::F_DUPFD_CLOEXEC, 0) }; + if ret < 0 { + errno_result() + } else { + Ok(ret) + } +} + /// Wraps a RawDescriptor and safely closes it when self falls out of scope. -#[derive(Debug, Eq)] +#[derive(Serialize, Deserialize, Debug, Eq)] +#[serde(transparent)] pub struct SafeDescriptor { + #[serde(with = "crate::with_raw_descriptor")] descriptor: RawDescriptor, } @@ -98,18 +116,41 @@ impl AsRawFd for SafeDescriptor { } } +impl TryFrom<&dyn AsRawFd> for SafeDescriptor { + type Error = std::io::Error; + + fn try_from(fd: &dyn AsRawFd) -> std::result::Result<Self, Self::Error> { + Ok(SafeDescriptor { + descriptor: clone_fd(fd)?, + }) + } +} + impl SafeDescriptor { /// Clones this descriptor, internally creating a new descriptor. The new SafeDescriptor will /// share the same underlying count within the kernel. pub fn try_clone(&self) -> Result<SafeDescriptor> { - // Safe because self.as_raw_descriptor() returns a valid value - let copy_fd = unsafe { libc::dup(self.as_raw_descriptor()) }; - if copy_fd < 0 { - return errno_result(); + // Safe because this doesn't modify any memory and we check the return value. + let descriptor = unsafe { libc::fcntl(self.descriptor, libc::F_DUPFD_CLOEXEC, 0) }; + if descriptor < 0 { + errno_result() + } else { + Ok(SafeDescriptor { descriptor }) } - // Safe becuase we just successfully duplicated and this object will uniquely - // own the raw descriptor. - Ok(unsafe { SafeDescriptor::from_raw_descriptor(copy_fd) }) + } +} + +impl From<SafeDescriptor> for File { + fn from(s: SafeDescriptor) -> File { + // Safe because we own the SafeDescriptor at this point. + unsafe { File::from_raw_fd(s.into_raw_descriptor()) } + } +} + +impl From<File> for SafeDescriptor { + fn from(f: File) -> SafeDescriptor { + // Safe because we own the File at this point. + unsafe { SafeDescriptor::from_raw_descriptor(f.into_raw_descriptor()) } } } diff --git a/sys_util/src/descriptor_reflection.rs b/sys_util/src/descriptor_reflection.rs new file mode 100644 index 000000000..9b18ffc37 --- /dev/null +++ b/sys_util/src/descriptor_reflection.rs @@ -0,0 +1,541 @@ +// Copyright 2020 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +//! Provides infrastructure for de/serializing descriptors embedded in Rust data structures. +//! +//! # Example +//! +//! ``` +//! use serde_json::to_string; +//! use sys_util::{ +//! FileSerdeWrapper, FromRawDescriptor, SafeDescriptor, SerializeDescriptors, +//! deserialize_with_descriptors, +//! }; +//! use tempfile::tempfile; +//! +//! let tmp_f = tempfile().unwrap(); +//! +//! // Uses a simple wrapper to serialize a File because we can't implement Serialize for File. +//! let data = FileSerdeWrapper(tmp_f); +//! +//! // Wraps Serialize types to collect side channel descriptors as Serialize is called. +//! let data_wrapper = SerializeDescriptors::new(&data); +//! +//! // Use the wrapper with any serializer to serialize data is normal, grabbing descriptors +//! // as the data structures are serialized by the serializer. +//! let out_json = serde_json::to_string(&data_wrapper).expect("failed to serialize"); +//! +//! // If data_wrapper contains any side channel descriptor refs +//! // (it contains tmp_f in this case), we can retrieve the actual descriptors +//! // from the side channel using into_descriptors(). +//! let out_descriptors = data_wrapper.into_descriptors(); +//! +//! // When sending out_json over some transport, also send out_descriptors. +//! +//! // For this example, we aren't really transporting data across the process, but we do need to +//! // convert the descriptor type. +//! let mut safe_descriptors = out_descriptors +//! .iter() +//! .map(|&v| Some(unsafe { SafeDescriptor::from_raw_descriptor(v) })) +//! .collect(); +//! std::mem::forget(data); // Prevent double drop of tmp_f. +//! +//! // The deserialize_with_descriptors function is used give the descriptor deserializers access +//! // to side channel descriptors. +//! let res: FileSerdeWrapper = +//! deserialize_with_descriptors(|| serde_json::from_str(&out_json), &mut safe_descriptors) +//! .expect("failed to deserialize"); +//! ``` + +use std::cell::{Cell, RefCell}; +use std::convert::TryInto; +use std::fmt; +use std::fs::File; +use std::ops::{Deref, DerefMut}; +use std::panic::{catch_unwind, resume_unwind, AssertUnwindSafe}; + +use serde::de::{self, Error, Visitor}; +use serde::ser; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; + +use crate::{RawDescriptor, SafeDescriptor}; + +thread_local! { + static DESCRIPTOR_DST: RefCell<Option<Vec<RawDescriptor>>> = Default::default(); +} + +/// Initializes the thread local storage for descriptor serialization. Fails if it was already +/// initialized without an intervening `take_descriptor_dst` on this thread. +fn init_descriptor_dst() -> Result<(), &'static str> { + DESCRIPTOR_DST.with(|d| { + let mut descriptors = d.borrow_mut(); + if descriptors.is_some() { + return Err( + "attempt to initialize descriptor destination that was already initialized", + ); + } + *descriptors = Some(Default::default()); + Ok(()) + }) +} + +/// Takes the thread local storage for descriptor serialization. Fails if there wasn't a prior call +/// to `init_descriptor_dst` on this thread. +fn take_descriptor_dst() -> Result<Vec<RawDescriptor>, &'static str> { + match DESCRIPTOR_DST.with(|d| d.replace(None)) { + Some(d) => Ok(d), + None => Err("attempt to take descriptor destination before it was initialized"), + } +} + +/// Pushes a descriptor on the thread local destination of descriptors, returning the index in which +/// the descriptor was pushed. +// +/// Returns Err if the thread local destination was not already initialized. +fn push_descriptor(rd: RawDescriptor) -> Result<usize, &'static str> { + DESCRIPTOR_DST.with(|d| { + d.borrow_mut() + .as_mut() + .ok_or("attempt to serialize descriptor without descriptor destination") + .map(|descriptors| { + let index = descriptors.len(); + descriptors.push(rd); + index + }) + }) +} + +/// Serializes a descriptor for later retrieval in a parent `SerializeDescriptors` struct. +/// +/// If there is no parent `SerializeDescriptors` being serialized, this will return an error. +/// +/// For convenience, it is recommended to use the `with_raw_descriptor` module in a `#[serde(with = +/// "...")]` attribute which will make use of this function. +pub fn serialize_descriptor<S: Serializer>( + rd: &RawDescriptor, + se: S, +) -> std::result::Result<S::Ok, S::Error> { + let index = push_descriptor(*rd).map_err(ser::Error::custom)?; + se.serialize_u32( + index + .try_into() + .map_err(|_| ser::Error::custom("attempt to serialize too many descriptors at once"))?, + ) +} + +/// Wrapper for a `Serialize` value which will capture any descriptors exported by the value when +/// given to an ordinary `Serializer`. +/// +/// This is the corresponding type to use for serialization before using +/// `deserialize_with_descriptors`. +/// +/// # Examples +/// +/// ``` +/// use serde_json::to_string; +/// use sys_util::{FileSerdeWrapper, SerializeDescriptors}; +/// use tempfile::tempfile; +/// +/// let tmp_f = tempfile().unwrap(); +/// let data = FileSerdeWrapper(tmp_f); +/// let data_wrapper = SerializeDescriptors::new(&data); +/// +/// // Serializes `v` as normal... +/// let out_json = serde_json::to_string(&data_wrapper).expect("failed to serialize"); +/// // If `serialize_descriptor` was called, we can capture the descriptors from here. +/// let out_descriptors = data_wrapper.into_descriptors(); +/// ``` +pub struct SerializeDescriptors<'a, T: Serialize>(&'a T, Cell<Vec<RawDescriptor>>); + +impl<'a, T: Serialize> SerializeDescriptors<'a, T> { + pub fn new(inner: &'a T) -> Self { + Self(inner, Default::default()) + } + + pub fn into_descriptors(self) -> Vec<RawDescriptor> { + self.1.into_inner() + } +} + +impl<'a, T: Serialize> Serialize for SerializeDescriptors<'a, T> { + fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> + where + S: Serializer, + { + init_descriptor_dst().map_err(ser::Error::custom)?; + + // catch_unwind is used to ensure that init_descriptor_dst is always balanced with a call to + // take_descriptor_dst afterwards. + let res = catch_unwind(AssertUnwindSafe(|| self.0.serialize(serializer))); + self.1.set(take_descriptor_dst().unwrap()); + match res { + Ok(r) => r, + Err(e) => resume_unwind(e), + } + } +} + +thread_local! { + static DESCRIPTOR_SRC: RefCell<Option<Vec<Option<SafeDescriptor>>>> = Default::default(); +} + +/// Sets the thread local storage of descriptors for deserialization. Fails if this was already +/// called without a call to `take_descriptor_src` on this thread. +/// +/// This is given as a collection of `Option` so that unused descriptors can be returned. +fn set_descriptor_src(descriptors: Vec<Option<SafeDescriptor>>) -> Result<(), &'static str> { + DESCRIPTOR_SRC.with(|d| { + let mut src = d.borrow_mut(); + if src.is_some() { + return Err("attempt to set descriptor source that was already set"); + } + *src = Some(descriptors); + Ok(()) + }) +} + +/// Takes the thread local storage of descriptors for deserialization. Fails if the storage was +/// already taken or never set with `set_descriptor_src`. +/// +/// If deserialization was done, the descriptors will mostly come back as `None` unless some of them +/// were unused. +fn take_descriptor_src() -> Result<Vec<Option<SafeDescriptor>>, &'static str> { + DESCRIPTOR_SRC.with(|d| { + d.replace(None) + .ok_or("attempt to take descriptor source which was never set") + }) +} + +/// Takes a descriptor at the given index from the thread local source of descriptors. +// +/// Returns None if the thread local source was not already initialized. +fn take_descriptor(index: usize) -> Result<SafeDescriptor, &'static str> { + DESCRIPTOR_SRC.with(|d| { + d.borrow_mut() + .as_mut() + .ok_or("attempt to deserialize descriptor without descriptor source")? + .get_mut(index) + .ok_or("attempt to deserialize out of bounds descriptor")? + .take() + .ok_or("attempt to deserialize descriptor that was already taken") + }) +} + +/// Deserializes a descriptor provided via `deserialize_with_descriptors`. +/// +/// If `deserialize_with_descriptors` is not in the call chain, this will return an error. +/// +/// For convenience, it is recommended to use the `with_raw_descriptor` module in a `#[serde(with = +/// "...")]` attribute which will make use of this function. +pub fn deserialize_descriptor<'de, D>(de: D) -> std::result::Result<SafeDescriptor, D::Error> +where + D: Deserializer<'de>, +{ + struct DescriptorVisitor; + + impl<'de> Visitor<'de> for DescriptorVisitor { + type Value = u32; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("an integer which fits into a u32") + } + + fn visit_u8<E: de::Error>(self, value: u8) -> Result<Self::Value, E> { + Ok(value as _) + } + + fn visit_u16<E: de::Error>(self, value: u16) -> Result<Self::Value, E> { + Ok(value as _) + } + + fn visit_u32<E: de::Error>(self, value: u32) -> Result<Self::Value, E> { + Ok(value) + } + + fn visit_u64<E: de::Error>(self, value: u64) -> Result<Self::Value, E> { + value.try_into().map_err(E::custom) + } + + fn visit_u128<E: de::Error>(self, value: u128) -> Result<Self::Value, E> { + value.try_into().map_err(E::custom) + } + + fn visit_i8<E: de::Error>(self, value: i8) -> Result<Self::Value, E> { + value.try_into().map_err(E::custom) + } + + fn visit_i16<E: de::Error>(self, value: i16) -> Result<Self::Value, E> { + value.try_into().map_err(E::custom) + } + + fn visit_i32<E: de::Error>(self, value: i32) -> Result<Self::Value, E> { + value.try_into().map_err(E::custom) + } + + fn visit_i64<E: de::Error>(self, value: i64) -> Result<Self::Value, E> { + value.try_into().map_err(E::custom) + } + + fn visit_i128<E: de::Error>(self, value: i128) -> Result<Self::Value, E> { + value.try_into().map_err(E::custom) + } + } + + let index = de.deserialize_u32(DescriptorVisitor)? as usize; + take_descriptor(index).map_err(D::Error::custom) +} + +/// Allows the use of any serde deserializer within a closure while providing access to the a set of +/// descriptors for use in `deserialize_descriptor`. +/// +/// This is the corresponding call to use deserialize after using `SerializeDescriptors`. +/// +/// If `deserialize_with_descriptors` is called anywhere within the given closure, it return an +/// error. +pub fn deserialize_with_descriptors<F, T, E>( + f: F, + descriptors: &mut Vec<Option<SafeDescriptor>>, +) -> Result<T, E> +where + F: FnOnce() -> Result<T, E>, + E: de::Error, +{ + let swap_descriptors = std::mem::take(descriptors); + set_descriptor_src(swap_descriptors).map_err(E::custom)?; + + // catch_unwind is used to ensure that set_descriptor_src is always balanced with a call to + // take_descriptor_src afterwards. + let res = catch_unwind(AssertUnwindSafe(f)); + + // unwrap is used because set_descriptor_src is always called before this, so it should never + // panic. + *descriptors = take_descriptor_src().unwrap(); + + match res { + Ok(r) => r, + Err(e) => resume_unwind(e), + } +} + +/// Module that exports `serialize`/`deserialize` functions for use with `#[serde(with = "...")]` +/// attribute. It only works with fields with `RawDescriptor` type. +/// +/// # Examples +/// +/// ``` +/// use serde::{Deserialize, Serialize}; +/// use sys_util::RawDescriptor; +/// +/// #[derive(Serialize, Deserialize)] +/// struct RawContainer { +/// #[serde(with = "sys_util::with_raw_descriptor")] +/// rd: RawDescriptor, +/// } +/// ``` +pub mod with_raw_descriptor { + use crate::{IntoRawDescriptor, RawDescriptor}; + use serde::Deserializer; + + pub use super::serialize_descriptor as serialize; + + pub fn deserialize<'de, D>(de: D) -> std::result::Result<RawDescriptor, D::Error> + where + D: Deserializer<'de>, + { + super::deserialize_descriptor(de).map(IntoRawDescriptor::into_raw_descriptor) + } +} + +/// Module that exports `serialize`/`deserialize` functions for use with `#[serde(with = "...")]` +/// attribute. +/// +/// # Examples +/// +/// ``` +/// use std::fs::File; +/// use serde::{Deserialize, Serialize}; +/// use sys_util::RawDescriptor; +/// +/// #[derive(Serialize, Deserialize)] +/// struct FileContainer { +/// #[serde(with = "sys_util::with_as_descriptor")] +/// file: File, +/// } +/// ``` +pub mod with_as_descriptor { + use crate::{AsRawDescriptor, FromRawDescriptor, IntoRawDescriptor}; + use serde::{Deserializer, Serializer}; + + pub fn serialize<S: Serializer>( + rd: &dyn AsRawDescriptor, + se: S, + ) -> std::result::Result<S::Ok, S::Error> { + super::serialize_descriptor(&rd.as_raw_descriptor(), se) + } + + pub fn deserialize<'de, D, T>(de: D) -> std::result::Result<T, D::Error> + where + D: Deserializer<'de>, + T: FromRawDescriptor, + { + super::deserialize_descriptor(de) + .map(IntoRawDescriptor::into_raw_descriptor) + .map(|rd| unsafe { T::from_raw_descriptor(rd) }) + } +} + +/// A simple wrapper around `File` that implements `Serialize`/`Deserialize`, which is useful when +/// the `#[serde(with = "with_as_descriptor")]` trait is infeasible, such as for a field with type +/// `Option<File>`. +#[derive(Serialize, Deserialize)] +#[serde(transparent)] +pub struct FileSerdeWrapper(#[serde(with = "with_as_descriptor")] pub File); + +impl fmt::Debug for FileSerdeWrapper { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + self.0.fmt(f) + } +} + +impl From<File> for FileSerdeWrapper { + fn from(file: File) -> Self { + FileSerdeWrapper(file) + } +} + +impl From<FileSerdeWrapper> for File { + fn from(f: FileSerdeWrapper) -> File { + f.0 + } +} + +impl Deref for FileSerdeWrapper { + type Target = File; + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for FileSerdeWrapper { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +#[cfg(test)] +mod tests { + use crate::{ + deserialize_with_descriptors, with_as_descriptor, with_raw_descriptor, FileSerdeWrapper, + FromRawDescriptor, RawDescriptor, SafeDescriptor, SerializeDescriptors, + }; + + use std::collections::HashMap; + use std::fs::File; + use std::mem::ManuallyDrop; + use std::os::unix::io::AsRawFd; + + use serde::{de::DeserializeOwned, Deserialize, Serialize}; + use tempfile::tempfile; + + fn deserialize<T: DeserializeOwned>(json: &str, descriptors: &[RawDescriptor]) -> T { + let mut safe_descriptors = descriptors + .iter() + .map(|&v| Some(unsafe { SafeDescriptor::from_raw_descriptor(v) })) + .collect(); + + let res = + deserialize_with_descriptors(|| serde_json::from_str(json), &mut safe_descriptors) + .unwrap(); + + assert!(safe_descriptors.iter().all(|v| v.is_none())); + + res + } + + #[test] + fn raw() { + #[derive(Serialize, Deserialize, PartialEq, Debug)] + struct RawContainer { + #[serde(with = "with_raw_descriptor")] + rd: RawDescriptor, + } + // Specifically chosen to not overlap a real descriptor to avoid having to allocate any + // descriptors for this test. + let fake_rd = 5_123_457 as _; + let v = RawContainer { rd: fake_rd }; + let v_serialize = SerializeDescriptors::new(&v); + let json = serde_json::to_string(&v_serialize).unwrap(); + let descriptors = v_serialize.into_descriptors(); + let res = deserialize(&json, &descriptors); + assert_eq!(v, res); + } + + #[test] + fn file() { + #[derive(Serialize, Deserialize)] + struct FileContainer { + #[serde(with = "with_as_descriptor")] + file: File, + } + + let v = FileContainer { + file: tempfile().unwrap(), + }; + let v_serialize = SerializeDescriptors::new(&v); + let json = serde_json::to_string(&v_serialize).unwrap(); + let descriptors = v_serialize.into_descriptors(); + let v = ManuallyDrop::new(v); + let res: FileContainer = deserialize(&json, &descriptors); + assert_eq!(v.file.as_raw_fd(), res.file.as_raw_fd()); + } + + #[test] + fn option() { + #[derive(Serialize, Deserialize)] + struct TestOption { + a: Option<FileSerdeWrapper>, + b: Option<FileSerdeWrapper>, + } + + let v = TestOption { + a: None, + b: Some(tempfile().unwrap().into()), + }; + let v_serialize = SerializeDescriptors::new(&v); + let json = serde_json::to_string(&v_serialize).unwrap(); + let descriptors = v_serialize.into_descriptors(); + let v = ManuallyDrop::new(v); + let res: TestOption = deserialize(&json, &descriptors); + assert!(res.a.is_none()); + assert!(res.b.is_some()); + assert_eq!( + v.b.as_ref().unwrap().as_raw_fd(), + res.b.unwrap().as_raw_fd() + ); + } + + #[test] + fn map() { + let mut v: HashMap<String, FileSerdeWrapper> = HashMap::new(); + v.insert("a".into(), tempfile().unwrap().into()); + v.insert("b".into(), tempfile().unwrap().into()); + v.insert("c".into(), tempfile().unwrap().into()); + let v_serialize = SerializeDescriptors::new(&v); + let json = serde_json::to_string(&v_serialize).unwrap(); + let descriptors = v_serialize.into_descriptors(); + // Prevent the files in `v` from dropping while allowing the HashMap itself to drop. It is + // done this way to prevent a double close of the files (which should reside in `res`) + // without triggering the leak sanitizer on `v`'s HashMap heap memory. + let v: HashMap<_, _> = v + .into_iter() + .map(|(k, v)| (k, ManuallyDrop::new(v))) + .collect(); + let res: HashMap<String, FileSerdeWrapper> = deserialize(&json, &descriptors); + + assert_eq!(v.len(), res.len()); + for (k, v) in v.iter() { + assert_eq!(res.get(k).unwrap().as_raw_fd(), v.as_raw_fd()); + } + } +} diff --git a/sys_util/src/errno.rs b/sys_util/src/errno.rs index 0475442fb..63af14a04 100644 --- a/sys_util/src/errno.rs +++ b/sys_util/src/errno.rs @@ -6,9 +6,12 @@ use std::fmt::{self, Display}; use std::io; use std::result; +use serde::{Deserialize, Serialize}; + /// An error number, retrieved from errno (man 3 errno), set by a libc /// function that returned an error. -#[derive(Clone, Copy, Debug, PartialEq)] +#[derive(Serialize, Deserialize, Clone, Copy, Debug, PartialEq)] +#[serde(transparent)] pub struct Error(i32); pub type Result<T> = result::Result<T, Error>; diff --git a/sys_util/src/eventfd.rs b/sys_util/src/eventfd.rs index 3897118e6..29fa34469 100644 --- a/sys_util/src/eventfd.rs +++ b/sys_util/src/eventfd.rs @@ -8,18 +8,20 @@ use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd}; use std::ptr; use std::time::Duration; -use libc::{c_void, dup, eventfd, read, write, POLLIN}; +use libc::{c_void, eventfd, read, write, POLLIN}; +use serde::{Deserialize, Serialize}; use crate::{ - errno_result, AsRawDescriptor, FromRawDescriptor, IntoRawDescriptor, RawDescriptor, Result, - SafeDescriptor, + duration_to_timespec, errno_result, AsRawDescriptor, FromRawDescriptor, IntoRawDescriptor, + RawDescriptor, Result, SafeDescriptor, }; /// A safe wrapper around a Linux eventfd (man 2 eventfd). /// /// An eventfd is useful because it is sendable across processes and can be used for signaling in /// and out of the KVM API. They can also be polled like any other file descriptor. -#[derive(Debug, PartialEq, Eq)] +#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)] +#[serde(transparent)] pub struct EventFd { event_handle: SafeDescriptor, } @@ -94,12 +96,7 @@ impl EventFd { events: POLLIN, revents: 0, }; - // Safe because we are zero-initializing a struct with only primitive member fields. - let mut timeoutspec: libc::timespec = unsafe { mem::zeroed() }; - timeoutspec.tv_sec = timeout.as_secs() as libc::time_t; - // nsec always fits in i32 because subsec_nanos is defined to be less than one billion. - let nsec = timeout.subsec_nanos() as i32; - timeoutspec.tv_nsec = libc::c_long::from(nsec); + let timeoutspec: libc::timespec = duration_to_timespec(timeout); // Safe because this only modifies |pfd| and we check the return value let ret = unsafe { libc::ppoll( @@ -137,16 +134,9 @@ impl EventFd { /// Clones this EventFd, internally creating a new file descriptor. The new EventFd will share /// the same underlying count within the kernel. pub fn try_clone(&self) -> Result<EventFd> { - // This is safe because we made this fd and properly check that it returns without error. - let ret = unsafe { dup(self.as_raw_descriptor()) }; - if ret < 0 { - return errno_result(); - } - // This is safe because we checked ret for success and know the kernel gave us an fd that we - // own. - Ok(EventFd { - event_handle: unsafe { SafeDescriptor::from_raw_descriptor(ret) }, - }) + self.event_handle + .try_clone() + .map(|event_handle| EventFd { event_handle }) } } diff --git a/sys_util/src/fork.rs b/sys_util/src/fork.rs index a1e9b132d..7f4a25640 100644 --- a/sys_util/src/fork.rs +++ b/sys_util/src/fork.rs @@ -8,8 +8,7 @@ use std::path::Path; use std::process; use std::result; -use libc::{c_long, pid_t, syscall, CLONE_NEWPID, CLONE_NEWUSER, SIGCHLD}; -use syscall_defines::linux::LinuxSyscall::SYS_clone; +use libc::{c_long, pid_t, syscall, SYS_clone, CLONE_NEWPID, CLONE_NEWUSER, SIGCHLD}; use crate::errno_result; diff --git a/sys_util/src/lib.rs b/sys_util/src/lib.rs index 87af85292..43bdcc4d1 100644 --- a/sys_util/src/lib.rs +++ b/sys_util/src/lib.rs @@ -22,6 +22,7 @@ pub mod syslog; mod capabilities; mod clock; mod descriptor; +mod descriptor_reflection; mod errno; mod eventfd; mod external_mapping; @@ -33,8 +34,11 @@ pub mod net; mod passwd; mod poll; mod priority; +pub mod rand; mod raw_fd; pub mod sched; +pub mod scoped_path; +pub mod scoped_signal_handler; mod seek_hole; mod shm; pub mod signal; @@ -43,6 +47,7 @@ mod sock_ctrl_msg; mod struct_util; mod terminal; mod timerfd; +pub mod vsock; mod write_zeroes; pub use crate::alloc::LayoutAllocation; @@ -61,6 +66,7 @@ pub use crate::poll::*; pub use crate::priority::*; pub use crate::raw_fd::*; pub use crate::sched::*; +pub use crate::scoped_signal_handler::*; pub use crate::shm::*; pub use crate::signal::*; pub use crate::signalfd::*; @@ -68,6 +74,10 @@ pub use crate::sock_ctrl_msg::*; pub use crate::struct_util::*; pub use crate::terminal::*; pub use crate::timerfd::*; +pub use descriptor_reflection::{ + deserialize_with_descriptors, with_as_descriptor, with_raw_descriptor, FileSerdeWrapper, + SerializeDescriptors, +}; pub use poll_token_derive::*; pub use crate::external_mapping::Error as ExternalMappingError; @@ -84,16 +94,21 @@ pub use crate::write_zeroes::{PunchHole, WriteZeroes, WriteZeroesAt}; use std::cell::Cell; use std::ffi::CStr; use std::fs::{remove_file, File}; +use std::mem; use std::os::unix::io::{AsRawFd, FromRawFd, RawFd}; use std::os::unix::net::UnixDatagram; use std::ptr; +use std::time::Duration; use libc::{ - c_int, c_long, fcntl, gid_t, kill, pid_t, pipe2, syscall, sysconf, uid_t, waitpid, F_GETFL, + c_int, c_long, fcntl, pipe2, syscall, sysconf, waitpid, SYS_getpid, SYS_gettid, F_GETFL, F_SETFL, O_CLOEXEC, SIGKILL, WNOHANG, _SC_IOV_MAX, _SC_PAGESIZE, }; -use syscall_defines::linux::LinuxSyscall::SYS_getpid; +/// Re-export libc types that are part of the API. +pub type Pid = libc::pid_t; +pub type Uid = libc::uid_t; +pub type Gid = libc::gid_t; /// Used to mark types as !Sync. pub type UnsyncMarker = std::marker::PhantomData<Cell<usize>>; @@ -121,28 +136,58 @@ pub fn round_up_to_page_size(v: usize) -> usize { /// This bypasses `libc`'s caching `getpid(2)` wrapper which can be invalid if a raw clone was used /// elsewhere. #[inline(always)] -pub fn getpid() -> pid_t { +pub fn getpid() -> Pid { // Safe because this syscall can never fail and we give it a valid syscall number. - unsafe { syscall(SYS_getpid as c_long) as pid_t } + unsafe { syscall(SYS_getpid as c_long) as Pid } +} + +/// Safe wrapper for the gettid Linux systemcall. +pub fn gettid() -> Pid { + // Calling the gettid() sycall is always safe. + unsafe { syscall(SYS_gettid as c_long) as Pid } +} + +/// Safe wrapper for `getsid(2)`. +pub fn getsid(pid: Option<Pid>) -> Result<Pid> { + // Calling the getsid() sycall is always safe. + let ret = unsafe { libc::getsid(pid.unwrap_or(0)) } as Pid; + + if ret < 0 { + errno_result() + } else { + Ok(ret) + } +} + +/// Wrapper for `setsid(2)`. +pub fn setsid() -> Result<Pid> { + // Safe because the return code is checked. + let ret = unsafe { libc::setsid() as Pid }; + + if ret < 0 { + errno_result() + } else { + Ok(ret) + } } /// Safe wrapper for `geteuid(2)`. #[inline(always)] -pub fn geteuid() -> uid_t { +pub fn geteuid() -> Uid { // trivially safe unsafe { libc::geteuid() } } /// Safe wrapper for `getegid(2)`. #[inline(always)] -pub fn getegid() -> gid_t { +pub fn getegid() -> Gid { // trivially safe unsafe { libc::getegid() } } /// Safe wrapper for chown(2). #[inline(always)] -pub fn chown(path: &CStr, uid: uid_t, gid: gid_t) -> Result<()> { +pub fn chown(path: &CStr, uid: Uid, gid: Gid) -> Result<()> { // Safe since we pass in a valid string pointer and check the return value. let ret = unsafe { libc::chown(path.as_ptr(), uid, gid) }; @@ -256,7 +301,7 @@ pub fn fallocate( /// } /// } /// ``` -pub fn reap_child() -> Result<pid_t> { +pub fn reap_child() -> Result<Pid> { // Safe because we pass in no memory, prevent blocking with WNOHANG, and check for error. let ret = unsafe { waitpid(-1, ptr::null_mut(), WNOHANG) }; if ret == -1 { @@ -271,13 +316,9 @@ pub fn reap_child() -> Result<pid_t> { /// On success, this kills all processes in the current process group, including the current /// process, meaning this will not return. This is equivalent to a call to `kill(0, SIGKILL)`. pub fn kill_process_group() -> Result<()> { - let ret = unsafe { kill(0, SIGKILL) }; - if ret == -1 { - errno_result() - } else { - // Kill succeeded, so this process never reaches here. - unreachable!(); - } + unsafe { kill(0, SIGKILL) }?; + // Kill succeeded, so this process never reaches here. + unreachable!(); } /// Spawns a pipe pair where the first pipe is the read end and the second pipe is the write end. @@ -435,6 +476,23 @@ pub fn clear_fd_flags(fd: RawFd, clear_flags: c_int) -> Result<()> { set_fd_flags(fd, start_flags & !clear_flags) } +/// Return a timespec filed with the specified Duration `duration`. +pub fn duration_to_timespec(duration: Duration) -> libc::timespec { + // Safe because we are zero-initializing a struct with only primitive member fields. + let mut ts: libc::timespec = unsafe { mem::zeroed() }; + + ts.tv_sec = duration.as_secs() as libc::time_t; + // nsec always fits in i32 because subsec_nanos is defined to be less than one billion. + let nsec = duration.subsec_nanos() as i32; + ts.tv_nsec = libc::c_long::from(nsec); + ts +} + +/// Return the maximum Duration that can be used with libc::timespec. +pub fn max_timeout() -> Duration { + Duration::new(libc::time_t::max_value() as u64, 999999999) +} + #[cfg(test)] mod tests { use std::io::Write; diff --git a/sys_util/src/linux/syslog.rs b/sys_util/src/linux/syslog.rs index 179e2bf53..1ec876304 100644 --- a/sys_util/src/linux/syslog.rs +++ b/sys_util/src/linux/syslog.rs @@ -148,7 +148,7 @@ fn send_buf(socket: &UnixDatagram, buf: &[u8]) { const SEND_RETRY: usize = 2; for _ in 0..SEND_RETRY { - match socket.send(&buf[..]) { + match socket.send(buf) { Ok(_) => break, Err(e) => match e.kind() { ErrorKind::ConnectionRefused diff --git a/sys_util/src/mmap.rs b/sys_util/src/mmap.rs index e7f1239d3..28ad07f57 100644 --- a/sys_util/src/mmap.rs +++ b/sys_util/src/mmap.rs @@ -59,7 +59,7 @@ impl Display for Error { "requested memory range spans past the end of the region: offset={} count={} region_size={}", offset, count, region_size, ), - SystemCallFailed(e) => write!(f, "mmap system call failed: {}", e), + SystemCallFailed(e) => write!(f, "mmap related system call failed: {}", e), ReadToMemory(e) => write!(f, "failed to read from file to memory: {}", e), RemoveMappingIsUnsupported => write!(f, "`remove_mapping` is unsupported"), WriteFromMemory(e) => write!(f, "failed to write from memory to file: {}", e), @@ -108,9 +108,9 @@ impl From<c_int> for Protection { } } -impl Into<c_int> for Protection { - fn into(self) -> c_int { - self.0 +impl From<Protection> for c_int { + fn from(p: Protection) -> c_int { + p.0 } } @@ -411,6 +411,32 @@ impl MemoryMapping { }) } + /// Madvise the kernel to use Huge Pages for this mapping. + pub fn use_hugepages(&self) -> Result<()> { + const SZ_2M: usize = 2 * 1024 * 1024; + + // THP uses 2M pages, so use THP only on mappings that are at least + // 2M in size. + if self.size() < SZ_2M { + return Ok(()); + } + + // This is safe because we call madvise with a valid address and size, and we check the + // return value. + let ret = unsafe { + libc::madvise( + self.as_ptr() as *mut libc::c_void, + self.size(), + libc::MADV_HUGEPAGE, + ) + }; + if ret == -1 { + Err(Error::SystemCallFailed(errno::Error::last())) + } else { + Ok(()) + } + } + /// Calls msync with MS_SYNC on the mapping. pub fn msync(&self) -> Result<()> { // This is safe since we use the exact address and length of a known @@ -918,8 +944,9 @@ impl Drop for MemoryMappingArena { #[cfg(test)] mod tests { use super::*; + use crate::Descriptor; use data_model::{VolatileMemory, VolatileMemoryError}; - use std::os::unix::io::FromRawFd; + use tempfile::tempfile; #[test] fn basic_map() { @@ -939,7 +966,7 @@ mod tests { #[test] fn map_invalid_fd() { - let fd = unsafe { std::fs::File::from_raw_fd(-1) }; + let fd = Descriptor(-1); let res = MemoryMapping::from_fd(&fd, 1024).unwrap_err(); if let Error::SystemCallFailed(e) = res { assert_eq!(e.errno(), libc::EBADF); @@ -999,7 +1026,7 @@ mod tests { #[test] fn from_fd_offset_invalid() { - let fd = unsafe { std::fs::File::from_raw_fd(-1) }; + let fd = tempfile().unwrap(); let res = MemoryMapping::from_fd_offset(&fd, 4096, (libc::off_t::max_value() as u64) + 1) .unwrap_err(); match res { diff --git a/sys_util/src/net.rs b/sys_util/src/net.rs index 8aaf04845..b59fb0f21 100644 --- a/sys_util/src/net.rs +++ b/sys_util/src/net.rs @@ -5,22 +5,282 @@ use std::ffi::OsString; use std::fs::remove_file; use std::io; -use std::mem; +use std::mem::{self, size_of}; +use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6, TcpListener, TcpStream, ToSocketAddrs}; use std::ops::Deref; use std::os::unix::{ ffi::{OsStrExt, OsStringExt}, - io::{AsRawFd, FromRawFd, RawFd}, + io::{AsRawFd, FromRawFd, IntoRawFd, RawFd}, }; use std::path::Path; use std::path::PathBuf; use std::ptr::null_mut; use std::time::Duration; -use libc::{recvfrom, MSG_PEEK, MSG_TRUNC}; +use libc::{ + c_int, in6_addr, in_addr, recvfrom, sa_family_t, sockaddr, sockaddr_in, sockaddr_in6, + socklen_t, AF_INET, AF_INET6, MSG_PEEK, MSG_TRUNC, SOCK_CLOEXEC, SOCK_STREAM, +}; +use serde::{Deserialize, Serialize}; use crate::sock_ctrl_msg::{ScmSocket, SCM_SOCKET_MAX_FD_COUNT}; use crate::{AsRawDescriptor, RawDescriptor}; +/// Assist in handling both IP version 4 and IP version 6. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub enum InetVersion { + V4, + V6, +} + +impl InetVersion { + pub fn from_sockaddr(s: &SocketAddr) -> Self { + match s { + SocketAddr::V4(_) => InetVersion::V4, + SocketAddr::V6(_) => InetVersion::V6, + } + } +} + +impl From<InetVersion> for sa_family_t { + fn from(v: InetVersion) -> sa_family_t { + match v { + InetVersion::V4 => AF_INET as sa_family_t, + InetVersion::V6 => AF_INET6 as sa_family_t, + } + } +} + +fn sockaddrv4_to_lib_c(s: &SocketAddrV4) -> sockaddr_in { + sockaddr_in { + sin_family: AF_INET as sa_family_t, + sin_port: s.port().to_be(), + sin_addr: in_addr { + s_addr: u32::from_ne_bytes(s.ip().octets()), + }, + sin_zero: [0; 8], + } +} + +fn sockaddrv6_to_lib_c(s: &SocketAddrV6) -> sockaddr_in6 { + sockaddr_in6 { + sin6_family: AF_INET6 as sa_family_t, + sin6_port: s.port().to_be(), + sin6_flowinfo: 0, + sin6_addr: in6_addr { + s6_addr: s.ip().octets(), + }, + sin6_scope_id: 0, + } +} + +/// A TCP socket. +/// +/// Do not use this class unless you need to change socket options or query the +/// state of the socket prior to calling listen or connect. Instead use either TcpStream or +/// TcpListener. +#[derive(Debug)] +pub struct TcpSocket { + inet_version: InetVersion, + fd: RawFd, +} + +impl TcpSocket { + pub fn new(inet_version: InetVersion) -> io::Result<Self> { + let fd = unsafe { + libc::socket( + Into::<sa_family_t>::into(inet_version) as c_int, + SOCK_STREAM | SOCK_CLOEXEC, + 0, + ) + }; + if fd < 0 { + Err(io::Error::last_os_error()) + } else { + Ok(TcpSocket { inet_version, fd }) + } + } + + pub fn bind<A: ToSocketAddrs>(&mut self, addr: A) -> io::Result<()> { + let sockaddr = addr + .to_socket_addrs() + .map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))? + .next() + .unwrap(); + + let ret = match sockaddr { + SocketAddr::V4(a) => { + let sin = sockaddrv4_to_lib_c(&a); + // Safe because this doesn't modify any memory and we check the return value. + unsafe { + libc::bind( + self.fd, + &sin as *const sockaddr_in as *const sockaddr, + size_of::<sockaddr_in>() as socklen_t, + ) + } + } + SocketAddr::V6(a) => { + let sin6 = sockaddrv6_to_lib_c(&a); + // Safe because this doesn't modify any memory and we check the return value. + unsafe { + libc::bind( + self.fd, + &sin6 as *const sockaddr_in6 as *const sockaddr, + size_of::<sockaddr_in6>() as socklen_t, + ) + } + } + }; + if ret < 0 { + let bind_err = io::Error::last_os_error(); + Err(bind_err) + } else { + Ok(()) + } + } + + pub fn connect<A: ToSocketAddrs>(self, addr: A) -> io::Result<TcpStream> { + let sockaddr = addr + .to_socket_addrs() + .map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))? + .next() + .unwrap(); + + let ret = match sockaddr { + SocketAddr::V4(a) => { + let sin = sockaddrv4_to_lib_c(&a); + // Safe because this doesn't modify any memory and we check the return value. + unsafe { + libc::connect( + self.fd, + &sin as *const sockaddr_in as *const sockaddr, + size_of::<sockaddr_in>() as socklen_t, + ) + } + } + SocketAddr::V6(a) => { + let sin6 = sockaddrv6_to_lib_c(&a); + // Safe because this doesn't modify any memory and we check the return value. + unsafe { + libc::connect( + self.fd, + &sin6 as *const sockaddr_in6 as *const sockaddr, + size_of::<sockaddr_in>() as socklen_t, + ) + } + } + }; + + if ret < 0 { + let connect_err = io::Error::last_os_error(); + Err(connect_err) + } else { + // Safe because the ownership of the raw fd is released from self and taken over by the + // new TcpStream. + Ok(unsafe { TcpStream::from_raw_fd(self.into_raw_fd()) }) + } + } + + pub fn listen(self) -> io::Result<TcpListener> { + // Safe because this doesn't modify any memory and we check the return value. + let ret = unsafe { libc::listen(self.fd, 1) }; + if ret < 0 { + let listen_err = io::Error::last_os_error(); + Err(listen_err) + } else { + // Safe because the ownership of the raw fd is released from self and taken over by the + // new TcpListener. + Ok(unsafe { TcpListener::from_raw_fd(self.into_raw_fd()) }) + } + } + + /// Returns the port that this socket is bound to. This can only succeed after bind is called. + pub fn local_port(&self) -> io::Result<u16> { + match self.inet_version { + InetVersion::V4 => { + let mut sin = sockaddr_in { + sin_family: 0, + sin_port: 0, + sin_addr: in_addr { s_addr: 0 }, + sin_zero: [0; 8], + }; + + // Safe because we give a valid pointer for addrlen and check the length. + let mut addrlen = size_of::<sockaddr_in>() as socklen_t; + let ret = unsafe { + // Get the socket address that was actually bound. + libc::getsockname( + self.fd, + &mut sin as *mut sockaddr_in as *mut sockaddr, + &mut addrlen as *mut socklen_t, + ) + }; + if ret < 0 { + let getsockname_err = io::Error::last_os_error(); + Err(getsockname_err) + } else { + // If this doesn't match, it's not safe to get the port out of the sockaddr. + assert_eq!(addrlen as usize, size_of::<sockaddr_in>()); + + Ok(u16::from_be(sin.sin_port)) + } + } + InetVersion::V6 => { + let mut sin6 = sockaddr_in6 { + sin6_family: 0, + sin6_port: 0, + sin6_flowinfo: 0, + sin6_addr: in6_addr { s6_addr: [0; 16] }, + sin6_scope_id: 0, + }; + + // Safe because we give a valid pointer for addrlen and check the length. + let mut addrlen = size_of::<sockaddr_in6>() as socklen_t; + let ret = unsafe { + // Get the socket address that was actually bound. + libc::getsockname( + self.fd, + &mut sin6 as *mut sockaddr_in6 as *mut sockaddr, + &mut addrlen as *mut socklen_t, + ) + }; + if ret < 0 { + let getsockname_err = io::Error::last_os_error(); + Err(getsockname_err) + } else { + // If this doesn't match, it's not safe to get the port out of the sockaddr. + assert_eq!(addrlen as usize, size_of::<sockaddr_in>()); + + Ok(u16::from_be(sin6.sin6_port)) + } + } + } + } +} + +impl IntoRawFd for TcpSocket { + fn into_raw_fd(self) -> RawFd { + let fd = self.fd; + mem::forget(self); + fd + } +} + +impl AsRawFd for TcpSocket { + fn as_raw_fd(&self) -> RawFd { + self.fd + } +} + +impl Drop for TcpSocket { + fn drop(&mut self) { + // Safe because this doesn't modify any memory and we are the only + // owner of the file descriptor. + unsafe { libc::close(self.fd) }; + } +} + // Offset of sun_path in structure sockaddr_un. fn sun_path_offset() -> usize { // Prefer 0 to null() so that we do not need to subtract from the `sub_path` pointer. @@ -72,7 +332,9 @@ fn sockaddr_un<P: AsRef<Path>>(path: P) -> io::Result<(libc::sockaddr_un, libc:: } /// A Unix `SOCK_SEQPACKET` socket point to given `path` +#[derive(Serialize, Deserialize)] pub struct UnixSeqpacket { + #[serde(with = "crate::with_raw_descriptor")] fd: RawFd, } @@ -134,12 +396,12 @@ impl UnixSeqpacket { /// Clone the underlying FD. pub fn try_clone(&self) -> io::Result<Self> { - // Calling `dup` is safe as the kernel doesn't touch any user memory it the process. - let new_fd = unsafe { libc::dup(self.fd) }; - if new_fd < 0 { + // Safe because this doesn't modify any memory and we check the return value. + let fd = unsafe { libc::fcntl(self.fd, libc::F_DUPFD_CLOEXEC, 0) }; + if fd < 0 { Err(io::Error::last_os_error()) } else { - Ok(UnixSeqpacket { fd: new_fd }) + Ok(Self { fd }) } } @@ -157,12 +419,21 @@ impl UnixSeqpacket { /// Gets the number of bytes in the next packet. This blocks as if `recv` were called, /// respecting the blocking and timeout settings of the underlying socket. pub fn next_packet_size(&self) -> io::Result<usize> { + #[cfg(not(debug_assertions))] + let buf = null_mut(); + // Work around for qemu's syscall translation which will reject null pointers in recvfrom. + // This only matters for running the unit tests for a non-native architecture. See the + // upstream thread for the qemu fix: + // https://lists.nongnu.org/archive/html/qemu-devel/2021-03/msg09027.html + #[cfg(debug_assertions)] + let buf = &mut 0 as *mut _ as *mut _; + // This form of recvfrom doesn't modify any data because all null pointers are used. We only // use the return value and check for errors on an FD owned by this structure. let ret = unsafe { recvfrom( self.fd, - null_mut(), + buf, 0, MSG_TRUNC | MSG_PEEK, null_mut(), diff --git a/sys_util/src/rand.rs b/sys_util/src/rand.rs new file mode 100644 index 000000000..7c7799397 --- /dev/null +++ b/sys_util/src/rand.rs @@ -0,0 +1,114 @@ +// Copyright 2021 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +//! Rust implementation of functionality parallel to libchrome's base/rand_util.h. + +use std::thread::sleep; +use std::time::Duration; + +use libc::{c_uint, c_void}; + +use crate::{ + errno::{errno_result, Result}, + handle_eintr_errno, +}; + +/// How long to wait before calling getrandom again if it does not return +/// enough bytes. +const POLL_INTERVAL: Duration = Duration::from_millis(50); + +/// Represents whether or not the random bytes are pulled from the source of +/// /dev/random or /dev/urandom. +#[derive(Debug, Clone, Eq, PartialEq)] +pub enum Source { + // This is the default and uses the same source as /dev/urandom. + Pseudorandom, + // This uses the same source as /dev/random and may be. + Random, +} + +impl Default for Source { + fn default() -> Self { + Source::Pseudorandom + } +} + +impl Source { + fn to_getrandom_flags(&self) -> c_uint { + match self { + Source::Random => libc::GRND_RANDOM, + Source::Pseudorandom => 0, + } + } +} + +/// Fills `output` completely with random bytes from the specified `source`. +pub fn rand_bytes(mut output: &mut [u8], source: Source) -> Result<()> { + if output.is_empty() { + return Ok(()); + } + + loop { + // Safe because output is mutable and the writes are limited by output.len(). + let bytes = handle_eintr_errno!(unsafe { + libc::getrandom( + output.as_mut_ptr() as *mut c_void, + output.len(), + source.to_getrandom_flags(), + ) + }); + + if bytes < 0 { + return errno_result(); + } + if bytes as usize == output.len() { + return Ok(()); + } + + // Wait for more entropy and try again for the remaining bytes. + sleep(POLL_INTERVAL); + output = &mut output[bytes as usize..]; + } +} + +/// Allocates a vector of length `len` filled with random bytes from the +/// specified `source`. +pub fn rand_vec(len: usize, source: Source) -> Result<Vec<u8>> { + let mut rand = Vec::with_capacity(len); + if len == 0 { + return Ok(rand); + } + + // Safe because rand will either be initialized by getrandom or dropped. + unsafe { rand.set_len(len) }; + rand_bytes(rand.as_mut_slice(), source)?; + Ok(rand) +} + +#[cfg(test)] +mod tests { + use super::*; + + const TEST_SIZE: usize = 64; + + #[test] + fn randbytes_success() { + let mut rand = vec![0u8; TEST_SIZE]; + rand_bytes(&mut rand, Source::Pseudorandom).unwrap(); + assert_ne!(&rand, &[0u8; TEST_SIZE]); + } + + #[test] + fn randvec_success() { + let rand = rand_vec(TEST_SIZE, Source::Pseudorandom).unwrap(); + assert_eq!(rand.len(), TEST_SIZE); + assert_ne!(&rand, &[0u8; TEST_SIZE]); + } + + #[test] + fn sourcerandom_success() { + let rand = rand_vec(TEST_SIZE, Source::Random).unwrap(); + assert_ne!(&rand, &[0u8; TEST_SIZE]); + } +} diff --git a/sys_util/src/scoped_path.rs b/sys_util/src/scoped_path.rs new file mode 100644 index 000000000..745137eae --- /dev/null +++ b/sys_util/src/scoped_path.rs @@ -0,0 +1,138 @@ +// Copyright 2021 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +use std::env::{current_exe, temp_dir}; +use std::fs::{create_dir_all, remove_dir_all}; +use std::io::Result; +use std::ops::Deref; +use std::path::{Path, PathBuf}; +use std::thread::panicking; + +use crate::{getpid, gettid}; + +/// Returns a stable path based on the label, pid, and tid. If the label isn't provided the +/// current_exe is used instead. +pub fn get_temp_path(label: Option<&str>) -> PathBuf { + if let Some(label) = label { + temp_dir().join(format!("{}-{}-{}", label, getpid(), gettid())) + } else { + get_temp_path(Some( + current_exe() + .unwrap() + .file_name() + .unwrap() + .to_str() + .unwrap(), + )) + } +} + +/// Automatically deletes the path it contains when it goes out of scope unless it is a test and +/// drop is called after a panic!. +/// +/// This is particularly useful for creating temporary directories for use with tests. +pub struct ScopedPath<P: AsRef<Path>>(P); + +impl<P: AsRef<Path>> ScopedPath<P> { + pub fn create(p: P) -> Result<Self> { + create_dir_all(p.as_ref())?; + Ok(ScopedPath(p)) + } +} + +impl<P: AsRef<Path>> AsRef<Path> for ScopedPath<P> { + fn as_ref(&self) -> &Path { + self.0.as_ref() + } +} + +impl<P: AsRef<Path>> Deref for ScopedPath<P> { + type Target = Path; + + fn deref(&self) -> &Self::Target { + self.0.as_ref() + } +} + +impl<P: AsRef<Path>> Drop for ScopedPath<P> { + fn drop(&mut self) { + // Leave the files on a failed test run for debugging. + if panicking() && cfg!(test) { + eprintln!("NOTE: Not removing {}", self.display()); + return; + } + if let Err(e) = remove_dir_all(&**self) { + eprintln!("Failed to remove {}: {}", self.display(), e); + } + } +} + +#[cfg(test)] +pub(crate) mod tests { + use super::*; + + use std::panic::catch_unwind; + + #[test] + fn gettemppath() { + assert_ne!("", get_temp_path(None).to_string_lossy()); + assert!(get_temp_path(None).starts_with(temp_dir())); + assert_eq!( + get_temp_path(None), + get_temp_path(Some( + current_exe() + .unwrap() + .file_name() + .unwrap() + .to_str() + .unwrap() + )) + ); + assert_ne!( + get_temp_path(Some("label")), + get_temp_path(Some( + current_exe() + .unwrap() + .file_name() + .unwrap() + .to_str() + .unwrap() + )) + ); + } + + #[test] + fn scopedpath_exists() { + let tmp_path = get_temp_path(None); + { + let scoped_path = ScopedPath::create(&tmp_path).unwrap(); + assert!(scoped_path.exists()); + } + assert!(!tmp_path.exists()); + } + + #[test] + fn scopedpath_notexists() { + let tmp_path = get_temp_path(None); + { + let _scoped_path = ScopedPath(&tmp_path); + } + assert!(!tmp_path.exists()); + } + + #[test] + fn scopedpath_panic() { + let tmp_path = get_temp_path(None); + assert!(catch_unwind(|| { + { + let scoped_path = ScopedPath::create(&tmp_path).unwrap(); + assert!(scoped_path.exists()); + panic!() + } + }) + .is_err()); + assert!(tmp_path.exists()); + remove_dir_all(&tmp_path).unwrap(); + } +} diff --git a/sys_util/src/scoped_signal_handler.rs b/sys_util/src/scoped_signal_handler.rs new file mode 100644 index 000000000..76286161c --- /dev/null +++ b/sys_util/src/scoped_signal_handler.rs @@ -0,0 +1,421 @@ +// Copyright 2021 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +//! Provides a struct for registering signal handlers that get cleared on drop. + +use std::convert::TryFrom; +use std::fmt::{self, Display}; +use std::io::{Cursor, Write}; +use std::panic::catch_unwind; +use std::result; + +use libc::{c_int, c_void, STDERR_FILENO}; + +use crate::errno; +use crate::signal::{ + clear_signal_handler, has_default_signal_handler, register_signal_handler, wait_for_signal, + Signal, +}; + +#[derive(Debug)] +pub enum Error { + /// Sigaction failed. + Sigaction(Signal, errno::Error), + /// Failed to check if signal has the default signal handler. + HasDefaultSignalHandler(Signal, errno::Error), + /// Failed to register a signal handler. + RegisterSignalHandler(Signal, errno::Error), + /// Signal already has a handler. + HandlerAlreadySet(Signal), + /// Already waiting for interrupt. + AlreadyWaiting, + /// Failed to wait for signal. + WaitForSignal(errno::Error), +} + +impl Display for Error { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + use self::Error::*; + + match self { + Sigaction(s, e) => write!(f, "sigaction failed for {0:?}: {1}", s, e), + HasDefaultSignalHandler(s, e) => { + write!(f, "failed to check the signal handler for {0:?}: {1}", s, e) + } + RegisterSignalHandler(s, e) => write!( + f, + "failed to register a signal handler for {0:?}: {1}", + s, e + ), + HandlerAlreadySet(s) => write!(f, "signal handler already set for {0:?}", s), + AlreadyWaiting => write!(f, "already waiting for interrupt."), + WaitForSignal(e) => write!(f, "wait_for_signal failed: {0}", e), + } + } +} + +pub type Result<T> = result::Result<T, Error>; + +/// The interface used by Scoped Signal handler. +/// +/// # Safety +/// The implementation of handle_signal needs to be async signal-safe. +/// +/// NOTE: panics are caught when possible because a panic inside ffi is undefined behavior. +pub unsafe trait SignalHandler { + /// A function that is called to handle the passed signal. + fn handle_signal(signal: Signal); +} + +/// Wrap the handler with an extern "C" function. +extern "C" fn call_handler<H: SignalHandler>(signum: c_int) { + // Make an effort to surface an error. + if catch_unwind(|| H::handle_signal(Signal::try_from(signum).unwrap())).is_err() { + // Note the following cannot be used: + // eprintln! - uses std::io which has locks that may be held. + // format! - uses the allocator which enforces mutual exclusion. + + // Get the debug representation of signum. + let signal: Signal; + let signal_debug: &dyn fmt::Debug = match Signal::try_from(signum) { + Ok(s) => { + signal = s; + &signal as &dyn fmt::Debug + } + Err(_) => &signum as &dyn fmt::Debug, + }; + + // Buffer the output, so a single call to write can be used. + // The message accounts for 29 chars, that leaves 35 for the string representation of the + // signal which is more than enough. + let mut buffer = [0u8; 64]; + let mut cursor = Cursor::new(buffer.as_mut()); + if writeln!(cursor, "signal handler got error for: {:?}", signal_debug).is_ok() { + let len = cursor.position() as usize; + // Safe in the sense that buffer is owned and the length is checked. This may print in + // the middle of an existing write, but that is considered better than dropping the + // error. + unsafe { + libc::write( + STDERR_FILENO, + cursor.get_ref().as_ptr() as *const c_void, + len, + ) + }; + } else { + // This should never happen, but write an error message just in case. + const ERROR_DROPPED: &str = "Error dropped by signal handler."; + let bytes = ERROR_DROPPED.as_bytes(); + unsafe { libc::write(STDERR_FILENO, bytes.as_ptr() as *const c_void, bytes.len()) }; + } + } +} + +/// Represents a signal handler that is registered with a set of signals that unregistered when the +/// struct goes out of scope. Prefer a signalfd based solution before using this. +pub struct ScopedSignalHandler { + signals: Vec<Signal>, +} + +impl ScopedSignalHandler { + /// Attempts to register `handler` with the provided `signals`. It will fail if there is already + /// an existing handler on any of `signals`. + /// + /// # Safety + /// This is safe if H::handle_signal is async-signal safe. + pub fn new<H: SignalHandler>(signals: &[Signal]) -> Result<Self> { + let mut scoped_handler = ScopedSignalHandler { + signals: Vec::with_capacity(signals.len()), + }; + for &signal in signals { + if !has_default_signal_handler((signal).into()) + .map_err(|err| Error::HasDefaultSignalHandler(signal, err))? + { + return Err(Error::HandlerAlreadySet(signal)); + } + // Requires an async-safe callback. + unsafe { + register_signal_handler((signal).into(), call_handler::<H>) + .map_err(|err| Error::RegisterSignalHandler(signal, err))? + }; + scoped_handler.signals.push(signal); + } + Ok(scoped_handler) + } +} + +/// Clears the signal handler for any of the associated signals. +impl Drop for ScopedSignalHandler { + fn drop(&mut self) { + for signal in &self.signals { + if let Err(err) = clear_signal_handler((*signal).into()) { + eprintln!("Error: failed to clear signal handler: {:?}", err); + } + } + } +} + +/// A signal handler that does nothing. +/// +/// This is useful in cases where wait_for_signal is used since it will never trigger if the signal +/// is blocked and the default handler may have undesired effects like terminating the process. +pub struct EmptySignalHandler; +/// # Safety +/// Safe because handle_signal is async-signal safe. +unsafe impl SignalHandler for EmptySignalHandler { + fn handle_signal(_: Signal) {} +} + +/// Blocks until SIGINT is received, which often happens because Ctrl-C was pressed in an +/// interactive terminal. +/// +/// Note: if you are using a multi-threaded application you need to block SIGINT on all other +/// threads or they may receive the signal instead of the desired thread. +pub fn wait_for_interrupt() -> Result<()> { + // Register a signal handler if there is not one already so the thread is not killed. + let ret = ScopedSignalHandler::new::<EmptySignalHandler>(&[Signal::Interrupt]); + if !matches!(&ret, Ok(_) | Err(Error::HandlerAlreadySet(_))) { + ret?; + } + + match wait_for_signal(&[Signal::Interrupt.into()], None) { + Ok(_) => Ok(()), + Err(err) => Err(Error::WaitForSignal(err)), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use std::fs::File; + use std::io::{BufRead, BufReader}; + use std::mem::zeroed; + use std::ptr::{null, null_mut}; + use std::sync::atomic::{AtomicI32, AtomicUsize, Ordering}; + use std::sync::{Arc, Mutex, MutexGuard, Once}; + use std::thread::{sleep, spawn}; + use std::time::{Duration, Instant}; + + use libc::sigaction; + + use crate::{gettid, kill, Pid}; + + const TEST_SIGNAL: Signal = Signal::User1; + const TEST_SIGNALS: &[Signal] = &[Signal::User1, Signal::User2]; + + static TEST_SIGNAL_COUNTER: AtomicUsize = AtomicUsize::new(0); + + /// Only allows one test case to execute at a time. + fn get_mutex() -> MutexGuard<'static, ()> { + static INIT: Once = Once::new(); + static mut VAL: Option<Arc<Mutex<()>>> = None; + + INIT.call_once(|| { + let val = Some(Arc::new(Mutex::new(()))); + // Safe because the mutation is protected by the Once. + unsafe { VAL = val } + }); + + // Safe mutation only happens in the Once. + unsafe { VAL.as_ref() }.unwrap().lock().unwrap() + } + + fn reset_counter() { + TEST_SIGNAL_COUNTER.swap(0, Ordering::SeqCst); + } + + fn get_sigaction(signal: Signal) -> Result<sigaction> { + // Safe because sigaction is owned and expected to be initialized ot zeros. + let mut sigact: sigaction = unsafe { zeroed() }; + + if unsafe { sigaction(signal.into(), null(), &mut sigact) } < 0 { + Err(Error::Sigaction(signal, errno::Error::last())) + } else { + Ok(sigact) + } + } + + /// Safety: + /// This is only safe if the signal handler set in sigaction is safe. + unsafe fn restore_sigaction(signal: Signal, sigact: sigaction) -> Result<sigaction> { + if sigaction(signal.into(), &sigact, null_mut()) < 0 { + Err(Error::Sigaction(signal, errno::Error::last())) + } else { + Ok(sigact) + } + } + + /// Safety: + /// Safe if the signal handler for Signal::User1 is safe. + unsafe fn send_test_signal() { + kill(gettid(), Signal::User1.into()).unwrap() + } + + macro_rules! assert_counter_eq { + ($compare_to:expr) => {{ + let expected: usize = $compare_to; + let got: usize = TEST_SIGNAL_COUNTER.load(Ordering::SeqCst); + if got != expected { + panic!( + "wrong signal counter value: got {}; expected {}", + got, expected + ); + } + }}; + } + + struct TestHandler; + + /// # Safety + /// Safe because handle_signal is async-signal safe. + unsafe impl SignalHandler for TestHandler { + fn handle_signal(signal: Signal) { + if TEST_SIGNAL == signal { + TEST_SIGNAL_COUNTER.fetch_add(1, Ordering::SeqCst); + } + } + } + + #[test] + fn scopedsignalhandler_success() { + // Prevent other test cases from running concurrently since the signal + // handlers are shared for the process. + let _guard = get_mutex(); + + reset_counter(); + assert_counter_eq!(0); + + assert!(has_default_signal_handler(TEST_SIGNAL.into()).unwrap()); + let handler = ScopedSignalHandler::new::<TestHandler>(&[TEST_SIGNAL]).unwrap(); + assert!(!has_default_signal_handler(TEST_SIGNAL.into()).unwrap()); + + // Safe because test_handler is safe. + unsafe { send_test_signal() }; + + // Give the handler time to run in case it is on a different thread. + for _ in 1..40 { + if TEST_SIGNAL_COUNTER.load(Ordering::SeqCst) > 0 { + break; + } + sleep(Duration::from_millis(250)); + } + + assert_counter_eq!(1); + + drop(handler); + assert!(has_default_signal_handler(TEST_SIGNAL.into()).unwrap()); + } + + #[test] + fn scopedsignalhandler_handleralreadyset() { + // Prevent other test cases from running concurrently since the signal + // handlers are shared for the process. + let _guard = get_mutex(); + + reset_counter(); + assert_counter_eq!(0); + + assert!(has_default_signal_handler(TEST_SIGNAL.into()).unwrap()); + // Safe because TestHandler is async-signal safe. + let handler = ScopedSignalHandler::new::<TestHandler>(&[TEST_SIGNAL]).unwrap(); + assert!(!has_default_signal_handler(TEST_SIGNAL.into()).unwrap()); + + // Safe because TestHandler is async-signal safe. + assert!(matches!( + ScopedSignalHandler::new::<TestHandler>(&TEST_SIGNALS), + Err(Error::HandlerAlreadySet(Signal::User1)) + )); + + assert_counter_eq!(0); + drop(handler); + assert!(has_default_signal_handler(TEST_SIGNAL.into()).unwrap()); + } + + /// Stores the thread used by WaitForInterruptHandler. + static WAIT_FOR_INTERRUPT_THREAD_ID: AtomicI32 = AtomicI32::new(0); + /// Forwards SIGINT to the appropriate thread. + struct WaitForInterruptHandler; + + /// # Safety + /// Safe because handle_signal is async-signal safe. + unsafe impl SignalHandler for WaitForInterruptHandler { + fn handle_signal(_: Signal) { + let tid = WAIT_FOR_INTERRUPT_THREAD_ID.load(Ordering::SeqCst); + // If the thread ID is set and executed on the wrong thread, forward the signal. + if tid != 0 && gettid() != tid { + // Safe because the handler is safe and the target thread id is expecting the signal. + unsafe { kill(tid, Signal::Interrupt.into()) }.unwrap(); + } + } + } + + /// Query /proc/${tid}/status for its State and check if it is either S (sleeping) or in + /// D (disk sleep). + fn thread_is_sleeping(tid: Pid) -> result::Result<bool, errno::Error> { + const PREFIX: &str = "State:"; + let mut status_reader = BufReader::new(File::open(format!("/proc/{}/status", tid))?); + let mut line = String::new(); + loop { + let count = status_reader.read_line(&mut line)?; + if count == 0 { + return Err(errno::Error::new(libc::EIO)); + } + if line.starts_with(PREFIX) { + return Ok(matches!( + line[PREFIX.len()..].trim_start().chars().next(), + Some('S') | Some('D') + )); + } + line.clear(); + } + } + + /// Wait for a process to block either in a sleeping or disk sleep state. + fn wait_for_thread_to_sleep(tid: Pid, timeout: Duration) -> result::Result<(), errno::Error> { + let start = Instant::now(); + loop { + if thread_is_sleeping(tid)? { + return Ok(()); + } + if start.elapsed() > timeout { + return Err(errno::Error::new(libc::EAGAIN)); + } + sleep(Duration::from_millis(50)); + } + } + + #[test] + fn waitforinterrupt_success() { + // Prevent other test cases from running concurrently since the signal + // handlers are shared for the process. + let _guard = get_mutex(); + + let to_restore = get_sigaction(Signal::Interrupt).unwrap(); + clear_signal_handler(Signal::Interrupt.into()).unwrap(); + // Safe because TestHandler is async-signal safe. + let handler = + ScopedSignalHandler::new::<WaitForInterruptHandler>(&[Signal::Interrupt]).unwrap(); + + let tid = gettid(); + WAIT_FOR_INTERRUPT_THREAD_ID.store(tid, Ordering::SeqCst); + + let join_handle = spawn(move || -> result::Result<(), errno::Error> { + // Wait unitl the thread is ready to receive the signal. + wait_for_thread_to_sleep(tid, Duration::from_secs(10)).unwrap(); + + // Safe because the SIGINT handler is safe. + unsafe { kill(tid, Signal::Interrupt.into()) } + }); + let wait_ret = wait_for_interrupt(); + let join_ret = join_handle.join(); + + drop(handler); + // Safe because we are restoring the previous SIGINT handler. + unsafe { restore_sigaction(Signal::Interrupt, to_restore) }.unwrap(); + + wait_ret.unwrap(); + join_ret.unwrap().unwrap(); + } +} diff --git a/sys_util/src/shm.rs b/sys_util/src/shm.rs index d9a583990..927734aef 100644 --- a/sys_util/src/shm.rs +++ b/sys_util/src/shm.rs @@ -5,19 +5,21 @@ use std::ffi::{CStr, CString}; use std::fs::{read_link, File}; use std::io::{self, Read, Seek, SeekFrom, Write}; -use std::os::unix::io::{AsRawFd, FromRawFd, RawFd}; +use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd}; use libc::{ - self, c_char, c_int, c_long, c_uint, close, fcntl, ftruncate64, off64_t, syscall, EINVAL, - F_ADD_SEALS, F_GET_SEALS, F_SEAL_GROW, F_SEAL_SEAL, F_SEAL_SHRINK, F_SEAL_WRITE, - MFD_ALLOW_SEALING, + self, c_char, c_int, c_long, c_uint, close, fcntl, ftruncate64, off64_t, syscall, + SYS_memfd_create, EINVAL, F_ADD_SEALS, F_GET_SEALS, F_SEAL_GROW, F_SEAL_SEAL, F_SEAL_SHRINK, + F_SEAL_WRITE, MFD_ALLOW_SEALING, }; -use syscall_defines::linux::LinuxSyscall::SYS_memfd_create; +use serde::{Deserialize, Serialize}; use crate::{errno, errno_result, Result}; /// A shared memory file descriptor and its size. +#[derive(Serialize, Deserialize)] pub struct SharedMemory { + #[serde(with = "crate::with_as_descriptor")] fd: File, size: u64, } @@ -267,9 +269,15 @@ impl AsRawFd for &SharedMemory { } } -impl Into<File> for SharedMemory { - fn into(self) -> File { - self.fd +impl IntoRawFd for SharedMemory { + fn into_raw_fd(self) -> RawFd { + self.fd.into_raw_fd() + } +} + +impl From<SharedMemory> for File { + fn from(s: SharedMemory) -> File { + s.fd } } diff --git a/sys_util/src/signal.rs b/sys_util/src/signal.rs index 8f5bc6a92..1191dccb5 100644 --- a/sys_util/src/signal.rs +++ b/sys_util/src/signal.rs @@ -4,20 +4,27 @@ use libc::{ c_int, pthread_kill, pthread_sigmask, pthread_t, sigaction, sigaddset, sigemptyset, siginfo_t, - sigismember, sigpending, sigset_t, sigtimedwait, timespec, EAGAIN, EINTR, EINVAL, SA_RESTART, - SIG_BLOCK, SIG_UNBLOCK, + sigismember, sigpending, sigset_t, sigtimedwait, sigwait, timespec, waitpid, EAGAIN, EINTR, + EINVAL, SA_RESTART, SIG_BLOCK, SIG_DFL, SIG_UNBLOCK, WNOHANG, }; use std::cmp::Ordering; +use std::convert::TryFrom; use std::fmt::{self, Display}; use std::io; use std::mem; use std::os::unix::thread::JoinHandleExt; +use std::process::Child; use std::ptr::{null, null_mut}; use std::result; use std::thread::JoinHandle; +use std::time::{Duration, Instant}; -use crate::{errno, errno_result}; +use crate::{duration_to_timespec, errno, errno_result, getsid, Pid, Result}; +use std::ops::{Deref, DerefMut}; + +const POLL_RATE: Duration = Duration::from_millis(50); +const DEFAULT_KILL_TIMEOUT: Duration = Duration::from_secs(5); #[derive(Debug)] pub enum Error { @@ -39,6 +46,20 @@ pub enum Error { ClearGetPending(errno::Error), /// Failed to check if given signal is in the set of pending signals. ClearCheckPending(errno::Error), + /// Failed to send signal to pid. + Kill(errno::Error), + /// Failed to get session id. + GetSid(errno::Error), + /// Failed to wait for signal. + WaitForSignal(errno::Error), + /// Failed to wait for pid. + WaitPid(errno::Error), + /// Timeout reached. + TimedOut, + /// Failed to convert signum to Signal. + UnrecognizedSignum(i32), + /// Converted signum greater than SIGRTMAX. + RtSignumGreaterThanMax(Signal), } impl Display for Error { @@ -67,7 +88,184 @@ impl Display for Error { "failed to check whether given signal is in the pending set: {}", e, ), + Kill(e) => write!(f, "failed to send signal: {}", e), + GetSid(e) => write!(f, "failed to get session id: {}", e), + WaitForSignal(e) => write!(f, "failed to wait for signal: {}", e), + WaitPid(e) => write!(f, "failed to wait for process: {}", e), + TimedOut => write!(f, "timeout reached."), + UnrecognizedSignum(signum) => write!(f, "unrecoginized signal number: {}", signum), + RtSignumGreaterThanMax(signal) => { + write!(f, "got RT signal greater than max: {:?}", signal) + } + } + } +} + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +#[repr(i32)] +pub enum Signal { + Abort = libc::SIGABRT, + Alarm = libc::SIGALRM, + Bus = libc::SIGBUS, + Child = libc::SIGCHLD, + Continue = libc::SIGCONT, + ExceededFileSize = libc::SIGXFSZ, + FloatingPointException = libc::SIGFPE, + HangUp = libc::SIGHUP, + IllegalInstruction = libc::SIGILL, + Interrupt = libc::SIGINT, + Io = libc::SIGIO, + Kill = libc::SIGKILL, + Pipe = libc::SIGPIPE, + Power = libc::SIGPWR, + Profile = libc::SIGPROF, + Quit = libc::SIGQUIT, + SegmentationViolation = libc::SIGSEGV, + StackFault = libc::SIGSTKFLT, + Stop = libc::SIGSTOP, + Sys = libc::SIGSYS, + Trap = libc::SIGTRAP, + Terminate = libc::SIGTERM, + TtyIn = libc::SIGTTIN, + TtyOut = libc::SIGTTOU, + TtyStop = libc::SIGTSTP, + Urgent = libc::SIGURG, + User1 = libc::SIGUSR1, + User2 = libc::SIGUSR2, + VtAlarm = libc::SIGVTALRM, + Winch = libc::SIGWINCH, + Xcpu = libc::SIGXCPU, + // Rt signal numbers are be adjusted in the conversion to integer. + Rt0 = libc::SIGSYS + 1, + Rt1, + Rt2, + Rt3, + Rt4, + Rt5, + Rt6, + Rt7, + // Only 8 are guaranteed by POSIX, Linux has 32, but only 29 or 30 are usable. + Rt8, + Rt9, + Rt10, + Rt11, + Rt12, + Rt13, + Rt14, + Rt15, + Rt16, + Rt17, + Rt18, + Rt19, + Rt20, + Rt21, + Rt22, + Rt23, + Rt24, + Rt25, + Rt26, + Rt27, + Rt28, + Rt29, + Rt30, + Rt31, +} + +impl From<Signal> for c_int { + fn from(signal: Signal) -> c_int { + let num = signal as libc::c_int; + if num >= Signal::Rt0 as libc::c_int { + return num - (Signal::Rt0 as libc::c_int) + SIGRTMIN(); } + num + } +} + +impl TryFrom<c_int> for Signal { + type Error = Error; + + fn try_from(value: c_int) -> result::Result<Self, Self::Error> { + use Signal::*; + + Ok(match value { + libc::SIGABRT => Abort, + libc::SIGALRM => Alarm, + libc::SIGBUS => Bus, + libc::SIGCHLD => Child, + libc::SIGCONT => Continue, + libc::SIGXFSZ => ExceededFileSize, + libc::SIGFPE => FloatingPointException, + libc::SIGHUP => HangUp, + libc::SIGILL => IllegalInstruction, + libc::SIGINT => Interrupt, + libc::SIGIO => Io, + libc::SIGKILL => Kill, + libc::SIGPIPE => Pipe, + libc::SIGPWR => Power, + libc::SIGPROF => Profile, + libc::SIGQUIT => Quit, + libc::SIGSEGV => SegmentationViolation, + libc::SIGSTKFLT => StackFault, + libc::SIGSTOP => Stop, + libc::SIGSYS => Sys, + libc::SIGTRAP => Trap, + libc::SIGTERM => Terminate, + libc::SIGTTIN => TtyIn, + libc::SIGTTOU => TtyOut, + libc::SIGTSTP => TtyStop, + libc::SIGURG => Urgent, + libc::SIGUSR1 => User1, + libc::SIGUSR2 => User2, + libc::SIGVTALRM => VtAlarm, + libc::SIGWINCH => Winch, + libc::SIGXCPU => Xcpu, + _ => { + if value < SIGRTMIN() { + return Err(Error::UnrecognizedSignum(value)); + } + let signal = match value - SIGRTMIN() { + 0 => Rt0, + 1 => Rt1, + 2 => Rt2, + 3 => Rt3, + 4 => Rt4, + 5 => Rt5, + 6 => Rt6, + 7 => Rt7, + 8 => Rt8, + 9 => Rt9, + 10 => Rt10, + 11 => Rt11, + 12 => Rt12, + 13 => Rt13, + 14 => Rt14, + 15 => Rt15, + 16 => Rt16, + 17 => Rt17, + 18 => Rt18, + 19 => Rt19, + 20 => Rt20, + 21 => Rt21, + 22 => Rt22, + 23 => Rt23, + 24 => Rt24, + 25 => Rt25, + 26 => Rt26, + 27 => Rt27, + 28 => Rt28, + 29 => Rt29, + 30 => Rt30, + 31 => Rt31, + _ => { + return Err(Error::UnrecognizedSignum(value)); + } + }; + if value > SIGRTMAX() { + return Err(Error::RtSignumGreaterThanMax(signal)); + } + signal + } + }) } } @@ -101,7 +299,10 @@ fn valid_rt_signal_num(num: c_int) -> bool { /// /// This is considered unsafe because the given handler will be called asynchronously, interrupting /// whatever the thread was doing and therefore must only do async-signal-safe operations. -pub unsafe fn register_signal_handler(num: c_int, handler: extern "C" fn()) -> errno::Result<()> { +pub unsafe fn register_signal_handler( + num: c_int, + handler: extern "C" fn(c_int), +) -> errno::Result<()> { let mut sigact: sigaction = mem::zeroed(); sigact.sa_flags = SA_RESTART; sigact.sa_sigaction = handler as *const () as usize; @@ -114,6 +315,36 @@ pub unsafe fn register_signal_handler(num: c_int, handler: extern "C" fn()) -> e Ok(()) } +/// Resets the signal handler of signum `num` back to the default. +pub fn clear_signal_handler(num: c_int) -> errno::Result<()> { + // Safe because sigaction is owned and expected to be initialized ot zeros. + let mut sigact: sigaction = unsafe { mem::zeroed() }; + sigact.sa_flags = SA_RESTART; + sigact.sa_sigaction = SIG_DFL; + + // Safe because sigact is owned, and this is restoring the default signal handler. + let ret = unsafe { sigaction(num, &sigact, null_mut()) }; + if ret < 0 { + return errno_result(); + } + + Ok(()) +} + +/// Returns true if the signal handler for signum `num` is the default. +pub fn has_default_signal_handler(num: c_int) -> errno::Result<bool> { + // Safe because sigaction is owned and expected to be initialized ot zeros. + let mut sigact: sigaction = unsafe { mem::zeroed() }; + + // Safe because sigact is owned, and this is just querying the existing state. + let ret = unsafe { sigaction(num, null(), &mut sigact) }; + if ret < 0 { + return errno_result(); + } + + Ok(sigact.sa_sigaction == SIG_DFL) +} + /// Registers `handler` as the signal handler for the real-time signal with signum `num`. /// /// The value of `num` must be within [`SIGRTMIN`, `SIGRTMAX`] range. @@ -124,7 +355,7 @@ pub unsafe fn register_signal_handler(num: c_int, handler: extern "C" fn()) -> e /// whatever the thread was doing and therefore must only do async-signal-safe operations. pub unsafe fn register_rt_signal_handler( num: c_int, - handler: extern "C" fn(), + handler: extern "C" fn(c_int), ) -> errno::Result<()> { if !valid_rt_signal_num(num) { return Err(errno::Error::new(EINVAL)); @@ -157,6 +388,34 @@ pub fn create_sigset(signals: &[c_int]) -> errno::Result<sigset_t> { Ok(sigset) } +/// Wait for signal before continuing. The signal number of the consumed signal is returned on +/// success. EAGAIN means the timeout was reached. +pub fn wait_for_signal(signals: &[c_int], timeout: Option<Duration>) -> errno::Result<c_int> { + let sigset = create_sigset(signals)?; + + match timeout { + Some(timeout) => { + let ts = duration_to_timespec(timeout); + // Safe - return value is checked. + let ret = handle_eintr_errno!(unsafe { sigtimedwait(&sigset, null_mut(), &ts) }); + if ret < 0 { + errno_result() + } else { + Ok(ret) + } + } + None => { + let mut ret: c_int = 0; + let err = handle_eintr_rc!(unsafe { sigwait(&sigset, &mut ret as *mut c_int) }); + if err != 0 { + Err(errno::Error::new(err)) + } else { + Ok(ret) + } + } + } +} + /// Retrieves the signal mask of the current thread as a vector of c_ints. pub fn get_blocked_signals() -> SignalResult<Vec<c_int>> { let mut mask = Vec::new(); @@ -266,6 +525,20 @@ pub fn clear_signal(num: c_int) -> SignalResult<()> { Ok(()) } +/// # Safety +/// This is marked unsafe because it allows signals to be sent to arbitrary PIDs. Sending some +/// signals may lead to undefined behavior. Also, the return codes of the child processes need to be +/// reaped to avoid leaking zombie processes. +pub unsafe fn kill(pid: Pid, signum: c_int) -> Result<()> { + let ret = libc::kill(pid, signum); + + if ret != 0 { + errno_result() + } else { + Ok(()) + } +} + /// Trait for threads that can be signalled via `pthread_kill`. /// /// Note that this is only useful for signals between SIGRTMIN and SIGRTMAX because these are @@ -300,3 +573,156 @@ unsafe impl<T> Killable for JoinHandle<T> { self.as_pthread_t() } } + +/// Treat some errno's as Ok(()). +macro_rules! ok_if { + ($result:expr, $($errno:pat)|+) => {{ + let res = $result; + match res { + Ok(_) => Ok(()), + Err(err) => { + if matches!(err.errno(), $($errno)|+) { + Ok(()) + } else { + Err(err) + } + } + } + }} +} + +/// Terminates and reaps a child process. If the child process is a group leader, its children will +/// be terminated and reaped as well. After the given timeout, the child process and any relevant +/// children are killed (i.e. sent SIGKILL). +pub fn kill_tree(child: &mut Child, terminate_timeout: Duration) -> SignalResult<()> { + let target = { + let pid = child.id() as Pid; + if getsid(Some(pid)).map_err(Error::GetSid)? == pid { + -pid + } else { + pid + } + }; + + // Safe because target is a child process (or group) and behavior of SIGTERM is defined. + ok_if!(unsafe { kill(target, libc::SIGTERM) }, libc::ESRCH).map_err(Error::Kill)?; + + // Reap the direct child first in case it waits for its descendants, afterward reap any + // remaining group members. + let start = Instant::now(); + let mut child_running = true; + loop { + // Wait for the direct child to exit before reaping any process group members. + if child_running { + if child + .try_wait() + .map_err(|e| Error::WaitPid(errno::Error::from(e)))? + .is_some() + { + child_running = false; + // Skip the timeout check because waitpid(..., WNOHANG) will not block. + continue; + } + } else { + // Safe because target is a child process (or group), WNOHANG is used, and the return + // value is checked. + let ret = unsafe { waitpid(target, null_mut(), WNOHANG) }; + match ret { + -1 => { + let err = errno::Error::last(); + if err.errno() == libc::ECHILD { + // No group members to wait on. + break; + } + return Err(Error::WaitPid(err)); + } + 0 => {} + // If a process was reaped, skip the timeout check in case there are more. + _ => continue, + }; + } + + // Check for a timeout. + let elapsed = start.elapsed(); + if elapsed > terminate_timeout { + // Safe because target is a child process (or group) and behavior of SIGKILL is defined. + ok_if!(unsafe { kill(target, libc::SIGKILL) }, libc::ESRCH).map_err(Error::Kill)?; + return Err(Error::TimedOut); + } + + // Wait a SIGCHLD or until either the remaining time or a poll interval elapses. + ok_if!( + wait_for_signal( + &[libc::SIGCHLD], + Some(POLL_RATE.min(terminate_timeout - elapsed)) + ), + libc::EAGAIN | libc::EINTR + ) + .map_err(Error::WaitForSignal)? + } + + Ok(()) +} + +/// Wraps a Child process, and calls kill_tree for its process group to clean +/// it up when dropped. +pub struct KillOnDrop { + process: Child, + timeout: Duration, +} + +impl KillOnDrop { + /// Get the timeout. See timeout_mut() for more details. + pub fn timeout(&self) -> Duration { + self.timeout + } + + /// Change the timeout for how long child processes are waited for before + /// the process group is forcibly killed. + pub fn timeout_mut(&mut self) -> &mut Duration { + &mut self.timeout + } +} + +impl From<Child> for KillOnDrop { + fn from(process: Child) -> Self { + KillOnDrop { + process, + timeout: DEFAULT_KILL_TIMEOUT, + } + } +} + +impl AsRef<Child> for KillOnDrop { + fn as_ref(&self) -> &Child { + &self.process + } +} + +impl AsMut<Child> for KillOnDrop { + fn as_mut(&mut self) -> &mut Child { + &mut self.process + } +} + +impl Deref for KillOnDrop { + type Target = Child; + + fn deref(&self) -> &Self::Target { + &self.process + } +} + +impl DerefMut for KillOnDrop { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.process + } +} + +impl Drop for KillOnDrop { + fn drop(&mut self) { + if let Err(err) = kill_tree(&mut self.process, self.timeout) { + eprintln!("failed to kill child process group: {}", err); + } + } +} diff --git a/sys_util/src/sock_ctrl_msg.rs b/sys_util/src/sock_ctrl_msg.rs index a9da8bac4..4bdfdc71b 100644 --- a/sys_util/src/sock_ctrl_msg.rs +++ b/sys_util/src/sock_ctrl_msg.rs @@ -105,11 +105,11 @@ impl CmsgBuffer { } } -fn raw_sendmsg<D: IntoIobuf>(fd: RawFd, out_data: &[D], out_fds: &[RawFd]) -> Result<usize> { +fn raw_sendmsg<D: AsIobuf>(fd: RawFd, out_data: &[D], out_fds: &[RawFd]) -> Result<usize> { let cmsg_capacity = CMSG_SPACE!(size_of::<RawFd>() * out_fds.len()); let mut cmsg_buffer = CmsgBuffer::with_capacity(cmsg_capacity); - let iovec = IntoIobuf::as_iobufs(out_data); + let iovec = AsIobuf::as_iobuf_slice(out_data); let mut msg = msghdr { msg_name: null_mut(), @@ -155,19 +155,26 @@ fn raw_sendmsg<D: IntoIobuf>(fd: RawFd, out_data: &[D], out_fds: &[RawFd]) -> Re } fn raw_recvmsg(fd: RawFd, in_data: &mut [u8], in_fds: &mut [RawFd]) -> Result<(usize, usize)> { - let cmsg_capacity = CMSG_SPACE!(size_of::<RawFd>() * in_fds.len()); - let mut cmsg_buffer = CmsgBuffer::with_capacity(cmsg_capacity); - - let mut iovec = iovec { + let iovec = iovec { iov_base: in_data.as_mut_ptr() as *mut c_void, iov_len: in_data.len(), }; + raw_recvmsg_iovecs(fd, &mut [iovec], in_fds) +} + +fn raw_recvmsg_iovecs( + fd: RawFd, + iovecs: &mut [iovec], + in_fds: &mut [RawFd], +) -> Result<(usize, usize)> { + let cmsg_capacity = CMSG_SPACE!(size_of::<RawFd>() * in_fds.len()); + let mut cmsg_buffer = CmsgBuffer::with_capacity(cmsg_capacity); let mut msg = msghdr { msg_name: null_mut(), msg_namelen: 0, - msg_iov: &mut iovec as *mut iovec, - msg_iovlen: 1, + msg_iov: iovecs.as_mut_ptr() as *mut iovec, + msg_iovlen: iovecs.len(), msg_control: null_mut(), msg_controllen: 0, msg_flags: 0, @@ -232,7 +239,7 @@ pub trait ScmSocket { /// /// * `buf` - A buffer of data to send on the `socket`. /// * `fd` - A file descriptors to be sent. - fn send_with_fd<D: IntoIobuf>(&self, buf: &[D], fd: RawFd) -> Result<usize> { + fn send_with_fd<D: AsIobuf>(&self, buf: &[D], fd: RawFd) -> Result<usize> { self.send_with_fds(buf, &[fd]) } @@ -244,10 +251,35 @@ pub trait ScmSocket { /// /// * `buf` - A buffer of data to send on the `socket`. /// * `fds` - A list of file descriptors to be sent. - fn send_with_fds<D: IntoIobuf>(&self, buf: &[D], fd: &[RawFd]) -> Result<usize> { + fn send_with_fds<D: AsIobuf>(&self, buf: &[D], fd: &[RawFd]) -> Result<usize> { raw_sendmsg(self.socket_fd(), buf, fd) } + /// Sends the given data and file descriptor over the socket. + /// + /// On success, returns the number of bytes sent. + /// + /// # Arguments + /// + /// * `bufs` - A slice of slices of data to send on the `socket`. + /// * `fd` - A file descriptors to be sent. + fn send_bufs_with_fd(&self, bufs: &[&[u8]], fd: RawFd) -> Result<usize> { + self.send_bufs_with_fds(bufs, &[fd]) + } + + /// Sends the given data and file descriptors over the socket. + /// + /// On success, returns the number of bytes sent. + /// + /// # Arguments + /// + /// * `bufs` - A slice of slices of data to send on the `socket`. + /// * `fds` - A list of file descriptors to be sent. + fn send_bufs_with_fds(&self, bufs: &[&[u8]], fd: &[RawFd]) -> Result<usize> { + let slices: Vec<IoSlice> = bufs.iter().map(|&b| IoSlice::new(b)).collect(); + raw_sendmsg(self.socket_fd(), &slices, fd) + } + /// Receives data and potentially a file descriptor from the socket. /// /// On success, returns the number of bytes and an optional file descriptor. @@ -284,6 +316,27 @@ pub trait ScmSocket { fn recv_with_fds(&self, buf: &mut [u8], fds: &mut [RawFd]) -> Result<(usize, usize)> { raw_recvmsg(self.socket_fd(), buf, fds) } + + /// Receives data and file descriptors from the socket. + /// + /// On success, returns the number of bytes and file descriptors received as a tuple + /// `(bytes count, files count)`. + /// + /// # Arguments + /// + /// * `ioves` - A slice of iovecs to store received data. + /// * `fds` - A slice of `RawFd`s to put the received file descriptors into. On success, the + /// number of valid file descriptors is indicated by the second element of the + /// returned tuple. The caller owns these file descriptors, but they will not be + /// closed on drop like a `File`-like type would be. It is recommended that each valid + /// file descriptor gets wrapped in a drop type that closes it after this returns. + fn recv_iovecs_with_fds( + &self, + iovecs: &mut [iovec], + fds: &mut [RawFd], + ) -> Result<(usize, usize)> { + raw_recvmsg_iovecs(self.socket_fd(), iovecs, fds) + } } impl ScmSocket for UnixDatagram { @@ -309,25 +362,26 @@ impl ScmSocket for UnixSeqpacket { /// /// This trait is unsafe because interfaces that use this trait depend on the base pointer and size /// being accurate. -pub unsafe trait IntoIobuf: Sized { +pub unsafe trait AsIobuf: Sized { /// Returns a `iovec` that describes a contiguous region of memory. - fn into_iobuf(&self) -> iovec; + fn as_iobuf(&self) -> iovec; /// Returns a slice of `iovec`s that each describe a contiguous region of memory. - fn as_iobufs(bufs: &[Self]) -> &[iovec]; + #[allow(clippy::wrong_self_convention)] + fn as_iobuf_slice(bufs: &[Self]) -> &[iovec]; } // Safe because there are no other mutable references to the memory described by `IoSlice` and it is // guaranteed to be ABI-compatible with `iovec`. -unsafe impl<'a> IntoIobuf for IoSlice<'a> { - fn into_iobuf(&self) -> iovec { +unsafe impl<'a> AsIobuf for IoSlice<'a> { + fn as_iobuf(&self) -> iovec { iovec { iov_base: self.as_ptr() as *mut c_void, iov_len: self.len(), } } - fn as_iobufs(bufs: &[Self]) -> &[iovec] { + fn as_iobuf_slice(bufs: &[Self]) -> &[iovec] { // Safe because `IoSlice` is guaranteed to be ABI-compatible with `iovec`. unsafe { slice::from_raw_parts(bufs.as_ptr() as *const iovec, bufs.len()) } } @@ -335,15 +389,15 @@ unsafe impl<'a> IntoIobuf for IoSlice<'a> { // Safe because there are no other references to the memory described by `IoSliceMut` and it is // guaranteed to be ABI-compatible with `iovec`. -unsafe impl<'a> IntoIobuf for IoSliceMut<'a> { - fn into_iobuf(&self) -> iovec { +unsafe impl<'a> AsIobuf for IoSliceMut<'a> { + fn as_iobuf(&self) -> iovec { iovec { iov_base: self.as_ptr() as *mut c_void, iov_len: self.len(), } } - fn as_iobufs(bufs: &[Self]) -> &[iovec] { + fn as_iobuf_slice(bufs: &[Self]) -> &[iovec] { // Safe because `IoSliceMut` is guaranteed to be ABI-compatible with `iovec`. unsafe { slice::from_raw_parts(bufs.as_ptr() as *const iovec, bufs.len()) } } @@ -351,12 +405,12 @@ unsafe impl<'a> IntoIobuf for IoSliceMut<'a> { // Safe because volatile slices are only ever accessed with other volatile interfaces and the // pointer and size are guaranteed to be accurate. -unsafe impl<'a> IntoIobuf for VolatileSlice<'a> { - fn into_iobuf(&self) -> iovec { +unsafe impl<'a> AsIobuf for VolatileSlice<'a> { + fn as_iobuf(&self) -> iovec { *self.as_iobuf() } - fn as_iobufs(bufs: &[Self]) -> &[iovec] { + fn as_iobuf_slice(bufs: &[Self]) -> &[iovec] { VolatileSlice::as_iobufs(bufs) } } @@ -416,7 +470,8 @@ mod tests { fn send_recv_no_fd() { let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair"); - let ioslice = IoSlice::new([1u8, 1, 2, 21, 34, 55].as_ref()); + let send_buf = [1u8, 1, 2, 21, 34, 55]; + let ioslice = IoSlice::new(&send_buf); let write_count = s1 .send_with_fds(&[ioslice], &[]) .expect("failed to send data"); @@ -432,6 +487,19 @@ mod tests { assert_eq!(read_count, 6); assert_eq!(file_count, 0); assert_eq!(buf, [1, 1, 2, 21, 34, 55]); + + let write_count = s1 + .send_bufs_with_fds(&[&send_buf[..]], &[]) + .expect("failed to send data"); + + assert_eq!(write_count, 6); + let (read_count, file_count) = s2 + .recv_with_fds(&mut buf[..], &mut files) + .expect("failed to recv data"); + + assert_eq!(read_count, 6); + assert_eq!(file_count, 0); + assert_eq!(buf, [1, 1, 2, 21, 34, 55]); } #[test] diff --git a/sys_util/src/vsock.rs b/sys_util/src/vsock.rs new file mode 100644 index 000000000..a51dfdbfc --- /dev/null +++ b/sys_util/src/vsock.rs @@ -0,0 +1,495 @@ +// Copyright 2021 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +/// Support for virtual sockets. +use std::fmt; +use std::io; +use std::mem::{self, size_of}; +use std::num::ParseIntError; +use std::os::raw::{c_uchar, c_uint, c_ushort}; +use std::os::unix::io::{AsRawFd, IntoRawFd, RawFd}; +use std::result; +use std::str::FromStr; + +use libc::{ + self, c_void, sa_family_t, size_t, sockaddr, socklen_t, F_GETFL, F_SETFL, O_NONBLOCK, + VMADDR_CID_ANY, VMADDR_CID_HOST, VMADDR_CID_HYPERVISOR, +}; + +// The domain for vsock sockets. +const AF_VSOCK: sa_family_t = 40; + +// Vsock loopback address. +const VMADDR_CID_LOCAL: c_uint = 1; + +/// Vsock equivalent of binding on port 0. Binds to a random port. +pub const VMADDR_PORT_ANY: c_uint = c_uint::max_value(); + +// The number of bytes of padding to be added to the sockaddr_vm struct. Taken directly +// from linux/vm_sockets.h. +const PADDING: usize = size_of::<sockaddr>() + - size_of::<sa_family_t>() + - size_of::<c_ushort>() + - (2 * size_of::<c_uint>()); + +#[repr(C)] +#[derive(Default)] +struct sockaddr_vm { + svm_family: sa_family_t, + svm_reserved1: c_ushort, + svm_port: c_uint, + svm_cid: c_uint, + svm_zero: [c_uchar; PADDING], +} + +#[derive(Debug)] +pub struct AddrParseError; + +impl fmt::Display for AddrParseError { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + write!(fmt, "failed to parse vsock address") + } +} + +/// The vsock equivalent of an IP address. +#[derive(Debug, Copy, Clone, Hash, Eq, PartialEq)] +pub enum VsockCid { + /// Vsock equivalent of INADDR_ANY. Indicates the context id of the current endpoint. + Any, + /// An address that refers to the bare-metal machine that serves as the hypervisor. + Hypervisor, + /// The loopback address. + Local, + /// The parent machine. It may not be the hypervisor for nested VMs. + Host, + /// An assigned CID that serves as the address for VSOCK. + Cid(c_uint), +} + +impl fmt::Display for VsockCid { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + match &self { + VsockCid::Any => write!(fmt, "Any"), + VsockCid::Hypervisor => write!(fmt, "Hypervisor"), + VsockCid::Local => write!(fmt, "Local"), + VsockCid::Host => write!(fmt, "Host"), + VsockCid::Cid(c) => write!(fmt, "'{}'", c), + } + } +} + +impl From<c_uint> for VsockCid { + fn from(c: c_uint) -> Self { + match c { + VMADDR_CID_ANY => VsockCid::Any, + VMADDR_CID_HYPERVISOR => VsockCid::Hypervisor, + VMADDR_CID_LOCAL => VsockCid::Local, + VMADDR_CID_HOST => VsockCid::Host, + _ => VsockCid::Cid(c), + } + } +} + +impl FromStr for VsockCid { + type Err = ParseIntError; + + fn from_str(s: &str) -> Result<Self, Self::Err> { + let c: c_uint = s.parse()?; + Ok(c.into()) + } +} + +impl From<VsockCid> for c_uint { + fn from(cid: VsockCid) -> c_uint { + match cid { + VsockCid::Any => VMADDR_CID_ANY, + VsockCid::Hypervisor => VMADDR_CID_HYPERVISOR, + VsockCid::Local => VMADDR_CID_LOCAL, + VsockCid::Host => VMADDR_CID_HOST, + VsockCid::Cid(c) => c, + } + } +} + +/// An address associated with a virtual socket. +#[derive(Debug, Copy, Clone, Hash, Eq, PartialEq)] +pub struct SocketAddr { + pub cid: VsockCid, + pub port: c_uint, +} + +pub trait ToSocketAddr { + fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError>; +} + +impl ToSocketAddr for SocketAddr { + fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> { + Ok(*self) + } +} + +impl ToSocketAddr for str { + fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> { + self.parse() + } +} + +impl ToSocketAddr for (VsockCid, c_uint) { + fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> { + let (cid, port) = *self; + Ok(SocketAddr { cid, port }) + } +} + +impl<'a, T: ToSocketAddr + ?Sized> ToSocketAddr for &'a T { + fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> { + (**self).to_socket_addr() + } +} + +impl FromStr for SocketAddr { + type Err = AddrParseError; + + /// Parse a vsock SocketAddr from a string. vsock socket addresses are of the form + /// "vsock:cid:port". + fn from_str(s: &str) -> Result<SocketAddr, AddrParseError> { + let components: Vec<&str> = s.split(':').collect(); + if components.len() != 3 || components[0] != "vsock" { + return Err(AddrParseError); + } + + Ok(SocketAddr { + cid: components[1].parse().map_err(|_| AddrParseError)?, + port: components[2].parse().map_err(|_| AddrParseError)?, + }) + } +} + +impl fmt::Display for SocketAddr { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + write!(fmt, "{}:{}", self.cid, self.port) + } +} + +/// Sets `fd` to be blocking or nonblocking. `fd` must be a valid fd of a type that accepts the +/// `O_NONBLOCK` flag. This includes regular files, pipes, and sockets. +unsafe fn set_nonblocking(fd: RawFd, nonblocking: bool) -> io::Result<()> { + let flags = libc::fcntl(fd, F_GETFL, 0); + if flags < 0 { + return Err(io::Error::last_os_error()); + } + + let flags = if nonblocking { + flags | O_NONBLOCK + } else { + flags & !O_NONBLOCK + }; + + let ret = libc::fcntl(fd, F_SETFL, flags); + if ret < 0 { + return Err(io::Error::last_os_error()); + } + + Ok(()) +} + +/// A virtual socket. +/// +/// Do not use this class unless you need to change socket options or query the +/// state of the socket prior to calling listen or connect. Instead use either VsockStream or +/// VsockListener. +#[derive(Debug)] +pub struct VsockSocket { + fd: RawFd, +} + +impl VsockSocket { + pub fn new() -> io::Result<Self> { + let fd = unsafe { libc::socket(libc::AF_VSOCK, libc::SOCK_STREAM | libc::SOCK_CLOEXEC, 0) }; + if fd < 0 { + Err(io::Error::last_os_error()) + } else { + Ok(VsockSocket { fd }) + } + } + + pub fn bind<A: ToSocketAddr>(&mut self, addr: A) -> io::Result<()> { + let sockaddr = addr + .to_socket_addr() + .map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?; + + // The compiler should optimize this out since these are both compile-time constants. + assert_eq!(size_of::<sockaddr_vm>(), size_of::<sockaddr>()); + + let svm = sockaddr_vm { + svm_family: AF_VSOCK, + svm_cid: sockaddr.cid.into(), + svm_port: sockaddr.port, + ..Default::default() + }; + + // Safe because this doesn't modify any memory and we check the return value. + let ret = unsafe { + libc::bind( + self.fd, + &svm as *const sockaddr_vm as *const sockaddr, + size_of::<sockaddr_vm>() as socklen_t, + ) + }; + if ret < 0 { + let bind_err = io::Error::last_os_error(); + Err(bind_err) + } else { + Ok(()) + } + } + + pub fn connect<A: ToSocketAddr>(self, addr: A) -> io::Result<VsockStream> { + let sockaddr = addr + .to_socket_addr() + .map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?; + + let svm = sockaddr_vm { + svm_family: AF_VSOCK, + svm_cid: sockaddr.cid.into(), + svm_port: sockaddr.port, + ..Default::default() + }; + + // Safe because this just connects a vsock socket, and the return value is checked. + let ret = unsafe { + libc::connect( + self.fd, + &svm as *const sockaddr_vm as *const sockaddr, + size_of::<sockaddr_vm>() as socklen_t, + ) + }; + if ret < 0 { + let connect_err = io::Error::last_os_error(); + Err(connect_err) + } else { + Ok(VsockStream { sock: self }) + } + } + + pub fn listen(self) -> io::Result<VsockListener> { + // Safe because this doesn't modify any memory and we check the return value. + let ret = unsafe { libc::listen(self.fd, 1) }; + if ret < 0 { + let listen_err = io::Error::last_os_error(); + return Err(listen_err); + } + Ok(VsockListener { sock: self }) + } + + /// Returns the port that this socket is bound to. This can only succeed after bind is called. + pub fn local_port(&self) -> io::Result<u32> { + let mut svm: sockaddr_vm = Default::default(); + + // Safe because we give a valid pointer for addrlen and check the length. + let mut addrlen = size_of::<sockaddr_vm>() as socklen_t; + let ret = unsafe { + // Get the socket address that was actually bound. + libc::getsockname( + self.fd, + &mut svm as *mut sockaddr_vm as *mut sockaddr, + &mut addrlen as *mut socklen_t, + ) + }; + if ret < 0 { + let getsockname_err = io::Error::last_os_error(); + Err(getsockname_err) + } else { + // If this doesn't match, it's not safe to get the port out of the sockaddr. + assert_eq!(addrlen as usize, size_of::<sockaddr_vm>()); + + Ok(svm.svm_port) + } + } + + pub fn try_clone(&self) -> io::Result<Self> { + // Safe because this doesn't modify any memory and we check the return value. + let dup_fd = unsafe { libc::fcntl(self.fd, libc::F_DUPFD_CLOEXEC, 0) }; + if dup_fd < 0 { + Err(io::Error::last_os_error()) + } else { + Ok(Self { fd: dup_fd }) + } + } + + pub fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> { + // Safe because the fd is valid and owned by this stream. + unsafe { set_nonblocking(self.fd, nonblocking) } + } +} + +impl IntoRawFd for VsockSocket { + fn into_raw_fd(self) -> RawFd { + let fd = self.fd; + mem::forget(self); + fd + } +} + +impl AsRawFd for VsockSocket { + fn as_raw_fd(&self) -> RawFd { + self.fd + } +} + +impl Drop for VsockSocket { + fn drop(&mut self) { + // Safe because this doesn't modify any memory and we are the only + // owner of the file descriptor. + unsafe { libc::close(self.fd) }; + } +} + +/// A virtual stream socket. +#[derive(Debug)] +pub struct VsockStream { + sock: VsockSocket, +} + +impl VsockStream { + pub fn connect<A: ToSocketAddr>(addr: A) -> io::Result<VsockStream> { + let sock = VsockSocket::new()?; + sock.connect(addr) + } + + /// Returns the port that this stream is bound to. + pub fn local_port(&self) -> io::Result<u32> { + self.sock.local_port() + } + + pub fn try_clone(&self) -> io::Result<VsockStream> { + self.sock.try_clone().map(|f| VsockStream { sock: f }) + } + + pub fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> { + self.sock.set_nonblocking(nonblocking) + } +} + +impl io::Read for VsockStream { + fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { + // Safe because this will only modify the contents of |buf| and we check the return value. + let ret = unsafe { + libc::read( + self.sock.as_raw_fd(), + buf as *mut [u8] as *mut c_void, + buf.len() as size_t, + ) + }; + if ret < 0 { + return Err(io::Error::last_os_error()); + } + + Ok(ret as usize) + } +} + +impl io::Write for VsockStream { + fn write(&mut self, buf: &[u8]) -> io::Result<usize> { + // Safe because this doesn't modify any memory and we check the return value. + let ret = unsafe { + libc::write( + self.sock.as_raw_fd(), + buf as *const [u8] as *const c_void, + buf.len() as size_t, + ) + }; + if ret < 0 { + return Err(io::Error::last_os_error()); + } + + Ok(ret as usize) + } + + fn flush(&mut self) -> io::Result<()> { + // No buffered data so nothing to do. + Ok(()) + } +} + +impl AsRawFd for VsockStream { + fn as_raw_fd(&self) -> RawFd { + self.sock.as_raw_fd() + } +} + +impl IntoRawFd for VsockStream { + fn into_raw_fd(self) -> RawFd { + self.sock.into_raw_fd() + } +} + +/// Represents a virtual socket server. +#[derive(Debug)] +pub struct VsockListener { + sock: VsockSocket, +} + +impl VsockListener { + /// Creates a new `VsockListener` bound to the specified port on the current virtual socket + /// endpoint. + pub fn bind<A: ToSocketAddr>(addr: A) -> io::Result<VsockListener> { + let mut sock = VsockSocket::new()?; + sock.bind(addr)?; + sock.listen() + } + + /// Returns the port that this listener is bound to. + pub fn local_port(&self) -> io::Result<u32> { + self.sock.local_port() + } + + /// Accepts a new incoming connection on this listener. Blocks the calling thread until a + /// new connection is established. When established, returns the corresponding `VsockStream` + /// and the remote peer's address. + pub fn accept(&self) -> io::Result<(VsockStream, SocketAddr)> { + let mut svm: sockaddr_vm = Default::default(); + + // Safe because this will only modify |svm| and we check the return value. + let mut socklen: socklen_t = size_of::<sockaddr_vm>() as socklen_t; + let fd = unsafe { + libc::accept4( + self.sock.as_raw_fd(), + &mut svm as *mut sockaddr_vm as *mut sockaddr, + &mut socklen as *mut socklen_t, + libc::SOCK_CLOEXEC, + ) + }; + if fd < 0 { + return Err(io::Error::last_os_error()); + } + + if svm.svm_family != AF_VSOCK { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("unexpected address family: {}", svm.svm_family), + )); + } + + Ok(( + VsockStream { + sock: VsockSocket { fd }, + }, + SocketAddr { + cid: svm.svm_cid.into(), + port: svm.svm_port, + }, + )) + } + + pub fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> { + self.sock.set_nonblocking(nonblocking) + } +} + +impl AsRawFd for VsockListener { + fn as_raw_fd(&self) -> RawFd { + self.sock.as_raw_fd() + } +} diff --git a/usb_util/src/device.rs b/usb_util/src/device.rs index b30d54139..47cad2130 100644 --- a/usb_util/src/device.rs +++ b/usb_util/src/device.rs @@ -7,7 +7,7 @@ use crate::{ ControlRequestDataPhaseTransferDirection, ControlRequestRecipient, ControlRequestType, DeviceDescriptor, DeviceDescriptorTree, Error, Result, StandardControlRequest, }; -use base::{handle_eintr_errno, IoctlNr}; +use base::{handle_eintr_errno, AsRawDescriptor, IoctlNr, RawDescriptor}; use data_model::vec_with_array_field; use libc::{EAGAIN, ENODEV, ENOENT}; use std::convert::TryInto; @@ -320,6 +320,12 @@ impl Device { } } +impl AsRawDescriptor for Device { + fn as_raw_descriptor(&self) -> RawDescriptor { + self.fd.as_raw_descriptor() + } +} + impl Transfer { fn urb(&self) -> &usb_sys::usbdevfs_urb { // self.urb is a Vec created with `vec_with_array_field`; the first entry is diff --git a/vfio_sys/src/lib.rs b/vfio_sys/src/lib.rs index 667feb553..b85233866 100644 --- a/vfio_sys/src/lib.rs +++ b/vfio_sys/src/lib.rs @@ -8,7 +8,9 @@ use base::ioctl_io_nr; +pub mod plat; pub mod vfio; +pub use crate::plat::*; pub use crate::vfio::*; ioctl_io_nr!(VFIO_GET_API_VERSION, VFIO_TYPE, VFIO_BASE); @@ -37,3 +39,9 @@ ioctl_io_nr!(VFIO_IOMMU_MAP_DMA, VFIO_TYPE, VFIO_BASE + 13); ioctl_io_nr!(VFIO_IOMMU_UNMAP_DMA, VFIO_TYPE, VFIO_BASE + 14); ioctl_io_nr!(VFIO_IOMMU_ENABLE, VFIO_TYPE, VFIO_BASE + 15); ioctl_io_nr!(VFIO_IOMMU_DISABLE, VFIO_TYPE, VFIO_BASE + 16); + +ioctl_io_nr!( + PLAT_IRQ_FORWARD_SET, + PLAT_IRQ_FORWARD_TYPE, + PLAT_IRQ_FORWARD_BASE +); diff --git a/vfio_sys/src/plat.rs b/vfio_sys/src/plat.rs new file mode 100644 index 000000000..7a2e633d1 --- /dev/null +++ b/vfio_sys/src/plat.rs @@ -0,0 +1,133 @@ +// Copyright 2021 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +/* + * automatically generated by rust-bindgen 0.56.0 + * bindgen --constified-enum '*' --with-derive-default --no-doc-comments --no-layout-tests + */ + +#[repr(C)] +#[derive(Default)] +pub struct __IncompleteArrayField<T>(::std::marker::PhantomData<T>, [T; 0]); +impl<T> __IncompleteArrayField<T> { + #[inline] + pub const fn new() -> Self { + __IncompleteArrayField(::std::marker::PhantomData, []) + } + #[inline] + pub fn as_ptr(&self) -> *const T { + self as *const _ as *const T + } + #[inline] + pub fn as_mut_ptr(&mut self) -> *mut T { + self as *mut _ as *mut T + } + #[inline] + pub unsafe fn as_slice(&self, len: usize) -> &[T] { + ::std::slice::from_raw_parts(self.as_ptr(), len) + } + #[inline] + pub unsafe fn as_mut_slice(&mut self, len: usize) -> &mut [T] { + ::std::slice::from_raw_parts_mut(self.as_mut_ptr(), len) + } +} +impl<T> ::std::fmt::Debug for __IncompleteArrayField<T> { + fn fmt(&self, fmt: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + fmt.write_str("__IncompleteArrayField") + } +} +pub const _IOC_NRBITS: u32 = 8; +pub const _IOC_TYPEBITS: u32 = 8; +pub const _IOC_SIZEBITS: u32 = 14; +pub const _IOC_DIRBITS: u32 = 2; +pub const _IOC_NRMASK: u32 = 255; +pub const _IOC_TYPEMASK: u32 = 255; +pub const _IOC_SIZEMASK: u32 = 16383; +pub const _IOC_DIRMASK: u32 = 3; +pub const _IOC_NRSHIFT: u32 = 0; +pub const _IOC_TYPESHIFT: u32 = 8; +pub const _IOC_SIZESHIFT: u32 = 16; +pub const _IOC_DIRSHIFT: u32 = 30; +pub const _IOC_NONE: u32 = 0; +pub const _IOC_WRITE: u32 = 1; +pub const _IOC_READ: u32 = 2; +pub const IOC_IN: u32 = 1073741824; +pub const IOC_OUT: u32 = 2147483648; +pub const IOC_INOUT: u32 = 3221225472; +pub const IOCSIZE_MASK: u32 = 1073676288; +pub const IOCSIZE_SHIFT: u32 = 16; +pub const __BITS_PER_LONG: u32 = 64; +pub const __FD_SETSIZE: u32 = 1024; +pub const PLAT_IRQ_FORWARD_API_VERSION: u32 = 0; +pub const PLAT_IRQ_FORWARD_TYPE: u32 = 59; +pub const PLAT_IRQ_FORWARD_BASE: u32 = 100; +pub const PLAT_IRQ_FORWARD_SET_LEVEL_TRIGGER_EVENTFD: u32 = 1; +pub const PLAT_IRQ_FORWARD_SET_LEVEL_UNMASK_EVENTFD: u32 = 2; +pub const PLAT_IRQ_FORWARD_SET_EDGE_TRIGGER: u32 = 4; +pub type __s8 = ::std::os::raw::c_schar; +pub type __u8 = ::std::os::raw::c_uchar; +pub type __s16 = ::std::os::raw::c_short; +pub type __u16 = ::std::os::raw::c_ushort; +pub type __s32 = ::std::os::raw::c_int; +pub type __u32 = ::std::os::raw::c_uint; +pub type __s64 = ::std::os::raw::c_longlong; +pub type __u64 = ::std::os::raw::c_ulonglong; +#[repr(C)] +#[derive(Debug, Default, Copy, Clone)] +pub struct __kernel_fd_set { + pub fds_bits: [::std::os::raw::c_ulong; 16usize], +} +pub type __kernel_sighandler_t = + ::std::option::Option<unsafe extern "C" fn(arg1: ::std::os::raw::c_int)>; +pub type __kernel_key_t = ::std::os::raw::c_int; +pub type __kernel_mqd_t = ::std::os::raw::c_int; +pub type __kernel_old_uid_t = ::std::os::raw::c_ushort; +pub type __kernel_old_gid_t = ::std::os::raw::c_ushort; +pub type __kernel_old_dev_t = ::std::os::raw::c_ulong; +pub type __kernel_long_t = ::std::os::raw::c_long; +pub type __kernel_ulong_t = ::std::os::raw::c_ulong; +pub type __kernel_ino_t = __kernel_ulong_t; +pub type __kernel_mode_t = ::std::os::raw::c_uint; +pub type __kernel_pid_t = ::std::os::raw::c_int; +pub type __kernel_ipc_pid_t = ::std::os::raw::c_int; +pub type __kernel_uid_t = ::std::os::raw::c_uint; +pub type __kernel_gid_t = ::std::os::raw::c_uint; +pub type __kernel_suseconds_t = __kernel_long_t; +pub type __kernel_daddr_t = ::std::os::raw::c_int; +pub type __kernel_uid32_t = ::std::os::raw::c_uint; +pub type __kernel_gid32_t = ::std::os::raw::c_uint; +pub type __kernel_size_t = __kernel_ulong_t; +pub type __kernel_ssize_t = __kernel_long_t; +pub type __kernel_ptrdiff_t = __kernel_long_t; +#[repr(C)] +#[derive(Debug, Default, Copy, Clone)] +pub struct __kernel_fsid_t { + pub val: [::std::os::raw::c_int; 2usize], +} +pub type __kernel_off_t = __kernel_long_t; +pub type __kernel_loff_t = ::std::os::raw::c_longlong; +pub type __kernel_time_t = __kernel_long_t; +pub type __kernel_clock_t = __kernel_long_t; +pub type __kernel_timer_t = ::std::os::raw::c_int; +pub type __kernel_clockid_t = ::std::os::raw::c_int; +pub type __kernel_caddr_t = *mut ::std::os::raw::c_char; +pub type __kernel_uid16_t = ::std::os::raw::c_ushort; +pub type __kernel_gid16_t = ::std::os::raw::c_ushort; +pub type __le16 = __u16; +pub type __be16 = __u16; +pub type __le32 = __u32; +pub type __be32 = __u32; +pub type __le64 = __u64; +pub type __be64 = __u64; +pub type __sum16 = __u16; +pub type __wsum = __u32; +#[repr(C)] +#[derive(Debug, Default)] +pub struct plat_irq_forward_set { + pub argsz: __u32, + pub action_flags: __u32, + pub irq_number_host: __u32, + pub count: __u32, + pub eventfd: __IncompleteArrayField<__u8>, +} diff --git a/vhost/src/lib.rs b/vhost/src/lib.rs index 1af9af6a9..618c5ac47 100644 --- a/vhost/src/lib.rs +++ b/vhost/src/lib.rs @@ -132,7 +132,7 @@ pub trait Vhost: AsRawDescriptor + std::marker::Sized { let _ = self .mem() - .with_regions::<_, ()>(|index, guest_addr, size, host_addr, _| { + .with_regions::<_, ()>(|index, guest_addr, size, host_addr, _, _| { vhost_regions[index] = virtio_sys::vhost_memory_region { guest_phys_addr: guest_addr.offset() as u64, memory_size: size as u64, @@ -341,7 +341,7 @@ mod tests { use crate::net::fakes::FakeNet; use net_util::fakes::FakeTap; - use std::result; + use std::{path::PathBuf, result}; use vm_memory::{GuestAddress, GuestMemory, GuestMemoryError}; fn create_guest_memory() -> result::Result<GuestMemory, GuestMemoryError> { @@ -361,7 +361,7 @@ mod tests { fn create_fake_vhost_net() -> FakeNet<FakeTap> { let gm = create_guest_memory().unwrap(); - FakeNet::<FakeTap>::new(&gm).unwrap() + FakeNet::<FakeTap>::new(&PathBuf::from(""), &gm).unwrap() } #[test] diff --git a/vhost/src/net.rs b/vhost/src/net.rs index df6de95ce..d0e57e710 100644 --- a/vhost/src/net.rs +++ b/vhost/src/net.rs @@ -3,17 +3,18 @@ // found in the LICENSE file. use net_util::TapT; -use std::fs::{File, OpenOptions}; use std::marker::PhantomData; use std::os::unix::fs::OpenOptionsExt; +use std::{ + fs::{File, OpenOptions}, + path::PathBuf, +}; use base::{ioctl_with_ref, AsRawDescriptor, RawDescriptor}; use vm_memory::GuestMemory; use super::{ioctl_result, Error, Result, Vhost}; -static DEVICE: &str = "/dev/vhost-net"; - /// Handle to run VHOST_NET ioctls. /// /// This provides a simple wrapper around a VHOST_NET file descriptor and @@ -28,7 +29,7 @@ pub struct Net<T> { pub trait NetT<T: TapT>: Vhost + AsRawDescriptor + Send + Sized { /// Create a new NetT instance - fn new(mem: &GuestMemory) -> Result<Self>; + fn new(vhost_net_device_path: &PathBuf, mem: &GuestMemory) -> Result<Self>; /// Set the tap file descriptor that will serve as the VHOST_NET backend. /// This will start the vhost worker for the given queue. @@ -47,13 +48,13 @@ where /// /// # Arguments /// * `mem` - Guest memory mapping. - fn new(mem: &GuestMemory) -> Result<Net<T>> { + fn new(vhost_net_device_path: &PathBuf, mem: &GuestMemory) -> Result<Net<T>> { Ok(Net::<T> { descriptor: OpenOptions::new() .read(true) .write(true) .custom_flags(libc::O_CLOEXEC | libc::O_NONBLOCK) - .open(DEVICE) + .open(vhost_net_device_path) .map_err(Error::VhostOpen)?, mem: mem.clone(), phantom: PhantomData, @@ -117,7 +118,7 @@ pub mod fakes { where T: TapT, { - fn new(mem: &GuestMemory) -> Result<FakeNet<T>> { + fn new(_vhost_net_device_path: &PathBuf, mem: &GuestMemory) -> Result<FakeNet<T>> { Ok(FakeNet::<T> { descriptor: OpenOptions::new() .read(true) diff --git a/vhost/src/vsock.rs b/vhost/src/vsock.rs index 95fbb72e5..fb9795137 100644 --- a/vhost/src/vsock.rs +++ b/vhost/src/vsock.rs @@ -2,8 +2,11 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -use std::fs::{File, OpenOptions}; use std::os::unix::fs::OpenOptionsExt; +use std::{ + fs::{File, OpenOptions}, + path::PathBuf, +}; use base::{ioctl_with_ref, AsRawDescriptor, RawDescriptor}; use virtio_sys::{VHOST_VSOCK_SET_GUEST_CID, VHOST_VSOCK_SET_RUNNING}; @@ -11,8 +14,6 @@ use vm_memory::GuestMemory; use super::{ioctl_result, Error, Result, Vhost}; -static DEVICE: &str = "/dev/vhost-vsock"; - /// Handle for running VHOST_VSOCK ioctls. pub struct Vsock { descriptor: File, @@ -21,13 +22,13 @@ pub struct Vsock { impl Vsock { /// Open a handle to a new VHOST_VSOCK instance. - pub fn new(mem: &GuestMemory) -> Result<Vsock> { + pub fn new(vhost_vsock_device_path: &PathBuf, mem: &GuestMemory) -> Result<Vsock> { Ok(Vsock { descriptor: OpenOptions::new() .read(true) .write(true) .custom_flags(libc::O_CLOEXEC | libc::O_NONBLOCK) - .open(DEVICE) + .open(vhost_vsock_device_path) .map_err(Error::VhostOpen)?, mem: mem.clone(), }) diff --git a/vm_control/Android.bp b/vm_control/Android.bp index 85cc738b0..18325c0b3 100644 --- a/vm_control/Android.bp +++ b/vm_control/Android.bp @@ -1,5 +1,4 @@ -// This file is generated by cargo2android.py --run --device --tests --dependencies --global_defaults=crosvm_defaults --add_workspace --features=gdb. -// NOTE: The --features=gdb should be applied only to the host (not the device) and there are inline changes to achieve this +// This file is generated by cargo2android.py --run --device --tests --dependencies --global_defaults=crosvm_defaults --add_workspace. package { // See: http://go/android-license-faq @@ -17,17 +16,6 @@ rust_library { crate_name: "vm_control", srcs: ["src/lib.rs"], edition: "2018", - target: { - linux_glibc_x86_64: { - features: [ - "gdb", - "gdbstub", - ], - rustlibs: [ - "libgdbstub", - ], - }, - }, rustlibs: [ "libbase_rust", "libdata_model", @@ -49,17 +37,6 @@ rust_defaults { test_suites: ["general-tests"], auto_gen_config: true, edition: "2018", - target: { - linux_glibc_x86_64: { - features: [ - "gdb", - "gdbstub", - ], - rustlibs: [ - "libgdbstub", - ], - }, - }, rustlibs: [ "libbase_rust", "libdata_model", @@ -76,6 +53,7 @@ rust_defaults { rust_test_host { name: "vm_control_host_test_src_lib", defaults: ["vm_control_defaults"], + shared_libs: ["libgfxstream_backend"], test_options: { unit_test: true, }, @@ -109,11 +87,8 @@ rust_test { // ../tempfile/src/lib.rs // ../vm_memory/src/lib.rs // async-task-4.0.3 "default,std" -// async-trait-0.1.48 -// autocfg-1.0.1 +// async-trait-0.1.45 // base-0.1.0 -// cfg-if-0.1.10 -// cfg-if-1.0.0 // downcast-rs-1.2.0 "default,std" // futures-0.3.13 "alloc,async-await,default,executor,futures-executor,std" // futures-channel-0.3.13 "alloc,futures-sink,sink,std" @@ -124,12 +99,8 @@ rust_test { // futures-sink-0.3.13 "alloc,std" // futures-task-0.3.13 "alloc,std" // futures-util-0.3.13 "alloc,async-await,async-await-macro,channel,futures-channel,futures-io,futures-macro,futures-sink,io,memchr,proc-macro-hack,proc-macro-nested,sink,slab,std" -// gdbstub-0.4.4 "alloc,default,std" -// libc-0.2.88 "default,std" -// log-0.4.14 -// managed-0.8.0 "alloc" +// libc-0.2.87 "default,std" // memchr-2.3.4 "default,std" -// num-traits-0.2.14 // paste-1.0.4 // pin-project-lite-0.2.6 // pin-utils-0.1.0 @@ -137,10 +108,10 @@ rust_test { // proc-macro-nested-0.1.7 // proc-macro2-1.0.24 "default,proc-macro" // quote-1.0.9 "default,proc-macro" -// serde-1.0.124 "default,derive,serde_derive,std" -// serde_derive-1.0.124 "default" +// serde-1.0.123 "default,derive,serde_derive,std" +// serde_derive-1.0.123 "default" // slab-0.4.2 -// syn-1.0.63 "clone-impls,default,derive,full,parsing,printing,proc-macro,quote,visit-mut" +// syn-1.0.61 "clone-impls,default,derive,full,parsing,printing,proc-macro,quote,visit-mut" // thiserror-1.0.24 // thiserror-impl-1.0.24 // unicode-xid-0.2.1 "default" diff --git a/vm_control/Cargo.toml b/vm_control/Cargo.toml index 3c59d9874..27b1066b1 100644 --- a/vm_control/Cargo.toml +++ b/vm_control/Cargo.toml @@ -8,13 +8,13 @@ edition = "2018" gdb = ["gdbstub"] [dependencies] +base = { path = "../base" } data_model = { path = "../data_model" } gdbstub = { version = "0.4.0", optional = true } hypervisor = { path = "../hypervisor" } libc = "*" -msg_socket = { path = "../msg_socket" } resources = { path = "../resources" } rutabaga_gfx = { path = "../rutabaga_gfx"} +serde = { version = "1", features = [ "derive" ] } sync = { path = "../sync" } -base = { path = "../base" } vm_memory = { path = "../vm_memory" } diff --git a/vm_control/src/client.rs b/vm_control/src/client.rs new file mode 100644 index 000000000..3fff3e32f --- /dev/null +++ b/vm_control/src/client.rs @@ -0,0 +1,211 @@ +// Copyright 2021 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. +use crate::*; +use base::{info, net::UnixSeqpacket, validate_raw_descriptor, RawDescriptor, Tube}; + +use std::fs::OpenOptions; +use std::num::ParseIntError; +use std::path::{Path, PathBuf}; + +enum ModifyBatError { + BatControlErr(BatControlResult), +} + +impl fmt::Display for ModifyBatError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + use self::ModifyBatError::*; + + match self { + BatControlErr(e) => write!(f, "{}", e), + } + } +} + +pub enum ModifyUsbError { + ArgMissing(&'static str), + ArgParse(&'static str, String), + ArgParseInt(&'static str, String, ParseIntError), + FailedDescriptorValidate(base::Error), + PathDoesNotExist(PathBuf), + SocketFailed, + UnexpectedResponse(VmResponse), + UnknownCommand(String), + UsbControl(UsbControlResult), +} + +impl std::fmt::Display for ModifyUsbError { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + use self::ModifyUsbError::*; + + match self { + ArgMissing(a) => write!(f, "argument missing: {}", a), + ArgParse(name, value) => { + write!(f, "failed to parse argument {} value `{}`", name, value) + } + ArgParseInt(name, value, e) => write!( + f, + "failed to parse integer argument {} value `{}`: {}", + name, value, e + ), + FailedDescriptorValidate(e) => write!(f, "failed to validate file descriptor: {}", e), + PathDoesNotExist(p) => write!(f, "path `{}` does not exist", p.display()), + SocketFailed => write!(f, "socket failed"), + UnexpectedResponse(r) => write!(f, "unexpected response: {}", r), + UnknownCommand(c) => write!(f, "unknown command: `{}`", c), + UsbControl(e) => write!(f, "{}", e), + } + } +} + +pub type ModifyUsbResult<T> = std::result::Result<T, ModifyUsbError>; + +fn raw_descriptor_from_path(path: &Path) -> ModifyUsbResult<RawDescriptor> { + if !path.exists() { + return Err(ModifyUsbError::PathDoesNotExist(path.to_owned())); + } + let raw_descriptor = path + .file_name() + .and_then(|fd_osstr| fd_osstr.to_str()) + .map_or( + Err(ModifyUsbError::ArgParse( + "USB_DEVICE_PATH", + path.to_string_lossy().into_owned(), + )), + |fd_str| { + fd_str.parse::<libc::c_int>().map_err(|e| { + ModifyUsbError::ArgParseInt("USB_DEVICE_PATH", fd_str.to_owned(), e) + }) + }, + )?; + validate_raw_descriptor(raw_descriptor).map_err(ModifyUsbError::FailedDescriptorValidate) +} + +pub type VmsRequestResult = std::result::Result<(), ()>; + +pub fn vms_request(request: &VmRequest, socket_path: &Path) -> VmsRequestResult { + let response = handle_request(request, socket_path)?; + info!("request response was {}", response); + Ok(()) +} + +pub fn do_usb_attach( + socket_path: &Path, + bus: u8, + addr: u8, + vid: u16, + pid: u16, + dev_path: &Path, +) -> ModifyUsbResult<UsbControlResult> { + let usb_file: File = if dev_path.parent() == Some(Path::new("/proc/self/fd")) { + // Special case '/proc/self/fd/*' paths. The FD is already open, just use it. + // Safe because we will validate |raw_fd|. + unsafe { File::from_raw_descriptor(raw_descriptor_from_path(&dev_path)?) } + } else { + OpenOptions::new() + .read(true) + .write(true) + .open(&dev_path) + .map_err(|_| ModifyUsbError::UsbControl(UsbControlResult::FailedToOpenDevice))? + }; + + let request = VmRequest::UsbCommand(UsbControlCommand::AttachDevice { + bus, + addr, + vid, + pid, + file: usb_file, + }); + let response = + handle_request(&request, socket_path).map_err(|_| ModifyUsbError::SocketFailed)?; + match response { + VmResponse::UsbResponse(usb_resp) => Ok(usb_resp), + r => Err(ModifyUsbError::UnexpectedResponse(r)), + } +} + +pub fn do_usb_detach(socket_path: &Path, port: u8) -> ModifyUsbResult<UsbControlResult> { + let request = VmRequest::UsbCommand(UsbControlCommand::DetachDevice { port }); + let response = + handle_request(&request, socket_path).map_err(|_| ModifyUsbError::SocketFailed)?; + match response { + VmResponse::UsbResponse(usb_resp) => Ok(usb_resp), + r => Err(ModifyUsbError::UnexpectedResponse(r)), + } +} + +pub fn do_usb_list(socket_path: &Path) -> ModifyUsbResult<UsbControlResult> { + let mut ports: [u8; USB_CONTROL_MAX_PORTS] = Default::default(); + for (index, port) in ports.iter_mut().enumerate() { + *port = index as u8 + } + let request = VmRequest::UsbCommand(UsbControlCommand::ListDevice { ports }); + let response = + handle_request(&request, socket_path).map_err(|_| ModifyUsbError::SocketFailed)?; + match response { + VmResponse::UsbResponse(usb_resp) => Ok(usb_resp), + r => Err(ModifyUsbError::UnexpectedResponse(r)), + } +} + +pub type DoModifyBatteryResult = std::result::Result<(), ()>; + +pub fn do_modify_battery( + socket_path: &Path, + battery_type: &str, + property: &str, + target: &str, +) -> DoModifyBatteryResult { + let response = match battery_type.parse::<BatteryType>() { + Ok(type_) => match BatControlCommand::new(property.to_string(), target.to_string()) { + Ok(cmd) => { + let request = VmRequest::BatCommand(type_, cmd); + Ok(handle_request(&request, socket_path)?) + } + Err(e) => Err(ModifyBatError::BatControlErr(e)), + }, + Err(e) => Err(ModifyBatError::BatControlErr(e)), + }; + + match response { + Ok(response) => { + println!("{}", response); + Ok(()) + } + Err(e) => { + println!("error {}", e); + Err(()) + } + } +} + +pub type HandleRequestResult = std::result::Result<VmResponse, ()>; + +pub fn handle_request(request: &VmRequest, socket_path: &Path) -> HandleRequestResult { + match UnixSeqpacket::connect(&socket_path) { + Ok(s) => { + let socket = Tube::new(s); + if let Err(e) = socket.send(request) { + error!( + "failed to send request to socket at '{:?}': {}", + socket_path, e + ); + return Err(()); + } + match socket.recv() { + Ok(response) => Ok(response), + Err(e) => { + error!( + "failed to send request to socket at '{:?}': {}", + socket_path, e + ); + Err(()) + } + } + } + Err(e) => { + error!("failed to connect to socket at '{:?}': {}", socket_path, e); + Err(()) + } + } +} diff --git a/vm_control/src/lib.rs b/vm_control/src/lib.rs index ce236284e..3222edf09 100644 --- a/vm_control/src/lib.rs +++ b/vm_control/src/lib.rs @@ -13,37 +13,41 @@ #[cfg(all(target_arch = "x86_64", feature = "gdb"))] pub mod gdb; +pub mod client; + use std::fmt::{self, Display}; use std::fs::File; -use std::mem::ManuallyDrop; use std::os::raw::c_int; use std::result::Result as StdResult; use std::str::FromStr; use std::sync::Arc; use libc::{EINVAL, EIO, ENODEV}; +use serde::{Deserialize, Serialize}; use base::{ - error, AsRawDescriptor, Error as SysError, Event, ExternalMapping, Fd, FromRawDescriptor, - IntoRawDescriptor, MappedRegion, MemoryMappingArena, MemoryMappingBuilder, MmapError, - Protection, RawDescriptor, Result, SafeDescriptor, + error, with_as_descriptor, AsRawDescriptor, Error as SysError, Event, ExternalMapping, Fd, + FromRawDescriptor, IntoRawDescriptor, MappedRegion, MemoryMappingArena, MemoryMappingBuilder, + MemoryMappingBuilderUnix, MmapError, Protection, Result, SafeDescriptor, SharedMemory, Tube, }; use hypervisor::{IrqRoute, IrqSource, Vm}; -use msg_socket::{MsgError, MsgOnSocket, MsgReceiver, MsgResult, MsgSender, MsgSocket}; use resources::{Alloc, MmioType, SystemAllocator}; -use rutabaga_gfx::{DrmFormat, ImageAllocationInfo, RutabagaGralloc, RutabagaGrallocFlags}; +use rutabaga_gfx::{ + DrmFormat, ImageAllocationInfo, RutabagaGralloc, RutabagaGrallocFlags, RutabagaHandle, + VulkanInfo, +}; use sync::Mutex; use vm_memory::GuestAddress; /// Struct that describes the offset and stride of a plane located in GPU memory. -#[derive(Clone, Copy, Debug, PartialEq, Default, MsgOnSocket)] +#[derive(Clone, Copy, Debug, PartialEq, Default, Serialize, Deserialize)] pub struct GpuMemoryPlaneDesc { pub stride: u32, pub offset: u32, } /// Struct that describes a GPU memory allocation that consists of up to 3 planes. -#[derive(Clone, Copy, Debug, Default, MsgOnSocket)] +#[derive(Clone, Copy, Debug, Default, Serialize, Deserialize)] pub struct GpuMemoryDesc { pub planes: [GpuMemoryPlaneDesc; 3], } @@ -60,66 +64,6 @@ pub enum VcpuControl { RunState(VmRunMode), } -/// A file descriptor either borrowed or owned by this. -#[derive(Debug)] -pub enum MaybeOwnedDescriptor { - /// Owned by this enum variant, and will be destructed automatically if not moved out. - Owned(SafeDescriptor), - /// A file descriptor borrwed by this enum. - Borrowed(RawDescriptor), -} - -impl AsRawDescriptor for MaybeOwnedDescriptor { - fn as_raw_descriptor(&self) -> RawDescriptor { - match self { - MaybeOwnedDescriptor::Owned(f) => f.as_raw_descriptor(), - MaybeOwnedDescriptor::Borrowed(descriptor) => *descriptor, - } - } -} - -impl AsRawDescriptor for &MaybeOwnedDescriptor { - fn as_raw_descriptor(&self) -> RawDescriptor { - match self { - MaybeOwnedDescriptor::Owned(f) => f.as_raw_descriptor(), - MaybeOwnedDescriptor::Borrowed(descriptor) => *descriptor, - } - } -} - -// When sent, it could be owned or borrowed. On the receiver end, it always owned. -impl MsgOnSocket for MaybeOwnedDescriptor { - fn uses_descriptor() -> bool { - true - } - fn fixed_size() -> Option<usize> { - Some(0) - } - fn descriptor_count(&self) -> usize { - 1usize - } - unsafe fn read_from_buffer( - buffer: &[u8], - descriptors: &[RawDescriptor], - ) -> MsgResult<(Self, usize)> { - let (file, size) = File::read_from_buffer(buffer, descriptors)?; - let safe_descriptor = SafeDescriptor::from_raw_descriptor(file.into_raw_descriptor()); - Ok((MaybeOwnedDescriptor::Owned(safe_descriptor), size)) - } - fn write_to_buffer( - &self, - _buffer: &mut [u8], - descriptors: &mut [RawDescriptor], - ) -> MsgResult<usize> { - if descriptors.is_empty() { - return Err(MsgError::WrongDescriptorBufferSize); - } - - descriptors[0] = self.as_raw_descriptor(); - Ok(1) - } -} - /// Mode of execution for the VM. #[derive(Debug, Clone, PartialEq)] pub enum VmRunMode { @@ -159,7 +103,7 @@ impl Default for VmRunMode { /// require adding a big dependency for a single const. pub const USB_CONTROL_MAX_PORTS: usize = 16; -#[derive(MsgOnSocket, Debug)] +#[derive(Serialize, Deserialize, Debug)] pub enum BalloonControlCommand { /// Set the size of the VM's balloon. Adjust { @@ -169,7 +113,7 @@ pub enum BalloonControlCommand { } // BalloonStats holds stats returned from the stats_queue. -#[derive(Default, MsgOnSocket, Debug)] +#[derive(Default, Serialize, Deserialize, Debug)] pub struct BalloonStats { pub swap_in: Option<u64>, pub swap_out: Option<u64>, @@ -220,7 +164,7 @@ impl Display for BalloonStats { } } -#[derive(MsgOnSocket, Debug)] +#[derive(Serialize, Deserialize, Debug)] pub enum BalloonControlResult { Stats { stats: BalloonStats, @@ -228,7 +172,7 @@ pub enum BalloonControlResult { }, } -#[derive(MsgOnSocket, Debug)] +#[derive(Serialize, Deserialize, Debug)] pub enum DiskControlCommand { /// Resize a disk to `new_size` in bytes. Resize { new_size: u64 }, @@ -244,20 +188,21 @@ impl Display for DiskControlCommand { } } -#[derive(MsgOnSocket, Debug)] +#[derive(Serialize, Deserialize, Debug)] pub enum DiskControlResult { Ok, Err(SysError), } -#[derive(MsgOnSocket, Debug)] +#[derive(Serialize, Deserialize, Debug)] pub enum UsbControlCommand { AttachDevice { bus: u8, addr: u8, vid: u16, pid: u16, - descriptor: Option<MaybeOwnedDescriptor>, + #[serde(with = "with_as_descriptor")] + file: File, }, DetachDevice { port: u8, @@ -267,7 +212,7 @@ pub enum UsbControlCommand { }, } -#[derive(MsgOnSocket, Copy, Clone, Debug, Default)] +#[derive(Serialize, Deserialize, Copy, Clone, Debug, Default)] pub struct UsbControlAttachedDevice { pub port: u8, pub vendor_id: u16, @@ -275,12 +220,12 @@ pub struct UsbControlAttachedDevice { } impl UsbControlAttachedDevice { - fn valid(self) -> bool { + pub fn valid(self) -> bool { self.port != 0 } } -#[derive(MsgOnSocket, Debug)] +#[derive(Serialize, Deserialize, Debug)] pub enum UsbControlResult { Ok { port: u8 }, NoAvailablePort, @@ -295,7 +240,7 @@ impl Display for UsbControlResult { use self::UsbControlResult::*; match self { - Ok { port } => write!(f, "ok {}", port), + UsbControlResult::Ok { port } => write!(f, "ok {}", port), NoAvailablePort => write!(f, "no_available_port"), NoSuchDevice => write!(f, "no_such_device"), NoSuchPort => write!(f, "no_such_port"), @@ -311,16 +256,27 @@ impl Display for UsbControlResult { } } -#[derive(MsgOnSocket, Debug)] +#[derive(Serialize, Deserialize)] pub enum VmMemoryRequest { /// Register shared memory represented by the given descriptor into guest address space. /// The response variant is `VmResponse::RegisterMemory`. - RegisterMemory(MaybeOwnedDescriptor, usize), + RegisterMemory(SharedMemory), /// Similiar to `VmMemoryRequest::RegisterMemory`, but doesn't allocate new address space. /// Useful for cases where the address space is already allocated (PCI regions). - RegisterFdAtPciBarOffset(Alloc, MaybeOwnedDescriptor, usize, u64), + RegisterFdAtPciBarOffset(Alloc, SafeDescriptor, usize, u64), /// Similar to RegisterFdAtPciBarOffset, but is for buffers in the current address space. RegisterHostPointerAtPciBarOffset(Alloc, u64), + /// Similiar to `RegisterFdAtPciBarOffset`, but uses Vulkano to map the resource instead of + /// the mmap system call. + RegisterVulkanMemoryAtPciBarOffset { + alloc: Alloc, + descriptor: SafeDescriptor, + handle_type: u32, + memory_idx: u32, + physical_device_idx: u32, + offset: u64, + size: u64, + }, /// Unregister the given memory slot that was previously registered with `RegisterMemory*`. UnregisterMemory(MemSlot), /// Allocate GPU buffer of a given size/format and register the memory into guest address space. @@ -332,7 +288,7 @@ pub enum VmMemoryRequest { }, /// Register mmaped memory into the hypervisor's EPT. RegisterMmapMemory { - descriptor: MaybeOwnedDescriptor, + descriptor: SafeDescriptor, size: usize, offset: u64, gpa: u64, @@ -350,16 +306,16 @@ impl VmMemoryRequest { /// `VmMemoryResponse` with the intended purpose of sending the response back over the socket /// that received this `VmMemoryResponse`. pub fn execute( - &self, + self, vm: &mut impl Vm, sys_allocator: &mut SystemAllocator, map_request: Arc<Mutex<Option<ExternalMapping>>>, gralloc: &mut RutabagaGralloc, ) -> VmMemoryResponse { use self::VmMemoryRequest::*; - match *self { - RegisterMemory(ref descriptor, size) => { - match register_memory(vm, sys_allocator, descriptor, size, None) { + match self { + RegisterMemory(ref shm) => { + match register_memory(vm, sys_allocator, shm, shm.size() as usize, None) { Ok((pfn, slot)) => VmMemoryResponse::RegisterMemory { pfn, slot }, Err(e) => VmMemoryResponse::Err(e), } @@ -381,7 +337,39 @@ impl VmMemoryRequest { .ok_or_else(|| VmMemoryResponse::Err(SysError::new(EINVAL))) .unwrap(); - match register_memory_hva(vm, sys_allocator, Box::new(mem), (alloc, offset)) { + match register_host_pointer(vm, sys_allocator, Box::new(mem), (alloc, offset)) { + Ok((pfn, slot)) => VmMemoryResponse::RegisterMemory { pfn, slot }, + Err(e) => VmMemoryResponse::Err(e), + } + } + RegisterVulkanMemoryAtPciBarOffset { + alloc, + descriptor, + handle_type, + memory_idx, + physical_device_idx, + offset, + size, + } => { + let mapped_region = match gralloc.import_and_map( + RutabagaHandle { + os_handle: descriptor, + handle_type, + }, + VulkanInfo { + memory_idx, + physical_device_idx, + }, + size, + ) { + Ok(mapped_region) => mapped_region, + Err(e) => { + error!("gralloc failed to import and map: {}", e); + return VmMemoryResponse::Err(SysError::new(EINVAL)); + } + }; + + match register_host_pointer(vm, sys_allocator, mapped_region, (alloc, offset)) { Ok((pfn, slot)) => VmMemoryResponse::RegisterMemory { pfn, slot }, Err(e) => VmMemoryResponse::Err(e), } @@ -438,11 +426,11 @@ impl VmMemoryRequest { Ok((pfn, slot)) => VmMemoryResponse::AllocateAndRegisterGpuMemory { // Safe because ownership is transferred to SafeDescriptor via // into_raw_descriptor - descriptor: MaybeOwnedDescriptor::Owned(unsafe { + descriptor: unsafe { SafeDescriptor::from_raw_descriptor( handle.os_handle.into_raw_descriptor(), ) - }), + }, pfn, slot, desc, @@ -473,7 +461,7 @@ impl VmMemoryRequest { } } -#[derive(MsgOnSocket, Debug)] +#[derive(Serialize, Deserialize, Debug)] pub enum VmMemoryResponse { /// The request to register memory into guest address space was successfully done at page frame /// number `pfn` and memory slot number `slot`. @@ -484,7 +472,7 @@ pub enum VmMemoryResponse { /// The request to allocate and register GPU memory into guest address space was successfully /// done at page frame number `pfn` and memory slot number `slot` for buffer with `desc`. AllocateAndRegisterGpuMemory { - descriptor: MaybeOwnedDescriptor, + descriptor: SafeDescriptor, pfn: u64, slot: MemSlot, desc: GpuMemoryDesc, @@ -493,10 +481,10 @@ pub enum VmMemoryResponse { Err(SysError), } -#[derive(MsgOnSocket, Debug)] +#[derive(Serialize, Deserialize, Debug)] pub enum VmIrqRequest { /// Allocate one gsi, and associate gsi to irqfd with register_irqfd() - AllocateOneMsi { irqfd: MaybeOwnedDescriptor }, + AllocateOneMsi { irqfd: Event }, /// Add one msi route entry into the IRQ chip. AddMsiRoute { gsi: u32, @@ -530,19 +518,7 @@ impl VmIrqRequest { match *self { AllocateOneMsi { ref irqfd } => { if let Some(irq_num) = sys_allocator.allocate_irq() { - // Because of the limitation of `MaybeOwnedDescriptor` not fitting into - // `register_irqfd` which expects an `&Event`, we use the unsafe `from_raw_fd` - // to assume that the descriptor given is an `Event`, and we ignore the - // ownership question using `ManuallyDrop`. This is safe because `ManuallyDrop` - // prevents any Drop implementation from triggering on `irqfd` which already has - // an owner, and the `Event` methods are never called. The underlying descriptor - // is merely passed to the kernel which doesn't care about ownership and deals - // with incorrect FDs, in the case of bugs on our part. - let evt = unsafe { - ManuallyDrop::new(Event::from_raw_descriptor(irqfd.as_raw_descriptor())) - }; - - match set_up_irq(IrqSetup::Event(irq_num, &evt)) { + match set_up_irq(IrqSetup::Event(irq_num, &irqfd)) { Ok(_) => VmIrqResponse::AllocateOneMsi { gsi: irq_num }, Err(e) => VmIrqResponse::Err(e), } @@ -571,14 +547,14 @@ impl VmIrqRequest { } } -#[derive(MsgOnSocket, Debug)] +#[derive(Serialize, Deserialize, Debug)] pub enum VmIrqResponse { AllocateOneMsi { gsi: u32 }, Ok, Err(SysError), } -#[derive(MsgOnSocket, Debug)] +#[derive(Serialize, Deserialize, Debug)] pub enum VmMsyncRequest { /// Flush the content of a memory mapping to its backing file. /// `slot` selects the arena (as returned by `Vm::add_mmap_arena`). @@ -591,7 +567,7 @@ pub enum VmMsyncRequest { }, } -#[derive(MsgOnSocket, Debug)] +#[derive(Serialize, Deserialize, Debug)] pub enum VmMsyncResponse { Ok, Err(SysError), @@ -617,7 +593,7 @@ impl VmMsyncRequest { } } -#[derive(MsgOnSocket, Debug)] +#[derive(Serialize, Deserialize, Debug)] pub enum BatControlResult { Ok, NoBatDevice, @@ -644,7 +620,7 @@ impl Display for BatControlResult { } } -#[derive(MsgOnSocket, Copy, Clone, Debug, PartialEq)] +#[derive(Serialize, Deserialize, Copy, Clone, Debug, PartialEq)] pub enum BatteryType { Goldfish, } @@ -666,7 +642,7 @@ impl FromStr for BatteryType { } } -#[derive(MsgOnSocket, Debug)] +#[derive(Serialize, Deserialize, Debug)] pub enum BatProperty { Status, Health, @@ -690,7 +666,7 @@ impl FromStr for BatProperty { } } -#[derive(MsgOnSocket, Debug)] +#[derive(Serialize, Deserialize, Debug)] pub enum BatStatus { Unknown, Charging, @@ -733,7 +709,7 @@ impl From<BatStatus> for u32 { } } -#[derive(MsgOnSocket, Debug)] +#[derive(Serialize, Deserialize, Debug)] pub enum BatHealth { Unknown, Good, @@ -773,7 +749,7 @@ impl From<BatHealth> for u32 { } } -#[derive(MsgOnSocket, Debug)] +#[derive(Serialize, Deserialize, Debug)] pub enum BatControlCommand { SetStatus(BatStatus), SetHealth(BatHealth), @@ -810,10 +786,10 @@ impl BatControlCommand { /// Used for VM to control battery properties. pub struct BatControl { pub type_: BatteryType, - pub control_socket: BatControlRequestSocket, + pub control_tube: Tube, } -#[derive(MsgOnSocket, Debug)] +#[derive(Serialize, Deserialize, Debug)] pub enum FsMappingRequest { /// Create an anonymous memory mapping that spans the entire region described by `Alloc`. AllocateSharedMemoryRegion(Alloc), @@ -823,7 +799,7 @@ pub enum FsMappingRequest { /// `AllocateSharedMemoryRegion` request. slot: u32, /// The file descriptor that should be mapped. - fd: MaybeOwnedDescriptor, + fd: SafeDescriptor, /// The size of the mapping. size: usize, /// The offset into the file from where the mapping should start. @@ -918,37 +894,10 @@ impl FsMappingRequest { } } } - -pub type BalloonControlRequestSocket = MsgSocket<BalloonControlCommand, BalloonControlResult>; -pub type BalloonControlResponseSocket = MsgSocket<BalloonControlResult, BalloonControlCommand>; - -pub type BatControlRequestSocket = MsgSocket<BatControlCommand, BatControlResult>; -pub type BatControlResponseSocket = MsgSocket<BatControlResult, BatControlCommand>; - -pub type DiskControlRequestSocket = MsgSocket<DiskControlCommand, DiskControlResult>; -pub type DiskControlResponseSocket = MsgSocket<DiskControlResult, DiskControlCommand>; - -pub type FsMappingRequestSocket = MsgSocket<FsMappingRequest, VmResponse>; -pub type FsMappingResponseSocket = MsgSocket<VmResponse, FsMappingRequest>; - -pub type UsbControlSocket = MsgSocket<UsbControlCommand, UsbControlResult>; - -pub type VmMemoryControlRequestSocket = MsgSocket<VmMemoryRequest, VmMemoryResponse>; -pub type VmMemoryControlResponseSocket = MsgSocket<VmMemoryResponse, VmMemoryRequest>; - -pub type VmIrqRequestSocket = MsgSocket<VmIrqRequest, VmIrqResponse>; -pub type VmIrqResponseSocket = MsgSocket<VmIrqResponse, VmIrqRequest>; - -pub type VmMsyncRequestSocket = MsgSocket<VmMsyncRequest, VmMsyncResponse>; -pub type VmMsyncResponseSocket = MsgSocket<VmMsyncResponse, VmMsyncRequest>; - -pub type VmControlRequestSocket = MsgSocket<VmRequest, VmResponse>; -pub type VmControlResponseSocket = MsgSocket<VmResponse, VmRequest>; - /// A request to the main process to perform some operation on the VM. /// /// Unless otherwise noted, each request should expect a `VmResponse::Ok` to be received on success. -#[derive(MsgOnSocket, Debug)] +#[derive(Serialize, Deserialize, Debug)] pub enum VmRequest { /// Break the VM's run loop and exit. Exit, @@ -1005,7 +954,7 @@ fn register_memory( Ok((addr >> 12, slot)) } -fn register_memory_hva( +fn register_host_pointer( vm: &mut impl Vm, allocator: &mut SystemAllocator, mem: Box<dyn MappedRegion>, @@ -1029,9 +978,9 @@ impl VmRequest { pub fn execute( &self, run_mode: &mut Option<VmRunMode>, - balloon_host_socket: &BalloonControlRequestSocket, - disk_host_sockets: &[DiskControlRequestSocket], - usb_control_socket: &UsbControlSocket, + balloon_host_tube: &Tube, + disk_host_tubes: &[Tube], + usb_control_tube: &Tube, bat_control: &mut Option<BatControl>, ) -> VmResponse { match *self { @@ -1048,14 +997,14 @@ impl VmRequest { VmResponse::Ok } VmRequest::BalloonCommand(BalloonControlCommand::Adjust { num_bytes }) => { - match balloon_host_socket.send(&BalloonControlCommand::Adjust { num_bytes }) { + match balloon_host_tube.send(&BalloonControlCommand::Adjust { num_bytes }) { Ok(_) => VmResponse::Ok, Err(_) => VmResponse::Err(SysError::last()), } } VmRequest::BalloonCommand(BalloonControlCommand::Stats) => { - match balloon_host_socket.send(&BalloonControlCommand::Stats {}) { - Ok(_) => match balloon_host_socket.recv() { + match balloon_host_tube.send(&BalloonControlCommand::Stats {}) { + Ok(_) => match balloon_host_tube.recv() { Ok(BalloonControlResult::Stats { stats, balloon_actual, @@ -1076,7 +1025,7 @@ impl VmRequest { ref command, } => { // Forward the request to the block device process via its control socket. - if let Some(sock) = disk_host_sockets.get(disk_index) { + if let Some(sock) = disk_host_tubes.get(disk_index) { if let Err(e) = sock.send(command) { error!("disk socket send failed: {}", e); VmResponse::Err(SysError::new(EINVAL)) @@ -1095,12 +1044,12 @@ impl VmRequest { } } VmRequest::UsbCommand(ref cmd) => { - let res = usb_control_socket.send(cmd); + let res = usb_control_tube.send(cmd); if let Err(e) = res { error!("fail to send command to usb control socket: {}", e); return VmResponse::Err(SysError::new(EIO)); } - match usb_control_socket.recv() { + match usb_control_tube.recv() { Ok(response) => VmResponse::UsbResponse(response), Err(e) => { error!("fail to recv command from usb control socket: {}", e); @@ -1116,13 +1065,13 @@ impl VmRequest { return VmResponse::Err(SysError::new(EINVAL)); } - let res = battery.control_socket.send(cmd); + let res = battery.control_tube.send(cmd); if let Err(e) = res { error!("fail to send command to bat control socket: {}", e); return VmResponse::Err(SysError::new(EIO)); } - match battery.control_socket.recv() { + match battery.control_tube.recv() { Ok(response) => VmResponse::BatResponse(response), Err(e) => { error!("fail to recv command from bat control socket: {}", e); @@ -1140,7 +1089,7 @@ impl VmRequest { /// Indication of success or failure of a `VmRequest`. /// /// Success is usually indicated `VmResponse::Ok` unless there is data associated with the response. -#[derive(MsgOnSocket, Debug)] +#[derive(Serialize, Deserialize, Debug)] pub enum VmResponse { /// Indicates the request was executed successfully. Ok, @@ -1152,7 +1101,7 @@ pub enum VmResponse { /// The request to allocate and register GPU memory into guest address space was successfully /// done at page frame number `pfn` and memory slot number `slot` for buffer with `desc`. AllocateAndRegisterGpuMemory { - descriptor: MaybeOwnedDescriptor, + descriptor: SafeDescriptor, pfn: u64, slot: u32, desc: GpuMemoryDesc, @@ -1185,7 +1134,7 @@ impl Display for VmResponse { "gpu memory allocated and registered to page frame number {:#x} and memory slot {}", pfn, slot ), - BalloonStats { + VmResponse::BalloonStats { stats, balloon_actual, } => write!( diff --git a/vm_memory/Cargo.toml b/vm_memory/Cargo.toml index df599aacc..631ad9b84 100644 --- a/vm_memory/Cargo.toml +++ b/vm_memory/Cargo.toml @@ -10,6 +10,6 @@ cros_async = { path = "../cros_async" } # provided by ebuild data_model = { path = "../data_model" } # provided by ebuild libc = "*" base = { path = "../base" } # provided by ebuild -syscall_defines = { path = "../syscall_defines" } # provided by ebuild +bitflags = "1" [workspace] diff --git a/vm_memory/src/guest_memory.rs b/vm_memory/src/guest_memory.rs index b160968b8..62ba134e7 100644 --- a/vm_memory/src/guest_memory.rs +++ b/vm_memory/src/guest_memory.rs @@ -14,17 +14,21 @@ use std::sync::Arc; use crate::guest_address::GuestAddress; use base::{pagesize, Error as SysError}; use base::{ - AsRawDescriptor, MappedRegion, MemfdSeals, MemoryMapping, MemoryMappingBuilder, - MemoryMappingUnix, MmapError, RawDescriptor, SharedMemory, SharedMemoryUnix, + AsRawDescriptor, AsRawDescriptors, MappedRegion, MemfdSeals, MemoryMapping, + MemoryMappingBuilder, MemoryMappingUnix, MmapError, RawDescriptor, SharedMemory, + SharedMemoryUnix, }; use cros_async::{mem, BackingMemory}; use data_model::volatile_memory::*; use data_model::DataInit; +use bitflags::bitflags; + #[derive(Debug)] pub enum Error { DescriptorChainOverflow, InvalidGuestAddress(GuestAddress), + InvalidOffset(u64), MemoryAccess(GuestAddress, MmapError), MemoryMappingFailed(MmapError), MemoryRegionOverlap, @@ -51,6 +55,7 @@ impl Display for Error { "the combined length of all the buffers in a DescriptorChain is too large" ), InvalidGuestAddress(addr) => write!(f, "invalid guest address {}", addr), + InvalidOffset(addr) => write!(f, "invalid offset {}", addr), MemoryAccess(addr, e) => { write!(f, "invalid guest memory access at addr={}: {}", addr, e) } @@ -82,10 +87,17 @@ impl Display for Error { } } +bitflags! { + pub struct MemoryPolicy: u32 { + const USE_HUGEPAGES = 1; + } +} + struct MemoryRegion { mapping: MemoryMapping, guest_base: GuestAddress, - memfd_offset: u64, + shm_offset: u64, + shm: Arc<SharedMemory>, } impl MemoryRegion { @@ -108,24 +120,20 @@ impl MemoryRegion { #[derive(Clone)] pub struct GuestMemory { regions: Arc<[MemoryRegion]>, - shm: Arc<SharedMemory>, -} - -impl AsRawDescriptor for GuestMemory { - fn as_raw_descriptor(&self) -> RawDescriptor { - self.shm.as_raw_descriptor() - } } -impl AsRef<SharedMemory> for GuestMemory { - fn as_ref(&self) -> &SharedMemory { - &self.shm +impl AsRawDescriptors for GuestMemory { + fn as_raw_descriptors(&self) -> Vec<RawDescriptor> { + self.regions + .iter() + .map(|r| r.shm.as_raw_descriptor()) + .collect() } } impl GuestMemory { /// Creates backing shm for GuestMemory regions - fn create_memfd(ranges: &[(GuestAddress, u64)]) -> Result<SharedMemory> { + fn create_shm(ranges: &[(GuestAddress, u64)]) -> Result<SharedMemory> { let mut aligned_size = 0; let pg_size = pagesize(); for range in ranges { @@ -154,7 +162,7 @@ impl GuestMemory { pub fn new(ranges: &[(GuestAddress, u64)]) -> Result<GuestMemory> { // Create shm - let shm = GuestMemory::create_memfd(ranges)?; + let shm = Arc::new(GuestMemory::create_shm(ranges)?); // Create memory regions let mut regions = Vec::<MemoryRegion>::new(); let mut offset = 0; @@ -173,14 +181,15 @@ impl GuestMemory { let size = usize::try_from(range.1).map_err(|_| Error::MemoryRegionTooLarge(range.1))?; let mapping = MemoryMappingBuilder::new(size) - .from_descriptor(&shm) + .from_shared_memory(shm.as_ref()) .offset(offset) .build() .map_err(Error::MemoryMappingFailed)?; regions.push(MemoryRegion { mapping, guest_base: range.0, - memfd_offset: offset, + shm_offset: offset, + shm: Arc::clone(&shm), }); offset += size as u64; @@ -188,7 +197,6 @@ impl GuestMemory { Ok(GuestMemory { regions: Arc::from(regions), - shm: Arc::new(shm), }) } @@ -252,13 +260,27 @@ impl GuestMemory { /// Madvise away the address range in the host that is associated with the given guest range. pub fn remove_range(&self, addr: GuestAddress, count: u64) -> Result<()> { - self.do_in_region(addr, move |mapping, offset| { + self.do_in_region(addr, move |mapping, offset, _| { mapping .remove_range(offset, count as usize) .map_err(|e| Error::MemoryAccess(addr, e)) }) } + /// Handles guest memory policy hints/advices. + pub fn set_memory_policy(&self, mem_policy: MemoryPolicy) { + if mem_policy.contains(MemoryPolicy::USE_HUGEPAGES) { + for (_, region) in self.regions.iter().enumerate() { + let ret = region.mapping.use_hugepages(); + + match ret { + Err(err) => println!("Failed to enable HUGEPAGE for mapping {}", err), + Ok(_) => (), + } + } + } + } + /// Perform the specified action on each region's addresses. /// /// Callback is called with arguments: @@ -266,10 +288,11 @@ impl GuestMemory { /// * guest_addr : GuestAddress /// * size: usize /// * host_addr: usize - /// * memfd_offset: usize + /// * shm: SharedMemory backing for the given region + /// * shm_offset: usize pub fn with_regions<F, E>(&self, mut cb: F) -> result::Result<(), E> where - F: FnMut(usize, GuestAddress, usize, usize, u64) -> result::Result<(), E>, + F: FnMut(usize, GuestAddress, usize, usize, &SharedMemory, u64) -> result::Result<(), E>, { for (index, region) in self.regions.iter().enumerate() { cb( @@ -277,7 +300,8 @@ impl GuestMemory { region.start(), region.mapping.size(), region.mapping.as_ptr() as usize, - region.memfd_offset, + region.shm.as_ref(), + region.shm_offset, )?; } Ok(()) @@ -303,7 +327,7 @@ impl GuestMemory { /// # } /// ``` pub fn write_at_addr(&self, buf: &[u8], guest_addr: GuestAddress) -> Result<usize> { - self.do_in_region(guest_addr, move |mapping, offset| { + self.do_in_region(guest_addr, move |mapping, offset, _| { mapping .write_slice(buf, offset) .map_err(|e| Error::MemoryAccess(guest_addr, e)) @@ -362,7 +386,7 @@ impl GuestMemory { /// # } /// ``` pub fn read_at_addr(&self, buf: &mut [u8], guest_addr: GuestAddress) -> Result<usize> { - self.do_in_region(guest_addr, move |mapping, offset| { + self.do_in_region(guest_addr, move |mapping, offset, _| { mapping .read_slice(buf, offset) .map_err(|e| Error::MemoryAccess(guest_addr, e)) @@ -422,7 +446,7 @@ impl GuestMemory { /// # } /// ``` pub fn read_obj_from_addr<T: DataInit>(&self, guest_addr: GuestAddress) -> Result<T> { - self.do_in_region(guest_addr, |mapping, offset| { + self.do_in_region(guest_addr, |mapping, offset, _| { mapping .read_obj(offset) .map_err(|e| Error::MemoryAccess(guest_addr, e)) @@ -446,7 +470,7 @@ impl GuestMemory { /// # } /// ``` pub fn write_obj_at_addr<T: DataInit>(&self, val: T, guest_addr: GuestAddress) -> Result<()> { - self.do_in_region(guest_addr, move |mapping, offset| { + self.do_in_region(guest_addr, move |mapping, offset, _| { mapping .write_obj(val, offset) .map_err(|e| Error::MemoryAccess(guest_addr, e)) @@ -543,7 +567,7 @@ impl GuestMemory { src: &dyn AsRawDescriptor, count: usize, ) -> Result<()> { - self.do_in_region(guest_addr, move |mapping, offset| { + self.do_in_region(guest_addr, move |mapping, offset, _| { mapping .read_to_memory(offset, src, count) .map_err(|e| Error::MemoryAccess(guest_addr, e)) @@ -581,7 +605,7 @@ impl GuestMemory { dst: &dyn AsRawDescriptor, count: usize, ) -> Result<()> { - self.do_in_region(guest_addr, move |mapping, offset| { + self.do_in_region(guest_addr, move |mapping, offset, _| { mapping .write_from_memory(offset, dst, count) .map_err(|e| Error::MemoryAccess(guest_addr, e)) @@ -609,16 +633,42 @@ impl GuestMemory { /// # } /// ``` pub fn get_host_address(&self, guest_addr: GuestAddress) -> Result<*const u8> { - self.do_in_region(guest_addr, |mapping, offset| { + self.do_in_region(guest_addr, |mapping, offset, _| { // This is safe; `do_in_region` already checks that offset is in // bounds. Ok(unsafe { mapping.as_ptr().add(offset) } as *const u8) }) } + /// Returns a reference to the SharedMemory region that backs the given address. + pub fn shm_region(&self, guest_addr: GuestAddress) -> Result<&SharedMemory> { + self.regions + .iter() + .find(|region| region.contains(guest_addr)) + .ok_or(Error::InvalidGuestAddress(guest_addr)) + .map(|region| region.shm.as_ref()) + } + + /// Returns the region that contains the memory at `offset` from the base of guest memory. + pub fn offset_region(&self, offset: u64) -> Result<&SharedMemory> { + self.shm_region( + self.checked_offset(self.regions[0].guest_base, offset) + .ok_or(Error::InvalidOffset(offset))?, + ) + } + + /// Loops over all guest memory regions of `self`, and performs the callback function `F` in + /// the target region that contains `guest_addr`. The callback function `F` takes in: + /// + /// (i) the memory mapping associated with the target region. + /// (ii) the relative offset from the start of the target region to `guest_addr`. + /// (iii) the absolute offset from the start of the memory mapping to the target region. + /// + /// If no target region is found, an error is returned. The callback function `F` may return + /// an Ok(`T`) on success or a `GuestMemoryError` on failure. pub fn do_in_region<F, T>(&self, guest_addr: GuestAddress, cb: F) -> Result<T> where - F: FnOnce(&MemoryMapping, usize) -> Result<T>, + F: FnOnce(&MemoryMapping, usize, u64) -> Result<T>, { self.regions .iter() @@ -628,11 +678,12 @@ impl GuestMemory { cb( ®ion.mapping, guest_addr.offset_from(region.start()) as usize, + region.shm_offset, ) }) } - /// Convert a GuestAddress into an offset within self.shm. + /// Convert a GuestAddress into an offset within the associated shm region. /// /// Due to potential gaps within GuestMemory, it is helpful to know the /// offset within the shm where a given address is found. This offset @@ -660,7 +711,7 @@ impl GuestMemory { .iter() .find(|region| region.contains(guest_addr)) .ok_or(Error::InvalidGuestAddress(guest_addr)) - .map(|region| region.memfd_offset + guest_addr.offset_from(region.start())) + .map(|region| region.shm_offset + guest_addr.offset_from(region.start())) } } @@ -800,7 +851,7 @@ mod tests { // Get the base address of the mapping for a GuestAddress. fn get_mapping(mem: &GuestMemory, addr: GuestAddress) -> Result<*const u8> { - mem.do_in_region(addr, |mapping, _| Ok(mapping.as_ptr() as *const u8)) + mem.do_in_region(addr, |mapping, _, _| Ok(mapping.as_ptr() as *const u8)) } #[test] @@ -823,7 +874,7 @@ mod tests { } #[test] - fn memfd_offset() { + fn shm_offset() { if !kernel_has_memfd() { return; } @@ -839,10 +890,10 @@ mod tests { gm.write_obj_at_addr(0x0420u16, GuestAddress(0x10000)) .unwrap(); - let _ = gm.with_regions::<_, ()>(|index, _, size, _, memfd_offset| { + let _ = gm.with_regions::<_, ()>(|index, _, size, _, shm, shm_offset| { let mmap = MemoryMappingBuilder::new(size) - .from_descriptor(gm.as_ref()) - .offset(memfd_offset) + .from_shared_memory(shm) + .offset(shm_offset) .build() .unwrap(); diff --git a/x86_64/Android.bp b/x86_64/Android.bp index e7dc658ed..8adb07a85 100644 --- a/x86_64/Android.bp +++ b/x86_64/Android.bp @@ -1,5 +1,4 @@ -// This file is generated by cargo2android.py --run --device --tests --dependencies --global_defaults=crosvm_defaults --add_workspace --features=gdb. -// NOTE: The --features=gdb should be applied only to the host (not the device) and there are inline changes to achieve this +// This file is generated by cargo2android.py --run --device --tests --dependencies --global_defaults=crosvm_defaults --add_workspace. package { // See: http://go/android-license-faq @@ -18,19 +17,6 @@ rust_library { crate_name: "x86_64", srcs: ["src/lib.rs"], edition: "2018", - target: { - linux_glibc_x86_64: { - features: [ - "gdb", - "gdbstub", - "msg_socket", - ], - rustlibs: [ - "libgdbstub", - "libmsg_socket", - ], - }, - }, rustlibs: [ "libacpi_tables", "libarch", @@ -68,18 +54,6 @@ rust_defaults { test_suites: ["general-tests"], auto_gen_config: true, edition: "2018", - target: { - linux_glibc_x86_64: { - features: [ - "gdb", - "gdbstub", - "msg_socket", - ], - rustlibs: [ - "libgdbstub", - ], - }, - }, rustlibs: [ "libacpi_tables", "libarch", @@ -133,7 +107,7 @@ rust_test { // ../../vm_tools/p9/src/lib.rs // ../../vm_tools/p9/wire_format_derive/wire_format_derive.rs // ../acpi_tables/src/lib.rs -// ../arch/src/lib.rs "gdb,gdbstub" +// ../arch/src/lib.rs // ../assertions/src/lib.rs // ../base/src/lib.rs // ../bit_field/bit_field_derive/bit_field_derive.rs @@ -172,11 +146,10 @@ rust_test { // ../vm_control/src/lib.rs // ../vm_memory/src/lib.rs // async-task-4.0.3 "default,std" -// async-trait-0.1.48 +// async-trait-0.1.45 // autocfg-1.0.1 // base-0.1.0 // bitflags-1.2.1 "default" -// cfg-if-0.1.10 // cfg-if-1.0.0 // downcast-rs-1.2.0 "default,std" // futures-0.3.13 "alloc,async-await,default,executor,futures-executor,std" @@ -188,15 +161,12 @@ rust_test { // futures-sink-0.3.13 "alloc,std" // futures-task-0.3.13 "alloc,std" // futures-util-0.3.13 "alloc,async-await,async-await-macro,channel,futures-channel,futures-io,futures-macro,futures-sink,io,memchr,proc-macro-hack,proc-macro-nested,sink,slab,std" -// gdbstub-0.4.4 "alloc,default,std" // getrandom-0.2.2 "std" // intrusive-collections-0.9.0 "alloc,default" -// libc-0.2.88 "default,std" +// libc-0.2.87 "default,std" // log-0.4.14 -// managed-0.8.0 "alloc" // memchr-2.3.4 "default,std" // memoffset-0.5.6 "default" -// num-traits-0.2.14 // paste-1.0.4 // pin-project-lite-0.2.6 // pin-utils-0.1.0 @@ -212,10 +182,10 @@ rust_test { // rand_core-0.6.2 "alloc,getrandom,std" // remain-0.2.2 // remove_dir_all-0.5.3 -// serde-1.0.124 "default,derive,serde_derive,std" -// serde_derive-1.0.124 "default" +// serde-1.0.123 "default,derive,serde_derive,std" +// serde_derive-1.0.123 "default" // slab-0.4.2 -// syn-1.0.63 "clone-impls,default,derive,full,parsing,printing,proc-macro,quote,visit-mut" +// syn-1.0.61 "clone-impls,default,derive,full,parsing,printing,proc-macro,quote,visit-mut" // tempfile-3.2.0 // thiserror-1.0.24 // thiserror-impl-1.0.24 diff --git a/x86_64/Cargo.toml b/x86_64/Cargo.toml index fcfb0f15f..07fac9dc9 100644 --- a/x86_64/Cargo.toml +++ b/x86_64/Cargo.toml @@ -5,7 +5,7 @@ authors = ["The Chromium OS Authors"] edition = "2018" [features] -gdb = ["gdbstub", "msg_socket", "arch/gdb"] +gdb = ["gdbstub", "arch/gdb"] [dependencies] arch = { path = "../arch" } @@ -18,7 +18,6 @@ kernel_cmdline = { path = "../kernel_cmdline" } kernel_loader = { path = "../kernel_loader" } libc = "*" minijail = { path = "../../minijail/rust/minijail" } # ignored by ebuild -msg_socket = { path = "../msg_socket", optional = true } remain = "*" resources = { path = "../resources" } sync = { path = "../sync" } @@ -26,6 +25,3 @@ base = { path = "../base" } acpi_tables = {path = "../acpi_tables" } vm_control = { path = "../vm_control" } vm_memory = { path = "../vm_memory" } - -[dev-dependencies] -msg_socket = { path = "../msg_socket"} diff --git a/x86_64/src/cpuid.rs b/x86_64/src/cpuid.rs index d965a2f51..6b31d6012 100644 --- a/x86_64/src/cpuid.rs +++ b/x86_64/src/cpuid.rs @@ -50,7 +50,7 @@ fn filter_cpuid( cpuid: &mut hypervisor::CpuId, irq_chip: &dyn IrqChipX86_64, no_smt: bool, -) -> Result<()> { +) { let entries = &mut cpuid.cpu_id_entries; for entry in entries { @@ -142,8 +142,6 @@ fn filter_cpuid( _ => (), } } - - Ok(()) } /// Sets up the cpuid entries for the given vcpu. Can fail if there are too many CPUs specified or @@ -167,7 +165,7 @@ pub fn setup_cpuid( .get_supported_cpuid() .map_err(Error::GetSupportedCpusFailed)?; - filter_cpuid(vcpu_id, nrcpus, &mut cpuid, irq_chip, no_smt)?; + filter_cpuid(vcpu_id, nrcpus, &mut cpuid, irq_chip, no_smt); vcpu.set_cpuid(&cpuid) .map_err(Error::SetSupportedCpusFailed) @@ -211,7 +209,7 @@ mod tests { edx: 0, ..Default::default() }); - assert_eq!(Ok(()), filter_cpuid(1, 2, &mut cpuid, &irq_chip, false)); + filter_cpuid(1, 2, &mut cpuid, &irq_chip, false); let entries = &mut cpuid.cpu_id_entries; assert_eq!(entries[0].function, 0); diff --git a/x86_64/src/lib.rs b/x86_64/src/lib.rs index da1052d78..2de945ae8 100644 --- a/x86_64/src/lib.rs +++ b/x86_64/src/lib.rs @@ -349,13 +349,23 @@ fn arch_memory_regions(size: u64, bios_size: Option<u64>) -> Vec<(GuestAddress, impl arch::LinuxArch for X8664arch { type Error = Error; - fn build_vm<V, Vcpu, I, FD, FV, FI, E1, E2, E3>( + fn guest_memory_layout( + components: &VmComponents, + ) -> std::result::Result<Vec<(GuestAddress, u64)>, Self::Error> { + let bios_size = match &components.vm_image { + VmImage::Bios(bios_file) => Some(bios_file.metadata().map_err(Error::LoadBios)?.len()), + VmImage::Kernel(_) => None, + }; + Ok(arch_memory_regions(components.memory_size, bios_size)) + } + + fn build_vm<V, Vcpu, I, FD, FI, E1, E2>( mut components: VmComponents, serial_parameters: &BTreeMap<(SerialHardware, u8), SerialParameters>, serial_jail: Option<Minijail>, battery: (&Option<BatteryType>, Option<Minijail>), + mut vm: V, create_devices: FD, - create_vm: FV, create_irq_chip: FI, ) -> std::result::Result<RunnableLinuxVm<V, Vcpu, I>, Self::Error> where @@ -368,28 +378,18 @@ impl arch::LinuxArch for X8664arch { &mut SystemAllocator, &Event, ) -> std::result::Result<Vec<(Box<dyn PciDevice>, Option<Minijail>)>, E1>, - FV: FnOnce(GuestMemory) -> std::result::Result<V, E2>, - FI: FnOnce(&V, /* vcpu_count: */ usize) -> std::result::Result<I, E3>, + FI: FnOnce(&V, /* vcpu_count: */ usize) -> std::result::Result<I, E2>, E1: StdError + 'static, E2: StdError + 'static, - E3: StdError + 'static, { if components.protected_vm != ProtectionType::Unprotected { return Err(Error::UnsupportedProtectionType); } - let bios_size = match components.vm_image { - VmImage::Bios(ref mut bios_file) => { - Some(bios_file.metadata().map_err(Error::LoadBios)?.len()) - } - VmImage::Kernel(_) => None, - }; - let has_bios = bios_size.is_some(); - let mem = Self::setup_memory(components.memory_size, bios_size)?; + let mem = vm.get_memory().clone(); let mut resources = Self::get_resource_allocator(&mem); let vcpu_count = components.vcpu_count; - let mut vm = create_vm(mem.clone()).map_err(|e| Error::CreateVm(Box::new(e)))?; let mut irq_chip = create_irq_chip(&vm, vcpu_count).map_err(|e| Error::CreateIrqChip(Box::new(e)))?; @@ -464,7 +464,8 @@ impl arch::LinuxArch for X8664arch { // Note that this puts the mptable at 0x9FC00 in guest physical memory. mptable::setup_mptable(&mem, vcpu_count as u8, pci_irqs).map_err(Error::SetupMptable)?; - smbios::setup_smbios(&mem).map_err(Error::SetupSmbios)?; + smbios::setup_smbios(&mem, components.dmi_path).map_err(Error::SetupSmbios)?; + // TODO (tjeznach) Write RSDP to bootconfig before writing to memory acpi::create_acpi_tables(&mem, vcpu_count as u8, X86_64_SCI_IRQ, acpi_dev_resource); @@ -523,7 +524,7 @@ impl arch::LinuxArch for X8664arch { vcpu_affinity: components.vcpu_affinity, no_smt: components.no_smt, irq_chip, - has_bios, + has_bios: matches!(components.vm_image, VmImage::Bios(_)), io_bus, mmio_bus, pid_debug_label_map, @@ -929,15 +930,6 @@ impl X8664arch { Ok(()) } - /// This creates a GuestMemory object for this VM - /// - /// * `mem_size` - Desired physical memory size in bytes for this VM - fn setup_memory(mem_size: u64, bios_size: Option<u64>) -> Result<GuestMemory> { - let arch_mem_regions = arch_memory_regions(mem_size, bios_size); - let mem = GuestMemory::new(&arch_mem_regions).map_err(Error::SetupGuestMemory)?; - Ok(mem) - } - /// This returns the start address of high mmio /// /// # Arguments @@ -1092,7 +1084,7 @@ impl X8664arch { let bat_control = if let Some(battery_type) = battery.0 { match battery_type { BatteryType::Goldfish => { - let control_socket = arch::add_goldfish_battery( + let control_tube = arch::add_goldfish_battery( &mut amls, battery.1, mmio_bus, @@ -1103,7 +1095,7 @@ impl X8664arch { .map_err(Error::CreateBatDevices)?; Some(BatControl { type_: BatteryType::Goldfish, - control_socket, + control_tube, }) } } diff --git a/x86_64/src/smbios.rs b/x86_64/src/smbios.rs index 1d6622ed0..6fd8774e1 100644 --- a/x86_64/src/smbios.rs +++ b/x86_64/src/smbios.rs @@ -7,6 +7,10 @@ use std::mem; use std::result; use std::slice; +use std::fs::OpenOptions; +use std::io::prelude::*; +use std::path::{Path, PathBuf}; + use data_model::DataInit; use vm_memory::{GuestAddress, GuestMemory}; @@ -22,6 +26,12 @@ pub enum Error { WriteSmbiosEp, /// Failure to write additional data to memory WriteData, + /// Failure while reading SMBIOS data file + IoFailed, + /// Incorrect or not readable host SMBIOS data + InvalidInput, + /// Invalid table entry point checksum + InvalidChecksum, } impl std::error::Error for Error {} @@ -36,6 +46,9 @@ impl Display for Error { Clear => "Failure while zeroing out the memory for the SMBIOS table", WriteSmbiosEp => "Failure to write SMBIOS entrypoint structure", WriteData => "Failure to write additional data to memory", + IoFailed => "Failure while reading SMBIOS data file", + InvalidInput => "Failure to read host SMBIOS data", + InvalidChecksum => "Failure to verify host SMBIOS entry checksum", }; write!(f, "SMBIOS error: {}", description) @@ -46,10 +59,14 @@ pub type Result<T> = result::Result<T, Error>; const SMBIOS_START: u64 = 0xf0000; // First possible location per the spec. +// Constants sourced from SMBIOS Spec 2.3.1. +const SM2_MAGIC_IDENT: &[u8; 4usize] = b"_SM_"; + // Constants sourced from SMBIOS Spec 3.2.0. const SM3_MAGIC_IDENT: &[u8; 5usize] = b"_SM3_"; const BIOS_INFORMATION: u8 = 0; const SYSTEM_INFORMATION: u8 = 1; +const END_OF_TABLE: u8 = 127; const PCI_SUPPORTED: u64 = 1 << 7; const IS_VIRTUAL_MACHINE: u8 = 1 << 4; @@ -65,6 +82,47 @@ fn compute_checksum<T: Copy>(v: &T) -> u8 { #[repr(packed)] #[derive(Default, Copy)] +pub struct Smbios23Intermediate { + pub signature: [u8; 5usize], + pub checksum: u8, + pub length: u16, + pub address: u32, + pub count: u16, + pub revision: u8, +} + +unsafe impl data_model::DataInit for Smbios23Intermediate {} + +impl Clone for Smbios23Intermediate { + fn clone(&self) -> Self { + *self + } +} + +#[repr(packed)] +#[derive(Default, Copy)] +pub struct Smbios23Entrypoint { + pub signature: [u8; 4usize], + pub checksum: u8, + pub length: u8, + pub majorver: u8, + pub minorver: u8, + pub max_size: u16, + pub revision: u8, + pub reserved: [u8; 5usize], + pub dmi: Smbios23Intermediate, +} + +unsafe impl data_model::DataInit for Smbios23Entrypoint {} + +impl Clone for Smbios23Entrypoint { + fn clone(&self) -> Self { + *self + } +} + +#[repr(packed)] +#[derive(Default, Copy)] pub struct Smbios30Entrypoint { pub signature: [u8; 5usize], pub checksum: u8, @@ -154,7 +212,83 @@ fn write_string(mem: &GuestMemory, val: &str, mut curptr: GuestAddress) -> Resul Ok(curptr) } -pub fn setup_smbios(mem: &GuestMemory) -> Result<()> { +fn setup_smbios_from_file(mem: &GuestMemory, path: &Path) -> Result<()> { + let mut sme_path = PathBuf::from(path); + sme_path.push("smbios_entry_point"); + let mut sme = Vec::new(); + OpenOptions::new() + .read(true) + .open(&sme_path) + .map_err(|_| Error::IoFailed)? + .read_to_end(&mut sme) + .map_err(|_| Error::IoFailed)?; + + let mut dmi_path = PathBuf::from(path); + dmi_path.push("DMI"); + let mut dmi = Vec::new(); + OpenOptions::new() + .read(true) + .open(&dmi_path) + .map_err(|_| Error::IoFailed)? + .read_to_end(&mut dmi) + .map_err(|_| Error::IoFailed)?; + + // Try SMBIOS 3.0 format. + if sme.len() == mem::size_of::<Smbios30Entrypoint>() && sme.starts_with(SM3_MAGIC_IDENT) { + let mut smbios_ep = Smbios30Entrypoint::default(); + smbios_ep.as_mut_slice().copy_from_slice(&sme); + + let physptr = GuestAddress(SMBIOS_START) + .checked_add(mem::size_of::<Smbios30Entrypoint>() as u64) + .ok_or(Error::NotEnoughMemory)?; + + mem.write_at_addr(&dmi, physptr) + .map_err(|_| Error::NotEnoughMemory)?; + + // Update EP DMI location + smbios_ep.physptr = physptr.offset(); + smbios_ep.checksum = 0; + smbios_ep.checksum = compute_checksum(&smbios_ep); + + mem.write_obj_at_addr(smbios_ep, GuestAddress(SMBIOS_START)) + .map_err(|_| Error::NotEnoughMemory)?; + + return Ok(()); + } + + // Try SMBIOS 2.3 format. + if sme.len() == mem::size_of::<Smbios23Entrypoint>() && sme.starts_with(SM2_MAGIC_IDENT) { + let mut smbios_ep = Smbios23Entrypoint::default(); + smbios_ep.as_mut_slice().copy_from_slice(&sme); + + let physptr = GuestAddress(SMBIOS_START) + .checked_add(mem::size_of::<Smbios23Entrypoint>() as u64) + .ok_or(Error::NotEnoughMemory)?; + + mem.write_at_addr(&dmi, physptr) + .map_err(|_| Error::NotEnoughMemory)?; + + // Update EP DMI location + smbios_ep.dmi.address = physptr.offset() as u32; + smbios_ep.dmi.checksum = 0; + smbios_ep.dmi.checksum = compute_checksum(&smbios_ep.dmi); + smbios_ep.checksum = 0; + smbios_ep.checksum = compute_checksum(&smbios_ep); + + mem.write_obj_at_addr(smbios_ep, GuestAddress(SMBIOS_START)) + .map_err(|_| Error::WriteSmbiosEp)?; + + return Ok(()); + } + + Err(Error::InvalidInput) +} + +pub fn setup_smbios(mem: &GuestMemory, dmi_path: Option<PathBuf>) -> Result<()> { + if let Some(dmi_path) = dmi_path { + return setup_smbios_from_file(mem, &dmi_path); + } + let physptr = GuestAddress(SMBIOS_START) .checked_add(mem::size_of::<Smbios30Entrypoint>() as u64) .ok_or(Error::NotEnoughMemory)?; @@ -196,6 +330,18 @@ pub fn setup_smbios(mem: &GuestMemory) -> Result<()> { } { + handle += 1; + let smbios_sysinfo = SmbiosSysInfo { + typ: END_OF_TABLE, + length: mem::size_of::<SmbiosSysInfo>() as u8, + handle, + ..Default::default() + }; + curptr = write_and_incr(mem, smbios_sysinfo, curptr)?; + curptr = write_and_incr(mem, 0_u8, curptr)?; + } + + { let mut smbios_ep = Smbios30Entrypoint::default(); smbios_ep.signature = *SM3_MAGIC_IDENT; smbios_ep.length = mem::size_of::<Smbios30Entrypoint>() as u8; @@ -221,6 +367,11 @@ mod tests { #[test] fn struct_size() { assert_eq!( + mem::size_of::<Smbios23Entrypoint>(), + 0x1fusize, + concat!("Size of: ", stringify!(Smbios23Entrypoint)) + ); + assert_eq!( mem::size_of::<Smbios30Entrypoint>(), 0x18usize, concat!("Size of: ", stringify!(Smbios30Entrypoint)) @@ -241,7 +392,8 @@ mod tests { fn entrypoint_checksum() { let mem = GuestMemory::new(&[(GuestAddress(SMBIOS_START), 4096)]).unwrap(); - setup_smbios(&mem).unwrap(); + // Use default 3.0 SMBIOS format. + setup_smbios(&mem, None).unwrap(); let smbios_ep: Smbios30Entrypoint = mem.read_obj_from_addr(GuestAddress(SMBIOS_START)).unwrap(); diff --git a/x86_64/src/test_integration.rs b/x86_64/src/test_integration.rs index 6545c85a2..6ee54e1a3 100644 --- a/x86_64/src/test_integration.rs +++ b/x86_64/src/test_integration.rs @@ -12,13 +12,13 @@ use super::cpuid::setup_cpuid; use super::interrupts::set_lint; use super::regs::{setup_fpu, setup_msrs, setup_regs, setup_sregs}; use super::X8664arch; -use super::{acpi, bootparam, mptable, smbios}; +use super::{acpi, arch_memory_regions, bootparam, mptable, smbios}; use super::{ BOOT_STACK_POINTER, END_ADDR_BEFORE_32BITS, KERNEL_64BIT_ENTRY_OFFSET, KERNEL_START_OFFSET, X86_64_SCI_IRQ, ZERO_PAGE_OFFSET, }; -use base::Event; +use base::{Event, Tube}; use std::collections::BTreeMap; use std::ffi::CString; @@ -28,14 +28,9 @@ use sync::Mutex; use devices::PciConfigIo; -use vm_control::{ - DiskControlCommand, DiskControlResult, VmIrqRequest, VmIrqRequestSocket, VmIrqResponse, - VmIrqResponseSocket, VmMemoryControlResponseSocket, VmMemoryRequest, VmMemoryResponse, -}; - -enum TaggedControlSocket { - VmMemory(VmMemoryControlResponseSocket), - VmIrq(VmIrqResponseSocket), +enum TaggedControlTube { + VmMemory(Tube), + VmIrq(Tube), } #[test] @@ -64,8 +59,8 @@ fn simple_kvm_split_irqchip_test() { let vm = KvmVm::new(&kvm, guest_mem).expect("failed to create kvm vm"); (kvm, vm) }, - |vm, vcpu_count, device_socket| { - KvmSplitIrqChip::new(vm, vcpu_count, device_socket) + |vm, vcpu_count, device_tube| { + KvmSplitIrqChip::new(vm, vcpu_count, device_tube, None) .expect("failed to create KvmSplitIrqChip") }, ); @@ -82,7 +77,7 @@ where Vcpu: VcpuX86_64 + 'static, I: IrqChipX86_64 + 'static, FV: FnOnce(GuestMemory) -> (H, V), - FI: FnOnce(V, /* vcpu_count: */ usize, VmIrqRequestSocket) -> I, + FI: FnOnce(V, /* vcpu_count: */ usize, Tube) -> I, { /* 0x0000000000000000: 67 89 18 mov dword ptr [eax], ebx @@ -100,38 +95,32 @@ where let write_addr = GuestAddress(0x4000); // guest mem is 400 pages - let guest_mem = X8664arch::setup_memory(memory_size, None).unwrap(); - // let guest_mem = GuestMemory::new(&[(GuestAddress(0), memory_size)]).unwrap(); + let arch_mem_regions = arch_memory_regions(memory_size, None); + let guest_mem = GuestMemory::new(&arch_mem_regions).unwrap(); + let mut resources = X8664arch::get_resource_allocator(&guest_mem); let (hyp, mut vm) = create_vm(guest_mem.clone()); - let (irqchip_socket, device_socket) = - msg_socket::pair::<VmIrqResponse, VmIrqRequest>().expect("failed to create irq socket"); + let (irqchip_tube, device_tube) = Tube::pair().expect("failed to create irq tube"); - let mut irq_chip = create_irq_chip( - vm.try_clone().expect("failed to clone vm"), - 1, - device_socket, - ); + let mut irq_chip = create_irq_chip(vm.try_clone().expect("failed to clone vm"), 1, device_tube); let mut mmio_bus = devices::Bus::new(); let exit_evt = Event::new().unwrap(); - let mut control_sockets = vec![TaggedControlSocket::VmIrq(irqchip_socket)]; + let mut control_tubes = vec![TaggedControlTube::VmIrq(irqchip_tube)]; // Create one control socket per disk. - let mut disk_device_sockets = Vec::new(); - let mut disk_host_sockets = Vec::new(); + let mut disk_device_tubes = Vec::new(); + let mut disk_host_tubes = Vec::new(); let disk_count = 0; for _ in 0..disk_count { - let (disk_host_socket, disk_device_socket) = - msg_socket::pair::<DiskControlCommand, DiskControlResult>().unwrap(); - disk_host_sockets.push(disk_host_socket); - disk_device_sockets.push(disk_device_socket); + let (disk_host_tube, disk_device_tube) = Tube::pair().unwrap(); + disk_host_tubes.push(disk_host_tube); + disk_device_tubes.push(disk_device_tube); } - let (gpu_host_socket, _gpu_device_socket) = - msg_socket::pair::<VmMemoryResponse, VmMemoryRequest>().unwrap(); + let (gpu_host_tube, _gpu_device_tube) = Tube::pair().unwrap(); - control_sockets.push(TaggedControlSocket::VmMemory(gpu_host_socket)); + control_tubes.push(TaggedControlTube::VmMemory(gpu_host_tube)); let devices = vec![]; @@ -212,7 +201,7 @@ where // Note that this puts the mptable at 0x9FC00 in guest physical memory. mptable::setup_mptable(&guest_mem, 1, pci_irqs).expect("failed to setup mptable"); - smbios::setup_smbios(&guest_mem).expect("failed to setup smbios"); + smbios::setup_smbios(&guest_mem, None).expect("failed to setup smbios"); acpi::create_acpi_tables(&guest_mem, 1, X86_64_SCI_IRQ, acpi_dev_resource.0); |