aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJorge E. Moreira <jemoreira@google.com>2021-04-21 20:32:53 -0700
committerJorge E. Moreira <jemoreira@google.com>2021-04-21 20:32:53 -0700
commit67c7636ec8364d9d92a96bceae6441ba01461eb2 (patch)
tree93c4786948168cfe49e3f47fa3ed94924d3b2e99
parent38a78326ce3cfe429a8344297abfb162d47e6204 (diff)
parent5be4f273e87bc55dfc1ed3f4a6126f4d9f02e797 (diff)
downloadcrosvm-67c7636ec8364d9d92a96bceae6441ba01461eb2.tar.gz
Merge remote-tracking branch 'aosp/upstream-main'
Bug: 185155959 Test: locally with following change Change-Id: I9580972149384e197e57abb09d480d8997f527e5
-rw-r--r--.gitignore1
-rw-r--r--Android.bp31
-rw-r--r--Cargo.lock375
-rw-r--r--Cargo.toml16
-rw-r--r--OWNERS7
-rw-r--r--aarch64/src/lib.rs66
-rw-r--r--arch/Android.bp39
-rw-r--r--arch/Cargo.toml1
-rw-r--r--arch/src/fdt.rs4
-rw-r--r--arch/src/lib.rs52
-rw-r--r--arch/src/pstore.rs2
-rw-r--r--base/Cargo.toml4
-rw-r--r--base/src/event.rs5
-rw-r--r--base/src/lib.rs4
-rw-r--r--base/src/mmap.rs68
-rw-r--r--base/src/shm.rs19
-rw-r--r--base/src/tube.rs239
-rw-r--r--base/src/wait_context.rs16
-rwxr-xr-xbin/clippy4
-rw-r--r--bit_field/bit_field_derive/bit_field_derive.rs2
-rwxr-xr-xci/crosvm_aarch64_builder/entrypoint9
-rw-r--r--ci/crosvm_base/rust-toolchain2
-rwxr-xr-xci/crosvm_builder/entrypoint9
-rw-r--r--ci/crosvm_test_vm/Dockerfile7
-rw-r--r--ci/crosvm_test_vm/build/cloud_init_data.yaml26
-rwxr-xr-xci/crosvm_test_vm/runtime/sync_so20
-rw-r--r--ci/image_tag2
-rwxr-xr-xci/kokoro/common.sh23
-rw-r--r--ci/kokoro/manifest.xml24
-rwxr-xr-xci/kokoro/uprev30
-rw-r--r--ci/test_runner.py142
-rw-r--r--ci/vm_tools/README.md8
-rwxr-xr-xci/vm_tools/exec_binary_in_vm (renamed from ci/crosvm_test_vm/runtime/exec_file)12
-rwxr-xr-xci/vm_tools/sync_deps37
-rwxr-xr-xci/vm_tools/wait_for_vm (renamed from ci/crosvm_test_vm/runtime/wait_for_vm)0
-rwxr-xr-xci/vm_tools/wait_for_vm_with_timeout (renamed from ci/crosvm_test_vm/runtime/exec)2
-rw-r--r--cros_async/Cargo.toml17
-rw-r--r--cros_async/src/fd_executor.rs109
-rw-r--r--cros_async/src/lib.rs1
-rw-r--r--cros_async/src/poll_source.rs107
-rw-r--r--cros_async/src/queue.rs36
-rw-r--r--cros_async/src/sync.rs14
-rw-r--r--cros_async/src/sync/blocking.rs192
-rw-r--r--cros_async/src/sync/cv.rs1159
-rw-r--r--cros_async/src/sync/mu.rs2289
-rw-r--r--cros_async/src/sync/spin.rs277
-rw-r--r--cros_async/src/sync/waiter.rs281
-rw-r--r--cros_async/src/uring_executor.rs22
-rw-r--r--crosvm_plugin/src/lib.rs8
-rw-r--r--data_model/src/endian.rs24
-rw-r--r--devices/Cargo.toml10
-rw-r--r--devices/src/bat.rs27
-rw-r--r--devices/src/bus.rs5
-rw-r--r--devices/src/direct_io.rs68
-rw-r--r--devices/src/direct_irq.rs115
-rw-r--r--devices/src/irqchip/ioapic.rs205
-rw-r--r--devices/src/irqchip/kvm/x86_64.rs79
-rw-r--r--devices/src/lib.rs10
-rw-r--r--devices/src/pci/ac97.rs37
-rw-r--r--devices/src/pci/ac97_bus_master.rs36
-rw-r--r--devices/src/pci/ac97_mixer.rs2
-rw-r--r--devices/src/pci/ac97_regs.rs12
-rw-r--r--devices/src/pci/msix.rs30
-rw-r--r--devices/src/pci/pci_configuration.rs31
-rw-r--r--devices/src/pci/pci_device.rs1
-rw-r--r--devices/src/pci/pci_root.rs1
-rw-r--r--devices/src/pci/vfio_pci.rs88
-rw-r--r--devices/src/proxy.rs62
-rw-r--r--devices/src/usb/host_backend/error.rs21
-rw-r--r--devices/src/usb/host_backend/host_backend_device_provider.rs93
-rw-r--r--devices/src/usb/host_backend/host_device.rs13
-rw-r--r--devices/src/usb/xhci/device_slot.rs7
-rw-r--r--devices/src/usb/xhci/xhci.rs31
-rw-r--r--devices/src/usb/xhci/xhci_backend_device.rs5
-rw-r--r--devices/src/usb/xhci/xhci_backend_device_provider.rs3
-rw-r--r--devices/src/usb/xhci/xhci_controller.rs1
-rw-r--r--devices/src/vfio.rs65
-rw-r--r--devices/src/virtio/balloon.rs203
-rw-r--r--devices/src/virtio/block.rs47
-rw-r--r--devices/src/virtio/block_async.rs226
-rw-r--r--devices/src/virtio/console.rs13
-rw-r--r--devices/src/virtio/descriptor_utils.rs13
-rw-r--r--devices/src/virtio/fs/caps.rs163
-rw-r--r--devices/src/virtio/fs/mod.rs42
-rw-r--r--devices/src/virtio/fs/passthrough.rs532
-rw-r--r--devices/src/virtio/fs/worker.rs66
-rw-r--r--devices/src/virtio/gpu/mod.rs105
-rw-r--r--devices/src/virtio/gpu/protocol.rs24
-rw-r--r--devices/src/virtio/gpu/udmabuf.rs266
-rw-r--r--devices/src/virtio/gpu/udmabuf_bindings.rs74
-rw-r--r--devices/src/virtio/gpu/virtio_gpu.rs99
-rw-r--r--devices/src/virtio/input/event_source.rs10
-rw-r--r--devices/src/virtio/input/mod.rs16
-rw-r--r--devices/src/virtio/interrupt.rs89
-rw-r--r--devices/src/virtio/mod.rs2
-rw-r--r--devices/src/virtio/net.rs13
-rw-r--r--devices/src/virtio/p9.rs17
-rw-r--r--devices/src/virtio/pmem.rs44
-rw-r--r--devices/src/virtio/queue.rs9
-rw-r--r--devices/src/virtio/resource_bridge.rs40
-rw-r--r--devices/src/virtio/rng.rs12
-rw-r--r--devices/src/virtio/snd/constants.rs86
-rw-r--r--devices/src/virtio/snd/layout.rs73
-rw-r--r--devices/src/virtio/snd/vios_backend/mod.rs2
-rw-r--r--devices/src/virtio/snd/vios_backend/shm_streams.rs129
-rw-r--r--devices/src/virtio/snd/vios_backend/shm_vios.rs422
-rw-r--r--devices/src/virtio/tpm.rs12
-rw-r--r--devices/src/virtio/vhost/control_socket.rs25
-rw-r--r--devices/src/virtio/vhost/mod.rs6
-rw-r--r--devices/src/virtio/vhost/net.rs104
-rw-r--r--devices/src/virtio/vhost/user/block.rs180
-rw-r--r--devices/src/virtio/vhost/user/fs.rs192
-rw-r--r--devices/src/virtio/vhost/user/handler.rs228
-rw-r--r--devices/src/virtio/vhost/user/mod.rs101
-rw-r--r--devices/src/virtio/vhost/user/net.rs185
-rw-r--r--devices/src/virtio/vhost/user/worker.rs82
-rw-r--r--devices/src/virtio/vhost/vsock.rs12
-rw-r--r--devices/src/virtio/vhost/worker.rs36
-rw-r--r--devices/src/virtio/video/decoder/backend/mod.rs1
-rw-r--r--devices/src/virtio/video/decoder/backend/vda.rs2
-rw-r--r--devices/src/virtio/video/decoder/mod.rs38
-rw-r--r--devices/src/virtio/video/device.rs5
-rw-r--r--devices/src/virtio/video/encoder/mod.rs16
-rw-r--r--devices/src/virtio/video/format.rs2
-rw-r--r--devices/src/virtio/video/mod.rs7
-rw-r--r--devices/src/virtio/video/protocol.rs2
-rw-r--r--devices/src/virtio/video/response.rs2
-rw-r--r--devices/src/virtio/video/worker.rs14
-rw-r--r--devices/src/virtio/virtio_pci_device.rs11
-rw-r--r--devices/src/virtio/wl.rs125
-rw-r--r--disk/src/composite.rs45
-rw-r--r--disk/src/disk.rs109
-rw-r--r--disk/src/qcow/mod.rs2
-rw-r--r--docs/architecture.md4
-rw-r--r--fuse/src/server.rs2
-rw-r--r--fuzz/Cargo.toml6
-rw-r--r--gpu_display/build.rs4
-rw-r--r--gpu_display/src/gpu_display_stub.rs14
-rw-r--r--gpu_display/src/gpu_display_wl.rs2
-rw-r--r--gpu_display/src/gpu_display_x.rs16
-rw-r--r--gpu_display/src/lib.rs1
-rw-r--r--hypervisor/Cargo.toml2
-rw-r--r--hypervisor/src/kvm/mod.rs48
-rw-r--r--hypervisor/src/kvm/x86_64.rs5
-rw-r--r--hypervisor/src/lib.rs7
-rw-r--r--hypervisor/src/x86_64.rs8
-rw-r--r--integration_tests/guest_under_test/Dockerfile4
-rw-r--r--integration_tests/guest_under_test/Makefile54
-rw-r--r--integration_tests/guest_under_test/PREBUILT_VERSION2
-rwxr-xr-xintegration_tests/guest_under_test/upload_prebuilts.sh41
-rwxr-xr-xintegration_tests/run11
-rw-r--r--integration_tests/tests/boot.rs15
-rw-r--r--integration_tests/tests/fixture.rs184
-rw-r--r--io_uring/Cargo.toml9
-rw-r--r--io_uring/src/syscalls.rs7
-rw-r--r--io_uring/src/uring.rs12
-rw-r--r--kvm/Cargo.toml1
-rw-r--r--kvm/src/lib.rs21
-rw-r--r--kvm/tests/dirty_log.rs4
-rw-r--r--kvm/tests/read_only_memory.rs8
-rw-r--r--libcrosvm_control/Cargo.toml13
-rw-r--r--libcrosvm_control/src/lib.rs358
-rw-r--r--linux_input_sys/src/lib.rs28
-rw-r--r--msg_socket/Cargo.toml14
-rw-r--r--msg_socket/msg_on_socket_derive/Cargo.toml15
-rw-r--r--msg_socket/msg_on_socket_derive/msg_on_socket_derive.rs911
-rw-r--r--msg_socket/src/lib.rs232
-rw-r--r--msg_socket/src/msg_on_socket.rs438
-rw-r--r--msg_socket/src/msg_on_socket/slice.rs184
-rw-r--r--msg_socket/src/msg_on_socket/tuple.rs205
-rw-r--r--msg_socket/src/serializable_descriptors.rs85
-rw-r--r--msg_socket/tests/enum.rs67
-rw-r--r--msg_socket/tests/struct.rs38
-rw-r--r--msg_socket/tests/tuple.rs22
-rw-r--r--msg_socket/tests/unit.rs12
-rw-r--r--resources/Cargo.toml2
-rw-r--r--resources/src/lib.rs10
-rwxr-xr-xrun_tests21
-rw-r--r--rutabaga_gfx/src/cross_domain/cross_domain.rs11
-rw-r--r--rutabaga_gfx/src/generated/virgl_renderer_bindings.rs2
-rw-r--r--rutabaga_gfx/src/gfxstream.rs3
-rw-r--r--rutabaga_gfx/src/lib.rs1
-rw-r--r--rutabaga_gfx/src/renderer_utils.rs2
-rw-r--r--rutabaga_gfx/src/rutabaga_core.rs29
-rw-r--r--rutabaga_gfx/src/rutabaga_gralloc/gralloc.rs89
-rw-r--r--rutabaga_gfx/src/rutabaga_gralloc/minigbm.rs4
-rw-r--r--rutabaga_gfx/src/rutabaga_gralloc/vulkano_gralloc.rs217
-rw-r--r--rutabaga_gfx/src/rutabaga_utils.rs31
-rw-r--r--rutabaga_gfx/src/virgl_renderer.rs31
-rw-r--r--seccomp/aarch64/9p_device.policy5
-rw-r--r--seccomp/aarch64/balloon_device.policy2
-rw-r--r--seccomp/aarch64/battery.policy1
-rw-r--r--seccomp/aarch64/block_device.policy2
-rw-r--r--seccomp/aarch64/common_device.policy3
-rw-r--r--seccomp/aarch64/cras_audio_device.policy1
-rw-r--r--seccomp/aarch64/fs_device.policy3
-rw-r--r--seccomp/aarch64/gpu_device.policy5
-rw-r--r--seccomp/aarch64/input_device.policy2
-rw-r--r--seccomp/aarch64/net_device.policy2
-rw-r--r--seccomp/aarch64/null_audio_device.policy1
-rw-r--r--seccomp/aarch64/pmem_device.policy1
-rw-r--r--seccomp/aarch64/rng_device.policy1
-rw-r--r--seccomp/aarch64/serial.policy1
-rw-r--r--seccomp/aarch64/tpm_device.policy1
-rw-r--r--seccomp/aarch64/vhost_net_device.policy2
-rw-r--r--seccomp/aarch64/vhost_vsock_device.policy2
-rw-r--r--seccomp/aarch64/vios_audio_device.policy2
-rw-r--r--seccomp/aarch64/wl_device.policy4
-rw-r--r--seccomp/aarch64/xhci.policy2
-rw-r--r--seccomp/arm/9p_device.policy6
-rw-r--r--seccomp/arm/balloon_device.policy2
-rw-r--r--seccomp/arm/battery.policy1
-rw-r--r--seccomp/arm/block_device.policy4
-rw-r--r--seccomp/arm/common_device.policy9
-rw-r--r--seccomp/arm/cras_audio_device.policy2
-rw-r--r--seccomp/arm/fs_device.policy6
-rw-r--r--seccomp/arm/gpu_device.policy10
-rw-r--r--seccomp/arm/input_device.policy2
-rw-r--r--seccomp/arm/net_device.policy1
-rw-r--r--seccomp/arm/null_audio_device.policy2
-rw-r--r--seccomp/arm/pmem_device.policy1
-rw-r--r--seccomp/arm/rng_device.policy1
-rw-r--r--seccomp/arm/serial.policy1
-rw-r--r--seccomp/arm/tpm_device.policy7
-rw-r--r--seccomp/arm/vhost_net_device.policy1
-rw-r--r--seccomp/arm/vhost_vsock_device.policy1
-rw-r--r--seccomp/arm/video_device.policy5
-rw-r--r--seccomp/arm/vios_audio_device.policy3
-rw-r--r--seccomp/arm/wl_device.policy4
-rw-r--r--seccomp/arm/xhci.policy4
-rw-r--r--seccomp/x86_64/9p_device.policy5
-rw-r--r--seccomp/x86_64/balloon_device.policy2
-rw-r--r--seccomp/x86_64/battery.policy2
-rw-r--r--seccomp/x86_64/block_device.policy2
-rw-r--r--seccomp/x86_64/common_device.policy3
-rw-r--r--seccomp/x86_64/cras_audio_device.policy1
-rw-r--r--seccomp/x86_64/fs_device.policy5
-rw-r--r--seccomp/x86_64/gpu_device.policy5
-rw-r--r--seccomp/x86_64/input_device.policy2
-rw-r--r--seccomp/x86_64/net_device.policy1
-rw-r--r--seccomp/x86_64/null_audio_device.policy2
-rw-r--r--seccomp/x86_64/pmem_device.policy1
-rw-r--r--seccomp/x86_64/rng_device.policy1
-rw-r--r--seccomp/x86_64/serial.policy1
-rw-r--r--seccomp/x86_64/tpm_device.policy1
-rw-r--r--seccomp/x86_64/vfio_device.policy1
-rw-r--r--seccomp/x86_64/vhost_net_device.policy1
-rw-r--r--seccomp/x86_64/vhost_vsock_device.policy1
-rw-r--r--seccomp/x86_64/video_device.policy4
-rw-r--r--seccomp/x86_64/vios_audio_device.policy2
-rw-r--r--seccomp/x86_64/wl_device.policy3
-rw-r--r--seccomp/x86_64/xhci.policy3
-rw-r--r--src/argument.rs2
-rw-r--r--src/crosvm.rs49
-rw-r--r--src/gdb.rs23
-rw-r--r--src/linux.rs809
-rw-r--r--src/main.rs610
-rw-r--r--src/plugin/mod.rs13
-rw-r--r--src/plugin/process.rs2
-rw-r--r--sys_util/Cargo.toml3
-rw-r--r--sys_util/src/descriptor.rs57
-rw-r--r--sys_util/src/descriptor_reflection.rs541
-rw-r--r--sys_util/src/errno.rs5
-rw-r--r--sys_util/src/eventfd.rs30
-rw-r--r--sys_util/src/fork.rs3
-rw-r--r--sys_util/src/lib.rs88
-rw-r--r--sys_util/src/linux/syslog.rs2
-rw-r--r--sys_util/src/mmap.rs41
-rw-r--r--sys_util/src/net.rs287
-rw-r--r--sys_util/src/rand.rs114
-rw-r--r--sys_util/src/scoped_path.rs138
-rw-r--r--sys_util/src/scoped_signal_handler.rs421
-rw-r--r--sys_util/src/shm.rs24
-rw-r--r--sys_util/src/signal.rs436
-rw-r--r--sys_util/src/sock_ctrl_msg.rs114
-rw-r--r--sys_util/src/vsock.rs495
-rw-r--r--usb_util/src/device.rs8
-rw-r--r--vfio_sys/src/lib.rs8
-rw-r--r--vfio_sys/src/plat.rs133
-rw-r--r--vhost/src/lib.rs6
-rw-r--r--vhost/src/net.rs15
-rw-r--r--vhost/src/vsock.rs11
-rw-r--r--vm_control/Android.bp43
-rw-r--r--vm_control/Cargo.toml4
-rw-r--r--vm_control/src/client.rs211
-rw-r--r--vm_control/src/lib.rs273
-rw-r--r--vm_memory/Cargo.toml2
-rw-r--r--vm_memory/src/guest_memory.rs125
-rw-r--r--x86_64/Android.bp44
-rw-r--r--x86_64/Cargo.toml6
-rw-r--r--x86_64/src/cpuid.rs8
-rw-r--r--x86_64/src/lib.rs46
-rw-r--r--x86_64/src/smbios.rs156
-rw-r--r--x86_64/src/test_integration.rs55
294 files changed, 14679 insertions, 5797 deletions
diff --git a/.gitignore b/.gitignore
index 1acdbc7a0..bfe747156 100644
--- a/.gitignore
+++ b/.gitignore
@@ -7,3 +7,4 @@ target/
!/Cargo.lock
lcov.info
.idea
+.vscode
diff --git a/Android.bp b/Android.bp
index 3c8a620a8..5ea3df5d5 100644
--- a/Android.bp
+++ b/Android.bp
@@ -85,7 +85,6 @@ rust_binary {
name: "crosvm",
defaults: ["crosvm_defaults"],
host_supported: true,
- prefer_rlib: true,
crate_name: "crosvm",
srcs: ["src/main.rs"],
@@ -94,15 +93,7 @@ rust_binary {
relative_install_path: "aarch64-linux-bionic",
},
linux_glibc_x86_64: {
- features: [
- "gdb",
- "gdbstub",
- ],
relative_install_path: "x86_64-linux-gnu",
- rustlibs: [
- "libgdbstub",
- "libthiserror",
- ],
},
darwin: {
enabled: false,
@@ -183,18 +174,6 @@ rust_defaults {
rustlibs: ["libaarch64"],
},
},
- target: {
- linux_glibc_x86_64: {
- features: [
- "gdb",
- "gdbstub",
- ],
- rustlibs: [
- "libgdbstub",
- "libthiserror",
- ],
- },
- },
features: [
"default",
],
@@ -342,16 +321,6 @@ rust_library {
"gfxstream",
],
},
- linux_glibc_x86_64: {
- features: [
- "gdb",
- "gdbstub",
- ],
- rustlibs: [
- "libgdbstub",
- "libthiserror",
- ],
- },
},
arch: {
x86_64: {
diff --git a/Cargo.lock b/Cargo.lock
index 9d9e76995..87cdedda2 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -53,7 +53,6 @@ dependencies = [
"kernel_cmdline",
"libc",
"minijail",
- "msg_socket",
"power_monitor",
"resources",
"sync",
@@ -93,6 +92,12 @@ dependencies = [
[[package]]
name = "autocfg"
+version = "0.1.7"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "1d49d90015b3c36167a20fe2810c5cd875ad504b39cff3d4eae7977e6b7c1cb2"
+
+[[package]]
+name = "autocfg"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cdb031dd78e28731d87d56cc8ffef4a8f36ca26c38fe2de700543e627f8a464a"
@@ -104,8 +109,12 @@ dependencies = [
"cros_async",
"data_model",
"libc",
+ "serde",
+ "serde_json",
+ "smallvec",
"sync",
"sys_util",
+ "thiserror",
]
[[package]]
@@ -143,6 +152,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4785bdd1c96b2a846b2bd7cc02e86b6b3dbf14e7e53446c4f54c92a361040822"
[[package]]
+name = "cloudabi"
+version = "0.0.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "ddfc5b9aa5d4507acaf872de71051dfd0e309860e88966e1051e462a077aac4f"
+dependencies = [
+ "bitflags",
+]
+
+[[package]]
name = "cras-sys"
version = "0.1.0"
dependencies = [
@@ -158,6 +176,7 @@ dependencies = [
"async-trait",
"data_model",
"futures",
+ "intrusive-collections",
"io_uring",
"libc",
"paste",
@@ -165,11 +184,17 @@ dependencies = [
"slab",
"sync",
"sys_util",
- "syscall_defines",
"thiserror",
]
[[package]]
+name = "cros_fuzz"
+version = "0.1.0"
+dependencies = [
+ "rand_core 0.4.2",
+]
+
+[[package]]
name = "crosvm"
version = "0.1.0"
dependencies = [
@@ -194,7 +219,6 @@ dependencies = [
"libc",
"libcras",
"minijail",
- "msg_socket",
"net_util",
"p9",
"protobuf",
@@ -213,6 +237,24 @@ dependencies = [
]
[[package]]
+name = "crosvm-fuzz"
+version = "0.0.1"
+dependencies = [
+ "base",
+ "cros_fuzz",
+ "data_model",
+ "devices",
+ "disk",
+ "fuse",
+ "kernel_loader",
+ "libc",
+ "rand",
+ "tempfile",
+ "usb_util",
+ "vm_memory",
+]
+
+[[package]]
name = "crosvm_plugin"
version = "0.17.0"
dependencies = [
@@ -261,13 +303,10 @@ dependencies = [
"hypervisor",
"kvm_sys",
"libc",
- "libchromeos",
"libcras",
"libvda",
"linux_input_sys",
"minijail",
- "msg_on_socket_derive",
- "msg_socket",
"net_sys",
"net_util",
"p9",
@@ -277,9 +316,10 @@ dependencies = [
"remain",
"resources",
"rutabaga_gfx",
+ "serde",
+ "smallvec",
"sync",
"sys_util",
- "syscall_defines",
"tempfile",
"thiserror",
"tpm2",
@@ -289,6 +329,7 @@ dependencies = [
"virtio_sys",
"vm_control",
"vm_memory",
+ "vmm_vhost",
]
[[package]]
@@ -324,6 +365,12 @@ dependencies = [
]
[[package]]
+name = "fuchsia-cprng"
+version = "0.1.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "a06f77d526c1a601b7c4cdd98f54b5eaabffc14d5f2f0296febdc7f357c6d3ba"
+
+[[package]]
name = "fuse"
version = "0.1.0"
dependencies = [
@@ -343,7 +390,6 @@ checksum = "b6f16056ecbb57525ff698bb955162d0cd03bee84e6241c27ff75c08d8ca5987"
dependencies = [
"futures-channel",
"futures-core",
- "futures-executor",
"futures-io",
"futures-sink",
"futures-task",
@@ -367,35 +413,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "79564c427afefab1dfb3298535b21eda083ef7935b4f0ecbfcb121f0aec10866"
[[package]]
-name = "futures-executor"
-version = "0.3.1"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "1e274736563f686a837a0568b478bdabfeaec2dca794b5649b04e2fe1627c231"
-dependencies = [
- "futures-core",
- "futures-task",
- "futures-util",
-]
-
-[[package]]
name = "futures-io"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e676577d229e70952ab25f3945795ba5b16d63ca794ca9d2c860e5595d20b5ff"
[[package]]
-name = "futures-macro"
-version = "0.3.1"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "52e7c56c15537adb4f76d0b7a76ad131cb4d2f4f32d3b0bcabcbe1c7c5e87764"
-dependencies = [
- "proc-macro-hack",
- "proc-macro2",
- "quote",
- "syn",
-]
-
-[[package]]
name = "futures-sink"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -416,13 +439,10 @@ dependencies = [
"futures-channel",
"futures-core",
"futures-io",
- "futures-macro",
"futures-sink",
"futures-task",
"memchr",
"pin-utils",
- "proc-macro-hack",
- "proc-macro-nested",
"slab",
]
@@ -471,7 +491,7 @@ dependencies = [
"kvm",
"kvm_sys",
"libc",
- "msg_socket",
+ "serde",
"sync",
"vm_memory",
]
@@ -505,10 +525,15 @@ dependencies = [
"libc",
"sync",
"sys_util",
- "syscall_defines",
]
[[package]]
+name = "itoa"
+version = "0.4.7"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "dd25036021b0de88a0aff6b850051563c6516d0bf53f8638938edbb9de732736"
+
+[[package]]
name = "kernel_cmdline"
version = "0.1.0"
dependencies = [
@@ -533,7 +558,6 @@ dependencies = [
"data_model",
"kvm_sys",
"libc",
- "msg_socket",
"sync",
"vm_memory",
]
@@ -549,9 +573,9 @@ dependencies = [
[[package]]
name = "libc"
-version = "0.2.81"
+version = "0.2.93"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "1482821306169ec4d07f6aca392a4681f66c75c9918aa49641a2595db64053cb"
+checksum = "9385f66bf6105b241aa65a61cb923ef20efc665cb9f9bb50ac2f0c4b7f378d41"
[[package]]
name = "libchromeos"
@@ -559,11 +583,12 @@ version = "0.1.0"
dependencies = [
"data_model",
"futures",
- "intrusive-collections",
"libc",
"log",
"protobuf",
+ "sys_util",
"thiserror",
+ "zeroize",
]
[[package]]
@@ -578,6 +603,15 @@ dependencies = [
]
[[package]]
+name = "libcrosvm_control"
+version = "0.1.0"
+dependencies = [
+ "base",
+ "libc",
+ "vm_control",
+]
+
+[[package]]
name = "libdbus-sys"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -631,7 +665,7 @@ version = "0.5.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "043175f069eda7b85febe4a74abbaeff828d9f8b448515d3151a14a3542811aa"
dependencies = [
- "autocfg",
+ "autocfg 1.0.1",
]
[[package]]
@@ -651,29 +685,6 @@ dependencies = [
]
[[package]]
-name = "msg_on_socket_derive"
-version = "0.1.0"
-dependencies = [
- "base",
- "proc-macro2",
- "quote",
- "syn",
-]
-
-[[package]]
-name = "msg_socket"
-version = "0.1.0"
-dependencies = [
- "base",
- "cros_async",
- "data_model",
- "futures",
- "libc",
- "msg_on_socket_derive",
- "sync",
-]
-
-[[package]]
name = "net_sys"
version = "0.1.0"
dependencies = [
@@ -696,7 +707,7 @@ version = "0.2.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ac267bcc07f48ee5f8935ab0d24f316fb722d7a1292e2913f0cc196b29ffd611"
dependencies = [
- "autocfg",
+ "autocfg 1.0.1",
]
[[package]]
@@ -755,23 +766,6 @@ dependencies = [
]
[[package]]
-name = "proc-macro-hack"
-version = "0.5.11"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "ecd45702f76d6d3c75a80564378ae228a85f0b59d2f3ed43c91b4a69eb2ebfc5"
-dependencies = [
- "proc-macro2",
- "quote",
- "syn",
-]
-
-[[package]]
-name = "proc-macro-nested"
-version = "0.1.3"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "369a6ed065f249a159e06c45752c780bda2fb53c995718f9e484d08daa9eb42e"
-
-[[package]]
name = "proc-macro2"
version = "1.0.24"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -845,10 +839,125 @@ dependencies = [
]
[[package]]
+name = "rand"
+version = "0.6.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "6d71dacdc3c88c1fde3885a3be3fbab9f35724e6ce99467f7d9c5026132184ca"
+dependencies = [
+ "autocfg 0.1.7",
+ "libc",
+ "rand_chacha",
+ "rand_core 0.4.2",
+ "rand_hc",
+ "rand_isaac",
+ "rand_jitter",
+ "rand_os",
+ "rand_pcg",
+ "rand_xorshift",
+ "winapi",
+]
+
+[[package]]
+name = "rand_chacha"
+version = "0.1.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "556d3a1ca6600bfcbab7c7c91ccb085ac7fbbcd70e008a98742e7847f4f7bcef"
+dependencies = [
+ "autocfg 0.1.7",
+ "rand_core 0.3.1",
+]
+
+[[package]]
+name = "rand_core"
+version = "0.3.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "7a6fdeb83b075e8266dcc8762c22776f6877a63111121f5f8c7411e5be7eed4b"
+dependencies = [
+ "rand_core 0.4.2",
+]
+
+[[package]]
+name = "rand_core"
+version = "0.4.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "9c33a3c44ca05fa6f1807d8e6743f3824e8509beca625669633be0acbdf509dc"
+
+[[package]]
+name = "rand_hc"
+version = "0.1.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "7b40677c7be09ae76218dc623efbf7b18e34bced3f38883af07bb75630a21bc4"
+dependencies = [
+ "rand_core 0.3.1",
+]
+
+[[package]]
+name = "rand_isaac"
+version = "0.1.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "ded997c9d5f13925be2a6fd7e66bf1872597f759fd9dd93513dd7e92e5a5ee08"
+dependencies = [
+ "rand_core 0.3.1",
+]
+
+[[package]]
name = "rand_ish"
version = "0.1.0"
[[package]]
+name = "rand_jitter"
+version = "0.1.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "1166d5c91dc97b88d1decc3285bb0a99ed84b05cfd0bc2341bdf2d43fc41e39b"
+dependencies = [
+ "libc",
+ "rand_core 0.4.2",
+ "winapi",
+]
+
+[[package]]
+name = "rand_os"
+version = "0.1.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "7b75f676a1e053fc562eafbb47838d67c84801e38fc1ba459e8f180deabd5071"
+dependencies = [
+ "cloudabi",
+ "fuchsia-cprng",
+ "libc",
+ "rand_core 0.4.2",
+ "rdrand",
+ "winapi",
+]
+
+[[package]]
+name = "rand_pcg"
+version = "0.1.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "abf9b09b01790cfe0364f52bf32995ea3c39f4d2dd011eac241d2914146d0b44"
+dependencies = [
+ "autocfg 0.1.7",
+ "rand_core 0.4.2",
+]
+
+[[package]]
+name = "rand_xorshift"
+version = "0.1.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "cbf7e9e623549b0e21f6e97cf8ecf247c1a8fd2e8a992ae265314300b2455d5c"
+dependencies = [
+ "rand_core 0.3.1",
+]
+
+[[package]]
+name = "rdrand"
+version = "0.4.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "678054eb77286b51581ba43620cc911abf02758c91f93f479767aed0f90458b2"
+dependencies = [
+ "rand_core 0.3.1",
+]
+
+[[package]]
name = "remain"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -865,7 +974,7 @@ version = "0.1.0"
dependencies = [
"base",
"libc",
- "msg_socket",
+ "serde",
]
[[package]]
@@ -879,6 +988,12 @@ dependencies = [
]
[[package]]
+name = "ryu"
+version = "1.0.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "71d301d4193d031abdd79ff7e3dd721168a9572ef3fe51a1517aba235bd8f86e"
+
+[[package]]
name = "serde"
version = "1.0.121"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -899,12 +1014,29 @@ dependencies = [
]
[[package]]
+name = "serde_json"
+version = "1.0.64"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "799e97dc9fdae36a5c8b8f2cae9ce2ee9fdce2058c57a93e6099d919fd982f79"
+dependencies = [
+ "itoa",
+ "ryu",
+ "serde",
+]
+
+[[package]]
name = "slab"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c111b5bd5695e56cffe5129854aa230b39c93a305372fdbb2668ca2394eea9f8"
[[package]]
+name = "smallvec"
+version = "1.6.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "fe0f37c9e8f3c5a4a66ad655a93c74daac4ad00c441533bf5c6e7990bb42604e"
+
+[[package]]
name = "syn"
version = "1.0.58"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -920,6 +1052,18 @@ name = "sync"
version = "0.1.0"
[[package]]
+name = "synstructure"
+version = "0.12.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "b834f2d66f734cb897113e34aaff2f1ab4719ca946f9a7358dba8f8064148701"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "syn",
+ "unicode-xid",
+]
+
+[[package]]
name = "sys_util"
version = "0.1.0"
dependencies = [
@@ -927,16 +1071,13 @@ dependencies = [
"data_model",
"libc",
"poll_token_derive",
+ "serde",
+ "serde_json",
"sync",
- "syscall_defines",
"tempfile",
]
[[package]]
-name = "syscall_defines"
-version = "0.1.0"
-
-[[package]]
name = "tempfile"
version = "3.0.7"
dependencies = [
@@ -1044,9 +1185,9 @@ dependencies = [
"gdbstub",
"hypervisor",
"libc",
- "msg_socket",
"resources",
"rutabaga_gfx",
+ "serde",
"sync",
"vm_memory",
]
@@ -1056,13 +1197,45 @@ name = "vm_memory"
version = "0.1.0"
dependencies = [
"base",
+ "bitflags",
"cros_async",
"data_model",
"libc",
- "syscall_defines",
]
[[package]]
+name = "vmm_vhost"
+version = "0.1.0"
+dependencies = [
+ "bitflags",
+ "libc",
+ "sys_util",
+ "tempfile",
+]
+
+[[package]]
+name = "winapi"
+version = "0.3.9"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419"
+dependencies = [
+ "winapi-i686-pc-windows-gnu",
+ "winapi-x86_64-pc-windows-gnu",
+]
+
+[[package]]
+name = "winapi-i686-pc-windows-gnu"
+version = "0.4.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
+
+[[package]]
+name = "winapi-x86_64-pc-windows-gnu"
+version = "0.4.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
+
+[[package]]
name = "wire_format_derive"
version = "0.1.0"
dependencies = [
@@ -1087,10 +1260,30 @@ dependencies = [
"kernel_loader",
"libc",
"minijail",
- "msg_socket",
"remain",
"resources",
"sync",
"vm_control",
"vm_memory",
]
+
+[[package]]
+name = "zeroize"
+version = "1.2.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "81a974bcdd357f0dca4d41677db03436324d45a4c9ed2d0b873a5a360ce41c36"
+dependencies = [
+ "zeroize_derive",
+]
+
+[[package]]
+name = "zeroize_derive"
+version = "1.0.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "c3f369ddb18862aba61aa49bf31e74d29f0f162dec753063200e1dc084345d16"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "syn",
+ "synstructure",
+]
diff --git a/Cargo.toml b/Cargo.toml
index 5ff22bde9..8c969aca8 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -3,6 +3,7 @@ name = "crosvm"
version = "0.1.0"
authors = ["The Chromium OS Authors"]
edition = "2018"
+default-run = "crosvm"
[lib]
path = "src/crosvm.rs"
@@ -11,14 +12,21 @@ path = "src/crosvm.rs"
name = "crosvm"
path = "src/main.rs"
+[[bin]]
+name = "crosvm-direct"
+path = "src/main.rs"
+required-features = [ "direct" ]
+
[profile.release]
panic = 'abort'
overflow-checks = true
[workspace]
members = [
+ "fuzz",
"qcow_utils",
"integration_tests",
+ "libcrosvm_control",
]
exclude = [
"assertions",
@@ -28,7 +36,6 @@ exclude = [
"rand_ish",
"sync",
"sys_util",
- "syscall_defines",
"tempfile",
"vm_memory",
]
@@ -37,6 +44,7 @@ exclude = [
default = ["audio", "gpu"]
chromeos = ["base/chromeos"]
default-no-sandbox = []
+direct = ["devices/direct"]
audio = ["devices/audio"]
gpu = ["devices/gpu"]
plugin = ["protos/plugin", "crosvm_plugin", "kvm", "kvm_sys", "protobuf"]
@@ -70,10 +78,9 @@ kernel_cmdline = { path = "kernel_cmdline" }
kernel_loader = { path = "kernel_loader" }
kvm = { path = "kvm", optional = true }
kvm_sys = { path = "kvm_sys", optional = true }
-libc = "0.2.65"
+libc = "0.2.93"
libcras = "*"
minijail = "*" # provided by ebuild
-msg_socket = { path = "msg_socket" }
net_util = { path = "net_util" }
p9 = { path = "../vm_tools/p9" }
protobuf = { version = "2.3", optional = true }
@@ -102,13 +109,14 @@ base = "*"
assertions = { path = "assertions" }
audio_streams = { path = "../adhd/audio_streams" } # ignored by ebuild
base = { path = "base" }
+cros_fuzz = { path = "../../platform2/cros-fuzz" } # ignored by ebuild
data_model = { path = "data_model" }
libchromeos = { path = "../libchromeos-rs" } # ignored by ebuild
libcras = { path = "../adhd/cras/client/libcras" } # ignored by ebuild
minijail = { path = "../minijail/rust/minijail" } # ignored by ebuild
p9 = { path = "../vm_tools/p9" } # ignored by ebuild
sync = { path = "sync" }
-syscall_defines = { path = "syscall_defines" }
sys_util = { path = "sys_util" }
tempfile = { path = "tempfile" }
wire_format_derive = { path = "../vm_tools/p9/wire_format_derive" } # ignored by ebuild
+vmm_vhost = { path = "../rust/crates/vhost", features = ["vhost-user-master"] } # ignored by ebuild
diff --git a/OWNERS b/OWNERS
index 11582f140..6a326b06d 100644
--- a/OWNERS
+++ b/OWNERS
@@ -1,8 +1,5 @@
adelva@google.com
chirantan@google.com
dgreid@google.com
-smbarber@chromium.org
-zachr@chromium.org
-
-# So any team members can +2
-*
+dverkamp@google.com
+zachr@google.com
diff --git a/aarch64/src/lib.rs b/aarch64/src/lib.rs
index 4acef3dc2..4a3a35ecf 100644
--- a/aarch64/src/lib.rs
+++ b/aarch64/src/lib.rs
@@ -4,10 +4,8 @@
use std::collections::BTreeMap;
use std::error::Error as StdError;
-use std::ffi::{CStr, CString};
use std::fmt::{self, Display};
-use std::fs::File;
-use std::io::{self, Seek};
+use std::io::{self};
use std::sync::Arc;
use arch::{
@@ -15,13 +13,8 @@ use arch::{
VmComponents, VmImage,
};
use base::Event;
-use devices::{
- Bus, BusError, IrqChip, IrqChipAArch64, PciAddress, PciConfigMmio, PciDevice, PciInterruptPin,
- ProtectionType,
-};
-use hypervisor::{
- DeviceKind, Hypervisor, HypervisorCap, PsciVersion, VcpuAArch64, VcpuFeature, VmAArch64,
-};
+use devices::{Bus, BusError, IrqChip, IrqChipAArch64, PciConfigMmio, PciDevice, ProtectionType};
+use hypervisor::{DeviceKind, Hypervisor, HypervisorCap, VcpuAArch64, VcpuFeature, VmAArch64};
use minijail::Minijail;
use remain::sorted;
use resources::SystemAllocator;
@@ -222,13 +215,19 @@ pub struct AArch64;
impl arch::LinuxArch for AArch64 {
type Error = Error;
- fn build_vm<V, Vcpu, I, FD, FV, FI, E1, E2, E3>(
+ fn guest_memory_layout(
+ components: &VmComponents,
+ ) -> std::result::Result<Vec<(GuestAddress, u64)>, Self::Error> {
+ Ok(arch_memory_regions(components.memory_size))
+ }
+
+ fn build_vm<V, Vcpu, I, FD, FI, E1, E2>(
mut components: VmComponents,
serial_parameters: &BTreeMap<(SerialHardware, u8), SerialParameters>,
serial_jail: Option<Minijail>,
_battery: (&Option<BatteryType>, Option<Minijail>),
+ mut vm: V,
create_devices: FD,
- create_vm: FV,
create_irq_chip: FI,
) -> std::result::Result<RunnableLinuxVm<V, Vcpu, I>, Self::Error>
where
@@ -241,20 +240,17 @@ impl arch::LinuxArch for AArch64 {
&mut SystemAllocator,
&Event,
) -> std::result::Result<Vec<(Box<dyn PciDevice>, Option<Minijail>)>, E1>,
- FV: FnOnce(GuestMemory) -> std::result::Result<V, E2>,
- FI: FnOnce(&V, /* vcpu_count: */ usize) -> std::result::Result<I, E3>,
+ FI: FnOnce(&V, /* vcpu_count: */ usize) -> std::result::Result<I, E2>,
E1: StdError + 'static,
E2: StdError + 'static,
- E3: StdError + 'static,
{
let has_bios = match components.vm_image {
VmImage::Bios(_) => true,
_ => false,
};
+ let mem = vm.get_memory().clone();
let mut resources = Self::get_resource_allocator(components.memory_size);
- let mem = Self::setup_memory(components.memory_size)?;
- let mut vm = create_vm(mem.clone()).map_err(|e| Error::CreateVm(Box::new(e)))?;
if components.protected_vm == ProtectionType::Protected {
vm.enable_protected_vm(
@@ -270,7 +266,7 @@ impl arch::LinuxArch for AArch64 {
let vcpu_count = components.vcpu_count;
let mut vcpus = Vec::with_capacity(vcpu_count);
for vcpu_id in 0..vcpu_count {
- let vcpu = *vm
+ let vcpu: Vcpu = *vm
.create_vcpu(vcpu_id)
.map_err(Error::CreateVcpu)?
.downcast::<Vcpu>()
@@ -429,12 +425,6 @@ impl arch::LinuxArch for AArch64 {
}
impl AArch64 {
- fn setup_memory(mem_size: u64) -> Result<GuestMemory> {
- let arch_mem_regions = arch_memory_regions(mem_size);
- let mem = GuestMemory::new(&arch_mem_regions).map_err(Error::SetupGuestMemory)?;
- Ok(mem)
- }
-
fn get_high_mmio_base_size(mem_size: u64) -> (u64, u64) {
let base = AARCH64_PHYS_MEM_START + mem_size;
let size = u64::max_value() - base;
@@ -507,31 +497,27 @@ impl AArch64 {
}
vcpu.init(&features).map_err(Error::VcpuInit)?;
- // set up registers
- let mut data: u64;
- let mut reg_id: u64;
-
// All interrupts masked
- data = PSR_D_BIT | PSR_A_BIT | PSR_I_BIT | PSR_F_BIT | PSR_MODE_EL1H;
- reg_id = arm64_core_reg!(pstate);
- vcpu.set_one_reg(reg_id, data).map_err(Error::SetReg)?;
+ let pstate = PSR_D_BIT | PSR_A_BIT | PSR_I_BIT | PSR_F_BIT | PSR_MODE_EL1H;
+ vcpu.set_one_reg(arm64_core_reg!(pstate), pstate)
+ .map_err(Error::SetReg)?;
// Other cpus are powered off initially
if vcpu_id == 0 {
- if has_bios {
- data = AARCH64_PHYS_MEM_START + AARCH64_BIOS_OFFSET;
+ let entry_addr = if has_bios {
+ AARCH64_PHYS_MEM_START + AARCH64_BIOS_OFFSET
} else {
- data = AARCH64_PHYS_MEM_START + AARCH64_KERNEL_OFFSET;
- }
- reg_id = arm64_core_reg!(pc);
- vcpu.set_one_reg(reg_id, data).map_err(Error::SetReg)?;
+ AARCH64_PHYS_MEM_START + AARCH64_KERNEL_OFFSET
+ };
+ vcpu.set_one_reg(arm64_core_reg!(pc), entry_addr)
+ .map_err(Error::SetReg)?;
/* X0 -- fdt address */
let mem_size = guest_mem.memory_size();
- data = (AARCH64_PHYS_MEM_START + fdt_offset(mem_size, has_bios)) as u64;
+ let fdt_addr = (AARCH64_PHYS_MEM_START + fdt_offset(mem_size, has_bios)) as u64;
// hack -- can't get this to do offsetof(regs[0]) but luckily it's at offset 0
- reg_id = arm64_core_reg!(regs);
- vcpu.set_one_reg(reg_id, data).map_err(Error::SetReg)?;
+ vcpu.set_one_reg(arm64_core_reg!(regs), fdt_addr)
+ .map_err(Error::SetReg)?;
}
Ok(())
diff --git a/arch/Android.bp b/arch/Android.bp
index 5e5c40ddd..943e0750e 100644
--- a/arch/Android.bp
+++ b/arch/Android.bp
@@ -1,5 +1,4 @@
-// This file is generated by cargo2android.py --run --device --tests --dependencies --global_defaults=crosvm_defaults --add_workspace --features=gdb.
-// NOTE: The --features=gdb should be applied only to the host (not the device) and there are inline changes to achieve this
+// This file is generated by cargo2android.py --run --device --tests --dependencies --global_defaults=crosvm_defaults --add_workspace.
package {
// See: http://go/android-license-faq
@@ -18,17 +17,6 @@ rust_defaults {
test_suites: ["general-tests"],
auto_gen_config: true,
edition: "2018",
- target: {
- linux_glibc_x86_64: {
- features: [
- "gdb",
- "gdbstub",
- ],
- rustlibs: [
- "libgdbstub",
- ],
- },
- },
rustlibs: [
"libacpi_tables",
"libbase_rust",
@@ -70,17 +58,6 @@ rust_library {
crate_name: "arch",
srcs: ["src/lib.rs"],
edition: "2018",
- target: {
- linux_glibc_x86_64: {
- features: [
- "gdb",
- "gdbstub",
- ],
- rustlibs: [
- "libgdbstub",
- ],
- },
- },
rustlibs: [
"libacpi_tables",
"libbase_rust",
@@ -149,11 +126,10 @@ rust_library {
// ../vm_control/src/lib.rs
// ../vm_memory/src/lib.rs
// async-task-4.0.3 "default,std"
-// async-trait-0.1.48
+// async-trait-0.1.45
// autocfg-1.0.1
// base-0.1.0
// bitflags-1.2.1 "default"
-// cfg-if-0.1.10
// cfg-if-1.0.0
// downcast-rs-1.2.0 "default,std"
// futures-0.3.13 "alloc,async-await,default,executor,futures-executor,std"
@@ -165,15 +141,12 @@ rust_library {
// futures-sink-0.3.13 "alloc,std"
// futures-task-0.3.13 "alloc,std"
// futures-util-0.3.13 "alloc,async-await,async-await-macro,channel,futures-channel,futures-io,futures-macro,futures-sink,io,memchr,proc-macro-hack,proc-macro-nested,sink,slab,std"
-// gdbstub-0.4.4 "alloc,default,std"
// getrandom-0.2.2 "std"
// intrusive-collections-0.9.0 "alloc,default"
-// libc-0.2.88 "default,std"
+// libc-0.2.87 "default,std"
// log-0.4.14
-// managed-0.8.0 "alloc"
// memchr-2.3.4 "default,std"
// memoffset-0.5.6 "default"
-// num-traits-0.2.14
// paste-1.0.4
// pin-project-lite-0.2.6
// pin-utils-0.1.0
@@ -189,10 +162,10 @@ rust_library {
// rand_core-0.6.2 "alloc,getrandom,std"
// remain-0.2.2
// remove_dir_all-0.5.3
-// serde-1.0.124 "default,derive,serde_derive,std"
-// serde_derive-1.0.124 "default"
+// serde-1.0.123 "default,derive,serde_derive,std"
+// serde_derive-1.0.123 "default"
// slab-0.4.2
-// syn-1.0.63 "clone-impls,default,derive,full,parsing,printing,proc-macro,quote,visit-mut"
+// syn-1.0.61 "clone-impls,default,derive,full,parsing,printing,proc-macro,quote,visit-mut"
// tempfile-3.2.0
// thiserror-1.0.24
// thiserror-impl-1.0.24
diff --git a/arch/Cargo.toml b/arch/Cargo.toml
index b72d11741..095cc480c 100644
--- a/arch/Cargo.toml
+++ b/arch/Cargo.toml
@@ -16,7 +16,6 @@ hypervisor = { path = "../hypervisor" }
kernel_cmdline = { path = "../kernel_cmdline" }
libc = "*"
minijail = { path = "../../minijail/rust/minijail" } # ignored by ebuild
-msg_socket = { path = "../msg_socket" }
resources = { path = "../resources" }
sync = { path = "../sync" }
base = { path = "../base" }
diff --git a/arch/src/fdt.rs b/arch/src/fdt.rs
index f5c82ddd6..944e5025a 100644
--- a/arch/src/fdt.rs
+++ b/arch/src/fdt.rs
@@ -3,7 +3,7 @@
// found in the LICENSE file.
//! This module writes Flattened Devicetree blobs as defined here:
-//! https://devicetree-specification.readthedocs.io/en/stable/flattened-format.html
+//! <https://devicetree-specification.readthedocs.io/en/stable/flattened-format.html>
use std::collections::BTreeMap;
use std::convert::TryInto;
@@ -48,7 +48,7 @@ const FDT_END_NODE: u32 = 0x00000002;
const FDT_PROP: u32 = 0x00000003;
const FDT_END: u32 = 0x00000009;
-/// Interface for writing a Flattened Devicetree (FDT) and emitting a Devicetree Blob (FDT).
+/// Interface for writing a Flattened Devicetree (FDT) and emitting a Devicetree Blob (DTB).
///
/// # Example
///
diff --git a/arch/src/lib.rs b/arch/src/lib.rs
index c59a2cd13..5070e728c 100644
--- a/arch/src/lib.rs
+++ b/arch/src/lib.rs
@@ -17,7 +17,7 @@ use std::sync::Arc;
use acpi_tables::aml::Aml;
use acpi_tables::sdt::SDT;
-use base::{syslog, AsRawDescriptor, Event};
+use base::{syslog, AsRawDescriptor, Event, Tube};
use devices::virtio::VirtioDevice;
use devices::{
Bus, BusDevice, BusError, IrqChip, PciAddress, PciDevice, PciDeviceError, PciInterruptPin,
@@ -27,11 +27,7 @@ use hypervisor::{IoEventAddress, Vm};
use minijail::Minijail;
use resources::{MmioType, SystemAllocator};
use sync::Mutex;
-#[cfg(all(target_arch = "x86_64", feature = "gdb"))]
-use vm_control::VmControlRequestSocket;
-use vm_control::{
- BatControl, BatControlCommand, BatControlRequestSocket, BatControlResult, BatteryType,
-};
+use vm_control::{BatControl, BatteryType};
use vm_memory::{GuestAddress, GuestMemory, GuestMemoryError};
#[cfg(all(target_arch = "x86_64", feature = "gdb"))]
@@ -83,6 +79,7 @@ pub struct VmComponents {
pub vcpu_count: usize,
pub vcpu_affinity: Option<VcpuAffinity>,
pub no_smt: bool,
+ pub hugepages: bool,
pub vm_image: VmImage,
pub android_fstab: Option<File>,
pub pstore: Option<Pstore>,
@@ -93,7 +90,8 @@ pub struct VmComponents {
pub rt_cpus: Vec<usize>,
pub protected_vm: ProtectionType,
#[cfg(all(target_arch = "x86_64", feature = "gdb"))]
- pub gdb: Option<(u32, VmControlRequestSocket)>, // port and control socket.
+ pub gdb: Option<(u32, Tube)>, // port and control tube.
+ pub dmi_path: Option<PathBuf>,
}
/// Holds the elements needed to run a Linux VM. Created by `build_vm`.
@@ -116,7 +114,7 @@ pub struct RunnableLinuxVm<V: VmArch, Vcpu: VcpuArch, I: IrqChipArch> {
pub rt_cpus: Vec<usize>,
pub bat_control: Option<BatControl>,
#[cfg(all(target_arch = "x86_64", feature = "gdb"))]
- pub gdb: Option<(u32, VmControlRequestSocket)>,
+ pub gdb: Option<(u32, Tube)>,
}
/// The device and optional jail.
@@ -130,6 +128,16 @@ pub struct VirtioDeviceStub {
pub trait LinuxArch {
type Error: StdError;
+ /// Returns a Vec of the valid memory addresses as pairs of address and length. These should be
+ /// used to configure the `GuestMemory` structure for the platform.
+ ///
+ /// # Arguments
+ ///
+ /// * `components` - Parts used to determine the memory layout.
+ fn guest_memory_layout(
+ components: &VmComponents,
+ ) -> std::result::Result<Vec<(GuestAddress, u64)>, Self::Error>;
+
/// Takes `VmComponents` and generates a `RunnableLinuxVm`.
///
/// # Arguments
@@ -138,15 +146,14 @@ pub trait LinuxArch {
/// * `serial_parameters` - definitions for how the serial devices should be configured.
/// * `battery` - defines what battery device will be created.
/// * `create_devices` - Function to generate a list of devices.
- /// * `create_vm` - Function to generate a VM.
/// * `create_irq_chip` - Function to generate an IRQ chip.
- fn build_vm<V, Vcpu, I, FD, FV, FI, E1, E2, E3>(
+ fn build_vm<V, Vcpu, I, FD, FI, E1, E2>(
components: VmComponents,
serial_parameters: &BTreeMap<(SerialHardware, u8), SerialParameters>,
serial_jail: Option<Minijail>,
battery: (&Option<BatteryType>, Option<Minijail>),
+ vm: V,
create_devices: FD,
- create_vm: FV,
create_irq_chip: FI,
) -> std::result::Result<RunnableLinuxVm<V, Vcpu, I>, Self::Error>
where
@@ -159,11 +166,9 @@ pub trait LinuxArch {
&mut SystemAllocator,
&Event,
) -> std::result::Result<Vec<(Box<dyn PciDevice>, Option<Minijail>)>, E1>,
- FV: FnOnce(GuestMemory) -> std::result::Result<V, E2>,
- FI: FnOnce(&V, /* vcpu_count: */ usize) -> std::result::Result<I, E3>,
+ FI: FnOnce(&V, /* vcpu_count: */ usize) -> std::result::Result<I, E2>,
E1: StdError + 'static,
- E2: StdError + 'static,
- E3: StdError + 'static;
+ E2: StdError + 'static;
/// Configures the vcpu and should be called once per vcpu from the vcpu's thread.
///
@@ -240,8 +245,8 @@ pub enum DeviceRegistrationError {
CreatePipe(base::Error),
// Unable to create serial device from serial parameters
CreateSerialDevice(serial::Error),
- // Unable to create socket
- CreateSocket(io::Error),
+ // Unable to create tube
+ CreateTube(base::TubeError),
/// Could not clone an event.
EventClone(base::Error),
/// Could not create an event.
@@ -279,7 +284,7 @@ impl Display for DeviceRegistrationError {
AllocateIrq => write!(f, "Allocating IRQ number"),
CreatePipe(e) => write!(f, "failed to create pipe: {}", e),
CreateSerialDevice(e) => write!(f, "failed to create serial device: {}", e),
- CreateSocket(e) => write!(f, "failed to create socket: {}", e),
+ CreateTube(e) => write!(f, "failed to create tube: {}", e),
Cmdline(e) => write!(f, "unable to add device to kernel command line: {}", e),
EventClone(e) => write!(f, "failed to clone event: {}", e),
EventCreate(e) => write!(f, "failed to create event: {}", e),
@@ -434,7 +439,7 @@ pub fn add_goldfish_battery(
irq_chip: &mut impl IrqChip,
irq_num: u32,
resources: &mut SystemAllocator,
-) -> Result<BatControlRequestSocket, DeviceRegistrationError> {
+) -> Result<Tube, DeviceRegistrationError> {
let alloc = resources.get_anon_alloc();
let mmio_base = resources
.mmio_allocator(MmioType::Low)
@@ -453,9 +458,8 @@ pub fn add_goldfish_battery(
.register_irq_event(irq_num, &irq_evt, Some(&irq_resample_evt))
.map_err(DeviceRegistrationError::RegisterIrqfd)?;
- let (control_socket, response_socket) =
- msg_socket::pair::<BatControlCommand, BatControlResult>()
- .map_err(DeviceRegistrationError::CreateSocket)?;
+ let (control_tube, response_tube) =
+ Tube::pair().map_err(DeviceRegistrationError::CreateTube)?;
#[cfg(feature = "power-monitor-powerd")]
let create_monitor = Some(Box::new(power_monitor::powerd::DBusMonitor::connect)
@@ -469,7 +473,7 @@ pub fn add_goldfish_battery(
irq_num,
irq_evt,
irq_resample_evt,
- response_socket,
+ response_tube,
create_monitor,
)
.map_err(DeviceRegistrationError::RegisterBattery)?;
@@ -501,7 +505,7 @@ pub fn add_goldfish_battery(
}
}
- Ok(control_socket)
+ Ok(control_tube)
}
/// Errors for image loading.
diff --git a/arch/src/pstore.rs b/arch/src/pstore.rs
index 3a12aab8c..9335a1d46 100644
--- a/arch/src/pstore.rs
+++ b/arch/src/pstore.rs
@@ -63,7 +63,7 @@ pub fn create_memory_region(
.map_err(Error::ResourcesError)?;
let memory_mapping = MemoryMappingBuilder::new(pstore.size as usize)
- .from_descriptor(&file)
+ .from_file(&file)
.build()
.map_err(Error::MmapError)?;
diff --git a/base/Cargo.toml b/base/Cargo.toml
index 2520d4256..aecbb3eaa 100644
--- a/base/Cargo.toml
+++ b/base/Cargo.toml
@@ -11,5 +11,9 @@ chromeos = ["sys_util/chromeos"]
cros_async = { path = "../cros_async" }
data_model = { path = "../data_model" }
libc = "*"
+serde = { version = "1", features = [ "derive" ] }
+serde_json = "*"
+smallvec = "1.6.1"
sync = { path = "../sync" }
sys_util = { path = "../sys_util" }
+thiserror = "1.0.20"
diff --git a/base/src/event.rs b/base/src/event.rs
index 4a6c0cdf0..4f34914a8 100644
--- a/base/src/event.rs
+++ b/base/src/event.rs
@@ -8,13 +8,16 @@ use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd};
use std::ptr;
use std::time::Duration;
+use serde::{Deserialize, Serialize};
+
use crate::{AsRawDescriptor, FromRawDescriptor, IntoRawDescriptor, RawDescriptor, Result};
use sys_util::EventFd;
pub use sys_util::EventReadResult;
/// See [EventFd](sys_util::EventFd) for struct- and method-level
/// documentation.
-#[derive(Debug, PartialEq, Eq)]
+#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
+#[serde(transparent)]
pub struct Event(pub EventFd);
impl Event {
pub fn new() -> Result<Event> {
diff --git a/base/src/lib.rs b/base/src/lib.rs
index 2e9b2ebfa..b562d48d3 100644
--- a/base/src/lib.rs
+++ b/base/src/lib.rs
@@ -10,6 +10,7 @@ mod ioctl;
mod mmap;
mod shm;
mod timer;
+mod tube;
mod wait_context;
pub use async_types::*;
@@ -18,7 +19,7 @@ pub use ioctl::{
ioctl, ioctl_with_mut_ptr, ioctl_with_mut_ref, ioctl_with_ptr, ioctl_with_ref, ioctl_with_val,
};
pub use mmap::Unix as MemoryMappingUnix;
-pub use mmap::{MemoryMapping, MemoryMappingBuilder};
+pub use mmap::{MemoryMapping, MemoryMappingBuilder, MemoryMappingBuilderUnix};
pub use shm::{SharedMemory, Unix as SharedMemoryUnix};
pub use sys_util::ioctl::*;
pub use sys_util::sched::*;
@@ -28,6 +29,7 @@ pub use sys_util::{
};
pub use sys_util::{SeekHole, WriteZeroesAt};
pub use timer::{FakeTimer, Timer};
+pub use tube::{AsyncTube, Error as TubeError, Result as TubeResult, Tube};
pub use wait_context::{EventToken, EventType, TriggeredEvent, WaitContext};
/// Wraps an AsRawDescriptor in the simple Descriptor struct, which
diff --git a/base/src/mmap.rs b/base/src/mmap.rs
index 85b6acd1b..d2bd4cb33 100644
--- a/base/src/mmap.rs
+++ b/base/src/mmap.rs
@@ -2,9 +2,10 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
-use crate::{wrap_descriptor, AsRawDescriptor, MappedRegion, MmapError, Protection};
+use crate::{wrap_descriptor, AsRawDescriptor, MappedRegion, MmapError, Protection, SharedMemory};
use data_model::volatile_memory::*;
use data_model::DataInit;
+use std::fs::File;
use sys_util::MemoryMapping as SysUtilMmap;
pub type Result<T> = std::result::Result<T, MmapError>;
@@ -12,26 +13,33 @@ pub type Result<T> = std::result::Result<T, MmapError>;
/// See [MemoryMapping](sys_util::MemoryMapping) for struct- and method-level
/// documentation.
#[derive(Debug)]
-pub struct MemoryMapping(SysUtilMmap);
+pub struct MemoryMapping {
+ mapping: SysUtilMmap,
+}
+
impl MemoryMapping {
pub fn write_slice(&self, buf: &[u8], offset: usize) -> Result<usize> {
- self.0.write_slice(buf, offset)
+ self.mapping.write_slice(buf, offset)
}
pub fn read_slice(&self, buf: &mut [u8], offset: usize) -> Result<usize> {
- self.0.read_slice(buf, offset)
+ self.mapping.read_slice(buf, offset)
}
pub fn write_obj<T: DataInit>(&self, val: T, offset: usize) -> Result<()> {
- self.0.write_obj(val, offset)
+ self.mapping.write_obj(val, offset)
}
pub fn read_obj<T: DataInit>(&self, offset: usize) -> Result<T> {
- self.0.read_obj(offset)
+ self.mapping.read_obj(offset)
}
pub fn msync(&self) -> Result<()> {
- self.0.msync()
+ self.mapping.msync()
+ }
+
+ pub fn use_hugepages(&self) -> Result<()> {
+ self.mapping.use_hugepages()
}
pub fn read_to_memory(
@@ -40,7 +48,7 @@ impl MemoryMapping {
src: &dyn AsRawDescriptor,
count: usize,
) -> Result<()> {
- self.0
+ self.mapping
.read_to_memory(mem_offset, &wrap_descriptor(src), count)
}
@@ -50,7 +58,7 @@ impl MemoryMapping {
dst: &dyn AsRawDescriptor,
count: usize,
) -> Result<()> {
- self.0
+ self.mapping
.write_from_memory(mem_offset, &wrap_descriptor(dst), count)
}
}
@@ -61,10 +69,14 @@ pub trait Unix {
impl Unix for MemoryMapping {
fn remove_range(&self, mem_offset: usize, count: usize) -> Result<()> {
- self.0.remove_range(mem_offset, count)
+ self.mapping.remove_range(mem_offset, count)
}
}
+pub trait MemoryMappingBuilderUnix<'a> {
+ fn from_descriptor(self, descriptor: &'a dyn AsRawDescriptor) -> MemoryMappingBuilder;
+}
+
pub struct MemoryMappingBuilder<'a> {
descriptor: Option<&'a dyn AsRawDescriptor>,
size: usize,
@@ -73,6 +85,16 @@ pub struct MemoryMappingBuilder<'a> {
populate: bool,
}
+impl<'a> MemoryMappingBuilderUnix<'a> for MemoryMappingBuilder<'a> {
+ /// Build the memory mapping given the specified descriptor to mapped memory
+ ///
+ /// Default: Create a new memory mapping.
+ fn from_descriptor(mut self, descriptor: &'a dyn AsRawDescriptor) -> MemoryMappingBuilder {
+ self.descriptor = Some(descriptor);
+ self
+ }
+}
+
/// Builds a MemoryMapping object from the specified arguments.
impl<'a> MemoryMappingBuilder<'a> {
/// Creates a new builder specifying size of the memory region in bytes.
@@ -86,11 +108,23 @@ impl<'a> MemoryMappingBuilder<'a> {
}
}
- /// Build the memory mapping given the specified descriptor to mapped memory
+ /// Build the memory mapping given the specified File to mapped memory
///
/// Default: Create a new memory mapping.
- pub fn from_descriptor(mut self, descriptor: &'a dyn AsRawDescriptor) -> MemoryMappingBuilder {
- self.descriptor = Some(descriptor);
+ ///
+ /// Note: this is a forward looking interface to accomodate platforms that
+ /// require special handling for file backed mappings.
+ #[allow(unused_mut)]
+ pub fn from_file(mut self, file: &'a File) -> MemoryMappingBuilder {
+ self.descriptor = Some(file as &dyn AsRawDescriptor);
+ self
+ }
+
+ /// Build the memory mapping given the specified SharedMemory to mapped memory
+ ///
+ /// Default: Create a new memory mapping.
+ pub fn from_shared_memory(mut self, shm: &'a SharedMemory) -> MemoryMappingBuilder {
+ self.descriptor = Some(shm as &dyn AsRawDescriptor);
self
}
@@ -176,23 +210,23 @@ impl<'a> MemoryMappingBuilder<'a> {
}
fn wrap(result: Result<SysUtilMmap>) -> Result<MemoryMapping> {
- result.map(MemoryMapping)
+ result.map(|mapping| MemoryMapping { mapping })
}
}
impl VolatileMemory for MemoryMapping {
fn get_slice(&self, offset: usize, count: usize) -> VolatileMemoryResult<VolatileSlice> {
- self.0.get_slice(offset, count)
+ self.mapping.get_slice(offset, count)
}
}
// Safe because it exclusively forwards calls to a safe implementation.
unsafe impl MappedRegion for MemoryMapping {
fn as_ptr(&self) -> *mut u8 {
- self.0.as_ptr()
+ self.mapping.as_ptr()
}
fn size(&self) -> usize {
- self.0.size()
+ self.mapping.size()
}
}
diff --git a/base/src/shm.rs b/base/src/shm.rs
index 36c6f7e1a..a445f3d62 100644
--- a/base/src/shm.rs
+++ b/base/src/shm.rs
@@ -8,11 +8,15 @@ use crate::{
};
use std::ffi::CStr;
use std::fs::File;
-use std::os::unix::io::AsRawFd;
+use std::os::unix::io::{AsRawFd, IntoRawFd};
+
+use serde::{Deserialize, Serialize};
use sys_util::SharedMemory as SysUtilSharedMemory;
/// See [SharedMemory](sys_util::SharedMemory) for struct- and method-level
/// documentation.
+#[derive(Serialize, Deserialize)]
+#[serde(transparent)]
pub struct SharedMemory(SysUtilSharedMemory);
impl SharedMemory {
pub fn named<T: Into<Vec<u8>>>(name: T, size: u64) -> Result<SharedMemory> {
@@ -77,8 +81,15 @@ impl AsRawDescriptor for SharedMemory {
}
}
-impl Into<File> for SharedMemory {
- fn into(self) -> File {
- self.0.into()
+impl IntoRawDescriptor for SharedMemory {
+ fn into_raw_descriptor(self) -> RawDescriptor {
+ self.0.into_raw_fd()
+ }
+}
+
+impl Into<SafeDescriptor> for SharedMemory {
+ fn into(self) -> SafeDescriptor {
+ // Safe because we own the SharedMemory at this point.
+ unsafe { SafeDescriptor::from_raw_descriptor(self.into_raw_descriptor()) }
}
}
diff --git a/base/src/tube.rs b/base/src/tube.rs
new file mode 100644
index 000000000..21917c931
--- /dev/null
+++ b/base/src/tube.rs
@@ -0,0 +1,239 @@
+// Copyright 2021 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+use std::io::{self, IoSlice};
+use std::marker::PhantomData;
+use std::ops::Deref;
+use std::os::unix::prelude::{AsRawFd, RawFd};
+use std::time::Duration;
+
+use crate::{net::UnixSeqpacket, FromRawDescriptor, SafeDescriptor, ScmSocket, UnsyncMarker};
+
+use cros_async::{Executor, IntoAsync, IoSourceExt};
+use serde::{de::DeserializeOwned, Serialize};
+use sys_util::{
+ deserialize_with_descriptors, AsRawDescriptor, RawDescriptor, SerializeDescriptors,
+};
+use thiserror::Error as ThisError;
+
+#[derive(ThisError, Debug)]
+pub enum Error {
+ #[error("failed to serialize/deserialize json from packet: {0}")]
+ Json(serde_json::Error),
+ #[error("failed to send packet: {0}")]
+ Send(sys_util::Error),
+ #[error("failed to receive packet: {0}")]
+ Recv(io::Error),
+ #[error("tube was disconnected")]
+ Disconnected,
+ #[error("failed to crate tube pair: {0}")]
+ Pair(io::Error),
+ #[error("failed to set send timeout: {0}")]
+ SetSendTimeout(io::Error),
+ #[error("failed to set recv timeout: {0}")]
+ SetRecvTimeout(io::Error),
+ #[error("failed to create async tube: {0}")]
+ CreateAsync(cros_async::AsyncError),
+}
+
+pub type Result<T> = std::result::Result<T, Error>;
+
+/// Bidirectional tube that support both send and recv.
+pub struct Tube {
+ socket: UnixSeqpacket,
+ _unsync_marker: UnsyncMarker,
+}
+
+impl Tube {
+ /// Create a pair of connected tubes. Request is send in one direction while response is in the
+ /// other direction.
+ pub fn pair() -> Result<(Tube, Tube)> {
+ let (socket1, socket2) = UnixSeqpacket::pair().map_err(Error::Pair)?;
+ let tube1 = Tube::new(socket1);
+ let tube2 = Tube::new(socket2);
+ Ok((tube1, tube2))
+ }
+
+ // Create a new `Tube`.
+ pub fn new(socket: UnixSeqpacket) -> Tube {
+ Tube {
+ socket,
+ _unsync_marker: PhantomData,
+ }
+ }
+
+ pub fn into_async_tube(self, ex: &Executor) -> Result<AsyncTube> {
+ let inner = ex.async_from(self).map_err(Error::CreateAsync)?;
+ Ok(AsyncTube { inner })
+ }
+
+ pub fn send<T: Serialize>(&self, msg: &T) -> Result<()> {
+ let msg_serialize = SerializeDescriptors::new(&msg);
+ let msg_json = serde_json::to_vec(&msg_serialize).map_err(Error::Json)?;
+ let msg_descriptors = msg_serialize.into_descriptors();
+
+ self.socket
+ .send_with_fds(&[IoSlice::new(&msg_json)], &msg_descriptors)
+ .map_err(Error::Send)?;
+ Ok(())
+ }
+
+ pub fn recv<T: DeserializeOwned>(&self) -> Result<T> {
+ let (msg_json, msg_descriptors) =
+ self.socket.recv_as_vec_with_fds().map_err(Error::Recv)?;
+
+ if msg_json.is_empty() {
+ return Err(Error::Disconnected);
+ }
+
+ let mut msg_descriptors_safe = msg_descriptors
+ .into_iter()
+ .map(|v| {
+ Some(unsafe {
+ // Safe because the socket returns new fds that are owned locally by this scope.
+ SafeDescriptor::from_raw_descriptor(v)
+ })
+ })
+ .collect();
+
+ deserialize_with_descriptors(
+ || serde_json::from_slice(&msg_json),
+ &mut msg_descriptors_safe,
+ )
+ .map_err(Error::Json)
+ }
+
+ /// Returns true if there is a packet ready to `recv` without blocking.
+ ///
+ /// If there is an error trying to determine if there is a packet ready, this returns false.
+ pub fn is_packet_ready(&self) -> bool {
+ self.socket.get_readable_bytes().unwrap_or(0) > 0
+ }
+
+ pub fn set_send_timeout(&self, timeout: Option<Duration>) -> Result<()> {
+ self.socket
+ .set_write_timeout(timeout)
+ .map_err(Error::SetSendTimeout)
+ }
+
+ pub fn set_recv_timeout(&self, timeout: Option<Duration>) -> Result<()> {
+ self.socket
+ .set_read_timeout(timeout)
+ .map_err(Error::SetRecvTimeout)
+ }
+}
+
+impl AsRawDescriptor for Tube {
+ fn as_raw_descriptor(&self) -> RawDescriptor {
+ self.socket.as_raw_descriptor()
+ }
+}
+
+impl AsRawFd for Tube {
+ fn as_raw_fd(&self) -> RawFd {
+ self.socket.as_raw_fd()
+ }
+}
+
+impl IntoAsync for Tube {}
+
+pub struct AsyncTube {
+ inner: Box<dyn IoSourceExt<Tube>>,
+}
+
+impl AsyncTube {
+ pub async fn next<T: DeserializeOwned>(&self) -> Result<T> {
+ self.inner.wait_readable().await.unwrap();
+ self.inner.as_source().recv()
+ }
+}
+
+impl Deref for AsyncTube {
+ type Target = Tube;
+
+ fn deref(&self) -> &Self::Target {
+ self.inner.as_source()
+ }
+}
+
+impl Into<Tube> for AsyncTube {
+ fn into(self) -> Tube {
+ self.inner.into_source()
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::Event;
+
+ use std::collections::HashMap;
+ use std::time::Duration;
+
+ use serde::{Deserialize, Serialize};
+
+ #[track_caller]
+ fn test_event_pair(send: Event, mut recv: Event) {
+ send.write(1).unwrap();
+ recv.read_timeout(Duration::from_secs(1)).unwrap();
+ }
+
+ #[test]
+ fn send_recv_no_fd() {
+ let (s1, s2) = Tube::pair().unwrap();
+
+ let test_msg = "hello world";
+ s1.send(&test_msg).unwrap();
+ let recv_msg: String = s2.recv().unwrap();
+
+ assert_eq!(test_msg, recv_msg);
+ }
+
+ #[test]
+ fn send_recv_one_fd() {
+ #[derive(Serialize, Deserialize)]
+ struct EventStruct {
+ x: u32,
+ b: Event,
+ }
+
+ let (s1, s2) = Tube::pair().unwrap();
+
+ let test_msg = EventStruct {
+ x: 100,
+ b: Event::new().unwrap(),
+ };
+ s1.send(&test_msg).unwrap();
+ let recv_msg: EventStruct = s2.recv().unwrap();
+
+ assert_eq!(test_msg.x, recv_msg.x);
+
+ test_event_pair(test_msg.b, recv_msg.b);
+ }
+
+ #[test]
+ fn send_recv_hash_map() {
+ let (s1, s2) = Tube::pair().unwrap();
+
+ let mut test_msg = HashMap::new();
+ test_msg.insert("Red".to_owned(), Event::new().unwrap());
+ test_msg.insert("White".to_owned(), Event::new().unwrap());
+ test_msg.insert("Blue".to_owned(), Event::new().unwrap());
+ test_msg.insert("Orange".to_owned(), Event::new().unwrap());
+ test_msg.insert("Green".to_owned(), Event::new().unwrap());
+ s1.send(&test_msg).unwrap();
+ let mut recv_msg: HashMap<String, Event> = s2.recv().unwrap();
+
+ let mut test_msg_keys: Vec<_> = test_msg.keys().collect();
+ test_msg_keys.sort();
+ let mut recv_msg_keys: Vec<_> = recv_msg.keys().collect();
+ recv_msg_keys.sort();
+ assert_eq!(test_msg_keys, recv_msg_keys);
+
+ for (key, test_event) in test_msg {
+ let recv_event = recv_msg.remove(&key).unwrap();
+ test_event_pair(test_event, recv_event);
+ }
+ }
+}
diff --git a/base/src/wait_context.rs b/base/src/wait_context.rs
index 4e492b7bf..f3afb6ae4 100644
--- a/base/src/wait_context.rs
+++ b/base/src/wait_context.rs
@@ -6,6 +6,7 @@ use std::os::unix::io::AsRawFd;
use std::time::Duration;
use crate::{wrap_descriptor, AsRawDescriptor, RawDescriptor, Result};
+use smallvec::SmallVec;
use sys_util::{PollContext, PollToken, WatchingEvents};
// Typedef PollToken as EventToken for better adherance to base naming.
@@ -124,23 +125,22 @@ impl<T: EventToken> WaitContext<T> {
}
/// Waits for one or more of the registered triggers to become signaled.
- pub fn wait(&self) -> Result<Vec<TriggeredEvent<T>>> {
+ pub fn wait(&self) -> Result<SmallVec<[TriggeredEvent<T>; 16]>> {
self.wait_timeout(Duration::new(i64::MAX as u64, 0))
}
/// Waits for one or more of the registered triggers to become signaled, failing if no triggers
/// are signaled before the designated timeout has elapsed.
- pub fn wait_timeout(&self, timeout: Duration) -> Result<Vec<TriggeredEvent<T>>> {
+ pub fn wait_timeout(&self, timeout: Duration) -> Result<SmallVec<[TriggeredEvent<T>; 16]>> {
let events = self.0.wait_timeout(timeout)?;
- let mut return_vec: Vec<TriggeredEvent<T>> = vec![];
- for event in events.iter() {
- return_vec.push(TriggeredEvent {
+ Ok(events
+ .iter()
+ .map(|event| TriggeredEvent {
token: event.token(),
is_readable: event.readable(),
is_writable: event.writable(),
is_hungup: event.hungup(),
- });
- }
- Ok(return_vec)
+ })
+ .collect())
}
}
diff --git a/bin/clippy b/bin/clippy
index a9b70f675..0e930793c 100755
--- a/bin/clippy
+++ b/bin/clippy
@@ -54,6 +54,10 @@ SUPPRESS=(
useless_format
wrong_self_convention
+ # False positives affecting WlVfd @ `devices/src/virtio/wl.rs`.
+ # Bug: https://github.com/rust-lang/rust-clippy/issues/6312
+ field_reassign_with_default
+
# We don't care about these lints. Okay to remain suppressed globally.
blacklisted_name
cast_lossless
diff --git a/bit_field/bit_field_derive/bit_field_derive.rs b/bit_field/bit_field_derive/bit_field_derive.rs
index 92fea946b..df3e2825e 100644
--- a/bit_field/bit_field_derive/bit_field_derive.rs
+++ b/bit_field/bit_field_derive/bit_field_derive.rs
@@ -473,7 +473,7 @@ fn get_fields_impl(fields: &[FieldSpec]) -> Vec<TokenStream> {
let span = expected_bits.span();
quote_spanned! {span=>
#[allow(dead_code)]
- const EXPECTED_BITS: [(); #expected_bits as usize] =
+ const EXPECTED_BITS: [(); #expected_bits] =
[(); <#ty as ::bit_field::BitFieldSpecifier>::FIELD_WIDTH as usize];
}
});
diff --git a/ci/crosvm_aarch64_builder/entrypoint b/ci/crosvm_aarch64_builder/entrypoint
index 21d3087e0..74ff7e510 100755
--- a/ci/crosvm_aarch64_builder/entrypoint
+++ b/ci/crosvm_aarch64_builder/entrypoint
@@ -27,8 +27,13 @@ if [ "$1" = "--vm" ]; then
echo "Starting testing vm..."
(cd /workspace/vm && screen -Sdm vm ./start_vm)
export CARGO_TARGET_AARCH64_UNKNOWN_LINUX_GNU_RUNNER="\
- /workspace/vm/exec_file"
- test_target="Virtual Machine (See 'screen -r vm')"
+ /workspace/src/platform/crosvm/ci/vm_tools/exec_binary_in_vm"
+
+ if [[ $# -eq 0 ]]; then
+ test_target="Virtual Machine (See 'screen -r vm' or 'ssh vm')"
+ else
+ test_target="Virtual Machine"
+ fi
export CROSVM_USE_VM=1
else
test_target="User-space emulation"
diff --git a/ci/crosvm_base/rust-toolchain b/ci/crosvm_base/rust-toolchain
index 9db5ea12f..5a5c7211d 100644
--- a/ci/crosvm_base/rust-toolchain
+++ b/ci/crosvm_base/rust-toolchain
@@ -1 +1 @@
-1.48.0
+1.50.0
diff --git a/ci/crosvm_builder/entrypoint b/ci/crosvm_builder/entrypoint
index 6b51b2a2e..ec4c0765b 100755
--- a/ci/crosvm_builder/entrypoint
+++ b/ci/crosvm_builder/entrypoint
@@ -27,8 +27,13 @@ if [ "$1" = "--vm" ]; then
echo "Starting testing vm..."
(cd /workspace/vm && screen -Sdm vm ./start_vm)
export CARGO_TARGET_X86_64_UNKNOWN_LINUX_GNU_RUNNER="\
- /workspace/vm/exec_file"
- test_target="Virtual Machine (See 'screen -r vm')"
+ /workspace/src/platform/crosvm/ci/vm_tools/exec_binary_in_vm"
+
+ if [[ $# -eq 0 ]]; then
+ test_target="Virtual Machine (See 'screen -r vm' or 'ssh vm')"
+ else
+ test_target="Virtual Machine"
+ fi
export CROSVM_USE_VM=1
else
test_target="Native execution"
diff --git a/ci/crosvm_test_vm/Dockerfile b/ci/crosvm_test_vm/Dockerfile
index 997405214..aa19f1748 100644
--- a/ci/crosvm_test_vm/Dockerfile
+++ b/ci/crosvm_test_vm/Dockerfile
@@ -24,7 +24,7 @@ RUN apt-get update && apt-get install --yes \
WORKDIR /workspace/vm
RUN curl -sSfL -o rootfs.qcow2 \
- "https://cdimage.debian.org/cdimage/cloud/buster/20201214-484/debian-10-generic-${VM_ARCH}-20201214-484.qcow2"
+ "http://cloud.debian.org/images/cloud/bullseye/daily/20210208-542/debian-11-generic-${VM_ARCH}-daily-20210208-542.qcow2"
# Package `cloud_init_data.yaml` to be loaded during `first_boot.expect`
COPY build/cloud_init_data.yaml ./
@@ -73,11 +73,6 @@ RUN chmod 0600 /root/.ssh/id_rsa
# Copy utility scripts
COPY runtime/start_vm.${VM_ARCH} ./start_vm
-COPY runtime/exec \
- runtime/exec_file \
- runtime/sync_so \
- runtime/wait_for_vm \
- ./
# Automatically start the VM.
ENTRYPOINT [ "/workspace/vm/start_vm" ]
diff --git a/ci/crosvm_test_vm/build/cloud_init_data.yaml b/ci/crosvm_test_vm/build/cloud_init_data.yaml
index 5544d5f2e..a9d4dc953 100644
--- a/ci/crosvm_test_vm/build/cloud_init_data.yaml
+++ b/ci/crosvm_test_vm/build/cloud_init_data.yaml
@@ -3,6 +3,7 @@ users:
- name: crosvm
sudo: ALL=(ALL) NOPASSWD:ALL
lock_passwd: False
+ shell: /bin/bash
# Hashed password is 'crosvm'
passwd: $6$rounds=4096$os6Q9Ok4Y9a8hKvG$EwQ1bbS0qd4IJyRP.bnRbyjPbSS8BwxEJh18PfhsyD0w7a4GhTwakrmYZ6KuBoyP.cSjYYSW9wYwko4oCPoJr.
# Pubkey for `../vm_key`
@@ -12,11 +13,7 @@ users:
crosvm@localhost
groups: kvm, disk, tty
-apt:
- sources:
- testing:
- source: "deb $MIRROR bullseye main"
- conf: APT::Default-Release "stable";
+hostname: testvm
# Store working data on tmpfs to reduce unnecessary disk IO
mounts:
@@ -37,17 +34,14 @@ packages:
- rsync
runcmd:
- # Install testing (debian bullseye) versions of some libraries.
- - [
- apt-get,
- install,
- --yes,
- -t,
- testing,
- --no-install-recommends,
- libdrm2,
- libepoxy0,
- ]
+ # Prevent those annoying "host not found errors".
+ - echo 127.0.0.1 testvm >> /etc/hosts
+
+ # Make it easier to identify which VM we are in.
+ - echo "export PS1=\"testvm-$(arch):\\\\w# \"" >> /etc/bash.bashrc
+
+ # Enable core dumps for debugging crashes
+ - echo "* soft core unlimited" > /etc/security/limits.conf
# Trim some fat
- [apt-get, remove, --yes, vim-runtime, iso-codes, perl, grub-common]
diff --git a/ci/crosvm_test_vm/runtime/sync_so b/ci/crosvm_test_vm/runtime/sync_so
deleted file mode 100755
index 7214945bc..000000000
--- a/ci/crosvm_test_vm/runtime/sync_so
+++ /dev/null
@@ -1,20 +0,0 @@
-#!/bin/bash
-# Copyright 2021 The Chromium OS Authors. All rights reserved.
-# Use of this source code is governed by a BSD-style license that can be
-# found in the LICENSE file.
-#
-# Synchronizes shared objects into the virtual machine to allow crosvm binaries
-# to run.
-
-${0%/*}/exec exit || exit 1 # Wait for VM to be available
-
-rust_toolchain=$(cat /workspace/src/platform/crosvm/rust-toolchain)
-
-# List of shared objects used by crosvm that need to be synced
-shared_objects=(
- /workspace/scratch/lib/*.so*
- /root/.rustup/toolchains/${rust_toolchain}-*/lib/libstd-*.so
- /root/.rustup/toolchains/${rust_toolchain}-*/lib/libtest-*.so
-)
-
-rsync -azPL --rsync-path="sudo rsync" ${shared_objects[@]} vm:/usr/lib
diff --git a/ci/image_tag b/ci/image_tag
index 75d30fb53..b7be5a127 100644
--- a/ci/image_tag
+++ b/ci/image_tag
@@ -1 +1 @@
-r0001
+r0004
diff --git a/ci/kokoro/common.sh b/ci/kokoro/common.sh
index 54aba174f..c1e1e8b3b 100755
--- a/ci/kokoro/common.sh
+++ b/ci/kokoro/common.sh
@@ -5,6 +5,21 @@
crosvm_root="${KOKORO_ARTIFACTS_DIR}"/git/crosvm
+# Enable SSH access to the kokoro builder.
+# Use the fusion2/ UI to trigger a build and set the DEBUG_SSH_KEY environment
+# variable to your public key, that will allow you to connect to the builder
+# via SSH.
+# Note: Access is restricted to the google corporate network.
+# Details: https://yaqs.corp.google.com/eng/q/6628551334035456
+if [[ ! -z "${DEBUG_SSH_KEY}" ]]; then
+ echo "${DEBUG_SSH_KEY}" >>~/.ssh/authorized_keys
+ external_ip=$(
+ curl -s -H "Metadata-Flavor: Google"
+ http://metadata/computeMetadata/v1/instance/network-interfaces/0/access-configs/0/external-ip
+ )
+ echo "SSH Debug enabled. Connect to: kbuilder@${external_ip}"
+fi
+
setup_source() {
if [ -z "${KOKORO_ARTIFACTS_DIR}" ]; then
echo "This script must be run in kokoro"
@@ -30,7 +45,7 @@ setup_source() {
-u https://chromium.googlesource.com/chromiumos/manifest.git \
--repo-url https://chromium.googlesource.com/external/repo.git \
-g crosvm || return 1
- ./repo sync -j8 -c -m "${crosvm_root}/ci/kokoro/manifest.xml" || return 1
+ ./repo sync -j8 -c || return 1
# Bind mount source into cros checkout.
echo ""
@@ -45,6 +60,12 @@ setup_source() {
}
cleanup() {
+ # Sleep after the build to allow for SSH debugging to continue.
+ if [[ ! -z "${DEBUG_SSH_KEY}" ]]; then
+ echo "Build done. Blocking for SSH debugging."
+ sleep 1h
+ fi
+
if command -v bindfs >/dev/null; then
fusermount -uz "${KOKORO_ARTIFACTS_DIR}/cros/src/platform/crosvm"
else
diff --git a/ci/kokoro/manifest.xml b/ci/kokoro/manifest.xml
deleted file mode 100644
index 1a337ad1f..000000000
--- a/ci/kokoro/manifest.xml
+++ /dev/null
@@ -1,24 +0,0 @@
-<?xml version="1.0" encoding="UTF-8"?>
-<manifest>
- <remote fetch="https://android.googlesource.com" name="aosp" review="https://android-review.googlesource.com"/>
- <remote alias="cros" fetch="https://chromium.googlesource.com/" name="chromium"/>
- <remote fetch="https://chromium.googlesource.com" name="cros" review="https://chromium-review.googlesource.com"/>
-
- <default remote="cros" revision="refs/heads/main" sync-j="8"/>
-
- <project dest-branch="refs/heads/master" groups="minilayout,firmware,buildtools,chromeos-admin,labtools,sysmon,devserver,crosvm" name="chromiumos/chromite" path="chromite" revision="2c0017fef941137472b9b58f8acbb41780c0f14f" upstream="refs/heads/master">
- <copyfile dest="AUTHORS" src="AUTHORS"/>
- <copyfile dest="LICENSE" src="LICENSE"/>
- </project>
- <project dest-branch="refs/heads/main" groups="crosvm" name="chromiumos/docs" path="docs" revision="73161502e320ebea00bf620a54d3a606ca1b9836" upstream="refs/heads/main"/>
- <project dest-branch="refs/heads/main" groups="crosvm" name="chromiumos/platform/crosvm" path="src/platform/crosvm" revision="f4d1cdaaeb7b1f8e747e1af4c91e81a3255defe9" upstream="refs/heads/main"/>
- <project dest-branch="refs/heads/main" groups="crosvm" name="chromiumos/platform/minigbm" path="src/platform/minigbm" revision="6e27708ff2093a19c01d51cef61507bf8a804bf9" upstream="refs/heads/main"/>
- <project dest-branch="refs/heads/main" groups="crosvm" name="chromiumos/platform2" path="src/platform2" revision="c2c0bbfe867d53bc8f7aca0edb2ba65ec5849de1" upstream="refs/heads/main"/>
- <project dest-branch="refs/heads/main" groups="minilayout,firmware,buildtools,labtools,crosvm" name="chromiumos/repohooks" path="src/repohooks" revision="d9ed85f771d9aaf012f850d0b15e69af029023b1" upstream="refs/heads/main"/>
- <project dest-branch="refs/heads/main" groups="crosvm" name="chromiumos/third_party/adhd" path="src/third_party/adhd" revision="4cde9efd2bfde57def27da9ee864cce6dc430046" upstream="refs/heads/main"/>
- <project dest-branch="refs/heads/main" groups="firmware,crosvm" name="chromiumos/third_party/tpm2" path="src/third_party/tpm2" revision="86e93379322f012d354b9b8a369373ed9b62718c" upstream="refs/heads/main"/>
- <project dest-branch="refs/heads/master" groups="crosvm" name="chromiumos/third_party/virglrenderer" path="src/third_party/virglrenderer" revision="51f45f343b77c01897ddddc6bce84117a6278793" upstream="refs/heads/master"/>
- <project dest-branch="refs/heads/master" groups="crosvm" name="platform/external/minijail" path="src/aosp/external/minijail" remote="aosp" revision="e119bbb81cb42aaddef61882b3747cf7995465f7" upstream="refs/heads/master"/>
-
- <repo-hooks enabled-list="pre-upload" in-project="chromiumos/repohooks"/>
-</manifest>
diff --git a/ci/kokoro/uprev b/ci/kokoro/uprev
deleted file mode 100755
index 515f84f5a..000000000
--- a/ci/kokoro/uprev
+++ /dev/null
@@ -1,30 +0,0 @@
-#!/bin/bash
-# Copyright 2021 The Chromium OS Authors. All rights reserved.
-# Use of this source code is governed by a BSD-style license that can be
-# found in the LICENSE file.
-#
-# Uprevs manifest.xml to the latest versions.
-#
-# This is just a wrapper around `repo manifest`. Usually we have unsubmitted
-# CLs in our local repo, so it's safer to pull the latest manifest revisions
-# from a fresh repo checkout to make sure all commit sha's are available.
-
-tmp=$(mktemp -d)
-manifest_path=$(realpath $(dirname $0)/manifest.xml)
-
-cleanup() {
- rm -rf "${tmp}"
-}
-
-main() {
- cd "${tmp}"
- repo init --depth 1 \
- -u https://chromium.googlesource.com/chromiumos/manifest.git \
- --repo-url https://chromium.googlesource.com/external/repo.git \
- -g crosvm
- repo sync -j8 -c
- repo manifest --revision-as-HEAD -o "${manifest_path}"
-}
-
-trap cleanup EXIT
-main "$@"
diff --git a/ci/test_runner.py b/ci/test_runner.py
index eb7ec8712..1207207c1 100644
--- a/ci/test_runner.py
+++ b/ci/test_runner.py
@@ -28,7 +28,9 @@ VERY_VERBOSE = False
# Runs tests using the exec_file wrapper, which will run the test inside the
# builders built-in VM.
-VM_TEST_RUNNER = "/workspace/vm/exec_file --no-sync"
+VM_TEST_RUNNER = (
+ os.path.abspath("./ci/vm_tools/exec_binary_in_vm") + " --no-sync"
+)
# Runs tests using QEMU user-space emulation.
QEMU_TEST_RUNNER = (
@@ -52,12 +54,21 @@ class Requirements(enum.Enum):
# Test is disabled explicitly.
DISABLED = "disabled"
- # Test needs to be executed with expanded privileges for device access.
+ # Test needs to be executed with expanded privileges for device access and
+ # will be run inside a VM.
PRIVILEGED = "privileged"
# Test needs to run single-threaded
SINGLE_THREADED = "single_threaded"
+ # Separate workspaces that have dev-dependencies cannot be built from the
+ # crosvm workspace and need to be built separately.
+ # Note: Separate workspaces are built with no features enabled.
+ SEPARATE_WORKSPACE = "separate_workspace"
+
+ # Build, but do not run.
+ DO_NOT_RUN = "do_not_run"
+
BUILD_TIME_REQUIREMENTS = [
Requirements.AARCH64,
@@ -84,8 +95,13 @@ class CrateInfo(object):
build_reqs = requirements.intersection(BUILD_TIME_REQUIREMENTS)
self.can_build = all(req in capabilities for req in build_reqs)
- self.can_run = self.can_build and (
- not self.needs_privilege or Requirements.PRIVILEGED in capabilities
+ self.can_run = (
+ self.can_build
+ and (
+ not self.needs_privilege
+ or Requirements.PRIVILEGED in capabilities
+ )
+ and not Requirements.DO_NOT_RUN in self.requirements
)
def __repr__(self):
@@ -209,7 +225,7 @@ def results_summary(results: Union[RunResults, CrateResults]):
num_pass = results.count(TestResult.PASS)
num_skip = results.count(TestResult.SKIP)
num_fail = results.count(TestResult.FAIL)
- msg = []
+ msg: List[str] = []
if num_pass:
msg.append(f"{num_pass} passed")
if num_skip:
@@ -219,9 +235,42 @@ def results_summary(results: Union[RunResults, CrateResults]):
return ", ".join(msg)
+def cargo_build_process(
+ cwd: str = ".", crates: List[CrateInfo] = [], features: Set[str] = set()
+):
+ """Builds the main crosvm crate."""
+ cmd = [
+ "cargo",
+ "build",
+ "--color=never",
+ "--no-default-features",
+ "--features",
+ ",".join(features),
+ ]
+
+ for crate in sorted(crate.name for crate in crates):
+ cmd += ["-p", crate]
+
+ if VERY_VERBOSE:
+ print("CMD", " ".join(cmd))
+
+ process = subprocess.run(
+ cmd,
+ cwd=cwd,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.STDOUT,
+ text=True,
+ )
+ if process.returncode != 0 or VERBOSE:
+ print()
+ print(process.stdout)
+ return process
+
+
def cargo_test_process(
- crates: List[CrateInfo],
- features: Set[str],
+ cwd: str,
+ crates: List[CrateInfo] = [],
+ features: Set[str] = set(),
run: bool = True,
single_threaded: bool = False,
use_vm: bool = False,
@@ -233,6 +282,11 @@ def cargo_test_process(
cmd += ["--no-run"]
if features:
cmd += ["--no-default-features", "--features", ",".join(features)]
+
+ # Skip doc tests as these cannot be run in the VM.
+ if use_vm:
+ cmd += ["--bins", "--tests"]
+
for crate in sorted(crate.name for crate in crates):
cmd += ["-p", crate]
@@ -247,6 +301,7 @@ def cargo_test_process(
process = subprocess.run(
cmd,
+ cwd=cwd,
env=env,
timeout=timeout,
stdout=subprocess.PIPE,
@@ -261,9 +316,41 @@ def cargo_test_process(
def cargo_build_tests(crates: List[CrateInfo], features: Set[str]):
"""Runs cargo test --no-run to build all listed `crates`."""
- print("Building: ", ", ".join(crate.name for crate in crates))
- process = cargo_test_process(crates, features, run=False)
- return process.returncode == 0
+ separate_workspace_crates = [
+ crate
+ for crate in crates
+ if Requirements.SEPARATE_WORKSPACE in crate.requirements
+ ]
+ workspace_crates = [
+ crate
+ for crate in crates
+ if Requirements.SEPARATE_WORKSPACE not in crate.requirements
+ ]
+
+ print(
+ "Building workspace: ",
+ ", ".join(crate.name for crate in workspace_crates),
+ )
+ build_process = cargo_build_process(
+ cwd=".", crates=workspace_crates, features=features
+ )
+ if build_process.returncode != 0:
+ return False
+ test_process = cargo_test_process(
+ cwd=".", crates=workspace_crates, features=features, run=False
+ )
+ if test_process.returncode != 0:
+ return False
+
+ for crate in separate_workspace_crates:
+ print("Building crate:", crate.name)
+ build_process = cargo_build_process(cwd=crate.name)
+ if build_process.returncode != 0:
+ return False
+ test_process = cargo_test_process(cwd=crate.name, run=False)
+ if test_process.returncode != 0:
+ return False
+ return True
def cargo_test(
@@ -279,16 +366,29 @@ def cargo_test(
msg.append("in vm")
if single_threaded:
msg.append("(single-threaded)")
+ if Requirements.SEPARATE_WORKSPACE in crate.requirements:
+ msg.append("(separate workspace)")
sys.stdout.write(f"{' '.join(msg)}... ")
sys.stdout.flush()
- process = cargo_test_process(
- [crate],
- features,
- run=True,
- single_threaded=single_threaded,
- use_vm=use_vm,
- timeout=TEST_TIMEOUT_SECS,
- )
+
+ if Requirements.SEPARATE_WORKSPACE in crate.requirements:
+ process = cargo_test_process(
+ cwd=crate.name,
+ run=True,
+ single_threaded=single_threaded,
+ use_vm=use_vm,
+ timeout=TEST_TIMEOUT_SECS,
+ )
+ else:
+ process = cargo_test_process(
+ cwd=".",
+ crates=[crate],
+ features=features,
+ run=True,
+ single_threaded=single_threaded,
+ use_vm=use_vm,
+ timeout=TEST_TIMEOUT_SECS,
+ )
results = CrateResults(
crate.name, process.returncode == 0, process.stdout
)
@@ -318,7 +418,6 @@ def execute_batched_by_privilege(
Non-privileged tests are run first. Privileged tests are executed in
a VM if use_vm is set.
"""
-
build_crates = [crate for crate in crates if crate.can_build]
if not cargo_build_tests(build_crates, features):
return []
@@ -335,7 +434,7 @@ def execute_batched_by_privilege(
]
if privileged_crates:
if use_vm:
- subprocess.run("/workspace/vm/sync_so", check=True)
+ subprocess.run("./ci/vm_tools/sync_deps", check=True)
yield from execute_batched_by_parallelism(
privileged_crates, features, use_vm=True
)
@@ -491,8 +590,11 @@ def main(
help="Path to file where to store junit xml results",
)
args = parser.parse_args()
+
+ global VERBOSE, VERY_VERBOSE
VERBOSE = args.verbose or args.very_verbose # type: ignore
VERY_VERBOSE = args.very_verbose # type: ignore
+
use_vm = os.environ.get("CROSVM_USE_VM") != None or args.use_vm
cros_build = os.environ.get("CROSVM_CROS_BUILD") != None or args.cros_build
diff --git a/ci/vm_tools/README.md b/ci/vm_tools/README.md
new file mode 100644
index 000000000..0dcc87f06
--- /dev/null
+++ b/ci/vm_tools/README.md
@@ -0,0 +1,8 @@
+# VM Tools
+
+The scripts in this directory are used to make it easier to work with the test
+VM.
+
+The VM is expected to be accessible via `ssh vm` without a password prompt, this
+is set up by the builders, so it'll work out-of-the-box when running inside the
+builder shell.
diff --git a/ci/crosvm_test_vm/runtime/exec_file b/ci/vm_tools/exec_binary_in_vm
index c14f64885..7b2c230d1 100755
--- a/ci/crosvm_test_vm/runtime/exec_file
+++ b/ci/vm_tools/exec_binary_in_vm
@@ -3,15 +3,16 @@
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
#
-# Uploads and executes a file in the VM.
+# Uploads and executes a file in the VM. This script can be set as a runner
+# for cargo to execute tests inside the VM.
-${0%/*}/exec exit || exit 1 # Wait for VM to be available
+${0%/*}/wait_for_vm_with_timeout || exit 1
if [ "$1" = "--no-sync" ]; then
shift
else
- echo "Syncing shared objects..."
- ${0%/*}/sync_so || exit 1
+ echo "Syncing dependencies..."
+ ${0%/*}/sync_deps || exit 1
fi
filepath=$1
@@ -19,8 +20,9 @@ filename=$(basename $filepath)
echo "Executing $filename ${@:2}"
scp -q $filepath vm:/tmp/$filename
+ssh vm -q -t "cd /tmp && sudo ./$filename ${@:2}"
+
# Make sure to preserve the exit code of $filename after cleaning up the file.
-ssh vm -q -t "cd /tmp && ./$filename ${@:2}"
ret=$?
ssh vm -q -t "rm /tmp/$filename"
exit $ret
diff --git a/ci/vm_tools/sync_deps b/ci/vm_tools/sync_deps
new file mode 100755
index 000000000..a97b019e0
--- /dev/null
+++ b/ci/vm_tools/sync_deps
@@ -0,0 +1,37 @@
+#!/bin/bash
+# Copyright 2021 The Chromium OS Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file.
+#
+# Synchronizes dependencies of crosvm into the virtual machine to allow test
+# binaries to execute.
+
+${0%/*}/wait_for_vm_with_timeout || exit 1
+
+crosvm_root="/workspace/src/platform/crosvm"
+rust_toolchain=$(cat ${crosvm_root}/rust-toolchain)
+target_dir=$(
+ cargo metadata --no-deps --format-version 1 |
+ jq -r ".target_directory"
+)
+
+# List of shared objects used by crosvm that need to be synced.
+shared_objects=(
+ /workspace/scratch/lib/*.so*
+ /root/.rustup/toolchains/${rust_toolchain}-*/lib/libstd-*.so
+ /root/.rustup/toolchains/${rust_toolchain}-*/lib/libtest-*.so
+)
+rsync -azPLq --rsync-path="sudo rsync" ${shared_objects[@]} vm:/usr/lib
+
+# Files needed by binaries at runtime in the working directory.
+if [ -z "${CARGO_BUILD_TARGET}" ]; then
+ runtime_files=(
+ "${target_dir}/debug/crosvm"
+ )
+else
+ runtime_files=(
+ "${target_dir}/${CARGO_BUILD_TARGET}/debug/crosvm"
+ )
+fi
+
+rsync -azPLq --rsync-path="sudo rsync" ${runtime_files} vm:/tmp
diff --git a/ci/crosvm_test_vm/runtime/wait_for_vm b/ci/vm_tools/wait_for_vm
index dca090f9a..dca090f9a 100755
--- a/ci/crosvm_test_vm/runtime/wait_for_vm
+++ b/ci/vm_tools/wait_for_vm
diff --git a/ci/crosvm_test_vm/runtime/exec b/ci/vm_tools/wait_for_vm_with_timeout
index 1d7043caf..7502177d1 100755
--- a/ci/crosvm_test_vm/runtime/exec
+++ b/ci/vm_tools/wait_for_vm_with_timeout
@@ -2,8 +2,6 @@
# Copyright 2021 The Chromium OS Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
-#
-# Executes a command in the VM once it is available.
if ! timeout --foreground 180s ${0%/*}/wait_for_vm; then
echo ""
diff --git a/cros_async/Cargo.toml b/cros_async/Cargo.toml
index 5b497c13b..fcaf95bc7 100644
--- a/cros_async/Cargo.toml
+++ b/cros_async/Cargo.toml
@@ -7,15 +7,15 @@ edition = "2018"
[dependencies]
async-trait = "0.1.36"
async-task = "4"
-io_uring = { path = "../io_uring" }
+intrusive-collections = "0.9"
+io_uring = { path = "../io_uring" } # provided by ebuild
libc = "*"
paste = "1.0"
pin-utils = "0.1.0-alpha.4"
slab = "0.4"
-sync = { path = "../sync" }
-sys_util = { path = "../sys_util" }
-data_model = { path = "../data_model" }
-syscall_defines = { path = "../syscall_defines" }
+sync = { path = "../sync" } # provided by ebuild
+sys_util = { path = "../sys_util" } # provided by ebuild
+data_model = { path = "../data_model" } # provided by ebuild
thiserror = "1.0.20"
[dependencies.futures]
@@ -25,5 +25,8 @@ features = ["alloc"]
[dev-dependencies]
futures = { version = "*", features = ["executor"] }
-tempfile = { path = "../tempfile" }
-vm_memory = { path = "../vm_memory" }
+futures-executor = { version = "0.3", features = ["thread-pool"] }
+futures-util = "0.3"
+tempfile = { path = "../tempfile" } # provided by ebuild
+
+[workspace] \ No newline at end of file
diff --git a/cros_async/src/fd_executor.rs b/cros_async/src/fd_executor.rs
index 4ab5dc578..ec6a72061 100644
--- a/cros_async/src/fd_executor.rs
+++ b/cros_async/src/fd_executor.rs
@@ -49,6 +49,9 @@ pub enum Error {
/// PollContext failure.
#[error("PollContext failure: {0}")]
PollContextError(sys_util::Error),
+ /// An error occurred when setting the FD non-blocking.
+ #[error("An error occurred setting the FD non-blocking: {0}.")]
+ SettingNonBlocking(sys_util::Error),
/// Failed to submit the waker to the polling context.
#[error("An error adding to the Aio context: {0}")]
SubmittingWaker(sys_util::Error),
@@ -70,6 +73,62 @@ enum OpStatus {
Completed,
}
+// An IO source previously registered with an FdExecutor. Used to initiate asynchronous IO with the
+// associated executor.
+pub struct RegisteredSource<F> {
+ source: F,
+ ex: Weak<RawExecutor>,
+}
+
+impl<F: AsRawFd> RegisteredSource<F> {
+ // Start an asynchronous operation to wait for this source to become readable. The returned
+ // future will not be ready until the source is readable.
+ pub fn wait_readable(&self) -> Result<PendingOperation> {
+ let ex = self.ex.upgrade().ok_or(Error::ExecutorGone)?;
+
+ let token =
+ ex.add_operation(self.source.as_raw_fd(), WatchingEvents::empty().set_read())?;
+
+ Ok(PendingOperation {
+ token: Some(token),
+ ex: self.ex.clone(),
+ })
+ }
+
+ // Start an asynchronous operation to wait for this source to become writable. The returned
+ // future will not be ready until the source is writable.
+ pub fn wait_writable(&self) -> Result<PendingOperation> {
+ let ex = self.ex.upgrade().ok_or(Error::ExecutorGone)?;
+
+ let token =
+ ex.add_operation(self.source.as_raw_fd(), WatchingEvents::empty().set_write())?;
+
+ Ok(PendingOperation {
+ token: Some(token),
+ ex: self.ex.clone(),
+ })
+ }
+}
+
+impl<F> RegisteredSource<F> {
+ // Consume this RegisteredSource and return the inner IO source.
+ pub fn into_source(self) -> F {
+ self.source
+ }
+}
+
+impl<F> AsRef<F> for RegisteredSource<F> {
+ fn as_ref(&self) -> &F {
+ &self.source
+ }
+}
+
+impl<F> AsMut<F> for RegisteredSource<F> {
+ fn as_mut(&mut self) -> &mut F {
+ &mut self.source
+ }
+}
+
/// A token returned from `add_operation` that can be used to cancel the waker before it completes.
/// Used to manage getting the result from the underlying executor for a completed operation.
/// Dropping a `PendingOperation` will get the result from the executor.
@@ -227,11 +286,12 @@ impl RawExecutor {
let raw = Arc::downgrade(self);
let schedule = move |runnable| {
if let Some(r) = raw.upgrade() {
- r.queue.schedule(runnable);
+ r.queue.push_back(runnable);
+ r.wake();
}
};
let (runnable, task) = async_task::spawn(f, schedule);
- self.queue.schedule(runnable);
+ runnable.schedule();
task
}
@@ -243,11 +303,12 @@ impl RawExecutor {
let raw = Arc::downgrade(self);
let schedule = move |runnable| {
if let Some(r) = raw.upgrade() {
- r.queue.schedule(runnable);
+ r.queue.push_back(runnable);
+ r.wake();
}
};
let (runnable, task) = async_task::spawn_local(f, schedule);
- self.queue.schedule(runnable);
+ runnable.schedule();
task
}
@@ -257,7 +318,6 @@ impl RawExecutor {
loop {
self.state.store(PROCESSING, Ordering::Release);
- self.queue.set_waker(cx.waker().clone());
for runnable in self.queue.iter() {
runnable.run();
}
@@ -266,10 +326,13 @@ impl RawExecutor {
return Ok(val);
}
- let oldstate = self
- .state
- .compare_and_swap(PROCESSING, WAITING, Ordering::Acquire);
- if oldstate != PROCESSING {
+ let oldstate = self.state.compare_exchange(
+ PROCESSING,
+ WAITING,
+ Ordering::Acquire,
+ Ordering::Acquire,
+ );
+ if let Err(oldstate) = oldstate {
debug_assert_eq!(oldstate, WOKEN);
// One or more futures have become runnable.
continue;
@@ -430,24 +493,10 @@ impl FdExecutor {
self.raw.run(&mut ctx, f)
}
- pub fn wait_readable<F: AsRawFd>(&self, f: &F) -> Result<PendingOperation> {
- let token = self
- .raw
- .add_operation(f.as_raw_fd(), WatchingEvents::empty().set_read())?;
-
- Ok(PendingOperation {
- token: Some(token),
- ex: Arc::downgrade(&self.raw),
- })
- }
-
- pub fn wait_writable<F: AsRawFd>(&self, f: &F) -> Result<PendingOperation> {
- let token = self
- .raw
- .add_operation(f.as_raw_fd(), WatchingEvents::empty().set_read())?;
-
- Ok(PendingOperation {
- token: Some(token),
+ pub(crate) fn register_source<F: AsRawFd>(&self, f: F) -> Result<RegisteredSource<F>> {
+ add_fd_flags(f.as_raw_fd(), libc::O_NONBLOCK).map_err(Error::SettingNonBlocking)?;
+ Ok(RegisteredSource {
+ source: f,
ex: Arc::downgrade(&self.raw),
})
}
@@ -479,7 +528,8 @@ mod test {
async fn do_test(ex: &FdExecutor) {
let (r, _w) = sys_util::pipe(true).unwrap();
let done = Box::pin(async { 5usize });
- let pending = ex.wait_readable(&r).unwrap();
+ let source = ex.register_source(r).unwrap();
+ let pending = source.wait_readable().unwrap();
match futures::future::select(pending, done).await {
Either::Right((5, pending)) => std::mem::drop(pending),
_ => panic!("unexpected select result"),
@@ -520,7 +570,8 @@ mod test {
let ex = FdExecutor::new().unwrap();
- let op = ex.wait_writable(&tx).unwrap();
+ let source = ex.register_source(tx.try_clone().unwrap()).unwrap();
+ let op = source.wait_writable().unwrap();
ex.spawn_local(write_value(tx)).detach();
ex.spawn_local(check_op(op)).detach();
diff --git a/cros_async/src/lib.rs b/cros_async/src/lib.rs
index 4188779ae..271115bd4 100644
--- a/cros_async/src/lib.rs
+++ b/cros_async/src/lib.rs
@@ -67,6 +67,7 @@ pub mod mem;
mod poll_source;
mod queue;
mod select;
+pub mod sync;
mod timer;
mod uring_executor;
mod uring_source;
diff --git a/cros_async/src/poll_source.rs b/cros_async/src/poll_source.rs
index 748da78da..7fc674c62 100644
--- a/cros_async/src/poll_source.rs
+++ b/cros_async/src/poll_source.rs
@@ -10,11 +10,9 @@ use std::ops::{Deref, DerefMut};
use std::os::unix::io::AsRawFd;
use std::sync::Arc;
-use libc::O_NONBLOCK;
-use sys_util::{self, add_fd_flags};
use thiserror::Error as ThisError;
-use crate::fd_executor::{self, FdExecutor};
+use crate::fd_executor::{self, FdExecutor, RegisteredSource};
use crate::mem::{BackingMemory, MemRegion};
use crate::{AsyncError, AsyncResult};
use crate::{IoSourceExt, ReadAsync, WriteAsync};
@@ -35,15 +33,11 @@ pub enum Error {
#[error("An error occurred when executing fsync synchronously: {0}")]
Fsync(sys_util::Error),
/// An error occurred when reading the FD.
- ///
#[error("An error occurred when reading the FD: {0}.")]
Read(sys_util::Error),
/// Can't seek file.
#[error("An error occurred when seeking the FD: {0}.")]
Seeking(sys_util::Error),
- /// An error occurred when setting the FD non-blocking.
- #[error("An error occurred setting the FD non-blocking: {0}.")]
- SettingNonBlocking(sys_util::Error),
/// An error occurred when writing the FD.
#[error("An error occurred when writing the FD: {0}.")]
Write(sys_util::Error),
@@ -52,25 +46,19 @@ pub type Result<T> = std::result::Result<T, Error>;
/// Async wrapper for an IO source that uses the FD executor to drive async operations.
/// Used by `IoSourceExt::new` when uring isn't available.
-pub struct PollSource<F: AsRawFd> {
- source: F,
- ex: FdExecutor,
-}
+pub struct PollSource<F>(RegisteredSource<F>);
impl<F: AsRawFd> PollSource<F> {
/// Create a new `PollSource` from the given IO source.
pub fn new(f: F, ex: &FdExecutor) -> Result<Self> {
- let fd = f.as_raw_fd();
- add_fd_flags(fd, O_NONBLOCK).map_err(Error::SettingNonBlocking)?;
- Ok(Self {
- source: f,
- ex: ex.clone(),
- })
+ ex.register_source(f)
+ .map(PollSource)
+ .map_err(Error::Executor)
}
/// Return the inner source.
pub fn into_source(self) -> F {
- self.source
+ self.0.into_source()
}
}
@@ -78,7 +66,13 @@ impl<F: AsRawFd> Deref for PollSource<F> {
type Target = F;
fn deref(&self) -> &Self::Target {
- &self.source
+ self.0.as_ref()
+ }
+}
+
+impl<F: AsRawFd> DerefMut for PollSource<F> {
+ fn deref_mut(&mut self) -> &mut Self::Target {
+ self.0.as_mut()
}
}
@@ -94,7 +88,7 @@ impl<F: AsRawFd> ReadAsync for PollSource<F> {
// Safe because this will only modify `vec` and we check the return value.
let res = unsafe {
libc::pread64(
- self.source.as_raw_fd(),
+ self.as_raw_fd(),
vec.as_mut_ptr() as *mut libc::c_void,
vec.len(),
file_offset as libc::off64_t,
@@ -107,10 +101,7 @@ impl<F: AsRawFd> ReadAsync for PollSource<F> {
match sys_util::Error::last() {
e if e.errno() == libc::EWOULDBLOCK => {
- let op = self
- .ex
- .wait_readable(&self.source)
- .map_err(Error::AddingWaker)?;
+ let op = self.0.wait_readable().map_err(Error::AddingWaker)?;
op.await.map_err(Error::Executor)?;
}
e => return Err(Error::Read(e).into()),
@@ -135,7 +126,7 @@ impl<F: AsRawFd> ReadAsync for PollSource<F> {
// guaranteed to be valid from the pointer by io_slice_mut.
let res = unsafe {
libc::preadv64(
- self.source.as_raw_fd(),
+ self.as_raw_fd(),
iovecs.as_mut_ptr() as *mut _,
iovecs.len() as i32,
file_offset as libc::off64_t,
@@ -148,10 +139,7 @@ impl<F: AsRawFd> ReadAsync for PollSource<F> {
match sys_util::Error::last() {
e if e.errno() == libc::EWOULDBLOCK => {
- let op = self
- .ex
- .wait_readable(&self.source)
- .map_err(Error::AddingWaker)?;
+ let op = self.0.wait_readable().map_err(Error::AddingWaker)?;
op.await.map_err(Error::Executor)?;
}
e => return Err(Error::Read(e).into()),
@@ -161,10 +149,7 @@ impl<F: AsRawFd> ReadAsync for PollSource<F> {
/// Wait for the FD of `self` to be readable.
async fn wait_readable(&self) -> AsyncResult<()> {
- let op = self
- .ex
- .wait_readable(&self.source)
- .map_err(Error::AddingWaker)?;
+ let op = self.0.wait_readable().map_err(Error::AddingWaker)?;
op.await.map_err(Error::Executor)?;
Ok(())
}
@@ -175,7 +160,7 @@ impl<F: AsRawFd> ReadAsync for PollSource<F> {
// Safe because this will only modify `buf` and we check the return value.
let res = unsafe {
libc::read(
- self.source.as_raw_fd(),
+ self.as_raw_fd(),
buf.as_mut_ptr() as *mut libc::c_void,
buf.len(),
)
@@ -187,10 +172,7 @@ impl<F: AsRawFd> ReadAsync for PollSource<F> {
match sys_util::Error::last() {
e if e.errno() == libc::EWOULDBLOCK => {
- let op = self
- .ex
- .wait_readable(&self.source)
- .map_err(Error::AddingWaker)?;
+ let op = self.0.wait_readable().map_err(Error::AddingWaker)?;
op.await.map_err(Error::Executor)?;
}
e => return Err(Error::Read(e).into()),
@@ -211,7 +193,7 @@ impl<F: AsRawFd> WriteAsync for PollSource<F> {
// Safe because this will not modify any memory and we check the return value.
let res = unsafe {
libc::pwrite64(
- self.source.as_raw_fd(),
+ self.as_raw_fd(),
vec.as_ptr() as *const libc::c_void,
vec.len(),
file_offset as libc::off64_t,
@@ -224,10 +206,7 @@ impl<F: AsRawFd> WriteAsync for PollSource<F> {
match sys_util::Error::last() {
e if e.errno() == libc::EWOULDBLOCK => {
- let op = self
- .ex
- .wait_writable(&self.source)
- .map_err(Error::AddingWaker)?;
+ let op = self.0.wait_writable().map_err(Error::AddingWaker)?;
op.await.map_err(Error::Executor)?;
}
e => return Err(Error::Write(e).into()),
@@ -253,7 +232,7 @@ impl<F: AsRawFd> WriteAsync for PollSource<F> {
// guaranteed to be valid from the pointer by io_slice_mut.
let res = unsafe {
libc::pwritev64(
- self.source.as_raw_fd(),
+ self.as_raw_fd(),
iovecs.as_ptr() as *mut _,
iovecs.len() as i32,
file_offset as libc::off64_t,
@@ -266,10 +245,7 @@ impl<F: AsRawFd> WriteAsync for PollSource<F> {
match sys_util::Error::last() {
e if e.errno() == libc::EWOULDBLOCK => {
- let op = self
- .ex
- .wait_writable(&self.source)
- .map_err(Error::AddingWaker)?;
+ let op = self.0.wait_writable().map_err(Error::AddingWaker)?;
op.await.map_err(Error::Executor)?;
}
e => return Err(Error::Write(e).into()),
@@ -281,7 +257,7 @@ impl<F: AsRawFd> WriteAsync for PollSource<F> {
async fn fallocate(&self, file_offset: u64, len: u64, mode: u32) -> AsyncResult<()> {
let ret = unsafe {
libc::fallocate64(
- self.source.as_raw_fd(),
+ self.as_raw_fd(),
mode as libc::c_int,
file_offset as libc::off64_t,
len as libc::off64_t,
@@ -296,7 +272,7 @@ impl<F: AsRawFd> WriteAsync for PollSource<F> {
/// Sync all completed write operations to the backing storage.
async fn fsync(&self) -> AsyncResult<()> {
- let ret = unsafe { libc::fsync(self.source.as_raw_fd()) };
+ let ret = unsafe { libc::fsync(self.as_raw_fd()) };
if ret == 0 {
Ok(())
} else {
@@ -309,23 +285,17 @@ impl<F: AsRawFd> WriteAsync for PollSource<F> {
impl<F: AsRawFd> IoSourceExt<F> for PollSource<F> {
/// Yields the underlying IO source.
fn into_source(self: Box<Self>) -> F {
- self.source
+ self.0.into_source()
}
/// Provides a mutable ref to the underlying IO source.
fn as_source_mut(&mut self) -> &mut F {
- &mut self.source
+ self
}
/// Provides a ref to the underlying IO source.
fn as_source(&self) -> &F {
- &self.source
- }
-}
-
-impl<F: AsRawFd> DerefMut for PollSource<F> {
- fn deref_mut(&mut self) -> &mut Self::Target {
- &mut self.source
+ self
}
}
@@ -393,4 +363,23 @@ mod tests {
let ex = FdExecutor::new().unwrap();
ex.run_until(go(&ex)).unwrap();
}
+
+ #[test]
+ fn memory_leak() {
+ // This test needs to run under ASAN to detect memory leaks.
+
+ async fn owns_poll_source(source: PollSource<File>) {
+ let _ = source.wait_readable().await;
+ }
+
+ let (rx, _tx) = sys_util::pipe(true).unwrap();
+ let ex = FdExecutor::new().unwrap();
+ let source = PollSource::new(rx, &ex).unwrap();
+ ex.spawn_local(owns_poll_source(source)).detach();
+
+ // Drop `ex` without running. This would cause a memory leak if PollSource owned a strong
+ // reference to the executor because it owns a reference to the future that owns PollSource
+ // (via its Runnable). The strong reference prevents the drop impl from running, which would
+ // otherwise poll the future and have it return with an error.
+ }
}
diff --git a/cros_async/src/queue.rs b/cros_async/src/queue.rs
index dc7bb2d87..3192703a5 100644
--- a/cros_async/src/queue.rs
+++ b/cros_async/src/queue.rs
@@ -3,50 +3,32 @@
// found in the LICENSE file.
use std::collections::VecDeque;
-use std::mem;
-use std::task::Waker;
use async_task::Runnable;
use sync::Mutex;
-struct Inner {
- runnables: VecDeque<Runnable>,
- waker: Option<Waker>,
-}
-
/// A queue of `Runnables`. Intended to be used by executors to keep track of futures that have been
/// scheduled to run.
pub struct RunnableQueue {
- inner: Mutex<Inner>,
+ runnables: Mutex<VecDeque<Runnable>>,
}
impl RunnableQueue {
/// Create a new, empty `RunnableQueue`.
pub fn new() -> RunnableQueue {
RunnableQueue {
- inner: Mutex::new(Inner {
- runnables: VecDeque::new(),
- waker: None,
- }),
+ runnables: Mutex::new(VecDeque::new()),
}
}
- /// Schedule `runnable` to run in the future by adding it to this `RunnableQueue`. Also wakes up
- /// the waker associated with this `RunnableQueue`, if any.
- pub fn schedule(&self, runnable: Runnable) {
- let mut inner = self.inner.lock();
- inner.runnables.push_back(runnable);
- let waker = inner.waker.take();
- mem::drop(inner);
-
- if let Some(w) = waker {
- w.wake();
- }
+ /// Schedule `runnable` to run in the future by adding it to this `RunnableQueue`.
+ pub fn push_back(&self, runnable: Runnable) {
+ self.runnables.lock().push_back(runnable);
}
/// Remove and return the first `Runnable` in this `RunnableQueue` or `None` if it is empty.
pub fn pop_front(&self) -> Option<Runnable> {
- self.inner.lock().runnables.pop_front()
+ self.runnables.lock().pop_front()
}
/// Create an iterator over this `RunnableQueue` that repeatedly calls `pop_front()` until it is
@@ -54,12 +36,6 @@ impl RunnableQueue {
pub fn iter(&self) -> RunnableQueueIter {
self.into_iter()
}
-
- /// Associate `waker` with this `RunnableQueue`. `waker` will be woken up the next time
- /// `schedule()` is called.
- pub fn set_waker(&self, waker: Waker) {
- self.inner.lock().waker = Some(waker);
- }
}
impl<'q> IntoIterator for &'q RunnableQueue {
diff --git a/cros_async/src/sync.rs b/cros_async/src/sync.rs
new file mode 100644
index 000000000..80dec3858
--- /dev/null
+++ b/cros_async/src/sync.rs
@@ -0,0 +1,14 @@
+// Copyright 2020 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+mod blocking;
+mod cv;
+mod mu;
+mod spin;
+mod waiter;
+
+pub use blocking::block_on;
+pub use cv::Condvar;
+pub use mu::Mutex;
+pub use spin::SpinLock;
diff --git a/cros_async/src/sync/blocking.rs b/cros_async/src/sync/blocking.rs
new file mode 100644
index 000000000..ce02e4dd7
--- /dev/null
+++ b/cros_async/src/sync/blocking.rs
@@ -0,0 +1,192 @@
+// Copyright 2020 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+use std::future::Future;
+use std::ptr;
+use std::sync::atomic::{AtomicI32, Ordering};
+use std::sync::Arc;
+use std::task::{Context, Poll};
+
+use futures::pin_mut;
+use futures::task::{waker_ref, ArcWake};
+
+// Randomly generated values to indicate the state of the current thread.
+const WAITING: i32 = 0x25de_74d1;
+const WOKEN: i32 = 0x72d3_2c9f;
+
+const FUTEX_WAIT_PRIVATE: libc::c_int = libc::FUTEX_WAIT | libc::FUTEX_PRIVATE_FLAG;
+const FUTEX_WAKE_PRIVATE: libc::c_int = libc::FUTEX_WAKE | libc::FUTEX_PRIVATE_FLAG;
+
+thread_local!(static PER_THREAD_WAKER: Arc<Waker> = Arc::new(Waker(AtomicI32::new(WAITING))));
+
+#[repr(transparent)]
+struct Waker(AtomicI32);
+
+extern "C" {
+ #[cfg_attr(target_os = "android", link_name = "__errno")]
+ #[cfg_attr(target_os = "linux", link_name = "__errno_location")]
+ fn errno_location() -> *mut libc::c_int;
+}
+
+impl ArcWake for Waker {
+ fn wake_by_ref(arc_self: &Arc<Self>) {
+ let state = arc_self.0.swap(WOKEN, Ordering::Release);
+ if state == WAITING {
+ // The thread hasn't already been woken up so wake it up now. Safe because this doesn't
+ // modify any memory and we check the return value.
+ let res = unsafe {
+ libc::syscall(
+ libc::SYS_futex,
+ &arc_self.0,
+ FUTEX_WAKE_PRIVATE,
+ libc::INT_MAX, // val
+ ptr::null() as *const libc::timespec, // timeout
+ ptr::null() as *const libc::c_int, // uaddr2
+ 0_i32, // val3
+ )
+ };
+ if res < 0 {
+ panic!("unexpected error from FUTEX_WAKE_PRIVATE: {}", unsafe {
+ *errno_location()
+ });
+ }
+ }
+ }
+}
+
+/// Run a future to completion on the current thread.
+///
+/// This method will block the current thread until `f` completes. Useful when you need to call an
+/// async fn from a non-async context.
+pub fn block_on<F: Future>(f: F) -> F::Output {
+ pin_mut!(f);
+
+ PER_THREAD_WAKER.with(|thread_waker| {
+ let waker = waker_ref(thread_waker);
+ let mut cx = Context::from_waker(&waker);
+
+ loop {
+ if let Poll::Ready(t) = f.as_mut().poll(&mut cx) {
+ return t;
+ }
+
+ let state = thread_waker.0.swap(WAITING, Ordering::Acquire);
+ if state == WAITING {
+ // If we weren't already woken up then wait until we are. Safe because this doesn't
+ // modify any memory and we check the return value.
+ let res = unsafe {
+ libc::syscall(
+ libc::SYS_futex,
+ &thread_waker.0,
+ FUTEX_WAIT_PRIVATE,
+ state,
+ ptr::null() as *const libc::timespec, // timeout
+ ptr::null() as *const libc::c_int, // uaddr2
+ 0_i32, // val3
+ )
+ };
+
+ if res < 0 {
+ // Safe because libc guarantees that this is a valid pointer.
+ match unsafe { *errno_location() } {
+ libc::EAGAIN | libc::EINTR => {}
+ e => panic!("unexpected error from FUTEX_WAIT_PRIVATE: {}", e),
+ }
+ }
+
+ // Clear the state to prevent unnecessary extra loop iterations and also to allow
+ // nested usage of `block_on`.
+ thread_waker.0.store(WAITING, Ordering::Release);
+ }
+ }
+ })
+}
+
+#[cfg(test)]
+mod test {
+ use super::*;
+
+ use std::future::Future;
+ use std::pin::Pin;
+ use std::sync::mpsc::{channel, Sender};
+ use std::sync::Arc;
+ use std::task::{Context, Poll, Waker};
+ use std::thread;
+ use std::time::Duration;
+
+ use crate::sync::SpinLock;
+
+ struct TimerState {
+ fired: bool,
+ waker: Option<Waker>,
+ }
+ struct Timer {
+ state: Arc<SpinLock<TimerState>>,
+ }
+
+ impl Future for Timer {
+ type Output = ();
+
+ fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
+ let mut state = self.state.lock();
+ if state.fired {
+ return Poll::Ready(());
+ }
+
+ state.waker = Some(cx.waker().clone());
+ Poll::Pending
+ }
+ }
+
+ fn start_timer(dur: Duration, notify: Option<Sender<()>>) -> Timer {
+ let state = Arc::new(SpinLock::new(TimerState {
+ fired: false,
+ waker: None,
+ }));
+
+ let thread_state = Arc::clone(&state);
+ thread::spawn(move || {
+ thread::sleep(dur);
+ let mut ts = thread_state.lock();
+ ts.fired = true;
+ if let Some(waker) = ts.waker.take() {
+ waker.wake();
+ }
+ drop(ts);
+
+ if let Some(tx) = notify {
+ tx.send(()).expect("Failed to send completion notification");
+ }
+ });
+
+ Timer { state }
+ }
+
+ #[test]
+ fn it_works() {
+ block_on(start_timer(Duration::from_millis(100), None));
+ }
+
+ #[test]
+ fn nested() {
+ async fn inner() {
+ block_on(start_timer(Duration::from_millis(100), None));
+ }
+
+ block_on(inner());
+ }
+
+ #[test]
+ fn ready_before_poll() {
+ let (tx, rx) = channel();
+
+ let timer = start_timer(Duration::from_millis(50), Some(tx));
+
+ rx.recv()
+ .expect("Failed to receive completion notification");
+
+ // We know the timer has already fired so the poll should complete immediately.
+ block_on(timer);
+ }
+}
diff --git a/cros_async/src/sync/cv.rs b/cros_async/src/sync/cv.rs
new file mode 100644
index 000000000..4da0a12ef
--- /dev/null
+++ b/cros_async/src/sync/cv.rs
@@ -0,0 +1,1159 @@
+// Copyright 2020 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+use std::cell::UnsafeCell;
+use std::mem;
+use std::sync::atomic::{spin_loop_hint, AtomicUsize, Ordering};
+use std::sync::Arc;
+
+use crate::sync::mu::{MutexGuard, MutexReadGuard, RawMutex};
+use crate::sync::waiter::{Kind as WaiterKind, Waiter, WaiterAdapter, WaiterList, WaitingFor};
+
+const SPINLOCK: usize = 1 << 0;
+const HAS_WAITERS: usize = 1 << 1;
+
+/// A primitive to wait for an event to occur without consuming CPU time.
+///
+/// Condition variables are used in combination with a `Mutex` when a thread wants to wait for some
+/// condition to become true. The condition must always be verified while holding the `Mutex` lock.
+/// It is an error to use a `Condvar` with more than one `Mutex` while there are threads waiting on
+/// the `Condvar`.
+///
+/// # Examples
+///
+/// ```edition2018
+/// use std::sync::Arc;
+/// use std::thread;
+/// use std::sync::mpsc::channel;
+///
+/// use cros_async::sync::{block_on, Condvar, Mutex};
+///
+/// const N: usize = 13;
+///
+/// // Spawn a few threads to increment a shared variable (non-atomically), and
+/// // let all threads waiting on the Condvar know once the increments are done.
+/// let data = Arc::new(Mutex::new(0));
+/// let cv = Arc::new(Condvar::new());
+///
+/// for _ in 0..N {
+/// let (data, cv) = (data.clone(), cv.clone());
+/// thread::spawn(move || {
+/// let mut data = block_on(data.lock());
+/// *data += 1;
+/// if *data == N {
+/// cv.notify_all();
+/// }
+/// });
+/// }
+///
+/// let mut val = block_on(data.lock());
+/// while *val != N {
+/// val = block_on(cv.wait(val));
+/// }
+/// ```
+#[repr(align(128))]
+pub struct Condvar {
+ state: AtomicUsize,
+ waiters: UnsafeCell<WaiterList>,
+ mu: UnsafeCell<usize>,
+}
+
+impl Condvar {
+ /// Creates a new condition variable ready to be waited on and notified.
+ pub fn new() -> Condvar {
+ Condvar {
+ state: AtomicUsize::new(0),
+ waiters: UnsafeCell::new(WaiterList::new(WaiterAdapter::new())),
+ mu: UnsafeCell::new(0),
+ }
+ }
+
+ /// Block the current thread until this `Condvar` is notified by another thread.
+ ///
+ /// This method will atomically unlock the `Mutex` held by `guard` and then block the current
+ /// thread. Any call to `notify_one` or `notify_all` after the `Mutex` is unlocked may wake up
+ /// the thread.
+ ///
+ /// To allow for more efficient scheduling, this call may return even when the programmer
+ /// doesn't expect the thread to be woken. Therefore, calls to `wait()` should be used inside a
+ /// loop that checks the predicate before continuing.
+ ///
+ /// Callers that are not in an async context may wish to use the `block_on` method to block the
+ /// thread until the `Condvar` is notified.
+ ///
+ /// # Panics
+ ///
+ /// This method will panic if used with more than one `Mutex` at the same time.
+ ///
+ /// # Examples
+ ///
+ /// ```
+ /// # use std::sync::Arc;
+ /// # use std::thread;
+ ///
+ /// # use cros_async::sync::{block_on, Condvar, Mutex};
+ ///
+ /// # let mu = Arc::new(Mutex::new(false));
+ /// # let cv = Arc::new(Condvar::new());
+ /// # let (mu2, cv2) = (mu.clone(), cv.clone());
+ ///
+ /// # let t = thread::spawn(move || {
+ /// # *block_on(mu2.lock()) = true;
+ /// # cv2.notify_all();
+ /// # });
+ ///
+ /// let mut ready = block_on(mu.lock());
+ /// while !*ready {
+ /// ready = block_on(cv.wait(ready));
+ /// }
+ ///
+ /// # t.join().expect("failed to join thread");
+ /// ```
+ // Clippy doesn't like the lifetime parameters here but doing what it suggests leads to code
+ // that doesn't compile.
+ #[allow(clippy::needless_lifetimes)]
+ pub async fn wait<'g, T>(&self, guard: MutexGuard<'g, T>) -> MutexGuard<'g, T> {
+ let waiter = Arc::new(Waiter::new(
+ WaiterKind::Exclusive,
+ cancel_waiter,
+ self as *const Condvar as usize,
+ WaitingFor::Condvar,
+ ));
+
+ self.add_waiter(waiter.clone(), guard.as_raw_mutex());
+
+ // Get a reference to the mutex and then drop the lock.
+ let mu = guard.into_inner();
+
+ // Wait to be woken up.
+ waiter.wait().await;
+
+ // Now re-acquire the lock.
+ mu.lock_from_cv().await
+ }
+
+ /// Like `wait()` but takes and returns a `MutexReadGuard` instead.
+ // Clippy doesn't like the lifetime parameters here but doing what it suggests leads to code
+ // that doesn't compile.
+ #[allow(clippy::needless_lifetimes)]
+ pub async fn wait_read<'g, T>(&self, guard: MutexReadGuard<'g, T>) -> MutexReadGuard<'g, T> {
+ let waiter = Arc::new(Waiter::new(
+ WaiterKind::Shared,
+ cancel_waiter,
+ self as *const Condvar as usize,
+ WaitingFor::Condvar,
+ ));
+
+ self.add_waiter(waiter.clone(), guard.as_raw_mutex());
+
+ // Get a reference to the mutex and then drop the lock.
+ let mu = guard.into_inner();
+
+ // Wait to be woken up.
+ waiter.wait().await;
+
+ // Now re-acquire the lock.
+ mu.read_lock_from_cv().await
+ }
+
+ fn add_waiter(&self, waiter: Arc<Waiter>, raw_mutex: &RawMutex) {
+ // Acquire the spin lock.
+ let mut oldstate = self.state.load(Ordering::Relaxed);
+ while (oldstate & SPINLOCK) != 0
+ || self
+ .state
+ .compare_exchange_weak(
+ oldstate,
+ oldstate | SPINLOCK | HAS_WAITERS,
+ Ordering::Acquire,
+ Ordering::Relaxed,
+ )
+ .is_err()
+ {
+ spin_loop_hint();
+ oldstate = self.state.load(Ordering::Relaxed);
+ }
+
+ // Safe because the spin lock guarantees exclusive access and the reference does not escape
+ // this function.
+ let mu = unsafe { &mut *self.mu.get() };
+ let muptr = raw_mutex as *const RawMutex as usize;
+
+ match *mu {
+ 0 => *mu = muptr,
+ p if p == muptr => {}
+ _ => panic!("Attempting to use Condvar with more than one Mutex at the same time"),
+ }
+
+ // Safe because the spin lock guarantees exclusive access.
+ unsafe { (*self.waiters.get()).push_back(waiter) };
+
+ // Release the spin lock. Use a direct store here because no other thread can modify
+ // `self.state` while we hold the spin lock. Keep the `HAS_WAITERS` bit that we set earlier
+ // because we just added a waiter.
+ self.state.store(HAS_WAITERS, Ordering::Release);
+ }
+
+ /// Notify at most one thread currently waiting on the `Condvar`.
+ ///
+ /// If there is a thread currently waiting on the `Condvar` it will be woken up from its call to
+ /// `wait`.
+ ///
+ /// Unlike more traditional condition variable interfaces, this method requires a reference to
+ /// the `Mutex` associated with this `Condvar`. This is because it is inherently racy to call
+ /// `notify_one` or `notify_all` without first acquiring the `Mutex` lock. Additionally, taking
+ /// a reference to the `Mutex` here allows us to make some optimizations that can improve
+ /// performance by reducing unnecessary wakeups.
+ pub fn notify_one(&self) {
+ let mut oldstate = self.state.load(Ordering::Relaxed);
+ if (oldstate & HAS_WAITERS) == 0 {
+ // No waiters.
+ return;
+ }
+
+ while (oldstate & SPINLOCK) != 0
+ || self
+ .state
+ .compare_exchange_weak(
+ oldstate,
+ oldstate | SPINLOCK,
+ Ordering::Acquire,
+ Ordering::Relaxed,
+ )
+ .is_err()
+ {
+ spin_loop_hint();
+ oldstate = self.state.load(Ordering::Relaxed);
+ }
+
+ // Safe because the spin lock guarantees exclusive access and the reference does not escape
+ // this function.
+ let waiters = unsafe { &mut *self.waiters.get() };
+ let wake_list = get_wake_list(waiters);
+
+ let newstate = if waiters.is_empty() {
+ // Also clear the mutex associated with this Condvar since there are no longer any
+ // waiters. Safe because the spin lock guarantees exclusive access.
+ unsafe { *self.mu.get() = 0 };
+
+ // We are releasing the spin lock and there are no more waiters so we can clear all bits
+ // in `self.state`.
+ 0
+ } else {
+ // There are still waiters so we need to keep the HAS_WAITERS bit in the state.
+ HAS_WAITERS
+ };
+
+ // Release the spin lock.
+ self.state.store(newstate, Ordering::Release);
+
+ // Now wake any waiters in the wake list.
+ for w in wake_list {
+ w.wake();
+ }
+ }
+
+ /// Notify all threads currently waiting on the `Condvar`.
+ ///
+ /// All threads currently waiting on the `Condvar` will be woken up from their call to `wait`.
+ ///
+ /// Unlike more traditional condition variable interfaces, this method requires a reference to
+ /// the `Mutex` associated with this `Condvar`. This is because it is inherently racy to call
+ /// `notify_one` or `notify_all` without first acquiring the `Mutex` lock. Additionally, taking
+ /// a reference to the `Mutex` here allows us to make some optimizations that can improve
+ /// performance by reducing unnecessary wakeups.
+ pub fn notify_all(&self) {
+ let mut oldstate = self.state.load(Ordering::Relaxed);
+ if (oldstate & HAS_WAITERS) == 0 {
+ // No waiters.
+ return;
+ }
+
+ while (oldstate & SPINLOCK) != 0
+ || self
+ .state
+ .compare_exchange_weak(
+ oldstate,
+ oldstate | SPINLOCK,
+ Ordering::Acquire,
+ Ordering::Relaxed,
+ )
+ .is_err()
+ {
+ spin_loop_hint();
+ oldstate = self.state.load(Ordering::Relaxed);
+ }
+
+ // Safe because the spin lock guarantees exclusive access to `self.waiters`.
+ let wake_list = unsafe { (*self.waiters.get()).take() };
+
+ // Clear the mutex associated with this Condvar since there are no longer any waiters. Safe
+ // because we the spin lock guarantees exclusive access.
+ unsafe { *self.mu.get() = 0 };
+
+ // Mark any waiters left as no longer waiting for the Condvar.
+ for w in &wake_list {
+ w.set_waiting_for(WaitingFor::None);
+ }
+
+ // Release the spin lock. We can clear all bits in the state since we took all the waiters.
+ self.state.store(0, Ordering::Release);
+
+ // Now wake any waiters in the wake list.
+ for w in wake_list {
+ w.wake();
+ }
+ }
+
+ fn cancel_waiter(&self, waiter: &Waiter, wake_next: bool) {
+ let mut oldstate = self.state.load(Ordering::Relaxed);
+ while oldstate & SPINLOCK != 0
+ || self
+ .state
+ .compare_exchange_weak(
+ oldstate,
+ oldstate | SPINLOCK,
+ Ordering::Acquire,
+ Ordering::Relaxed,
+ )
+ .is_err()
+ {
+ spin_loop_hint();
+ oldstate = self.state.load(Ordering::Relaxed);
+ }
+
+ // Safe because the spin lock provides exclusive access and the reference does not escape
+ // this function.
+ let waiters = unsafe { &mut *self.waiters.get() };
+
+ let waiting_for = waiter.is_waiting_for();
+ // Don't drop the old waiter now as we're still holding the spin lock.
+ let old_waiter = if waiter.is_linked() && waiting_for == WaitingFor::Condvar {
+ // Safe because we know that the waiter is still linked and is waiting for the Condvar,
+ // which guarantees that it is still in `self.waiters`.
+ let mut cursor = unsafe { waiters.cursor_mut_from_ptr(waiter as *const Waiter) };
+ cursor.remove()
+ } else {
+ None
+ };
+
+ let wake_list = if wake_next || waiting_for == WaitingFor::None {
+ // Either the waiter was already woken or it's been removed from the condvar's waiter
+ // list and is going to be woken. Either way, we need to wake up another thread.
+ get_wake_list(waiters)
+ } else {
+ WaiterList::new(WaiterAdapter::new())
+ };
+
+ let set_on_release = if waiters.is_empty() {
+ // Clear the mutex associated with this Condvar since there are no longer any waiters. Safe
+ // because we the spin lock guarantees exclusive access.
+ unsafe { *self.mu.get() = 0 };
+
+ 0
+ } else {
+ HAS_WAITERS
+ };
+
+ self.state.store(set_on_release, Ordering::Release);
+
+ // Now wake any waiters still left in the wake list.
+ for w in wake_list {
+ w.wake();
+ }
+
+ mem::drop(old_waiter);
+ }
+}
+
+unsafe impl Send for Condvar {}
+unsafe impl Sync for Condvar {}
+
+impl Default for Condvar {
+ fn default() -> Self {
+ Self::new()
+ }
+}
+
+// Scan `waiters` and return all waiters that should be woken up.
+//
+// If the first waiter is trying to acquire a shared lock, then all waiters in the list that are
+// waiting for a shared lock are also woken up. In addition one writer is woken up, if possible.
+//
+// If the first waiter is trying to acquire an exclusive lock, then only that waiter is returned and
+// the rest of the list is not scanned.
+fn get_wake_list(waiters: &mut WaiterList) -> WaiterList {
+ let mut to_wake = WaiterList::new(WaiterAdapter::new());
+ let mut cursor = waiters.front_mut();
+
+ let mut waking_readers = false;
+ let mut all_readers = true;
+ while let Some(w) = cursor.get() {
+ match w.kind() {
+ WaiterKind::Exclusive if !waking_readers => {
+ // This is the first waiter and it's a writer. No need to check the other waiters.
+ // Also mark the waiter as having been removed from the Condvar's waiter list.
+ let waiter = cursor.remove().unwrap();
+ waiter.set_waiting_for(WaitingFor::None);
+ to_wake.push_back(waiter);
+ break;
+ }
+
+ WaiterKind::Shared => {
+ // This is a reader and the first waiter in the list was not a writer so wake up all
+ // the readers in the wait list.
+ let waiter = cursor.remove().unwrap();
+ waiter.set_waiting_for(WaitingFor::None);
+ to_wake.push_back(waiter);
+ waking_readers = true;
+ }
+
+ WaiterKind::Exclusive => {
+ debug_assert!(waking_readers);
+ if all_readers {
+ // We are waking readers but we need to ensure that at least one writer is woken
+ // up. Since we haven't yet woken up a writer, wake up this one.
+ let waiter = cursor.remove().unwrap();
+ waiter.set_waiting_for(WaitingFor::None);
+ to_wake.push_back(waiter);
+ all_readers = false;
+ } else {
+ // We are waking readers and have already woken one writer. Skip this one.
+ cursor.move_next();
+ }
+ }
+ }
+ }
+
+ to_wake
+}
+
+fn cancel_waiter(cv: usize, waiter: &Waiter, wake_next: bool) {
+ let condvar = cv as *const Condvar;
+
+ // Safe because the thread that owns the waiter being canceled must also own a reference to the
+ // Condvar, which guarantees that this pointer is valid.
+ unsafe { (*condvar).cancel_waiter(waiter, wake_next) }
+}
+
+#[cfg(test)]
+mod test {
+ use super::*;
+
+ use std::future::Future;
+ use std::mem;
+ use std::ptr;
+ use std::rc::Rc;
+ use std::sync::mpsc::{channel, Sender};
+ use std::sync::Arc;
+ use std::task::{Context, Poll};
+ use std::thread::{self, JoinHandle};
+ use std::time::Duration;
+
+ use futures::channel::oneshot;
+ use futures::task::{waker_ref, ArcWake};
+ use futures::{select, FutureExt};
+ use futures_executor::{LocalPool, LocalSpawner, ThreadPool};
+ use futures_util::task::LocalSpawnExt;
+
+ use crate::sync::{block_on, Mutex};
+
+ // Dummy waker used when we want to manually drive futures.
+ struct TestWaker;
+ impl ArcWake for TestWaker {
+ fn wake_by_ref(_arc_self: &Arc<Self>) {}
+ }
+
+ #[test]
+ fn smoke() {
+ let cv = Condvar::new();
+ cv.notify_one();
+ cv.notify_all();
+ }
+
+ #[test]
+ fn notify_one() {
+ let mu = Arc::new(Mutex::new(()));
+ let cv = Arc::new(Condvar::new());
+
+ let mu2 = mu.clone();
+ let cv2 = cv.clone();
+
+ let guard = block_on(mu.lock());
+ thread::spawn(move || {
+ let _g = block_on(mu2.lock());
+ cv2.notify_one();
+ });
+
+ let guard = block_on(cv.wait(guard));
+ mem::drop(guard);
+ }
+
+ #[test]
+ fn multi_mutex() {
+ const NUM_THREADS: usize = 5;
+
+ let mu = Arc::new(Mutex::new(false));
+ let cv = Arc::new(Condvar::new());
+
+ let mut threads = Vec::with_capacity(NUM_THREADS);
+ for _ in 0..NUM_THREADS {
+ let mu = mu.clone();
+ let cv = cv.clone();
+
+ threads.push(thread::spawn(move || {
+ let mut ready = block_on(mu.lock());
+ while !*ready {
+ ready = block_on(cv.wait(ready));
+ }
+ }));
+ }
+
+ let mut g = block_on(mu.lock());
+ *g = true;
+ mem::drop(g);
+ cv.notify_all();
+
+ threads
+ .into_iter()
+ .try_for_each(JoinHandle::join)
+ .expect("Failed to join threads");
+
+ // Now use the Condvar with a different mutex.
+ let alt_mu = Arc::new(Mutex::new(None));
+ let alt_mu2 = alt_mu.clone();
+ let cv2 = cv.clone();
+ let handle = thread::spawn(move || {
+ let mut g = block_on(alt_mu2.lock());
+ while g.is_none() {
+ g = block_on(cv2.wait(g));
+ }
+ });
+
+ let mut alt_g = block_on(alt_mu.lock());
+ *alt_g = Some(());
+ mem::drop(alt_g);
+ cv.notify_all();
+
+ handle
+ .join()
+ .expect("Failed to join thread alternate mutex");
+ }
+
+ #[test]
+ fn notify_one_single_thread_async() {
+ async fn notify(mu: Rc<Mutex<()>>, cv: Rc<Condvar>) {
+ let _g = mu.lock().await;
+ cv.notify_one();
+ }
+
+ async fn wait(mu: Rc<Mutex<()>>, cv: Rc<Condvar>, spawner: LocalSpawner) {
+ let mu2 = Rc::clone(&mu);
+ let cv2 = Rc::clone(&cv);
+
+ let g = mu.lock().await;
+ // Has to be spawned _after_ acquiring the lock to prevent a race
+ // where the notify happens before the waiter has acquired the lock.
+ spawner
+ .spawn_local(notify(mu2, cv2))
+ .expect("Failed to spawn `notify` task");
+ let _g = cv.wait(g).await;
+ }
+
+ let mut ex = LocalPool::new();
+ let spawner = ex.spawner();
+
+ let mu = Rc::new(Mutex::new(()));
+ let cv = Rc::new(Condvar::new());
+
+ spawner
+ .spawn_local(wait(mu, cv, spawner.clone()))
+ .expect("Failed to spawn `wait` task");
+
+ ex.run();
+ }
+
+ #[test]
+ fn notify_one_multi_thread_async() {
+ async fn notify(mu: Arc<Mutex<()>>, cv: Arc<Condvar>) {
+ let _g = mu.lock().await;
+ cv.notify_one();
+ }
+
+ async fn wait(mu: Arc<Mutex<()>>, cv: Arc<Condvar>, tx: Sender<()>, pool: ThreadPool) {
+ let mu2 = Arc::clone(&mu);
+ let cv2 = Arc::clone(&cv);
+
+ let g = mu.lock().await;
+ // Has to be spawned _after_ acquiring the lock to prevent a race
+ // where the notify happens before the waiter has acquired the lock.
+ pool.spawn_ok(notify(mu2, cv2));
+ let _g = cv.wait(g).await;
+
+ tx.send(()).expect("Failed to send completion notification");
+ }
+
+ let ex = ThreadPool::new().expect("Failed to create ThreadPool");
+
+ let mu = Arc::new(Mutex::new(()));
+ let cv = Arc::new(Condvar::new());
+
+ let (tx, rx) = channel();
+ ex.spawn_ok(wait(mu, cv, tx, ex.clone()));
+
+ rx.recv_timeout(Duration::from_secs(5))
+ .expect("Failed to receive completion notification");
+ }
+
+ #[test]
+ fn notify_one_with_cancel() {
+ const TASKS: usize = 17;
+ const OBSERVERS: usize = 7;
+ const ITERATIONS: usize = 103;
+
+ async fn observe(mu: &Arc<Mutex<usize>>, cv: &Arc<Condvar>) {
+ let mut count = mu.read_lock().await;
+ while *count == 0 {
+ count = cv.wait_read(count).await;
+ }
+ let _ = unsafe { ptr::read_volatile(&*count as *const usize) };
+ }
+
+ async fn decrement(mu: &Arc<Mutex<usize>>, cv: &Arc<Condvar>) {
+ let mut count = mu.lock().await;
+ while *count == 0 {
+ count = cv.wait(count).await;
+ }
+ *count -= 1;
+ }
+
+ async fn increment(mu: Arc<Mutex<usize>>, cv: Arc<Condvar>, done: Sender<()>) {
+ for _ in 0..TASKS * OBSERVERS * ITERATIONS {
+ *mu.lock().await += 1;
+ cv.notify_one();
+ }
+
+ done.send(()).expect("Failed to send completion message");
+ }
+
+ async fn observe_either(
+ mu: Arc<Mutex<usize>>,
+ cv: Arc<Condvar>,
+ alt_mu: Arc<Mutex<usize>>,
+ alt_cv: Arc<Condvar>,
+ done: Sender<()>,
+ ) {
+ for _ in 0..ITERATIONS {
+ select! {
+ () = observe(&mu, &cv).fuse() => {},
+ () = observe(&alt_mu, &alt_cv).fuse() => {},
+ }
+ }
+
+ done.send(()).expect("Failed to send completion message");
+ }
+
+ async fn decrement_either(
+ mu: Arc<Mutex<usize>>,
+ cv: Arc<Condvar>,
+ alt_mu: Arc<Mutex<usize>>,
+ alt_cv: Arc<Condvar>,
+ done: Sender<()>,
+ ) {
+ for _ in 0..ITERATIONS {
+ select! {
+ () = decrement(&mu, &cv).fuse() => {},
+ () = decrement(&alt_mu, &alt_cv).fuse() => {},
+ }
+ }
+
+ done.send(()).expect("Failed to send completion message");
+ }
+
+ let ex = ThreadPool::new().expect("Failed to create ThreadPool");
+
+ let mu = Arc::new(Mutex::new(0usize));
+ let alt_mu = Arc::new(Mutex::new(0usize));
+
+ let cv = Arc::new(Condvar::new());
+ let alt_cv = Arc::new(Condvar::new());
+
+ let (tx, rx) = channel();
+ for _ in 0..TASKS {
+ ex.spawn_ok(decrement_either(
+ Arc::clone(&mu),
+ Arc::clone(&cv),
+ Arc::clone(&alt_mu),
+ Arc::clone(&alt_cv),
+ tx.clone(),
+ ));
+ }
+
+ for _ in 0..OBSERVERS {
+ ex.spawn_ok(observe_either(
+ Arc::clone(&mu),
+ Arc::clone(&cv),
+ Arc::clone(&alt_mu),
+ Arc::clone(&alt_cv),
+ tx.clone(),
+ ));
+ }
+
+ ex.spawn_ok(increment(Arc::clone(&mu), Arc::clone(&cv), tx.clone()));
+ ex.spawn_ok(increment(Arc::clone(&alt_mu), Arc::clone(&alt_cv), tx));
+
+ for _ in 0..TASKS + OBSERVERS + 2 {
+ if let Err(e) = rx.recv_timeout(Duration::from_secs(20)) {
+ panic!("Error while waiting for threads to complete: {}", e);
+ }
+ }
+
+ assert_eq!(
+ *block_on(mu.read_lock()) + *block_on(alt_mu.read_lock()),
+ (TASKS * OBSERVERS * ITERATIONS * 2) - (TASKS * ITERATIONS)
+ );
+ assert_eq!(cv.state.load(Ordering::Relaxed), 0);
+ assert_eq!(alt_cv.state.load(Ordering::Relaxed), 0);
+ }
+
+ #[test]
+ fn notify_all_with_cancel() {
+ const TASKS: usize = 17;
+ const ITERATIONS: usize = 103;
+
+ async fn decrement(mu: &Arc<Mutex<usize>>, cv: &Arc<Condvar>) {
+ let mut count = mu.lock().await;
+ while *count == 0 {
+ count = cv.wait(count).await;
+ }
+ *count -= 1;
+ }
+
+ async fn increment(mu: Arc<Mutex<usize>>, cv: Arc<Condvar>, done: Sender<()>) {
+ for _ in 0..TASKS * ITERATIONS {
+ *mu.lock().await += 1;
+ cv.notify_all();
+ }
+
+ done.send(()).expect("Failed to send completion message");
+ }
+
+ async fn decrement_either(
+ mu: Arc<Mutex<usize>>,
+ cv: Arc<Condvar>,
+ alt_mu: Arc<Mutex<usize>>,
+ alt_cv: Arc<Condvar>,
+ done: Sender<()>,
+ ) {
+ for _ in 0..ITERATIONS {
+ select! {
+ () = decrement(&mu, &cv).fuse() => {},
+ () = decrement(&alt_mu, &alt_cv).fuse() => {},
+ }
+ }
+
+ done.send(()).expect("Failed to send completion message");
+ }
+
+ let ex = ThreadPool::new().expect("Failed to create ThreadPool");
+
+ let mu = Arc::new(Mutex::new(0usize));
+ let alt_mu = Arc::new(Mutex::new(0usize));
+
+ let cv = Arc::new(Condvar::new());
+ let alt_cv = Arc::new(Condvar::new());
+
+ let (tx, rx) = channel();
+ for _ in 0..TASKS {
+ ex.spawn_ok(decrement_either(
+ Arc::clone(&mu),
+ Arc::clone(&cv),
+ Arc::clone(&alt_mu),
+ Arc::clone(&alt_cv),
+ tx.clone(),
+ ));
+ }
+
+ ex.spawn_ok(increment(Arc::clone(&mu), Arc::clone(&cv), tx.clone()));
+ ex.spawn_ok(increment(Arc::clone(&alt_mu), Arc::clone(&alt_cv), tx));
+
+ for _ in 0..TASKS + 2 {
+ if let Err(e) = rx.recv_timeout(Duration::from_secs(10)) {
+ panic!("Error while waiting for threads to complete: {}", e);
+ }
+ }
+
+ assert_eq!(
+ *block_on(mu.read_lock()) + *block_on(alt_mu.read_lock()),
+ TASKS * ITERATIONS
+ );
+ assert_eq!(cv.state.load(Ordering::Relaxed), 0);
+ assert_eq!(alt_cv.state.load(Ordering::Relaxed), 0);
+ }
+ #[test]
+ fn notify_all() {
+ const THREADS: usize = 13;
+
+ let mu = Arc::new(Mutex::new(0));
+ let cv = Arc::new(Condvar::new());
+ let (tx, rx) = channel();
+
+ let mut threads = Vec::with_capacity(THREADS);
+ for _ in 0..THREADS {
+ let mu2 = mu.clone();
+ let cv2 = cv.clone();
+ let tx2 = tx.clone();
+
+ threads.push(thread::spawn(move || {
+ let mut count = block_on(mu2.lock());
+ *count += 1;
+ if *count == THREADS {
+ tx2.send(()).unwrap();
+ }
+
+ while *count != 0 {
+ count = block_on(cv2.wait(count));
+ }
+ }));
+ }
+
+ mem::drop(tx);
+
+ // Wait till all threads have started.
+ rx.recv_timeout(Duration::from_secs(5)).unwrap();
+
+ let mut count = block_on(mu.lock());
+ *count = 0;
+ mem::drop(count);
+ cv.notify_all();
+
+ for t in threads {
+ t.join().unwrap();
+ }
+ }
+
+ #[test]
+ fn notify_all_single_thread_async() {
+ const TASKS: usize = 13;
+
+ async fn reset(mu: Rc<Mutex<usize>>, cv: Rc<Condvar>) {
+ let mut count = mu.lock().await;
+ *count = 0;
+ cv.notify_all();
+ }
+
+ async fn watcher(mu: Rc<Mutex<usize>>, cv: Rc<Condvar>, spawner: LocalSpawner) {
+ let mut count = mu.lock().await;
+ *count += 1;
+ if *count == TASKS {
+ spawner
+ .spawn_local(reset(mu.clone(), cv.clone()))
+ .expect("Failed to spawn reset task");
+ }
+
+ while *count != 0 {
+ count = cv.wait(count).await;
+ }
+ }
+
+ let mut ex = LocalPool::new();
+ let spawner = ex.spawner();
+
+ let mu = Rc::new(Mutex::new(0));
+ let cv = Rc::new(Condvar::new());
+
+ for _ in 0..TASKS {
+ spawner
+ .spawn_local(watcher(mu.clone(), cv.clone(), spawner.clone()))
+ .expect("Failed to spawn watcher task");
+ }
+
+ ex.run();
+ }
+
+ #[test]
+ fn notify_all_multi_thread_async() {
+ const TASKS: usize = 13;
+
+ async fn reset(mu: Arc<Mutex<usize>>, cv: Arc<Condvar>) {
+ let mut count = mu.lock().await;
+ *count = 0;
+ cv.notify_all();
+ }
+
+ async fn watcher(
+ mu: Arc<Mutex<usize>>,
+ cv: Arc<Condvar>,
+ pool: ThreadPool,
+ tx: Sender<()>,
+ ) {
+ let mut count = mu.lock().await;
+ *count += 1;
+ if *count == TASKS {
+ pool.spawn_ok(reset(mu.clone(), cv.clone()));
+ }
+
+ while *count != 0 {
+ count = cv.wait(count).await;
+ }
+
+ tx.send(()).expect("Failed to send completion notification");
+ }
+
+ let pool = ThreadPool::new().expect("Failed to create ThreadPool");
+
+ let mu = Arc::new(Mutex::new(0));
+ let cv = Arc::new(Condvar::new());
+
+ let (tx, rx) = channel();
+ for _ in 0..TASKS {
+ pool.spawn_ok(watcher(mu.clone(), cv.clone(), pool.clone(), tx.clone()));
+ }
+
+ for _ in 0..TASKS {
+ rx.recv_timeout(Duration::from_secs(5))
+ .expect("Failed to receive completion notification");
+ }
+ }
+
+ #[test]
+ fn wake_all_readers() {
+ async fn read(mu: Arc<Mutex<bool>>, cv: Arc<Condvar>) {
+ let mut ready = mu.read_lock().await;
+ while !*ready {
+ ready = cv.wait_read(ready).await;
+ }
+ }
+
+ let mu = Arc::new(Mutex::new(false));
+ let cv = Arc::new(Condvar::new());
+ let mut readers = [
+ Box::pin(read(mu.clone(), cv.clone())),
+ Box::pin(read(mu.clone(), cv.clone())),
+ Box::pin(read(mu.clone(), cv.clone())),
+ Box::pin(read(mu.clone(), cv.clone())),
+ ];
+
+ let arc_waker = Arc::new(TestWaker);
+ let waker = waker_ref(&arc_waker);
+ let mut cx = Context::from_waker(&waker);
+
+ // First have all the readers wait on the Condvar.
+ for r in &mut readers {
+ if let Poll::Ready(()) = r.as_mut().poll(&mut cx) {
+ panic!("reader unexpectedly ready");
+ }
+ }
+
+ assert_eq!(cv.state.load(Ordering::Relaxed) & HAS_WAITERS, HAS_WAITERS);
+
+ // Now make the condition true and notify the condvar. Even though we will call notify_one,
+ // all the readers should be woken up.
+ *block_on(mu.lock()) = true;
+ cv.notify_one();
+
+ assert_eq!(cv.state.load(Ordering::Relaxed), 0);
+
+ // All readers should now be able to complete.
+ for r in &mut readers {
+ if r.as_mut().poll(&mut cx).is_pending() {
+ panic!("reader unable to complete");
+ }
+ }
+ }
+
+ #[test]
+ fn cancel_before_notify() {
+ async fn dec(mu: Arc<Mutex<usize>>, cv: Arc<Condvar>) {
+ let mut count = mu.lock().await;
+
+ while *count == 0 {
+ count = cv.wait(count).await;
+ }
+
+ *count -= 1;
+ }
+
+ let mu = Arc::new(Mutex::new(0));
+ let cv = Arc::new(Condvar::new());
+
+ let arc_waker = Arc::new(TestWaker);
+ let waker = waker_ref(&arc_waker);
+ let mut cx = Context::from_waker(&waker);
+
+ let mut fut1 = Box::pin(dec(mu.clone(), cv.clone()));
+ let mut fut2 = Box::pin(dec(mu.clone(), cv.clone()));
+
+ if let Poll::Ready(()) = fut1.as_mut().poll(&mut cx) {
+ panic!("future unexpectedly ready");
+ }
+ if let Poll::Ready(()) = fut2.as_mut().poll(&mut cx) {
+ panic!("future unexpectedly ready");
+ }
+ assert_eq!(cv.state.load(Ordering::Relaxed) & HAS_WAITERS, HAS_WAITERS);
+
+ *block_on(mu.lock()) = 2;
+ // Drop fut1 before notifying the cv.
+ mem::drop(fut1);
+ cv.notify_one();
+
+ // fut2 should now be ready to complete.
+ assert_eq!(cv.state.load(Ordering::Relaxed), 0);
+
+ if fut2.as_mut().poll(&mut cx).is_pending() {
+ panic!("future unable to complete");
+ }
+
+ assert_eq!(*block_on(mu.lock()), 1);
+ }
+
+ #[test]
+ fn cancel_after_notify_one() {
+ async fn dec(mu: Arc<Mutex<usize>>, cv: Arc<Condvar>) {
+ let mut count = mu.lock().await;
+
+ while *count == 0 {
+ count = cv.wait(count).await;
+ }
+
+ *count -= 1;
+ }
+
+ let mu = Arc::new(Mutex::new(0));
+ let cv = Arc::new(Condvar::new());
+
+ let arc_waker = Arc::new(TestWaker);
+ let waker = waker_ref(&arc_waker);
+ let mut cx = Context::from_waker(&waker);
+
+ let mut fut1 = Box::pin(dec(mu.clone(), cv.clone()));
+ let mut fut2 = Box::pin(dec(mu.clone(), cv.clone()));
+
+ if let Poll::Ready(()) = fut1.as_mut().poll(&mut cx) {
+ panic!("future unexpectedly ready");
+ }
+ if let Poll::Ready(()) = fut2.as_mut().poll(&mut cx) {
+ panic!("future unexpectedly ready");
+ }
+ assert_eq!(cv.state.load(Ordering::Relaxed) & HAS_WAITERS, HAS_WAITERS);
+
+ *block_on(mu.lock()) = 2;
+ cv.notify_one();
+
+ // fut1 should now be ready to complete. Drop it before polling. This should wake up fut2.
+ mem::drop(fut1);
+ assert_eq!(cv.state.load(Ordering::Relaxed), 0);
+
+ if fut2.as_mut().poll(&mut cx).is_pending() {
+ panic!("future unable to complete");
+ }
+
+ assert_eq!(*block_on(mu.lock()), 1);
+ }
+
+ #[test]
+ fn cancel_after_notify_all() {
+ async fn dec(mu: Arc<Mutex<usize>>, cv: Arc<Condvar>) {
+ let mut count = mu.lock().await;
+
+ while *count == 0 {
+ count = cv.wait(count).await;
+ }
+
+ *count -= 1;
+ }
+
+ let mu = Arc::new(Mutex::new(0));
+ let cv = Arc::new(Condvar::new());
+
+ let arc_waker = Arc::new(TestWaker);
+ let waker = waker_ref(&arc_waker);
+ let mut cx = Context::from_waker(&waker);
+
+ let mut fut1 = Box::pin(dec(mu.clone(), cv.clone()));
+ let mut fut2 = Box::pin(dec(mu.clone(), cv.clone()));
+
+ if let Poll::Ready(()) = fut1.as_mut().poll(&mut cx) {
+ panic!("future unexpectedly ready");
+ }
+ if let Poll::Ready(()) = fut2.as_mut().poll(&mut cx) {
+ panic!("future unexpectedly ready");
+ }
+ assert_eq!(cv.state.load(Ordering::Relaxed) & HAS_WAITERS, HAS_WAITERS);
+
+ let mut count = block_on(mu.lock());
+ *count = 2;
+
+ // Notify the cv while holding the lock. This should wake up both waiters.
+ cv.notify_all();
+ assert_eq!(cv.state.load(Ordering::Relaxed), 0);
+
+ mem::drop(count);
+
+ mem::drop(fut1);
+
+ if fut2.as_mut().poll(&mut cx).is_pending() {
+ panic!("future unable to complete");
+ }
+
+ assert_eq!(*block_on(mu.lock()), 1);
+ }
+
+ #[test]
+ fn timed_wait() {
+ async fn wait_deadline(
+ mu: Arc<Mutex<usize>>,
+ cv: Arc<Condvar>,
+ timeout: oneshot::Receiver<()>,
+ ) {
+ let mut count = mu.lock().await;
+
+ if *count == 0 {
+ let mut rx = timeout.fuse();
+
+ while *count == 0 {
+ select! {
+ res = rx => {
+ if let Err(e) = res {
+ panic!("Error while receiving timeout notification: {}", e);
+ }
+
+ return;
+ },
+ c = cv.wait(count).fuse() => count = c,
+ }
+ }
+ }
+
+ *count += 1;
+ }
+
+ let mu = Arc::new(Mutex::new(0));
+ let cv = Arc::new(Condvar::new());
+
+ let arc_waker = Arc::new(TestWaker);
+ let waker = waker_ref(&arc_waker);
+ let mut cx = Context::from_waker(&waker);
+
+ let (tx, rx) = oneshot::channel();
+ let mut wait = Box::pin(wait_deadline(mu.clone(), cv.clone(), rx));
+
+ if let Poll::Ready(()) = wait.as_mut().poll(&mut cx) {
+ panic!("wait_deadline unexpectedly ready");
+ }
+
+ assert_eq!(cv.state.load(Ordering::Relaxed), HAS_WAITERS);
+
+ // Signal the channel, which should cancel the wait.
+ tx.send(()).expect("Failed to send wakeup");
+
+ // Wait for the timer to run out.
+ if wait.as_mut().poll(&mut cx).is_pending() {
+ panic!("wait_deadline unable to complete in time");
+ }
+
+ assert_eq!(cv.state.load(Ordering::Relaxed), 0);
+ assert_eq!(*block_on(mu.lock()), 0);
+ }
+}
diff --git a/cros_async/src/sync/mu.rs b/cros_async/src/sync/mu.rs
new file mode 100644
index 000000000..4f3443fb5
--- /dev/null
+++ b/cros_async/src/sync/mu.rs
@@ -0,0 +1,2289 @@
+// Copyright 2020 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+use std::cell::UnsafeCell;
+use std::mem;
+use std::ops::{Deref, DerefMut};
+use std::sync::atomic::{spin_loop_hint, AtomicUsize, Ordering};
+use std::sync::Arc;
+use std::thread::yield_now;
+
+use crate::sync::waiter::{Kind as WaiterKind, Waiter, WaiterAdapter, WaiterList, WaitingFor};
+
+// Set when the mutex is exclusively locked.
+const LOCKED: usize = 1 << 0;
+// Set when there are one or more threads waiting to acquire the lock.
+const HAS_WAITERS: usize = 1 << 1;
+// Set when a thread has been woken up from the wait queue. Cleared when that thread either acquires
+// the lock or adds itself back into the wait queue. Used to prevent unnecessary wake ups when a
+// thread has been removed from the wait queue but has not gotten CPU time yet.
+const DESIGNATED_WAKER: usize = 1 << 2;
+// Used to provide exclusive access to the `waiters` field in `Mutex`. Should only be held while
+// modifying the waiter list.
+const SPINLOCK: usize = 1 << 3;
+// Set when a thread that wants an exclusive lock adds itself to the wait queue. New threads
+// attempting to acquire a shared lock will be preventing from getting it when this bit is set.
+// However, this bit is ignored once a thread has gone through the wait queue at least once.
+const WRITER_WAITING: usize = 1 << 4;
+// Set when a thread has gone through the wait queue many times but has failed to acquire the lock
+// every time it is woken up. When this bit is set, all other threads are prevented from acquiring
+// the lock until the thread that set the `LONG_WAIT` bit has acquired the lock.
+const LONG_WAIT: usize = 1 << 5;
+// The bit that is added to the mutex state in order to acquire a shared lock. Since more than one
+// thread can acquire a shared lock, we cannot use a single bit. Instead we use all the remaining
+// bits in the state to track the number of threads that have acquired a shared lock.
+const READ_LOCK: usize = 1 << 8;
+// Mask used for checking if any threads currently hold a shared lock.
+const READ_MASK: usize = !0xff;
+
+// The number of times the thread should just spin and attempt to re-acquire the lock.
+const SPIN_THRESHOLD: usize = 7;
+
+// The number of times the thread needs to go through the wait queue before it sets the `LONG_WAIT`
+// bit and forces all other threads to wait for it to acquire the lock. This value is set relatively
+// high so that we don't lose the benefit of having running threads unless it is absolutely
+// necessary.
+const LONG_WAIT_THRESHOLD: usize = 19;
+
+// Common methods between shared and exclusive locks.
+trait Kind {
+ // The bits that must be zero for the thread to acquire this kind of lock. If any of these bits
+ // are not zero then the thread will first spin and retry a few times before adding itself to
+ // the wait queue.
+ fn zero_to_acquire() -> usize;
+
+ // The bit that must be added in order to acquire this kind of lock. This should either be
+ // `LOCKED` or `READ_LOCK`.
+ fn add_to_acquire() -> usize;
+
+ // The bits that should be set when a thread adds itself to the wait queue while waiting to
+ // acquire this kind of lock.
+ fn set_when_waiting() -> usize;
+
+ // The bits that should be cleared when a thread acquires this kind of lock.
+ fn clear_on_acquire() -> usize;
+
+ // The waiter that a thread should use when waiting to acquire this kind of lock.
+ fn new_waiter(raw: &RawMutex) -> Arc<Waiter>;
+}
+
+// A lock type for shared read-only access to the data. More than one thread may hold this kind of
+// lock simultaneously.
+struct Shared;
+
+impl Kind for Shared {
+ fn zero_to_acquire() -> usize {
+ LOCKED | WRITER_WAITING | LONG_WAIT
+ }
+
+ fn add_to_acquire() -> usize {
+ READ_LOCK
+ }
+
+ fn set_when_waiting() -> usize {
+ 0
+ }
+
+ fn clear_on_acquire() -> usize {
+ 0
+ }
+
+ fn new_waiter(raw: &RawMutex) -> Arc<Waiter> {
+ Arc::new(Waiter::new(
+ WaiterKind::Shared,
+ cancel_waiter,
+ raw as *const RawMutex as usize,
+ WaitingFor::Mutex,
+ ))
+ }
+}
+
+// A lock type for mutually exclusive read-write access to the data. Only one thread can hold this
+// kind of lock at a time.
+struct Exclusive;
+
+impl Kind for Exclusive {
+ fn zero_to_acquire() -> usize {
+ LOCKED | READ_MASK | LONG_WAIT
+ }
+
+ fn add_to_acquire() -> usize {
+ LOCKED
+ }
+
+ fn set_when_waiting() -> usize {
+ WRITER_WAITING
+ }
+
+ fn clear_on_acquire() -> usize {
+ WRITER_WAITING
+ }
+
+ fn new_waiter(raw: &RawMutex) -> Arc<Waiter> {
+ Arc::new(Waiter::new(
+ WaiterKind::Exclusive,
+ cancel_waiter,
+ raw as *const RawMutex as usize,
+ WaitingFor::Mutex,
+ ))
+ }
+}
+
+// Scan `waiters` and return the ones that should be woken up. Also returns any bits that should be
+// set in the mutex state when the current thread releases the spin lock protecting the waiter list.
+//
+// If the first waiter is trying to acquire a shared lock, then all waiters in the list that are
+// waiting for a shared lock are also woken up. If any waiters waiting for an exclusive lock are
+// found when iterating through the list, then the returned `usize` contains the `WRITER_WAITING`
+// bit, which should be set when the thread releases the spin lock.
+//
+// If the first waiter is trying to acquire an exclusive lock, then only that waiter is returned and
+// no bits are set in the returned `usize`.
+fn get_wake_list(waiters: &mut WaiterList) -> (WaiterList, usize) {
+ let mut to_wake = WaiterList::new(WaiterAdapter::new());
+ let mut set_on_release = 0;
+ let mut cursor = waiters.front_mut();
+
+ let mut waking_readers = false;
+ while let Some(w) = cursor.get() {
+ match w.kind() {
+ WaiterKind::Exclusive if !waking_readers => {
+ // This is the first waiter and it's a writer. No need to check the other waiters.
+ let waiter = cursor.remove().unwrap();
+ waiter.set_waiting_for(WaitingFor::None);
+ to_wake.push_back(waiter);
+ break;
+ }
+
+ WaiterKind::Shared => {
+ // This is a reader and the first waiter in the list was not a writer so wake up all
+ // the readers in the wait list.
+ let waiter = cursor.remove().unwrap();
+ waiter.set_waiting_for(WaitingFor::None);
+ to_wake.push_back(waiter);
+ waking_readers = true;
+ }
+
+ WaiterKind::Exclusive => {
+ // We found a writer while looking for more readers to wake up. Set the
+ // `WRITER_WAITING` bit to prevent any new readers from acquiring the lock. All
+ // readers currently in the wait list will ignore this bit since they already waited
+ // once.
+ set_on_release |= WRITER_WAITING;
+ cursor.move_next();
+ }
+ }
+ }
+
+ (to_wake, set_on_release)
+}
+
+#[inline]
+fn cpu_relax(iterations: usize) {
+ for _ in 0..iterations {
+ spin_loop_hint();
+ }
+}
+
+pub(crate) struct RawMutex {
+ state: AtomicUsize,
+ waiters: UnsafeCell<WaiterList>,
+}
+
+impl RawMutex {
+ pub fn new() -> RawMutex {
+ RawMutex {
+ state: AtomicUsize::new(0),
+ waiters: UnsafeCell::new(WaiterList::new(WaiterAdapter::new())),
+ }
+ }
+
+ #[inline]
+ pub async fn lock(&self) {
+ match self
+ .state
+ .compare_exchange_weak(0, LOCKED, Ordering::Acquire, Ordering::Relaxed)
+ {
+ Ok(_) => {}
+ Err(oldstate) => {
+ // If any bits that should be zero are not zero or if we fail to acquire the lock
+ // with a single compare_exchange then go through the slow path.
+ if (oldstate & Exclusive::zero_to_acquire()) != 0
+ || self
+ .state
+ .compare_exchange_weak(
+ oldstate,
+ (oldstate + Exclusive::add_to_acquire())
+ & !Exclusive::clear_on_acquire(),
+ Ordering::Acquire,
+ Ordering::Relaxed,
+ )
+ .is_err()
+ {
+ self.lock_slow::<Exclusive>(0, 0).await;
+ }
+ }
+ }
+ }
+
+ #[inline]
+ pub async fn read_lock(&self) {
+ match self
+ .state
+ .compare_exchange_weak(0, READ_LOCK, Ordering::Acquire, Ordering::Relaxed)
+ {
+ Ok(_) => {}
+ Err(oldstate) => {
+ if (oldstate & Shared::zero_to_acquire()) != 0
+ || self
+ .state
+ .compare_exchange_weak(
+ oldstate,
+ (oldstate + Shared::add_to_acquire()) & !Shared::clear_on_acquire(),
+ Ordering::Acquire,
+ Ordering::Relaxed,
+ )
+ .is_err()
+ {
+ self.lock_slow::<Shared>(0, 0).await;
+ }
+ }
+ }
+ }
+
+ // Slow path for acquiring the lock. `clear` should contain any bits that need to be cleared
+ // when the lock is acquired. Any bits set in `zero_mask` are cleared from the bits returned by
+ // `K::zero_to_acquire()`.
+ #[cold]
+ async fn lock_slow<K: Kind>(&self, mut clear: usize, zero_mask: usize) {
+ let mut zero_to_acquire = K::zero_to_acquire() & !zero_mask;
+
+ let mut spin_count = 0;
+ let mut wait_count = 0;
+ let mut waiter = None;
+ loop {
+ let oldstate = self.state.load(Ordering::Relaxed);
+ // If all the bits in `zero_to_acquire` are actually zero then try to acquire the lock
+ // directly.
+ if (oldstate & zero_to_acquire) == 0 {
+ if self
+ .state
+ .compare_exchange_weak(
+ oldstate,
+ (oldstate + K::add_to_acquire()) & !(clear | K::clear_on_acquire()),
+ Ordering::Acquire,
+ Ordering::Relaxed,
+ )
+ .is_ok()
+ {
+ return;
+ }
+ } else if (oldstate & SPINLOCK) == 0 {
+ // The mutex is locked and the spin lock is available. Try to add this thread
+ // to the waiter queue.
+ let w = waiter.get_or_insert_with(|| K::new_waiter(self));
+ w.reset(WaitingFor::Mutex);
+
+ if self
+ .state
+ .compare_exchange_weak(
+ oldstate,
+ (oldstate | SPINLOCK | HAS_WAITERS | K::set_when_waiting()) & !clear,
+ Ordering::Acquire,
+ Ordering::Relaxed,
+ )
+ .is_ok()
+ {
+ let mut set_on_release = 0;
+
+ // Safe because we have acquired the spin lock and it provides exclusive
+ // access to the waiter queue.
+ if wait_count < LONG_WAIT_THRESHOLD {
+ // Add the waiter to the back of the queue.
+ unsafe { (*self.waiters.get()).push_back(w.clone()) };
+ } else {
+ // This waiter has gone through the queue too many times. Put it in the
+ // front of the queue and block all other threads from acquiring the lock
+ // until this one has acquired it at least once.
+ unsafe { (*self.waiters.get()).push_front(w.clone()) };
+
+ // Set the LONG_WAIT bit to prevent all other threads from acquiring the
+ // lock.
+ set_on_release |= LONG_WAIT;
+
+ // Make sure we clear the LONG_WAIT bit when we do finally get the lock.
+ clear |= LONG_WAIT;
+
+ // Since we set the LONG_WAIT bit we shouldn't allow that bit to prevent us
+ // from acquiring the lock.
+ zero_to_acquire &= !LONG_WAIT;
+ }
+
+ // Release the spin lock.
+ let mut state = oldstate;
+ loop {
+ match self.state.compare_exchange_weak(
+ state,
+ (state | set_on_release) & !SPINLOCK,
+ Ordering::Release,
+ Ordering::Relaxed,
+ ) {
+ Ok(_) => break,
+ Err(w) => state = w,
+ }
+ }
+
+ // Now wait until we are woken.
+ w.wait().await;
+
+ // The `DESIGNATED_WAKER` bit gets set when this thread is woken up by the
+ // thread that originally held the lock. While this bit is set, no other waiters
+ // will be woken up so it's important to clear it the next time we try to
+ // acquire the main lock or the spin lock.
+ clear |= DESIGNATED_WAKER;
+
+ // Now that the thread has waited once, we no longer care if there is a writer
+ // waiting. Only the limits of mutual exclusion can prevent us from acquiring
+ // the lock.
+ zero_to_acquire &= !WRITER_WAITING;
+
+ // Reset the spin count since we just went through the wait queue.
+ spin_count = 0;
+
+ // Increment the wait count since we went through the wait queue.
+ wait_count += 1;
+
+ // Skip the `cpu_relax` below.
+ continue;
+ }
+ }
+
+ // Both the lock and the spin lock are held by one or more other threads. First, we'll
+ // spin a few times in case we can acquire the lock or the spin lock. If that fails then
+ // we yield because we might be preventing the threads that do hold the 2 locks from
+ // getting cpu time.
+ if spin_count < SPIN_THRESHOLD {
+ cpu_relax(1 << spin_count);
+ spin_count += 1;
+ } else {
+ yield_now();
+ }
+ }
+ }
+
+ #[inline]
+ pub fn unlock(&self) {
+ // Fast path, if possible. We can directly clear the locked bit since we have exclusive
+ // access to the mutex.
+ let oldstate = self.state.fetch_sub(LOCKED, Ordering::Release);
+
+ // Panic if we just tried to unlock a mutex that wasn't held by this thread. This shouldn't
+ // really be possible since `unlock` is not a public method.
+ debug_assert_eq!(
+ oldstate & READ_MASK,
+ 0,
+ "`unlock` called on mutex held in read-mode"
+ );
+ debug_assert_ne!(
+ oldstate & LOCKED,
+ 0,
+ "`unlock` called on mutex not held in write-mode"
+ );
+
+ if (oldstate & HAS_WAITERS) != 0 && (oldstate & DESIGNATED_WAKER) == 0 {
+ // The oldstate has waiters but no designated waker has been chosen yet.
+ self.unlock_slow();
+ }
+ }
+
+ #[inline]
+ pub fn read_unlock(&self) {
+ // Fast path, if possible. We can directly subtract the READ_LOCK bit since we had
+ // previously added it.
+ let oldstate = self.state.fetch_sub(READ_LOCK, Ordering::Release);
+
+ debug_assert_eq!(
+ oldstate & LOCKED,
+ 0,
+ "`read_unlock` called on mutex held in write-mode"
+ );
+ debug_assert_ne!(
+ oldstate & READ_MASK,
+ 0,
+ "`read_unlock` called on mutex not held in read-mode"
+ );
+
+ if (oldstate & HAS_WAITERS) != 0
+ && (oldstate & DESIGNATED_WAKER) == 0
+ && (oldstate & READ_MASK) == READ_LOCK
+ {
+ // There are waiters, no designated waker has been chosen yet, and the last reader is
+ // unlocking so we have to take the slow path.
+ self.unlock_slow();
+ }
+ }
+
+ #[cold]
+ fn unlock_slow(&self) {
+ let mut spin_count = 0;
+
+ loop {
+ let oldstate = self.state.load(Ordering::Relaxed);
+ if (oldstate & HAS_WAITERS) == 0 || (oldstate & DESIGNATED_WAKER) != 0 {
+ // No more waiters or a designated waker has been chosen. Nothing left for us to do.
+ return;
+ } else if (oldstate & SPINLOCK) == 0 {
+ // The spin lock is not held by another thread. Try to acquire it. Also set the
+ // `DESIGNATED_WAKER` bit since we are likely going to wake up one or more threads.
+ if self
+ .state
+ .compare_exchange_weak(
+ oldstate,
+ oldstate | SPINLOCK | DESIGNATED_WAKER,
+ Ordering::Acquire,
+ Ordering::Relaxed,
+ )
+ .is_ok()
+ {
+ // Acquired the spinlock. Try to wake a waiter. We may also end up wanting to
+ // clear the HAS_WAITER and DESIGNATED_WAKER bits so start collecting the bits
+ // to be cleared.
+ let mut clear = SPINLOCK;
+
+ // Safe because the spinlock guarantees exclusive access to the waiter list and
+ // the reference does not escape this function.
+ let waiters = unsafe { &mut *self.waiters.get() };
+ let (wake_list, set_on_release) = get_wake_list(waiters);
+
+ // If the waiter list is now empty, clear the HAS_WAITERS bit.
+ if waiters.is_empty() {
+ clear |= HAS_WAITERS;
+ }
+
+ if wake_list.is_empty() {
+ // Since we are not going to wake any waiters clear the DESIGNATED_WAKER bit
+ // that we set when we acquired the spin lock.
+ clear |= DESIGNATED_WAKER;
+ }
+
+ // Release the spin lock and clear any other bits as necessary. Also, set any
+ // bits returned by `get_wake_list`. For now, this is just the `WRITER_WAITING`
+ // bit, which needs to be set when we are waking up a bunch of readers and there
+ // are still writers in the wait queue. This will prevent any readers that
+ // aren't in `wake_list` from acquiring the read lock.
+ let mut state = oldstate;
+ loop {
+ match self.state.compare_exchange_weak(
+ state,
+ (state | set_on_release) & !clear,
+ Ordering::Release,
+ Ordering::Relaxed,
+ ) {
+ Ok(_) => break,
+ Err(w) => state = w,
+ }
+ }
+
+ // Now wake the waiters, if any.
+ for w in wake_list {
+ w.wake();
+ }
+
+ // We're done.
+ return;
+ }
+ }
+
+ // Spin and try again. It's ok to block here as we have already released the lock.
+ if spin_count < SPIN_THRESHOLD {
+ cpu_relax(1 << spin_count);
+ spin_count += 1;
+ } else {
+ yield_now();
+ }
+ }
+ }
+
+ fn cancel_waiter(&self, waiter: &Waiter, wake_next: bool) {
+ let mut oldstate = self.state.load(Ordering::Relaxed);
+ while oldstate & SPINLOCK != 0
+ || self
+ .state
+ .compare_exchange_weak(
+ oldstate,
+ oldstate | SPINLOCK,
+ Ordering::Acquire,
+ Ordering::Relaxed,
+ )
+ .is_err()
+ {
+ spin_loop_hint();
+ oldstate = self.state.load(Ordering::Relaxed);
+ }
+
+ // Safe because the spin lock provides exclusive access and the reference does not escape
+ // this function.
+ let waiters = unsafe { &mut *self.waiters.get() };
+
+ let mut clear = SPINLOCK;
+
+ // If we are about to remove the first waiter in the wait list, then clear the LONG_WAIT
+ // bit. Also clear the bit if we are going to be waking some other waiters. In this case the
+ // waiter that set the bit may have already been removed from the waiter list (and could be
+ // the one that is currently being dropped). If it is still in the waiter list then clearing
+ // this bit may starve it for one more iteration through the lock_slow() loop, whereas not
+ // clearing this bit could cause a deadlock if the waiter that set it is the one that is
+ // being dropped.
+ if wake_next
+ || waiters
+ .front()
+ .get()
+ .map(|front| std::ptr::eq(front, waiter))
+ .unwrap_or(false)
+ {
+ clear |= LONG_WAIT;
+ }
+
+ let waiting_for = waiter.is_waiting_for();
+
+ // Don't drop the old waiter while holding the spin lock.
+ let old_waiter = if waiter.is_linked() && waiting_for == WaitingFor::Mutex {
+ // We know that the waiter is still linked and is waiting for the mutex, which
+ // guarantees that it is still linked into `self.waiters`.
+ let mut cursor = unsafe { waiters.cursor_mut_from_ptr(waiter as *const Waiter) };
+ cursor.remove()
+ } else {
+ None
+ };
+
+ let (wake_list, set_on_release) = if wake_next || waiting_for == WaitingFor::None {
+ // Either the waiter was already woken or it's been removed from the mutex's waiter
+ // list and is going to be woken. Either way, we need to wake up another thread.
+ get_wake_list(waiters)
+ } else {
+ (WaiterList::new(WaiterAdapter::new()), 0)
+ };
+
+ if waiters.is_empty() {
+ clear |= HAS_WAITERS;
+ }
+
+ if wake_list.is_empty() {
+ // We're not waking any other threads so clear the DESIGNATED_WAKER bit. In the worst
+ // case this leads to an additional thread being woken up but we risk a deadlock if we
+ // don't clear it.
+ clear |= DESIGNATED_WAKER;
+ }
+
+ if let WaiterKind::Exclusive = waiter.kind() {
+ // The waiter being dropped is a writer so clear the writer waiting bit for now. If we
+ // found more writers in the list while fetching waiters to wake up then this bit will
+ // be set again via `set_on_release`.
+ clear |= WRITER_WAITING;
+ }
+
+ while self
+ .state
+ .compare_exchange_weak(
+ oldstate,
+ (oldstate & !clear) | set_on_release,
+ Ordering::Release,
+ Ordering::Relaxed,
+ )
+ .is_err()
+ {
+ spin_loop_hint();
+ oldstate = self.state.load(Ordering::Relaxed);
+ }
+
+ for w in wake_list {
+ w.wake();
+ }
+
+ mem::drop(old_waiter);
+ }
+}
+
+unsafe impl Send for RawMutex {}
+unsafe impl Sync for RawMutex {}
+
+fn cancel_waiter(raw: usize, waiter: &Waiter, wake_next: bool) {
+ let raw_mutex = raw as *const RawMutex;
+
+ // Safe because the thread that owns the waiter that is being canceled must
+ // also own a reference to the mutex, which ensures that this pointer is
+ // valid.
+ unsafe { (*raw_mutex).cancel_waiter(waiter, wake_next) }
+}
+
+/// A high-level primitive that provides safe, mutable access to a shared resource.
+///
+/// Unlike more traditional mutexes, `Mutex` can safely provide both shared, immutable access (via
+/// `read_lock()`) as well as exclusive, mutable access (via `lock()`) to an underlying resource
+/// with no loss of performance.
+///
+/// # Poisoning
+///
+/// `Mutex` does not support lock poisoning so if a thread panics while holding the lock, the
+/// poisoned data will be accessible by other threads in your program. If you need to guarantee that
+/// other threads cannot access poisoned data then you may wish to wrap this `Mutex` inside another
+/// type that provides the poisoning feature. See the implementation of `std::sync::Mutex` for an
+/// example of this.
+///
+///
+/// # Fairness
+///
+/// This `Mutex` implementation does not guarantee that threads will acquire the lock in the same
+/// order that they call `lock()` or `read_lock()`. However it will attempt to prevent long-term
+/// starvation: if a thread repeatedly fails to acquire the lock beyond a threshold then all other
+/// threads will fail to acquire the lock until the starved thread has acquired it.
+///
+/// Similarly, this `Mutex` will attempt to balance reader and writer threads: once there is a
+/// writer thread waiting to acquire the lock no new reader threads will be allowed to acquire it.
+/// However, any reader threads that were already waiting will still be allowed to acquire it.
+///
+/// # Examples
+///
+/// ```edition2018
+/// use std::sync::Arc;
+/// use std::thread;
+/// use std::sync::mpsc::channel;
+///
+/// use cros_async::sync::{block_on, Mutex};
+///
+/// const N: usize = 10;
+///
+/// // Spawn a few threads to increment a shared variable (non-atomically), and
+/// // let the main thread know once all increments are done.
+/// //
+/// // Here we're using an Arc to share memory among threads, and the data inside
+/// // the Arc is protected with a mutex.
+/// let data = Arc::new(Mutex::new(0));
+///
+/// let (tx, rx) = channel();
+/// for _ in 0..N {
+/// let (data, tx) = (Arc::clone(&data), tx.clone());
+/// thread::spawn(move || {
+/// // The shared state can only be accessed once the lock is held.
+/// // Our non-atomic increment is safe because we're the only thread
+/// // which can access the shared state when the lock is held.
+/// let mut data = block_on(data.lock());
+/// *data += 1;
+/// if *data == N {
+/// tx.send(()).unwrap();
+/// }
+/// // the lock is unlocked here when `data` goes out of scope.
+/// });
+/// }
+///
+/// rx.recv().unwrap();
+/// ```
+#[repr(align(128))]
+pub struct Mutex<T: ?Sized> {
+ raw: RawMutex,
+ value: UnsafeCell<T>,
+}
+
+impl<T> Mutex<T> {
+ /// Create a new, unlocked `Mutex` ready for use.
+ pub fn new(v: T) -> Mutex<T> {
+ Mutex {
+ raw: RawMutex::new(),
+ value: UnsafeCell::new(v),
+ }
+ }
+
+ /// Consume the `Mutex` and return the contained value. This method does not perform any locking
+ /// as the compiler will guarantee that there are no other references to `self` and the caller
+ /// owns the `Mutex`.
+ pub fn into_inner(self) -> T {
+ // Don't need to acquire the lock because the compiler guarantees that there are
+ // no references to `self`.
+ self.value.into_inner()
+ }
+}
+
+impl<T: ?Sized> Mutex<T> {
+ /// Acquires exclusive, mutable access to the resource protected by the `Mutex`, blocking the
+ /// current thread until it is able to do so. Upon returning, the current thread will be the
+ /// only thread with access to the resource. The `Mutex` will be released when the returned
+ /// `MutexGuard` is dropped.
+ ///
+ /// Calling `lock()` while holding a `MutexGuard` or a `MutexReadGuard` will cause a deadlock.
+ ///
+ /// Callers that are not in an async context may wish to use the `block_on` method to block the
+ /// thread until the `Mutex` is acquired.
+ #[inline]
+ pub async fn lock(&self) -> MutexGuard<'_, T> {
+ self.raw.lock().await;
+
+ // Safe because we have exclusive access to `self.value`.
+ MutexGuard {
+ mu: self,
+ value: unsafe { &mut *self.value.get() },
+ }
+ }
+
+ /// Acquires shared, immutable access to the resource protected by the `Mutex`, blocking the
+ /// current thread until it is able to do so. Upon returning there may be other threads that
+ /// also have immutable access to the resource but there will not be any threads that have
+ /// mutable access to the resource. When the returned `MutexReadGuard` is dropped the thread
+ /// releases its access to the resource.
+ ///
+ /// Calling `read_lock()` while holding a `MutexReadGuard` may deadlock. Calling `read_lock()`
+ /// while holding a `MutexGuard` will deadlock.
+ ///
+ /// Callers that are not in an async context may wish to use the `block_on` method to block the
+ /// thread until the `Mutex` is acquired.
+ #[inline]
+ pub async fn read_lock(&self) -> MutexReadGuard<'_, T> {
+ self.raw.read_lock().await;
+
+ // Safe because we have shared read-only access to `self.value`.
+ MutexReadGuard {
+ mu: self,
+ value: unsafe { &*self.value.get() },
+ }
+ }
+
+ // Called from `Condvar::wait` when the thread wants to reacquire the lock.
+ #[inline]
+ pub(crate) async fn lock_from_cv(&self) -> MutexGuard<'_, T> {
+ self.raw.lock_slow::<Exclusive>(DESIGNATED_WAKER, 0).await;
+
+ // Safe because we have exclusive access to `self.value`.
+ MutexGuard {
+ mu: self,
+ value: unsafe { &mut *self.value.get() },
+ }
+ }
+
+ // Like `lock_from_cv` but for acquiring a shared lock.
+ #[inline]
+ pub(crate) async fn read_lock_from_cv(&self) -> MutexReadGuard<'_, T> {
+ // Threads that have waited in the Condvar's waiter list don't have to care if there is a
+ // writer waiting since they have already waited once.
+ self.raw
+ .lock_slow::<Shared>(DESIGNATED_WAKER, WRITER_WAITING)
+ .await;
+
+ // Safe because we have exclusive access to `self.value`.
+ MutexReadGuard {
+ mu: self,
+ value: unsafe { &*self.value.get() },
+ }
+ }
+
+ #[inline]
+ fn unlock(&self) {
+ self.raw.unlock();
+ }
+
+ #[inline]
+ fn read_unlock(&self) {
+ self.raw.read_unlock();
+ }
+
+ pub fn get_mut(&mut self) -> &mut T {
+ // Safe because the compiler statically guarantees that are no other references to `self`.
+ // This is also why we don't need to acquire the lock first.
+ unsafe { &mut *self.value.get() }
+ }
+}
+
+unsafe impl<T: ?Sized + Send> Send for Mutex<T> {}
+unsafe impl<T: ?Sized + Send> Sync for Mutex<T> {}
+
+impl<T: ?Sized + Default> Default for Mutex<T> {
+ fn default() -> Self {
+ Self::new(Default::default())
+ }
+}
+
+impl<T> From<T> for Mutex<T> {
+ fn from(source: T) -> Self {
+ Self::new(source)
+ }
+}
+
+/// An RAII implementation of a "scoped exclusive lock" for a `Mutex`. When this structure is
+/// dropped, the lock will be released. The resource protected by the `Mutex` can be accessed via
+/// the `Deref` and `DerefMut` implementations of this structure.
+pub struct MutexGuard<'a, T: ?Sized + 'a> {
+ mu: &'a Mutex<T>,
+ value: &'a mut T,
+}
+
+impl<'a, T: ?Sized> MutexGuard<'a, T> {
+ pub(crate) fn into_inner(self) -> &'a Mutex<T> {
+ self.mu
+ }
+
+ pub(crate) fn as_raw_mutex(&self) -> &RawMutex {
+ &self.mu.raw
+ }
+}
+
+impl<'a, T: ?Sized> Deref for MutexGuard<'a, T> {
+ type Target = T;
+
+ fn deref(&self) -> &Self::Target {
+ self.value
+ }
+}
+
+impl<'a, T: ?Sized> DerefMut for MutexGuard<'a, T> {
+ fn deref_mut(&mut self) -> &mut Self::Target {
+ self.value
+ }
+}
+
+impl<'a, T: ?Sized> Drop for MutexGuard<'a, T> {
+ fn drop(&mut self) {
+ self.mu.unlock()
+ }
+}
+
+/// An RAII implementation of a "scoped shared lock" for a `Mutex`. When this structure is dropped,
+/// the lock will be released. The resource protected by the `Mutex` can be accessed via the `Deref`
+/// implementation of this structure.
+pub struct MutexReadGuard<'a, T: ?Sized + 'a> {
+ mu: &'a Mutex<T>,
+ value: &'a T,
+}
+
+impl<'a, T: ?Sized> MutexReadGuard<'a, T> {
+ pub(crate) fn into_inner(self) -> &'a Mutex<T> {
+ self.mu
+ }
+
+ pub(crate) fn as_raw_mutex(&self) -> &RawMutex {
+ &self.mu.raw
+ }
+}
+
+impl<'a, T: ?Sized> Deref for MutexReadGuard<'a, T> {
+ type Target = T;
+
+ fn deref(&self) -> &Self::Target {
+ self.value
+ }
+}
+
+impl<'a, T: ?Sized> Drop for MutexReadGuard<'a, T> {
+ fn drop(&mut self) {
+ self.mu.read_unlock()
+ }
+}
+
+#[cfg(test)]
+mod test {
+ use super::*;
+
+ use std::future::Future;
+ use std::mem;
+ use std::pin::Pin;
+ use std::rc::Rc;
+ use std::sync::atomic::{AtomicUsize, Ordering};
+ use std::sync::mpsc::{channel, Sender};
+ use std::sync::Arc;
+ use std::task::{Context, Poll, Waker};
+ use std::thread;
+ use std::time::Duration;
+
+ use futures::channel::oneshot;
+ use futures::task::{waker_ref, ArcWake};
+ use futures::{pending, select, FutureExt};
+ use futures_executor::{LocalPool, ThreadPool};
+ use futures_util::task::LocalSpawnExt;
+
+ use crate::sync::{block_on, Condvar, SpinLock};
+
+ #[derive(Debug, Eq, PartialEq)]
+ struct NonCopy(u32);
+
+ // Dummy waker used when we want to manually drive futures.
+ struct TestWaker;
+ impl ArcWake for TestWaker {
+ fn wake_by_ref(_arc_self: &Arc<Self>) {}
+ }
+
+ #[test]
+ fn it_works() {
+ let mu = Mutex::new(NonCopy(13));
+
+ assert_eq!(*block_on(mu.lock()), NonCopy(13));
+ }
+
+ #[test]
+ fn smoke() {
+ let mu = Mutex::new(NonCopy(7));
+
+ mem::drop(block_on(mu.lock()));
+ mem::drop(block_on(mu.lock()));
+ }
+
+ #[test]
+ fn rw_smoke() {
+ let mu = Mutex::new(NonCopy(7));
+
+ mem::drop(block_on(mu.lock()));
+ mem::drop(block_on(mu.read_lock()));
+ mem::drop((block_on(mu.read_lock()), block_on(mu.read_lock())));
+ mem::drop(block_on(mu.lock()));
+ }
+
+ #[test]
+ fn async_smoke() {
+ async fn lock(mu: Rc<Mutex<NonCopy>>) {
+ mu.lock().await;
+ }
+
+ async fn read_lock(mu: Rc<Mutex<NonCopy>>) {
+ mu.read_lock().await;
+ }
+
+ async fn double_read_lock(mu: Rc<Mutex<NonCopy>>) {
+ let first = mu.read_lock().await;
+ mu.read_lock().await;
+
+ // Make sure first lives past the second read lock.
+ first.as_raw_mutex();
+ }
+
+ let mu = Rc::new(Mutex::new(NonCopy(7)));
+
+ let mut ex = LocalPool::new();
+ let spawner = ex.spawner();
+
+ spawner
+ .spawn_local(lock(Rc::clone(&mu)))
+ .expect("Failed to spawn future");
+ spawner
+ .spawn_local(read_lock(Rc::clone(&mu)))
+ .expect("Failed to spawn future");
+ spawner
+ .spawn_local(double_read_lock(Rc::clone(&mu)))
+ .expect("Failed to spawn future");
+ spawner
+ .spawn_local(lock(Rc::clone(&mu)))
+ .expect("Failed to spawn future");
+
+ ex.run();
+ }
+
+ #[test]
+ fn send() {
+ let mu = Mutex::new(NonCopy(19));
+
+ thread::spawn(move || {
+ let value = block_on(mu.lock());
+ assert_eq!(*value, NonCopy(19));
+ })
+ .join()
+ .unwrap();
+ }
+
+ #[test]
+ fn arc_nested() {
+ // Tests nested mutexes and access to underlying data.
+ let mu = Mutex::new(1);
+ let arc = Arc::new(Mutex::new(mu));
+ thread::spawn(move || {
+ let nested = block_on(arc.lock());
+ let lock2 = block_on(nested.lock());
+ assert_eq!(*lock2, 1);
+ })
+ .join()
+ .unwrap();
+ }
+
+ #[test]
+ fn arc_access_in_unwind() {
+ let arc = Arc::new(Mutex::new(1));
+ let arc2 = arc.clone();
+ thread::spawn(move || {
+ struct Unwinder {
+ i: Arc<Mutex<i32>>,
+ }
+ impl Drop for Unwinder {
+ fn drop(&mut self) {
+ *block_on(self.i.lock()) += 1;
+ }
+ }
+ let _u = Unwinder { i: arc2 };
+ panic!();
+ })
+ .join()
+ .expect_err("thread did not panic");
+ let lock = block_on(arc.lock());
+ assert_eq!(*lock, 2);
+ }
+
+ #[test]
+ fn unsized_value() {
+ let mutex: &Mutex<[i32]> = &Mutex::new([1, 2, 3]);
+ {
+ let b = &mut *block_on(mutex.lock());
+ b[0] = 4;
+ b[2] = 5;
+ }
+ let expected: &[i32] = &[4, 2, 5];
+ assert_eq!(&*block_on(mutex.lock()), expected);
+ }
+ #[test]
+ fn high_contention() {
+ const THREADS: usize = 17;
+ const ITERATIONS: usize = 103;
+
+ let mut threads = Vec::with_capacity(THREADS);
+
+ let mu = Arc::new(Mutex::new(0usize));
+ for _ in 0..THREADS {
+ let mu2 = mu.clone();
+ threads.push(thread::spawn(move || {
+ for _ in 0..ITERATIONS {
+ *block_on(mu2.lock()) += 1;
+ }
+ }));
+ }
+
+ for t in threads.into_iter() {
+ t.join().unwrap();
+ }
+
+ assert_eq!(*block_on(mu.read_lock()), THREADS * ITERATIONS);
+ assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0);
+ }
+
+ #[test]
+ fn high_contention_with_cancel() {
+ const TASKS: usize = 17;
+ const ITERATIONS: usize = 103;
+
+ async fn increment(mu: Arc<Mutex<usize>>, alt_mu: Arc<Mutex<usize>>, tx: Sender<()>) {
+ for _ in 0..ITERATIONS {
+ select! {
+ mut count = mu.lock().fuse() => *count += 1,
+ mut count = alt_mu.lock().fuse() => *count += 1,
+ }
+ }
+ tx.send(()).expect("Failed to send completion signal");
+ }
+
+ let ex = ThreadPool::new().expect("Failed to create ThreadPool");
+
+ let mu = Arc::new(Mutex::new(0usize));
+ let alt_mu = Arc::new(Mutex::new(0usize));
+
+ let (tx, rx) = channel();
+ for _ in 0..TASKS {
+ ex.spawn_ok(increment(Arc::clone(&mu), Arc::clone(&alt_mu), tx.clone()));
+ }
+
+ for _ in 0..TASKS {
+ if let Err(e) = rx.recv_timeout(Duration::from_secs(10)) {
+ panic!("Error while waiting for threads to complete: {}", e);
+ }
+ }
+
+ assert_eq!(
+ *block_on(mu.read_lock()) + *block_on(alt_mu.read_lock()),
+ TASKS * ITERATIONS
+ );
+ assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0);
+ assert_eq!(alt_mu.raw.state.load(Ordering::Relaxed), 0);
+ }
+
+ #[test]
+ fn single_thread_async() {
+ const TASKS: usize = 17;
+ const ITERATIONS: usize = 103;
+
+ // Async closures are unstable.
+ async fn increment(mu: Rc<Mutex<usize>>) {
+ for _ in 0..ITERATIONS {
+ *mu.lock().await += 1;
+ }
+ }
+
+ let mut ex = LocalPool::new();
+ let spawner = ex.spawner();
+
+ let mu = Rc::new(Mutex::new(0usize));
+ for _ in 0..TASKS {
+ spawner
+ .spawn_local(increment(Rc::clone(&mu)))
+ .expect("Failed to spawn task");
+ }
+
+ ex.run();
+
+ assert_eq!(*block_on(mu.read_lock()), TASKS * ITERATIONS);
+ assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0);
+ }
+
+ #[test]
+ fn multi_thread_async() {
+ const TASKS: usize = 17;
+ const ITERATIONS: usize = 103;
+
+ // Async closures are unstable.
+ async fn increment(mu: Arc<Mutex<usize>>, tx: Sender<()>) {
+ for _ in 0..ITERATIONS {
+ *mu.lock().await += 1;
+ }
+ tx.send(()).expect("Failed to send completion signal");
+ }
+
+ let ex = ThreadPool::new().expect("Failed to create ThreadPool");
+
+ let mu = Arc::new(Mutex::new(0usize));
+ let (tx, rx) = channel();
+ for _ in 0..TASKS {
+ ex.spawn_ok(increment(Arc::clone(&mu), tx.clone()));
+ }
+
+ for _ in 0..TASKS {
+ rx.recv_timeout(Duration::from_secs(5))
+ .expect("Failed to receive completion signal");
+ }
+ assert_eq!(*block_on(mu.read_lock()), TASKS * ITERATIONS);
+ assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0);
+ }
+
+ #[test]
+ fn get_mut() {
+ let mut mu = Mutex::new(NonCopy(13));
+ *mu.get_mut() = NonCopy(17);
+
+ assert_eq!(mu.into_inner(), NonCopy(17));
+ }
+
+ #[test]
+ fn into_inner() {
+ let mu = Mutex::new(NonCopy(29));
+ assert_eq!(mu.into_inner(), NonCopy(29));
+ }
+
+ #[test]
+ fn into_inner_drop() {
+ struct NeedsDrop(Arc<AtomicUsize>);
+ impl Drop for NeedsDrop {
+ fn drop(&mut self) {
+ self.0.fetch_add(1, Ordering::AcqRel);
+ }
+ }
+
+ let value = Arc::new(AtomicUsize::new(0));
+ let needs_drop = Mutex::new(NeedsDrop(value.clone()));
+ assert_eq!(value.load(Ordering::Acquire), 0);
+
+ {
+ let inner = needs_drop.into_inner();
+ assert_eq!(inner.0.load(Ordering::Acquire), 0);
+ }
+
+ assert_eq!(value.load(Ordering::Acquire), 1);
+ }
+
+ #[test]
+ fn rw_arc() {
+ const THREADS: isize = 7;
+ const ITERATIONS: isize = 13;
+
+ let mu = Arc::new(Mutex::new(0isize));
+ let mu2 = mu.clone();
+
+ let (tx, rx) = channel();
+ thread::spawn(move || {
+ let mut guard = block_on(mu2.lock());
+ for _ in 0..ITERATIONS {
+ let tmp = *guard;
+ *guard = -1;
+ thread::yield_now();
+ *guard = tmp + 1;
+ }
+ tx.send(()).unwrap();
+ });
+
+ let mut readers = Vec::with_capacity(10);
+ for _ in 0..THREADS {
+ let mu3 = mu.clone();
+ let handle = thread::spawn(move || {
+ let guard = block_on(mu3.read_lock());
+ assert!(*guard >= 0);
+ });
+
+ readers.push(handle);
+ }
+
+ // Wait for the readers to finish their checks.
+ for r in readers {
+ r.join().expect("One or more readers saw a negative value");
+ }
+
+ // Wait for the writer to finish.
+ rx.recv_timeout(Duration::from_secs(5)).unwrap();
+ assert_eq!(*block_on(mu.read_lock()), ITERATIONS);
+ assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0);
+ }
+
+ #[test]
+ fn rw_single_thread_async() {
+ // A Future that returns `Poll::pending` the first time it is polled and `Poll::Ready` every
+ // time after that.
+ struct TestFuture {
+ polled: bool,
+ waker: Arc<SpinLock<Option<Waker>>>,
+ }
+
+ impl Future for TestFuture {
+ type Output = ();
+
+ fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
+ if self.polled {
+ Poll::Ready(())
+ } else {
+ self.polled = true;
+ *self.waker.lock() = Some(cx.waker().clone());
+ Poll::Pending
+ }
+ }
+ }
+
+ fn wake_future(waker: Arc<SpinLock<Option<Waker>>>) {
+ loop {
+ if let Some(w) = waker.lock().take() {
+ w.wake();
+ return;
+ }
+
+ // This sleep cannot be moved into an else branch because we would end up holding
+ // the lock while sleeping due to rust's drop ordering rules.
+ thread::sleep(Duration::from_millis(10));
+ }
+ }
+
+ async fn writer(mu: Rc<Mutex<isize>>) {
+ let mut guard = mu.lock().await;
+ for _ in 0..ITERATIONS {
+ let tmp = *guard;
+ *guard = -1;
+ let waker = Arc::new(SpinLock::new(None));
+ let waker2 = Arc::clone(&waker);
+ thread::spawn(move || wake_future(waker2));
+ let fut = TestFuture {
+ polled: false,
+ waker,
+ };
+ fut.await;
+ *guard = tmp + 1;
+ }
+ }
+
+ async fn reader(mu: Rc<Mutex<isize>>) {
+ let guard = mu.read_lock().await;
+ assert!(*guard >= 0);
+ }
+
+ const TASKS: isize = 7;
+ const ITERATIONS: isize = 13;
+
+ let mu = Rc::new(Mutex::new(0isize));
+ let mut ex = LocalPool::new();
+ let spawner = ex.spawner();
+
+ spawner
+ .spawn_local(writer(Rc::clone(&mu)))
+ .expect("Failed to spawn writer");
+
+ for _ in 0..TASKS {
+ spawner
+ .spawn_local(reader(Rc::clone(&mu)))
+ .expect("Failed to spawn reader");
+ }
+
+ ex.run();
+
+ assert_eq!(*block_on(mu.read_lock()), ITERATIONS);
+ assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0);
+ }
+
+ #[test]
+ fn rw_multi_thread_async() {
+ async fn writer(mu: Arc<Mutex<isize>>, tx: Sender<()>) {
+ let mut guard = mu.lock().await;
+ for _ in 0..ITERATIONS {
+ let tmp = *guard;
+ *guard = -1;
+ thread::yield_now();
+ *guard = tmp + 1;
+ }
+
+ mem::drop(guard);
+ tx.send(()).unwrap();
+ }
+
+ async fn reader(mu: Arc<Mutex<isize>>, tx: Sender<()>) {
+ let guard = mu.read_lock().await;
+ assert!(*guard >= 0);
+
+ mem::drop(guard);
+ tx.send(()).expect("Failed to send completion message");
+ }
+
+ const TASKS: isize = 7;
+ const ITERATIONS: isize = 13;
+
+ let mu = Arc::new(Mutex::new(0isize));
+ let ex = ThreadPool::new().expect("Failed to create ThreadPool");
+
+ let (txw, rxw) = channel();
+ ex.spawn_ok(writer(Arc::clone(&mu), txw));
+
+ let (txr, rxr) = channel();
+ for _ in 0..TASKS {
+ ex.spawn_ok(reader(Arc::clone(&mu), txr.clone()));
+ }
+
+ // Wait for the readers to finish their checks.
+ for _ in 0..TASKS {
+ rxr.recv_timeout(Duration::from_secs(5))
+ .expect("Failed to receive completion message from reader");
+ }
+
+ // Wait for the writer to finish.
+ rxw.recv_timeout(Duration::from_secs(5))
+ .expect("Failed to receive completion message from writer");
+
+ assert_eq!(*block_on(mu.read_lock()), ITERATIONS);
+ assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0);
+ }
+
+ #[test]
+ fn wake_all_readers() {
+ async fn read(mu: Arc<Mutex<()>>) {
+ let g = mu.read_lock().await;
+ pending!();
+ mem::drop(g);
+ }
+
+ async fn write(mu: Arc<Mutex<()>>) {
+ mu.lock().await;
+ }
+
+ let mu = Arc::new(Mutex::new(()));
+ let mut futures: [Pin<Box<dyn Future<Output = ()>>>; 5] = [
+ Box::pin(read(mu.clone())),
+ Box::pin(read(mu.clone())),
+ Box::pin(read(mu.clone())),
+ Box::pin(write(mu.clone())),
+ Box::pin(read(mu.clone())),
+ ];
+ const NUM_READERS: usize = 4;
+
+ let arc_waker = Arc::new(TestWaker);
+ let waker = waker_ref(&arc_waker);
+ let mut cx = Context::from_waker(&waker);
+
+ // Acquire the lock so that the futures cannot get it.
+ let g = block_on(mu.lock());
+
+ for r in &mut futures {
+ if let Poll::Ready(()) = r.as_mut().poll(&mut cx) {
+ panic!("future unexpectedly ready");
+ }
+ }
+
+ assert_eq!(
+ mu.raw.state.load(Ordering::Relaxed) & HAS_WAITERS,
+ HAS_WAITERS
+ );
+
+ assert_eq!(
+ mu.raw.state.load(Ordering::Relaxed) & WRITER_WAITING,
+ WRITER_WAITING
+ );
+
+ // Drop the lock. This should allow all readers to make progress. Since they already waited
+ // once they should ignore the WRITER_WAITING bit that is currently set.
+ mem::drop(g);
+ for r in &mut futures {
+ if let Poll::Ready(()) = r.as_mut().poll(&mut cx) {
+ panic!("future unexpectedly ready");
+ }
+ }
+
+ // Check that all readers were able to acquire the lock.
+ assert_eq!(
+ mu.raw.state.load(Ordering::Relaxed) & READ_MASK,
+ READ_LOCK * NUM_READERS
+ );
+ assert_eq!(
+ mu.raw.state.load(Ordering::Relaxed) & WRITER_WAITING,
+ WRITER_WAITING
+ );
+
+ let mut needs_poll = None;
+
+ // All the readers can now finish but the writer needs to be polled again.
+ for (i, r) in futures.iter_mut().enumerate() {
+ match r.as_mut().poll(&mut cx) {
+ Poll::Ready(()) => {}
+ Poll::Pending => {
+ if needs_poll.is_some() {
+ panic!("More than one future unable to complete");
+ }
+ needs_poll = Some(i);
+ }
+ }
+ }
+
+ if futures[needs_poll.expect("Writer unexpectedly able to complete")]
+ .as_mut()
+ .poll(&mut cx)
+ .is_pending()
+ {
+ panic!("Writer unable to complete");
+ }
+
+ assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0);
+ }
+
+ #[test]
+ fn long_wait() {
+ async fn tight_loop(mu: Arc<Mutex<bool>>) {
+ loop {
+ let ready = mu.lock().await;
+ if *ready {
+ break;
+ }
+ pending!();
+ }
+ }
+
+ async fn mark_ready(mu: Arc<Mutex<bool>>) {
+ *mu.lock().await = true;
+ }
+
+ let mu = Arc::new(Mutex::new(false));
+ let mut tl = Box::pin(tight_loop(mu.clone()));
+ let mut mark = Box::pin(mark_ready(mu.clone()));
+
+ let arc_waker = Arc::new(TestWaker);
+ let waker = waker_ref(&arc_waker);
+ let mut cx = Context::from_waker(&waker);
+
+ for _ in 0..=LONG_WAIT_THRESHOLD {
+ if let Poll::Ready(()) = tl.as_mut().poll(&mut cx) {
+ panic!("tight_loop unexpectedly ready");
+ }
+
+ if let Poll::Ready(()) = mark.as_mut().poll(&mut cx) {
+ panic!("mark_ready unexpectedly ready");
+ }
+ }
+
+ assert_eq!(
+ mu.raw.state.load(Ordering::Relaxed),
+ LOCKED | HAS_WAITERS | WRITER_WAITING | LONG_WAIT
+ );
+
+ // This time the tight loop will fail to acquire the lock.
+ if let Poll::Ready(()) = tl.as_mut().poll(&mut cx) {
+ panic!("tight_loop unexpectedly ready");
+ }
+
+ // Which will finally allow the mark_ready function to make progress.
+ if mark.as_mut().poll(&mut cx).is_pending() {
+ panic!("mark_ready not able to make progress");
+ }
+
+ // Now the tight loop will finish.
+ if tl.as_mut().poll(&mut cx).is_pending() {
+ panic!("tight_loop not able to finish");
+ }
+
+ assert!(*block_on(mu.lock()));
+ assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0);
+ }
+
+ #[test]
+ fn cancel_long_wait_before_wake() {
+ async fn tight_loop(mu: Arc<Mutex<bool>>) {
+ loop {
+ let ready = mu.lock().await;
+ if *ready {
+ break;
+ }
+ pending!();
+ }
+ }
+
+ async fn mark_ready(mu: Arc<Mutex<bool>>) {
+ *mu.lock().await = true;
+ }
+
+ let mu = Arc::new(Mutex::new(false));
+ let mut tl = Box::pin(tight_loop(mu.clone()));
+ let mut mark = Box::pin(mark_ready(mu.clone()));
+
+ let arc_waker = Arc::new(TestWaker);
+ let waker = waker_ref(&arc_waker);
+ let mut cx = Context::from_waker(&waker);
+
+ for _ in 0..=LONG_WAIT_THRESHOLD {
+ if let Poll::Ready(()) = tl.as_mut().poll(&mut cx) {
+ panic!("tight_loop unexpectedly ready");
+ }
+
+ if let Poll::Ready(()) = mark.as_mut().poll(&mut cx) {
+ panic!("mark_ready unexpectedly ready");
+ }
+ }
+
+ assert_eq!(
+ mu.raw.state.load(Ordering::Relaxed),
+ LOCKED | HAS_WAITERS | WRITER_WAITING | LONG_WAIT
+ );
+
+ // Now drop the mark_ready future, which should clear the LONG_WAIT bit.
+ mem::drop(mark);
+ assert_eq!(mu.raw.state.load(Ordering::Relaxed), LOCKED);
+
+ mem::drop(tl);
+ assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0);
+ }
+
+ #[test]
+ fn cancel_long_wait_after_wake() {
+ async fn tight_loop(mu: Arc<Mutex<bool>>) {
+ loop {
+ let ready = mu.lock().await;
+ if *ready {
+ break;
+ }
+ pending!();
+ }
+ }
+
+ async fn mark_ready(mu: Arc<Mutex<bool>>) {
+ *mu.lock().await = true;
+ }
+
+ let mu = Arc::new(Mutex::new(false));
+ let mut tl = Box::pin(tight_loop(mu.clone()));
+ let mut mark = Box::pin(mark_ready(mu.clone()));
+
+ let arc_waker = Arc::new(TestWaker);
+ let waker = waker_ref(&arc_waker);
+ let mut cx = Context::from_waker(&waker);
+
+ for _ in 0..=LONG_WAIT_THRESHOLD {
+ if let Poll::Ready(()) = tl.as_mut().poll(&mut cx) {
+ panic!("tight_loop unexpectedly ready");
+ }
+
+ if let Poll::Ready(()) = mark.as_mut().poll(&mut cx) {
+ panic!("mark_ready unexpectedly ready");
+ }
+ }
+
+ assert_eq!(
+ mu.raw.state.load(Ordering::Relaxed),
+ LOCKED | HAS_WAITERS | WRITER_WAITING | LONG_WAIT
+ );
+
+ // This time the tight loop will fail to acquire the lock.
+ if let Poll::Ready(()) = tl.as_mut().poll(&mut cx) {
+ panic!("tight_loop unexpectedly ready");
+ }
+
+ // Now drop the mark_ready future, which should clear the LONG_WAIT bit.
+ mem::drop(mark);
+ assert_eq!(mu.raw.state.load(Ordering::Relaxed) & LONG_WAIT, 0);
+
+ // Since the lock is not held, we should be able to spawn a future to set the ready flag.
+ block_on(mark_ready(mu.clone()));
+
+ // Now the tight loop will finish.
+ if tl.as_mut().poll(&mut cx).is_pending() {
+ panic!("tight_loop not able to finish");
+ }
+
+ assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0);
+ }
+
+ #[test]
+ fn designated_waker() {
+ async fn inc(mu: Arc<Mutex<usize>>) {
+ *mu.lock().await += 1;
+ }
+
+ let mu = Arc::new(Mutex::new(0));
+
+ let mut futures = [
+ Box::pin(inc(mu.clone())),
+ Box::pin(inc(mu.clone())),
+ Box::pin(inc(mu.clone())),
+ ];
+
+ let arc_waker = Arc::new(TestWaker);
+ let waker = waker_ref(&arc_waker);
+ let mut cx = Context::from_waker(&waker);
+
+ let count = block_on(mu.lock());
+
+ // Poll 2 futures. Since neither will be able to acquire the lock, they should get added to
+ // the waiter list.
+ if let Poll::Ready(()) = futures[0].as_mut().poll(&mut cx) {
+ panic!("future unexpectedly ready");
+ }
+ if let Poll::Ready(()) = futures[1].as_mut().poll(&mut cx) {
+ panic!("future unexpectedly ready");
+ }
+
+ assert_eq!(
+ mu.raw.state.load(Ordering::Relaxed),
+ LOCKED | HAS_WAITERS | WRITER_WAITING,
+ );
+
+ // Now drop the lock. This should set the DESIGNATED_WAKER bit and wake up the first future
+ // in the wait list.
+ mem::drop(count);
+
+ assert_eq!(
+ mu.raw.state.load(Ordering::Relaxed),
+ DESIGNATED_WAKER | HAS_WAITERS | WRITER_WAITING,
+ );
+
+ // Now poll the third future. It should be able to acquire the lock immediately.
+ if futures[2].as_mut().poll(&mut cx).is_pending() {
+ panic!("future unable to complete");
+ }
+ assert_eq!(*block_on(mu.lock()), 1);
+
+ // There should still be a waiter in the wait list and the DESIGNATED_WAKER bit should still
+ // be set.
+ assert_eq!(
+ mu.raw.state.load(Ordering::Relaxed) & DESIGNATED_WAKER,
+ DESIGNATED_WAKER
+ );
+ assert_eq!(
+ mu.raw.state.load(Ordering::Relaxed) & HAS_WAITERS,
+ HAS_WAITERS
+ );
+
+ // Now let the future that was woken up run.
+ if futures[0].as_mut().poll(&mut cx).is_pending() {
+ panic!("future unable to complete");
+ }
+ assert_eq!(*block_on(mu.lock()), 2);
+
+ if futures[1].as_mut().poll(&mut cx).is_pending() {
+ panic!("future unable to complete");
+ }
+ assert_eq!(*block_on(mu.lock()), 3);
+
+ assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0);
+ }
+
+ #[test]
+ fn cancel_designated_waker() {
+ async fn inc(mu: Arc<Mutex<usize>>) {
+ *mu.lock().await += 1;
+ }
+
+ let mu = Arc::new(Mutex::new(0));
+
+ let mut fut = Box::pin(inc(mu.clone()));
+
+ let arc_waker = Arc::new(TestWaker);
+ let waker = waker_ref(&arc_waker);
+ let mut cx = Context::from_waker(&waker);
+
+ let count = block_on(mu.lock());
+
+ if let Poll::Ready(()) = fut.as_mut().poll(&mut cx) {
+ panic!("Future unexpectedly ready when lock is held");
+ }
+
+ // Drop the lock. This will wake up the future.
+ mem::drop(count);
+
+ // Now drop the future without polling. This should clear all the state in the mutex.
+ mem::drop(fut);
+
+ assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0);
+ }
+
+ #[test]
+ fn cancel_before_wake() {
+ async fn inc(mu: Arc<Mutex<usize>>) {
+ *mu.lock().await += 1;
+ }
+
+ let mu = Arc::new(Mutex::new(0));
+
+ let mut fut1 = Box::pin(inc(mu.clone()));
+
+ let mut fut2 = Box::pin(inc(mu.clone()));
+
+ let arc_waker = Arc::new(TestWaker);
+ let waker = waker_ref(&arc_waker);
+ let mut cx = Context::from_waker(&waker);
+
+ // First acquire the lock.
+ let count = block_on(mu.lock());
+
+ // Now poll the futures. Since the lock is acquired they will both get queued in the waiter
+ // list.
+ match fut1.as_mut().poll(&mut cx) {
+ Poll::Pending => {}
+ Poll::Ready(()) => panic!("Future is unexpectedly ready"),
+ }
+
+ match fut2.as_mut().poll(&mut cx) {
+ Poll::Pending => {}
+ Poll::Ready(()) => panic!("Future is unexpectedly ready"),
+ }
+
+ assert_eq!(
+ mu.raw.state.load(Ordering::Relaxed) & WRITER_WAITING,
+ WRITER_WAITING
+ );
+
+ // Drop fut1. This should remove it from the waiter list but shouldn't wake fut2.
+ mem::drop(fut1);
+
+ // There should be no designated waker.
+ assert_eq!(mu.raw.state.load(Ordering::Relaxed) & DESIGNATED_WAKER, 0);
+
+ // Since the waiter was a writer, we should clear the WRITER_WAITING bit.
+ assert_eq!(mu.raw.state.load(Ordering::Relaxed) & WRITER_WAITING, 0);
+
+ match fut2.as_mut().poll(&mut cx) {
+ Poll::Pending => {}
+ Poll::Ready(()) => panic!("Future is unexpectedly ready"),
+ }
+
+ // Now drop the lock. This should mark fut2 as ready to make progress.
+ mem::drop(count);
+
+ match fut2.as_mut().poll(&mut cx) {
+ Poll::Pending => panic!("Future is not ready to make progress"),
+ Poll::Ready(()) => {}
+ }
+
+ // Verify that we only incremented the count once.
+ assert_eq!(*block_on(mu.lock()), 1);
+ assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0);
+ }
+
+ #[test]
+ fn cancel_after_wake() {
+ async fn inc(mu: Arc<Mutex<usize>>) {
+ *mu.lock().await += 1;
+ }
+
+ let mu = Arc::new(Mutex::new(0));
+
+ let mut fut1 = Box::pin(inc(mu.clone()));
+
+ let mut fut2 = Box::pin(inc(mu.clone()));
+
+ let arc_waker = Arc::new(TestWaker);
+ let waker = waker_ref(&arc_waker);
+ let mut cx = Context::from_waker(&waker);
+
+ // First acquire the lock.
+ let count = block_on(mu.lock());
+
+ // Now poll the futures. Since the lock is acquired they will both get queued in the waiter
+ // list.
+ match fut1.as_mut().poll(&mut cx) {
+ Poll::Pending => {}
+ Poll::Ready(()) => panic!("Future is unexpectedly ready"),
+ }
+
+ match fut2.as_mut().poll(&mut cx) {
+ Poll::Pending => {}
+ Poll::Ready(()) => panic!("Future is unexpectedly ready"),
+ }
+
+ assert_eq!(
+ mu.raw.state.load(Ordering::Relaxed) & WRITER_WAITING,
+ WRITER_WAITING
+ );
+
+ // Drop the lock. This should mark fut1 as ready to make progress.
+ mem::drop(count);
+
+ // Now drop fut1. This should make fut2 ready to make progress.
+ mem::drop(fut1);
+
+ // Since there was still another waiter in the list we shouldn't have cleared the
+ // DESIGNATED_WAKER bit.
+ assert_eq!(
+ mu.raw.state.load(Ordering::Relaxed) & DESIGNATED_WAKER,
+ DESIGNATED_WAKER
+ );
+
+ // Since the waiter was a writer, we should clear the WRITER_WAITING bit.
+ assert_eq!(mu.raw.state.load(Ordering::Relaxed) & WRITER_WAITING, 0);
+
+ match fut2.as_mut().poll(&mut cx) {
+ Poll::Pending => panic!("Future is not ready to make progress"),
+ Poll::Ready(()) => {}
+ }
+
+ // Verify that we only incremented the count once.
+ assert_eq!(*block_on(mu.lock()), 1);
+ assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0);
+ }
+
+ #[test]
+ fn timeout() {
+ async fn timed_lock(timer: oneshot::Receiver<()>, mu: Arc<Mutex<()>>) {
+ select! {
+ res = timer.fuse() => {
+ match res {
+ Ok(()) => {},
+ Err(e) => panic!("Timer unexpectedly canceled: {}", e),
+ }
+ }
+ _ = mu.lock().fuse() => panic!("Successfuly acquired lock"),
+ }
+ }
+
+ let mu = Arc::new(Mutex::new(()));
+ let (tx, rx) = oneshot::channel();
+
+ let mut timeout = Box::pin(timed_lock(rx, mu.clone()));
+
+ let arc_waker = Arc::new(TestWaker);
+ let waker = waker_ref(&arc_waker);
+ let mut cx = Context::from_waker(&waker);
+
+ // Acquire the lock.
+ let g = block_on(mu.lock());
+
+ // Poll the future.
+ if let Poll::Ready(()) = timeout.as_mut().poll(&mut cx) {
+ panic!("timed_lock unexpectedly ready");
+ }
+
+ assert_eq!(
+ mu.raw.state.load(Ordering::Relaxed) & HAS_WAITERS,
+ HAS_WAITERS
+ );
+
+ // Signal the channel, which should cancel the lock.
+ tx.send(()).expect("Failed to send wakeup");
+
+ // Now the future should have completed without acquiring the lock.
+ if timeout.as_mut().poll(&mut cx).is_pending() {
+ panic!("timed_lock not ready after timeout");
+ }
+
+ // The mutex state should not show any waiters.
+ assert_eq!(mu.raw.state.load(Ordering::Relaxed) & HAS_WAITERS, 0);
+
+ mem::drop(g);
+
+ assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0);
+ }
+
+ #[test]
+ fn writer_waiting() {
+ async fn read_zero(mu: Arc<Mutex<usize>>) {
+ let val = mu.read_lock().await;
+ pending!();
+
+ assert_eq!(*val, 0);
+ }
+
+ async fn inc(mu: Arc<Mutex<usize>>) {
+ *mu.lock().await += 1;
+ }
+
+ async fn read_one(mu: Arc<Mutex<usize>>) {
+ let val = mu.read_lock().await;
+
+ assert_eq!(*val, 1);
+ }
+
+ let mu = Arc::new(Mutex::new(0));
+
+ let mut r1 = Box::pin(read_zero(mu.clone()));
+ let mut r2 = Box::pin(read_zero(mu.clone()));
+
+ let mut w = Box::pin(inc(mu.clone()));
+ let mut r3 = Box::pin(read_one(mu.clone()));
+
+ let arc_waker = Arc::new(TestWaker);
+ let waker = waker_ref(&arc_waker);
+ let mut cx = Context::from_waker(&waker);
+
+ if let Poll::Ready(()) = r1.as_mut().poll(&mut cx) {
+ panic!("read_zero unexpectedly ready");
+ }
+ if let Poll::Ready(()) = r2.as_mut().poll(&mut cx) {
+ panic!("read_zero unexpectedly ready");
+ }
+ assert_eq!(
+ mu.raw.state.load(Ordering::Relaxed) & READ_MASK,
+ 2 * READ_LOCK
+ );
+
+ if let Poll::Ready(()) = w.as_mut().poll(&mut cx) {
+ panic!("inc unexpectedly ready");
+ }
+ assert_eq!(
+ mu.raw.state.load(Ordering::Relaxed) & WRITER_WAITING,
+ WRITER_WAITING
+ );
+
+ // The WRITER_WAITING bit should prevent the next reader from acquiring the lock.
+ if let Poll::Ready(()) = r3.as_mut().poll(&mut cx) {
+ panic!("read_one unexpectedly ready");
+ }
+ assert_eq!(
+ mu.raw.state.load(Ordering::Relaxed) & READ_MASK,
+ 2 * READ_LOCK
+ );
+
+ if r1.as_mut().poll(&mut cx).is_pending() {
+ panic!("read_zero unable to complete");
+ }
+ if r2.as_mut().poll(&mut cx).is_pending() {
+ panic!("read_zero unable to complete");
+ }
+ if w.as_mut().poll(&mut cx).is_pending() {
+ panic!("inc unable to complete");
+ }
+ if r3.as_mut().poll(&mut cx).is_pending() {
+ panic!("read_one unable to complete");
+ }
+
+ assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0);
+ }
+
+ #[test]
+ fn notify_one() {
+ async fn read(mu: Arc<Mutex<usize>>, cv: Arc<Condvar>) {
+ let mut count = mu.read_lock().await;
+ while *count == 0 {
+ count = cv.wait_read(count).await;
+ }
+ }
+
+ async fn write(mu: Arc<Mutex<usize>>, cv: Arc<Condvar>) {
+ let mut count = mu.lock().await;
+ while *count == 0 {
+ count = cv.wait(count).await;
+ }
+
+ *count -= 1;
+ }
+
+ let mu = Arc::new(Mutex::new(0));
+ let cv = Arc::new(Condvar::new());
+
+ let arc_waker = Arc::new(TestWaker);
+ let waker = waker_ref(&arc_waker);
+ let mut cx = Context::from_waker(&waker);
+
+ let mut readers = [
+ Box::pin(read(mu.clone(), cv.clone())),
+ Box::pin(read(mu.clone(), cv.clone())),
+ Box::pin(read(mu.clone(), cv.clone())),
+ Box::pin(read(mu.clone(), cv.clone())),
+ ];
+ let mut writer = Box::pin(write(mu.clone(), cv.clone()));
+
+ for r in &mut readers {
+ if let Poll::Ready(()) = r.as_mut().poll(&mut cx) {
+ panic!("reader unexpectedly ready");
+ }
+ }
+ if let Poll::Ready(()) = writer.as_mut().poll(&mut cx) {
+ panic!("writer unexpectedly ready");
+ }
+
+ let mut count = block_on(mu.lock());
+ *count = 1;
+
+ // This should wake all readers + one writer.
+ cv.notify_one();
+
+ // Poll the readers and the writer so they add themselves to the mutex's waiter list.
+ for r in &mut readers {
+ if r.as_mut().poll(&mut cx).is_ready() {
+ panic!("reader unexpectedly ready");
+ }
+ }
+
+ if writer.as_mut().poll(&mut cx).is_ready() {
+ panic!("writer unexpectedly ready");
+ }
+
+ assert_eq!(
+ mu.raw.state.load(Ordering::Relaxed) & HAS_WAITERS,
+ HAS_WAITERS
+ );
+ assert_eq!(
+ mu.raw.state.load(Ordering::Relaxed) & WRITER_WAITING,
+ WRITER_WAITING
+ );
+
+ mem::drop(count);
+
+ assert_eq!(
+ mu.raw.state.load(Ordering::Relaxed) & (HAS_WAITERS | WRITER_WAITING),
+ HAS_WAITERS | WRITER_WAITING
+ );
+
+ for r in &mut readers {
+ if r.as_mut().poll(&mut cx).is_pending() {
+ panic!("reader unable to complete");
+ }
+ }
+
+ if writer.as_mut().poll(&mut cx).is_pending() {
+ panic!("writer unable to complete");
+ }
+
+ assert_eq!(*block_on(mu.read_lock()), 0);
+ }
+
+ #[test]
+ fn notify_when_unlocked() {
+ async fn dec(mu: Arc<Mutex<usize>>, cv: Arc<Condvar>) {
+ let mut count = mu.lock().await;
+
+ while *count == 0 {
+ count = cv.wait(count).await;
+ }
+
+ *count -= 1;
+ }
+
+ let mu = Arc::new(Mutex::new(0));
+ let cv = Arc::new(Condvar::new());
+
+ let arc_waker = Arc::new(TestWaker);
+ let waker = waker_ref(&arc_waker);
+ let mut cx = Context::from_waker(&waker);
+
+ let mut futures = [
+ Box::pin(dec(mu.clone(), cv.clone())),
+ Box::pin(dec(mu.clone(), cv.clone())),
+ Box::pin(dec(mu.clone(), cv.clone())),
+ Box::pin(dec(mu.clone(), cv.clone())),
+ ];
+
+ for f in &mut futures {
+ if let Poll::Ready(()) = f.as_mut().poll(&mut cx) {
+ panic!("future unexpectedly ready");
+ }
+ }
+
+ *block_on(mu.lock()) = futures.len();
+ cv.notify_all();
+
+ // Since we haven't polled `futures` yet, the mutex should not have any waiters.
+ assert_eq!(mu.raw.state.load(Ordering::Relaxed) & HAS_WAITERS, 0);
+
+ for f in &mut futures {
+ if f.as_mut().poll(&mut cx).is_pending() {
+ panic!("future unexpectedly ready");
+ }
+ }
+ assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0);
+ }
+
+ #[test]
+ fn notify_reader_writer() {
+ async fn read(mu: Arc<Mutex<usize>>, cv: Arc<Condvar>) {
+ let mut count = mu.read_lock().await;
+ while *count == 0 {
+ count = cv.wait_read(count).await;
+ }
+
+ // Yield once while holding the read lock, which should prevent the writer from waking
+ // up.
+ pending!();
+ }
+
+ async fn write(mu: Arc<Mutex<usize>>, cv: Arc<Condvar>) {
+ let mut count = mu.lock().await;
+ while *count == 0 {
+ count = cv.wait(count).await;
+ }
+
+ *count -= 1;
+ }
+
+ async fn lock(mu: Arc<Mutex<usize>>) {
+ mem::drop(mu.lock().await);
+ }
+
+ let mu = Arc::new(Mutex::new(0));
+ let cv = Arc::new(Condvar::new());
+
+ let arc_waker = Arc::new(TestWaker);
+ let waker = waker_ref(&arc_waker);
+ let mut cx = Context::from_waker(&waker);
+
+ let mut futures: [Pin<Box<dyn Future<Output = ()>>>; 5] = [
+ Box::pin(read(mu.clone(), cv.clone())),
+ Box::pin(read(mu.clone(), cv.clone())),
+ Box::pin(read(mu.clone(), cv.clone())),
+ Box::pin(write(mu.clone(), cv.clone())),
+ Box::pin(read(mu.clone(), cv.clone())),
+ ];
+ const NUM_READERS: usize = 4;
+
+ let mut l = Box::pin(lock(mu.clone()));
+
+ for f in &mut futures {
+ if let Poll::Ready(()) = f.as_mut().poll(&mut cx) {
+ panic!("future unexpectedly ready");
+ }
+ }
+
+ assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0);
+
+ let mut count = block_on(mu.lock());
+ *count = 1;
+
+ // Now poll the lock function. Since the lock is held by us, it will get queued on the
+ // waiter list.
+ if let Poll::Ready(()) = l.as_mut().poll(&mut cx) {
+ panic!("lock() unexpectedly ready");
+ }
+
+ assert_eq!(
+ mu.raw.state.load(Ordering::Relaxed) & (HAS_WAITERS | WRITER_WAITING),
+ HAS_WAITERS | WRITER_WAITING
+ );
+
+ // Wake up waiters while holding the lock.
+ cv.notify_all();
+
+ // Drop the lock. This should wake up the lock function.
+ mem::drop(count);
+
+ if l.as_mut().poll(&mut cx).is_pending() {
+ panic!("lock() unable to complete");
+ }
+
+ // Since we haven't polled `futures` yet, the mutex state should now be empty.
+ assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0);
+
+ // Poll everything again. The readers should be able to make progress (but not complete) but
+ // the writer should be blocked.
+ for f in &mut futures {
+ if let Poll::Ready(()) = f.as_mut().poll(&mut cx) {
+ panic!("future unexpectedly ready");
+ }
+ }
+
+ assert_eq!(
+ mu.raw.state.load(Ordering::Relaxed) & READ_MASK,
+ READ_LOCK * NUM_READERS
+ );
+
+ // All the readers can now finish but the writer needs to be polled again.
+ let mut needs_poll = None;
+ for (i, r) in futures.iter_mut().enumerate() {
+ match r.as_mut().poll(&mut cx) {
+ Poll::Ready(()) => {}
+ Poll::Pending => {
+ if needs_poll.is_some() {
+ panic!("More than one future unable to complete");
+ }
+ needs_poll = Some(i);
+ }
+ }
+ }
+
+ if futures[needs_poll.expect("Writer unexpectedly able to complete")]
+ .as_mut()
+ .poll(&mut cx)
+ .is_pending()
+ {
+ panic!("Writer unable to complete");
+ }
+
+ assert_eq!(*block_on(mu.lock()), 0);
+ assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0);
+ }
+
+ #[test]
+ fn notify_readers_with_read_lock() {
+ async fn read(mu: Arc<Mutex<usize>>, cv: Arc<Condvar>) {
+ let mut count = mu.read_lock().await;
+ while *count == 0 {
+ count = cv.wait_read(count).await;
+ }
+
+ // Yield once while holding the read lock.
+ pending!();
+ }
+
+ let mu = Arc::new(Mutex::new(0));
+ let cv = Arc::new(Condvar::new());
+
+ let arc_waker = Arc::new(TestWaker);
+ let waker = waker_ref(&arc_waker);
+ let mut cx = Context::from_waker(&waker);
+
+ let mut futures = [
+ Box::pin(read(mu.clone(), cv.clone())),
+ Box::pin(read(mu.clone(), cv.clone())),
+ Box::pin(read(mu.clone(), cv.clone())),
+ Box::pin(read(mu.clone(), cv.clone())),
+ ];
+
+ for f in &mut futures {
+ if let Poll::Ready(()) = f.as_mut().poll(&mut cx) {
+ panic!("future unexpectedly ready");
+ }
+ }
+
+ // Increment the count and then grab a read lock.
+ *block_on(mu.lock()) = 1;
+
+ let g = block_on(mu.read_lock());
+
+ // Notify the condvar while holding the read lock. This should wake up all the waiters.
+ cv.notify_all();
+
+ // Since the lock is held in shared mode, all the readers should immediately be able to
+ // acquire the read lock.
+ for f in &mut futures {
+ if let Poll::Ready(()) = f.as_mut().poll(&mut cx) {
+ panic!("future unexpectedly ready");
+ }
+ }
+ assert_eq!(mu.raw.state.load(Ordering::Relaxed) & HAS_WAITERS, 0);
+ assert_eq!(
+ mu.raw.state.load(Ordering::Relaxed) & READ_MASK,
+ READ_LOCK * (futures.len() + 1)
+ );
+
+ mem::drop(g);
+
+ for f in &mut futures {
+ if f.as_mut().poll(&mut cx).is_pending() {
+ panic!("future unable to complete");
+ }
+ }
+
+ assert_eq!(mu.raw.state.load(Ordering::Relaxed), 0);
+ }
+}
diff --git a/cros_async/src/sync/spin.rs b/cros_async/src/sync/spin.rs
new file mode 100644
index 000000000..8b7b81193
--- /dev/null
+++ b/cros_async/src/sync/spin.rs
@@ -0,0 +1,277 @@
+// Copyright 2020 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+use std::cell::UnsafeCell;
+use std::ops::{Deref, DerefMut};
+use std::sync::atomic::{spin_loop_hint, AtomicBool, Ordering};
+
+const UNLOCKED: bool = false;
+const LOCKED: bool = true;
+
+/// A primitive that provides safe, mutable access to a shared resource.
+///
+/// Unlike `Mutex`, a `SpinLock` will not voluntarily yield its CPU time until the resource is
+/// available and will instead keep spinning until the resource is acquired. For the vast majority
+/// of cases, `Mutex` is a better choice than `SpinLock`. If a `SpinLock` must be used then users
+/// should try to do as little work as possible while holding the `SpinLock` and avoid any sort of
+/// blocking at all costs as it can severely penalize performance.
+///
+/// # Poisoning
+///
+/// This `SpinLock` does not implement lock poisoning so it is possible for threads to access
+/// poisoned data if a thread panics while holding the lock. If lock poisoning is needed, it can be
+/// implemented by wrapping the `SpinLock` in a new type that implements poisoning. See the
+/// implementation of `std::sync::Mutex` for an example of how to do this.
+#[repr(align(128))]
+pub struct SpinLock<T: ?Sized> {
+ lock: AtomicBool,
+ value: UnsafeCell<T>,
+}
+
+impl<T> SpinLock<T> {
+ /// Creates a new, unlocked `SpinLock` that's ready for use.
+ pub fn new(value: T) -> SpinLock<T> {
+ SpinLock {
+ lock: AtomicBool::new(UNLOCKED),
+ value: UnsafeCell::new(value),
+ }
+ }
+
+ /// Consumes the `SpinLock` and returns the value guarded by it. This method doesn't perform any
+ /// locking as the compiler guarantees that there are no references to `self`.
+ pub fn into_inner(self) -> T {
+ // No need to take the lock because the compiler can statically guarantee
+ // that there are no references to the SpinLock.
+ self.value.into_inner()
+ }
+}
+
+impl<T: ?Sized> SpinLock<T> {
+ /// Acquires exclusive, mutable access to the resource protected by the `SpinLock`, blocking the
+ /// current thread until it is able to do so. Upon returning, the current thread will be the
+ /// only thread with access to the resource. The `SpinLock` will be released when the returned
+ /// `SpinLockGuard` is dropped. Attempting to call `lock` while already holding the `SpinLock`
+ /// will cause a deadlock.
+ pub fn lock(&self) -> SpinLockGuard<T> {
+ loop {
+ let state = self.lock.load(Ordering::Relaxed);
+ if state == UNLOCKED
+ && self
+ .lock
+ .compare_exchange_weak(UNLOCKED, LOCKED, Ordering::Acquire, Ordering::Relaxed)
+ .is_ok()
+ {
+ break;
+ }
+ spin_loop_hint();
+ }
+
+ SpinLockGuard {
+ lock: self,
+ value: unsafe { &mut *self.value.get() },
+ }
+ }
+
+ fn unlock(&self) {
+ // Don't need to compare and swap because we exclusively hold the lock.
+ self.lock.store(UNLOCKED, Ordering::Release);
+ }
+
+ /// Returns a mutable reference to the contained value. This method doesn't perform any locking
+ /// as the compiler will statically guarantee that there are no other references to `self`.
+ pub fn get_mut(&mut self) -> &mut T {
+ // Safe because the compiler can statically guarantee that there are no other references to
+ // `self`. This is also why we don't need to acquire the lock.
+ unsafe { &mut *self.value.get() }
+ }
+}
+
+unsafe impl<T: ?Sized + Send> Send for SpinLock<T> {}
+unsafe impl<T: ?Sized + Send> Sync for SpinLock<T> {}
+
+impl<T: ?Sized + Default> Default for SpinLock<T> {
+ fn default() -> Self {
+ Self::new(Default::default())
+ }
+}
+
+impl<T> From<T> for SpinLock<T> {
+ fn from(source: T) -> Self {
+ Self::new(source)
+ }
+}
+
+/// An RAII implementation of a "scoped lock" for a `SpinLock`. When this structure is dropped, the
+/// lock will be released. The resource protected by the `SpinLock` can be accessed via the `Deref`
+/// and `DerefMut` implementations of this structure.
+pub struct SpinLockGuard<'a, T: 'a + ?Sized> {
+ lock: &'a SpinLock<T>,
+ value: &'a mut T,
+}
+
+impl<'a, T: ?Sized> Deref for SpinLockGuard<'a, T> {
+ type Target = T;
+ fn deref(&self) -> &T {
+ self.value
+ }
+}
+
+impl<'a, T: ?Sized> DerefMut for SpinLockGuard<'a, T> {
+ fn deref_mut(&mut self) -> &mut T {
+ self.value
+ }
+}
+
+impl<'a, T: ?Sized> Drop for SpinLockGuard<'a, T> {
+ fn drop(&mut self) {
+ self.lock.unlock();
+ }
+}
+
+#[cfg(test)]
+mod test {
+ use super::*;
+
+ use std::mem;
+ use std::sync::atomic::{AtomicUsize, Ordering};
+ use std::sync::Arc;
+ use std::thread;
+
+ #[derive(PartialEq, Eq, Debug)]
+ struct NonCopy(u32);
+
+ #[test]
+ fn it_works() {
+ let sl = SpinLock::new(NonCopy(13));
+
+ assert_eq!(*sl.lock(), NonCopy(13));
+ }
+
+ #[test]
+ fn smoke() {
+ let sl = SpinLock::new(NonCopy(7));
+
+ mem::drop(sl.lock());
+ mem::drop(sl.lock());
+ }
+
+ #[test]
+ fn send() {
+ let sl = SpinLock::new(NonCopy(19));
+
+ thread::spawn(move || {
+ let value = sl.lock();
+ assert_eq!(*value, NonCopy(19));
+ })
+ .join()
+ .unwrap();
+ }
+
+ #[test]
+ fn high_contention() {
+ const THREADS: usize = 23;
+ const ITERATIONS: usize = 101;
+
+ let mut threads = Vec::with_capacity(THREADS);
+
+ let sl = Arc::new(SpinLock::new(0usize));
+ for _ in 0..THREADS {
+ let sl2 = sl.clone();
+ threads.push(thread::spawn(move || {
+ for _ in 0..ITERATIONS {
+ *sl2.lock() += 1;
+ }
+ }));
+ }
+
+ for t in threads.into_iter() {
+ t.join().unwrap();
+ }
+
+ assert_eq!(*sl.lock(), THREADS * ITERATIONS);
+ }
+
+ #[test]
+ fn get_mut() {
+ let mut sl = SpinLock::new(NonCopy(13));
+ *sl.get_mut() = NonCopy(17);
+
+ assert_eq!(sl.into_inner(), NonCopy(17));
+ }
+
+ #[test]
+ fn into_inner() {
+ let sl = SpinLock::new(NonCopy(29));
+ assert_eq!(sl.into_inner(), NonCopy(29));
+ }
+
+ #[test]
+ fn into_inner_drop() {
+ struct NeedsDrop(Arc<AtomicUsize>);
+ impl Drop for NeedsDrop {
+ fn drop(&mut self) {
+ self.0.fetch_add(1, Ordering::AcqRel);
+ }
+ }
+
+ let value = Arc::new(AtomicUsize::new(0));
+ let needs_drop = SpinLock::new(NeedsDrop(value.clone()));
+ assert_eq!(value.load(Ordering::Acquire), 0);
+
+ {
+ let inner = needs_drop.into_inner();
+ assert_eq!(inner.0.load(Ordering::Acquire), 0);
+ }
+
+ assert_eq!(value.load(Ordering::Acquire), 1);
+ }
+
+ #[test]
+ fn arc_nested() {
+ // Tests nested sltexes and access to underlying data.
+ let sl = SpinLock::new(1);
+ let arc = Arc::new(SpinLock::new(sl));
+ thread::spawn(move || {
+ let nested = arc.lock();
+ let lock2 = nested.lock();
+ assert_eq!(*lock2, 1);
+ })
+ .join()
+ .unwrap();
+ }
+
+ #[test]
+ fn arc_access_in_unwind() {
+ let arc = Arc::new(SpinLock::new(1));
+ let arc2 = arc.clone();
+ thread::spawn(move || {
+ struct Unwinder {
+ i: Arc<SpinLock<i32>>,
+ }
+ impl Drop for Unwinder {
+ fn drop(&mut self) {
+ *self.i.lock() += 1;
+ }
+ }
+ let _u = Unwinder { i: arc2 };
+ panic!();
+ })
+ .join()
+ .expect_err("thread did not panic");
+ let lock = arc.lock();
+ assert_eq!(*lock, 2);
+ }
+
+ #[test]
+ fn unsized_value() {
+ let sltex: &SpinLock<[i32]> = &SpinLock::new([1, 2, 3]);
+ {
+ let b = &mut *sltex.lock();
+ b[0] = 4;
+ b[2] = 5;
+ }
+ let expected: &[i32] = &[4, 2, 5];
+ assert_eq!(&*sltex.lock(), expected);
+ }
+}
diff --git a/cros_async/src/sync/waiter.rs b/cros_async/src/sync/waiter.rs
new file mode 100644
index 000000000..072a0f506
--- /dev/null
+++ b/cros_async/src/sync/waiter.rs
@@ -0,0 +1,281 @@
+// Copyright 2020 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+use std::cell::UnsafeCell;
+use std::future::Future;
+use std::mem;
+use std::pin::Pin;
+use std::ptr::NonNull;
+use std::sync::atomic::{AtomicBool, AtomicU8, Ordering};
+use std::sync::Arc;
+use std::task::{Context, Poll, Waker};
+
+use intrusive_collections::linked_list::{LinkedList, LinkedListOps};
+use intrusive_collections::{intrusive_adapter, DefaultLinkOps, LinkOps};
+
+use crate::sync::SpinLock;
+
+// An atomic version of a LinkedListLink. See https://github.com/Amanieu/intrusive-rs/issues/47 for
+// more details.
+#[repr(align(128))]
+pub struct AtomicLink {
+ prev: UnsafeCell<Option<NonNull<AtomicLink>>>,
+ next: UnsafeCell<Option<NonNull<AtomicLink>>>,
+ linked: AtomicBool,
+}
+
+impl AtomicLink {
+ fn new() -> AtomicLink {
+ AtomicLink {
+ linked: AtomicBool::new(false),
+ prev: UnsafeCell::new(None),
+ next: UnsafeCell::new(None),
+ }
+ }
+
+ fn is_linked(&self) -> bool {
+ self.linked.load(Ordering::Relaxed)
+ }
+}
+
+impl DefaultLinkOps for AtomicLink {
+ type Ops = AtomicLinkOps;
+
+ const NEW: Self::Ops = AtomicLinkOps;
+}
+
+// Safe because the only way to mutate `AtomicLink` is via the `LinkedListOps` trait whose methods
+// are all unsafe and require that the caller has first called `acquire_link` (and had it return
+// true) to use them safely.
+unsafe impl Send for AtomicLink {}
+unsafe impl Sync for AtomicLink {}
+
+#[derive(Copy, Clone, Default)]
+pub struct AtomicLinkOps;
+
+unsafe impl LinkOps for AtomicLinkOps {
+ type LinkPtr = NonNull<AtomicLink>;
+
+ unsafe fn acquire_link(&mut self, ptr: Self::LinkPtr) -> bool {
+ !ptr.as_ref().linked.swap(true, Ordering::Acquire)
+ }
+
+ unsafe fn release_link(&mut self, ptr: Self::LinkPtr) {
+ ptr.as_ref().linked.store(false, Ordering::Release)
+ }
+}
+
+unsafe impl LinkedListOps for AtomicLinkOps {
+ unsafe fn next(&self, ptr: Self::LinkPtr) -> Option<Self::LinkPtr> {
+ *ptr.as_ref().next.get()
+ }
+
+ unsafe fn prev(&self, ptr: Self::LinkPtr) -> Option<Self::LinkPtr> {
+ *ptr.as_ref().prev.get()
+ }
+
+ unsafe fn set_next(&mut self, ptr: Self::LinkPtr, next: Option<Self::LinkPtr>) {
+ *ptr.as_ref().next.get() = next;
+ }
+
+ unsafe fn set_prev(&mut self, ptr: Self::LinkPtr, prev: Option<Self::LinkPtr>) {
+ *ptr.as_ref().prev.get() = prev;
+ }
+}
+
+#[derive(Clone, Copy)]
+pub enum Kind {
+ Shared,
+ Exclusive,
+}
+
+enum State {
+ Init,
+ Waiting(Waker),
+ Woken,
+ Finished,
+ Processing,
+}
+
+// Indicates the queue to which the waiter belongs. It is the responsibility of the Mutex and
+// Condvar implementations to update this value when adding/removing a Waiter from their respective
+// waiter lists.
+#[repr(u8)]
+#[derive(Debug, Eq, PartialEq)]
+pub enum WaitingFor {
+ // The waiter is either not linked into a waiter list or it is linked into a temporary list.
+ None = 0,
+ // The waiter is linked into the Mutex's waiter list.
+ Mutex = 1,
+ // The waiter is linked into the Condvar's waiter list.
+ Condvar = 2,
+}
+
+// Represents a thread currently blocked on a Condvar or on acquiring a Mutex.
+pub struct Waiter {
+ link: AtomicLink,
+ state: SpinLock<State>,
+ cancel: fn(usize, &Waiter, bool),
+ cancel_data: usize,
+ kind: Kind,
+ waiting_for: AtomicU8,
+}
+
+impl Waiter {
+ // Create a new, initialized Waiter.
+ //
+ // `kind` should indicate whether this waiter represent a thread that is waiting for a shared
+ // lock or an exclusive lock.
+ //
+ // `cancel` is the function that is called when a `WaitFuture` (returned by the `wait()`
+ // function) is dropped before it can complete. `cancel_data` is used as the first parameter of
+ // the `cancel` function. The second parameter is the `Waiter` that was canceled and the third
+ // parameter indicates whether the `WaitFuture` was dropped after it was woken (but before it
+ // was polled to completion). A value of `false` for the third parameter may already be stale
+ // by the time the cancel function runs and so does not guarantee that the waiter was not woken.
+ // In this case, implementations should still check if the Waiter was woken. However, a value of
+ // `true` guarantees that the waiter was already woken up so no additional checks are necessary.
+ // In this case, the cancel implementation should wake up the next waiter in its wait list, if
+ // any.
+ //
+ // `waiting_for` indicates the waiter list to which this `Waiter` will be added. See the
+ // documentation of the `WaitingFor` enum for the meaning of the different values.
+ pub fn new(
+ kind: Kind,
+ cancel: fn(usize, &Waiter, bool),
+ cancel_data: usize,
+ waiting_for: WaitingFor,
+ ) -> Waiter {
+ Waiter {
+ link: AtomicLink::new(),
+ state: SpinLock::new(State::Init),
+ cancel,
+ cancel_data,
+ kind,
+ waiting_for: AtomicU8::new(waiting_for as u8),
+ }
+ }
+
+ // The kind of lock that this `Waiter` is waiting to acquire.
+ pub fn kind(&self) -> Kind {
+ self.kind
+ }
+
+ // Returns true if this `Waiter` is currently linked into a waiter list.
+ pub fn is_linked(&self) -> bool {
+ self.link.is_linked()
+ }
+
+ // Indicates the waiter list to which this `Waiter` belongs.
+ pub fn is_waiting_for(&self) -> WaitingFor {
+ match self.waiting_for.load(Ordering::Acquire) {
+ 0 => WaitingFor::None,
+ 1 => WaitingFor::Mutex,
+ 2 => WaitingFor::Condvar,
+ v => panic!("Unknown value for `WaitingFor`: {}", v),
+ }
+ }
+
+ // Change the waiter list to which this `Waiter` belongs. This will panic if called when the
+ // `Waiter` is still linked into a waiter list.
+ pub fn set_waiting_for(&self, waiting_for: WaitingFor) {
+ self.waiting_for.store(waiting_for as u8, Ordering::Release);
+ }
+
+ // Reset the Waiter back to its initial state. Panics if this `Waiter` is still linked into a
+ // waiter list.
+ pub fn reset(&self, waiting_for: WaitingFor) {
+ debug_assert!(!self.is_linked(), "Cannot reset `Waiter` while linked");
+ self.set_waiting_for(waiting_for);
+
+ let mut state = self.state.lock();
+ if let State::Waiting(waker) = mem::replace(&mut *state, State::Init) {
+ mem::drop(state);
+ mem::drop(waker);
+ }
+ }
+
+ // Wait until woken up by another thread.
+ pub fn wait(&self) -> WaitFuture<'_> {
+ WaitFuture { waiter: self }
+ }
+
+ // Wake up the thread associated with this `Waiter`. Panics if `waiting_for()` does not return
+ // `WaitingFor::None` or if `is_linked()` returns true.
+ pub fn wake(&self) {
+ debug_assert!(!self.is_linked(), "Cannot wake `Waiter` while linked");
+ debug_assert_eq!(self.is_waiting_for(), WaitingFor::None);
+
+ let mut state = self.state.lock();
+
+ if let State::Waiting(waker) = mem::replace(&mut *state, State::Woken) {
+ mem::drop(state);
+ waker.wake();
+ }
+ }
+}
+
+pub struct WaitFuture<'w> {
+ waiter: &'w Waiter,
+}
+
+impl<'w> Future for WaitFuture<'w> {
+ type Output = ();
+
+ fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
+ let mut state = self.waiter.state.lock();
+
+ match mem::replace(&mut *state, State::Processing) {
+ State::Init => {
+ *state = State::Waiting(cx.waker().clone());
+
+ Poll::Pending
+ }
+ State::Waiting(old_waker) => {
+ *state = State::Waiting(cx.waker().clone());
+ mem::drop(state);
+ mem::drop(old_waker);
+
+ Poll::Pending
+ }
+ State::Woken => {
+ *state = State::Finished;
+ Poll::Ready(())
+ }
+ State::Finished => {
+ panic!("Future polled after returning Poll::Ready");
+ }
+ State::Processing => {
+ panic!("Unexpected waker state");
+ }
+ }
+ }
+}
+
+impl<'w> Drop for WaitFuture<'w> {
+ fn drop(&mut self) {
+ let state = self.waiter.state.lock();
+
+ match *state {
+ State::Finished => {}
+ State::Processing => panic!("Unexpected waker state"),
+ State::Woken => {
+ mem::drop(state);
+
+ // We were woken but not polled. Wake up the next waiter.
+ (self.waiter.cancel)(self.waiter.cancel_data, self.waiter, true);
+ }
+ _ => {
+ mem::drop(state);
+
+ // Not woken. No need to wake up any waiters.
+ (self.waiter.cancel)(self.waiter.cancel_data, self.waiter, false);
+ }
+ }
+ }
+}
+
+intrusive_adapter!(pub WaiterAdapter = Arc<Waiter>: Waiter { link: AtomicLink });
+
+pub type WaiterList = LinkedList<WaiterAdapter>;
diff --git a/cros_async/src/uring_executor.rs b/cros_async/src/uring_executor.rs
index c6af38c27..39bce9581 100644
--- a/cros_async/src/uring_executor.rs
+++ b/cros_async/src/uring_executor.rs
@@ -307,11 +307,12 @@ impl RawExecutor {
let raw = Arc::downgrade(self);
let schedule = move |runnable| {
if let Some(r) = raw.upgrade() {
- r.queue.schedule(runnable);
+ r.queue.push_back(runnable);
+ r.wake();
}
};
let (runnable, task) = async_task::spawn(f, schedule);
- self.queue.schedule(runnable);
+ runnable.schedule();
task
}
@@ -323,11 +324,12 @@ impl RawExecutor {
let raw = Arc::downgrade(self);
let schedule = move |runnable| {
if let Some(r) = raw.upgrade() {
- r.queue.schedule(runnable);
+ r.queue.push_back(runnable);
+ r.wake();
}
};
let (runnable, task) = async_task::spawn_local(f, schedule);
- self.queue.schedule(runnable);
+ runnable.schedule();
task
}
@@ -351,7 +353,6 @@ impl RawExecutor {
pin_mut!(done);
loop {
self.state.store(PROCESSING, Ordering::Release);
- self.queue.set_waker(cx.waker().clone());
for runnable in self.queue.iter() {
runnable.run();
}
@@ -360,10 +361,13 @@ impl RawExecutor {
return Ok(val);
}
- let oldstate = self
- .state
- .compare_and_swap(PROCESSING, WAITING, Ordering::Acquire);
- if oldstate != PROCESSING {
+ let oldstate = self.state.compare_exchange(
+ PROCESSING,
+ WAITING,
+ Ordering::Acquire,
+ Ordering::Acquire,
+ );
+ if let Err(oldstate) = oldstate {
debug_assert_eq!(oldstate, WOKEN);
// One or more futures have become runnable.
continue;
diff --git a/crosvm_plugin/src/lib.rs b/crosvm_plugin/src/lib.rs
index 2f8f975f5..86827a681 100644
--- a/crosvm_plugin/src/lib.rs
+++ b/crosvm_plugin/src/lib.rs
@@ -152,8 +152,12 @@ impl IdAllocator {
}
fn free(&self, id: u32) {
- self.0
- .compare_and_swap(id as usize + 1, id as usize, Ordering::Relaxed);
+ let _ = self.0.compare_exchange(
+ id as usize + 1,
+ id as usize,
+ Ordering::Relaxed,
+ Ordering::Relaxed,
+ );
}
}
diff --git a/data_model/src/endian.rs b/data_model/src/endian.rs
index 686b8a182..6d3645cce 100644
--- a/data_model/src/endian.rs
+++ b/data_model/src/endian.rs
@@ -39,7 +39,7 @@ use crate::DataInit;
macro_rules! endian_type {
($old_type:ident, $new_type:ident, $to_new:ident, $from_new:ident) => {
- /// An unsigned integer type of with an explicit endianness.
+ /// An integer type of with an explicit endianness.
///
/// See module level documentation for examples.
#[derive(Copy, Clone, Eq, PartialEq, Debug, Default)]
@@ -71,9 +71,9 @@ macro_rules! endian_type {
}
}
- impl Into<$old_type> for $new_type {
- fn into(self) -> $old_type {
- $old_type::$from_new(self.0)
+ impl From<$new_type> for $old_type {
+ fn from(v: $new_type) -> $old_type {
+ $old_type::$from_new(v.0)
}
}
@@ -104,13 +104,21 @@ macro_rules! endian_type {
}
endian_type!(u16, Le16, to_le, from_le);
+endian_type!(i16, SLe16, to_le, from_le);
endian_type!(u32, Le32, to_le, from_le);
+endian_type!(i32, SLe32, to_le, from_le);
endian_type!(u64, Le64, to_le, from_le);
+endian_type!(i64, SLe64, to_le, from_le);
endian_type!(usize, LeSize, to_le, from_le);
+endian_type!(isize, SLeSize, to_le, from_le);
endian_type!(u16, Be16, to_be, from_be);
+endian_type!(i16, SBe16, to_be, from_be);
endian_type!(u32, Be32, to_be, from_be);
+endian_type!(i32, SBe32, to_be, from_be);
endian_type!(u64, Be64, to_be, from_be);
+endian_type!(i64, SBe64, to_be, from_be);
endian_type!(usize, BeSize, to_be, from_be);
+endian_type!(isize, SBeSize, to_be, from_be);
#[cfg(test)]
mod tests {
@@ -153,11 +161,19 @@ mod tests {
}
endian_test!(u16, Le16, test_le16, NATIVE_LITTLE);
+ endian_test!(i16, SLe16, test_sle16, NATIVE_LITTLE);
endian_test!(u32, Le32, test_le32, NATIVE_LITTLE);
+ endian_test!(i32, SLe32, test_sle32, NATIVE_LITTLE);
endian_test!(u64, Le64, test_le64, NATIVE_LITTLE);
+ endian_test!(i64, SLe64, test_sle64, NATIVE_LITTLE);
endian_test!(usize, LeSize, test_le_size, NATIVE_LITTLE);
+ endian_test!(isize, SLeSize, test_sle_size, NATIVE_LITTLE);
endian_test!(u16, Be16, test_be16, NATIVE_BIG);
+ endian_test!(i16, SBe16, test_sbe16, NATIVE_BIG);
endian_test!(u32, Be32, test_be32, NATIVE_BIG);
+ endian_test!(i32, SBe32, test_sbe32, NATIVE_BIG);
endian_test!(u64, Be64, test_be64, NATIVE_BIG);
+ endian_test!(i64, SBe64, test_sbe64, NATIVE_BIG);
endian_test!(usize, BeSize, test_be_size, NATIVE_BIG);
+ endian_test!(isize, SBeSize, test_sbe_size, NATIVE_BIG);
}
diff --git a/devices/Cargo.toml b/devices/Cargo.toml
index 3438f0bdb..f77035c57 100644
--- a/devices/Cargo.toml
+++ b/devices/Cargo.toml
@@ -6,6 +6,7 @@ edition = "2018"
[features]
audio = []
+direct = []
gpu = ["gpu_display","rutabaga_gfx"]
tpm = ["protos/trunks", "tpm2"]
video-decoder = []
@@ -18,6 +19,7 @@ gfxstream = ["gpu", "rutabaga_gfx/gfxstream"]
[dependencies]
acpi_tables = {path = "../acpi_tables" }
audio_streams = { path = "../../adhd/audio_streams" } # ignored by ebuild
+base = { path = "../base" }
bit_field = { path = "../bit_field" }
cros_async = { path = "../cros_async" }
data_model = { path = "../data_model" }
@@ -33,8 +35,6 @@ libchromeos = { path = "../../libchromeos-rs" } # ignored by ebuild
libcras = { path = "../../adhd/cras/client/libcras" } # ignored by ebuild
linux_input_sys = { path = "../linux_input_sys" }
minijail = { path = "../../minijail/rust/minijail" } # ignored by ebuild
-msg_on_socket_derive = { path = "../msg_socket/msg_on_socket_derive" }
-msg_socket = { path = "../msg_socket" }
net_sys = { path = "../net_sys" }
net_util = { path = "../net_util" }
p9 = { path = "../../vm_tools/p9" }
@@ -43,21 +43,23 @@ protos = { path = "../protos", optional = true }
rand_ish = { path = "../rand_ish" }
remain = "*"
resources = { path = "../resources" }
+serde = { version = "1", features = [ "derive" ] }
+smallvec = "1.6.1"
sync = { path = "../sync" }
sys_util = { path = "../sys_util" }
-base = { path = "../base" }
-syscall_defines = { path = "../syscall_defines" }
thiserror = "1.0.20"
tpm2 = { path = "../tpm2", optional = true }
usb_util = { path = "../usb_util" }
vfio_sys = { path = "../vfio_sys" }
vhost = { path = "../vhost" }
+vmm_vhost = { version = "*", features = ["vhost-user-master"] }
virtio_sys = { path = "../virtio_sys" }
vm_control = { path = "../vm_control" }
vm_memory = { path = "../vm_memory" }
[dependencies.futures]
version = "*"
+features = ["std"]
default-features = false
[dev-dependencies]
diff --git a/devices/src/bat.rs b/devices/src/bat.rs
index 68c125a94..95f3cd70a 100644
--- a/devices/src/bat.rs
+++ b/devices/src/bat.rs
@@ -5,15 +5,14 @@
use crate::{BusAccessInfo, BusDevice};
use acpi_tables::{aml, aml::Aml};
use base::{
- error, warn, AsRawDescriptor, Descriptor, Event, PollToken, RawDescriptor, WaitContext,
+ error, warn, AsRawDescriptor, Descriptor, Event, PollToken, RawDescriptor, Tube, WaitContext,
};
-use msg_socket::{MsgReceiver, MsgSender};
use power_monitor::{BatteryStatus, CreatePowerMonitorFn};
use std::fmt::{self, Display};
use std::sync::Arc;
use std::thread;
use sync::Mutex;
-use vm_control::{BatControlCommand, BatControlResponseSocket, BatControlResult};
+use vm_control::{BatControlCommand, BatControlResult};
/// Errors for battery devices.
#[derive(Debug)]
@@ -106,7 +105,7 @@ pub struct GoldfishBattery {
activated: bool,
monitor_thread: Option<thread::JoinHandle<()>>,
kill_evt: Option<Event>,
- socket: Option<BatControlResponseSocket>,
+ tube: Option<Tube>,
create_power_monitor: Option<Box<dyn CreatePowerMonitorFn>>,
}
@@ -143,7 +142,7 @@ const BATTERY_STATUS_VAL_NOT_CHARGING: u32 = 3;
const BATTERY_HEALTH_VAL_UNKNOWN: u32 = 0;
fn command_monitor(
- socket: BatControlResponseSocket,
+ tube: Tube,
irq_evt: Event,
irq_resample_evt: Event,
kill_evt: Event,
@@ -151,7 +150,7 @@ fn command_monitor(
create_power_monitor: Option<Box<dyn CreatePowerMonitorFn>>,
) {
let wait_ctx: WaitContext<Token> = match WaitContext::build_with(&[
- (&Descriptor(socket.as_raw_descriptor()), Token::Commands),
+ (&Descriptor(tube.as_raw_descriptor()), Token::Commands),
(
&Descriptor(irq_resample_evt.as_raw_descriptor()),
Token::Resample,
@@ -202,7 +201,7 @@ fn command_monitor(
for event in events.iter().filter(|e| e.is_readable) {
match event.token {
Token::Commands => {
- let req = match socket.recv() {
+ let req = match tube.recv() {
Ok(req) => req,
Err(e) => {
error!("failed to receive request: {}", e);
@@ -232,7 +231,7 @@ fn command_monitor(
let _ = irq_evt.write(1);
}
- if let Err(e) = socket.send(&BatControlResult::Ok) {
+ if let Err(e) = tube.send(&BatControlResult::Ok) {
error!("failed to send response: {}", e);
}
}
@@ -309,7 +308,7 @@ impl GoldfishBattery {
irq_num: u32,
irq_evt: Event,
irq_resample_evt: Event,
- socket: BatControlResponseSocket,
+ tube: Tube,
create_power_monitor: Option<Box<dyn CreatePowerMonitorFn>>,
) -> Result<Self> {
if mmio_base + GOLDFISHBAT_MMIO_LEN - 1 > u32::MAX as u64 {
@@ -338,7 +337,7 @@ impl GoldfishBattery {
activated: false,
monitor_thread: None,
kill_evt: None,
- socket: Some(socket),
+ tube: Some(tube),
create_power_monitor,
})
}
@@ -350,8 +349,8 @@ impl GoldfishBattery {
self.irq_resample_evt.as_raw_descriptor(),
];
- if let Some(socket) = &self.socket {
- rds.push(socket.as_raw_descriptor());
+ if let Some(tube) = &self.tube {
+ rds.push(tube.as_raw_descriptor());
}
rds
@@ -375,7 +374,7 @@ impl GoldfishBattery {
}
};
- if let Some(socket) = self.socket.take() {
+ if let Some(tube) = self.tube.take() {
let irq_evt = self.irq_evt.try_clone().unwrap();
let irq_resample_evt = self.irq_resample_evt.try_clone().unwrap();
let bat_state = self.state.clone();
@@ -385,7 +384,7 @@ impl GoldfishBattery {
.name(self.debug_label())
.spawn(move || {
command_monitor(
- socket,
+ tube,
irq_evt,
irq_resample_evt,
kill_evt,
diff --git a/devices/src/bus.rs b/devices/src/bus.rs
index 73093665b..90360cd9d 100644
--- a/devices/src/bus.rs
+++ b/devices/src/bus.rs
@@ -10,12 +10,11 @@ use std::fmt::{self, Display};
use std::result;
use std::sync::Arc;
-use base::RawDescriptor;
-use msg_socket::MsgOnSocket;
+use serde::{Deserialize, Serialize};
use sync::Mutex;
/// Information about how a device was accessed.
-#[derive(Copy, Clone, Eq, PartialEq, Debug, MsgOnSocket)]
+#[derive(Copy, Clone, Eq, PartialEq, Debug, Serialize, Deserialize)]
pub struct BusAccessInfo {
/// Offset from base address that the device was accessed at.
pub offset: u64,
diff --git a/devices/src/direct_io.rs b/devices/src/direct_io.rs
new file mode 100644
index 000000000..2f80d7eee
--- /dev/null
+++ b/devices/src/direct_io.rs
@@ -0,0 +1,68 @@
+// Copyright 2021 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+use crate::{BusAccessInfo, BusDevice, BusDeviceSync};
+use std::fs::{File, OpenOptions};
+use std::io;
+use std::os::unix::prelude::FileExt;
+use std::path::Path;
+use std::sync::Mutex;
+
+pub struct DirectIo {
+ dev: Mutex<File>,
+ read_only: bool,
+}
+
+impl DirectIo {
+ /// Create simple direct I/O access device.
+ pub fn new(path: &Path, read_only: bool) -> Result<Self, io::Error> {
+ let dev = OpenOptions::new().read(true).write(!read_only).open(path)?;
+ Ok(DirectIo {
+ dev: Mutex::new(dev),
+ read_only,
+ })
+ }
+
+ fn iowr(&self, port: u64, data: &[u8]) {
+ if !self.read_only {
+ if let Ok(ref mut dev) = self.dev.lock() {
+ let _ = dev.write_all_at(data, port);
+ }
+ }
+ }
+
+ fn iord(&self, port: u64, data: &mut [u8]) {
+ if let Ok(ref mut dev) = self.dev.lock() {
+ let _ = dev.read_exact_at(data, port);
+ }
+ }
+}
+
+impl BusDevice for DirectIo {
+ fn debug_label(&self) -> String {
+ "direct-io".to_string()
+ }
+
+ /// Reads at `offset` from this device
+ fn read(&mut self, ai: BusAccessInfo, data: &mut [u8]) {
+ self.iord(ai.address, data);
+ }
+
+ /// Writes at `offset` into this device
+ fn write(&mut self, ai: BusAccessInfo, data: &[u8]) {
+ self.iowr(ai.address, data);
+ }
+}
+
+impl BusDeviceSync for DirectIo {
+ /// Reads at `offset` from this device
+ fn read(&self, ai: BusAccessInfo, data: &mut [u8]) {
+ self.iord(ai.address, data);
+ }
+
+ /// Writes at `offset` into this device
+ fn write(&self, ai: BusAccessInfo, data: &[u8]) {
+ self.iowr(ai.address, data);
+ }
+}
diff --git a/devices/src/direct_irq.rs b/devices/src/direct_irq.rs
new file mode 100644
index 000000000..fe44c4c4d
--- /dev/null
+++ b/devices/src/direct_irq.rs
@@ -0,0 +1,115 @@
+// Copyright 2021 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+use base::{ioctl_with_ref, AsRawDescriptor, Event, RawDescriptor};
+use data_model::vec_with_array_field;
+use std::fmt;
+use std::fs::{File, OpenOptions};
+use std::io;
+use std::mem::size_of;
+
+use vfio_sys::*;
+
+#[derive(Debug)]
+pub enum DirectIrqError {
+ Open(io::Error),
+ Enable,
+}
+
+impl fmt::Display for DirectIrqError {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ match self {
+ DirectIrqError::Open(e) => write!(f, "failed to open /dev/plat-irq-forward: {}", e),
+ DirectIrqError::Enable => write!(f, "failed to enable direct irq"),
+ }
+ }
+}
+
+pub struct DirectIrq {
+ dev: File,
+ trigger: Event,
+ resample: Option<Event>,
+}
+
+impl DirectIrq {
+ /// Create DirectIrq object to access hardware triggered interrupts.
+ pub fn new(trigger: Event, resample: Option<Event>) -> Result<Self, DirectIrqError> {
+ let dev = OpenOptions::new()
+ .read(true)
+ .write(true)
+ .open("/dev/plat-irq-forward")
+ .map_err(DirectIrqError::Open)?;
+ Ok(DirectIrq {
+ dev,
+ trigger,
+ resample,
+ })
+ }
+
+ /// Enable hardware triggered interrupt handling.
+ ///
+ /// Note: this feature is not part of VFIO, but provides
+ /// missing IRQ forwarding functionality.
+ ///
+ /// # Arguments
+ ///
+ /// * `irq_num` - host interrupt number (GSI).
+ ///
+ pub fn irq_enable(&self, irq_num: u32) -> Result<(), DirectIrqError> {
+ if let Some(resample) = &self.resample {
+ self.plat_irq_ioctl(
+ irq_num,
+ PLAT_IRQ_FORWARD_SET_LEVEL_TRIGGER_EVENTFD,
+ self.trigger.as_raw_descriptor(),
+ )?;
+ self.plat_irq_ioctl(
+ irq_num,
+ PLAT_IRQ_FORWARD_SET_LEVEL_UNMASK_EVENTFD,
+ resample.as_raw_descriptor(),
+ )?;
+ } else {
+ self.plat_irq_ioctl(
+ irq_num,
+ PLAT_IRQ_FORWARD_SET_EDGE_TRIGGER,
+ self.trigger.as_raw_descriptor(),
+ )?;
+ };
+
+ Ok(())
+ }
+
+ fn plat_irq_ioctl(
+ &self,
+ irq_num: u32,
+ action: u32,
+ fd: RawDescriptor,
+ ) -> Result<(), DirectIrqError> {
+ let count = 1;
+ let u32_size = size_of::<u32>();
+ let mut irq_set = vec_with_array_field::<plat_irq_forward_set, u32>(count);
+ irq_set[0].argsz = (size_of::<plat_irq_forward_set>() + count * u32_size) as u32;
+ irq_set[0].action_flags = action;
+ irq_set[0].count = count as u32;
+ irq_set[0].irq_number_host = irq_num;
+ // Safe as we are the owner of irq_set and allocation provides enough space for
+ // eventfd array.
+ let data = unsafe { irq_set[0].eventfd.as_mut_slice(count * u32_size) };
+ let (left, _right) = data.split_at_mut(u32_size);
+ left.copy_from_slice(&fd.to_ne_bytes()[..]);
+
+ // Safe as we are the owner of plat_irq_forward and irq_set which are valid value
+ let ret = unsafe { ioctl_with_ref(self, PLAT_IRQ_FORWARD_SET(), &irq_set[0]) };
+ if ret < 0 {
+ Err(DirectIrqError::Enable)
+ } else {
+ Ok(())
+ }
+ }
+}
+
+impl AsRawDescriptor for DirectIrq {
+ fn as_raw_descriptor(&self) -> i32 {
+ self.dev.as_raw_descriptor()
+ }
+}
diff --git a/devices/src/irqchip/ioapic.rs b/devices/src/irqchip/ioapic.rs
index 2062bac10..a2117489f 100644
--- a/devices/src/irqchip/ioapic.rs
+++ b/devices/src/irqchip/ioapic.rs
@@ -2,20 +2,24 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
-// Implementation of an intel 82093AA Input/Output Advanced Programmable Interrupt Controller
-// See https://pdos.csail.mit.edu/6.828/2016/readings/ia32/ioapic.pdf for a specification.
+// Implementation of an Intel ICH10 Input/Output Advanced Programmable Interrupt Controller
+// See https://www.intel.com/content/dam/doc/datasheet/io-controller-hub-10-family-datasheet.pdf
+// for a specification.
use std::fmt::{self, Display};
use super::IrqEvent;
use crate::bus::BusAccessInfo;
use crate::BusDevice;
-use base::{error, warn, AsRawDescriptor, Error, Event, Result};
-use hypervisor::{IoapicState, MsiAddressMessage, MsiDataMessage, TriggerMode, NUM_IOAPIC_PINS};
-use msg_socket::{MsgError, MsgReceiver, MsgSender};
-use vm_control::{MaybeOwnedDescriptor, VmIrqRequest, VmIrqRequestSocket, VmIrqResponse};
-
-const IOAPIC_VERSION_ID: u32 = 0x00170011;
+use base::{error, warn, Error, Event, Result, Tube, TubeError};
+use hypervisor::{
+ IoapicRedirectionTableEntry, IoapicState, MsiAddressMessage, MsiDataMessage, TriggerMode,
+ NUM_IOAPIC_PINS,
+};
+use vm_control::{VmIrqRequest, VmIrqResponse};
+
+// ICH10 I/O APIC version: 0x20
+const IOAPIC_VERSION_ID: u32 = 0x00000020;
pub const IOAPIC_BASE_ADDRESS: u64 = 0xfec00000;
// The Intel manual does not specify this size, but KVM uses it.
pub const IOAPIC_MEM_LENGTH_BYTES: u64 = 0x100;
@@ -29,6 +33,8 @@ const IOAPIC_REG_ARBITRATION_ID: u8 = 0x02;
const IOREGSEL_OFF: u8 = 0x0;
const IOREGSEL_DUMMY_UPPER_32_BITS_OFF: u8 = 0x4;
const IOWIN_OFF: u8 = 0x10;
+const IOEOIR_OFF: u8 = 0x40;
+
const IOWIN_SCALE: u8 = 0x2;
/// Given an IRQ and whether or not the selector should refer to the high bits, return a selector
@@ -57,8 +63,12 @@ fn decode_irq_from_selector(selector: u8) -> (usize, bool) {
const RTC_IRQ: usize = 0x8;
pub struct Ioapic {
- /// State of the ioapic registers
- state: IoapicState,
+ /// Number of supported IO-APIC inputs / redirection entries.
+ num_pins: usize,
+ /// ioregsel register. Used for selecting which entry of the redirect table to read/write.
+ ioregsel: u8,
+ /// ioapicid register. Bits 24 - 27 contain the APIC ID for this device.
+ ioapicid: u32,
/// Remote IRR for Edge Triggered Real Time Clock interrupts, which allows the CMOS to know when
/// one of its interrupts is being coalesced.
rtc_remote_irr: bool,
@@ -67,8 +77,12 @@ pub struct Ioapic {
/// Events that should be triggered on an EOI. The outer Vec is indexed by GSI, and the inner
/// Vec is an unordered list of registered resample events for the GSI.
resample_events: Vec<Vec<Event>>,
- /// Socket used to route MSI irqs
- irq_socket: VmIrqRequestSocket,
+ /// Redirection settings for each irq line.
+ redirect_table: Vec<IoapicRedirectionTableEntry>,
+ /// Interrupt activation state.
+ interrupt_level: Vec<bool>,
+ /// Tube used to route MSI irqs.
+ irq_tube: Tube,
}
impl BusDevice for Ioapic {
@@ -85,9 +99,10 @@ impl BusDevice for Ioapic {
warn!("IOAPIC: Bad read from {}", info);
}
let out = match info.offset as u8 {
- IOREGSEL_OFF => self.state.ioregsel.into(),
+ IOREGSEL_OFF => self.ioregsel.into(),
IOREGSEL_DUMMY_UPPER_32_BITS_OFF => 0,
IOWIN_OFF => self.ioapic_read(),
+ IOEOIR_OFF => 0,
_ => {
warn!("IOAPIC: Bad read from {}", info);
return;
@@ -110,7 +125,7 @@ impl BusDevice for Ioapic {
warn!("IOAPIC: Bad write to {}", info);
}
match info.offset as u8 {
- IOREGSEL_OFF => self.state.ioregsel = data[0],
+ IOREGSEL_OFF => self.ioregsel = data[0],
IOREGSEL_DUMMY_UPPER_32_BITS_OFF => {} // Ignored.
IOWIN_OFF => {
if data.len() != 4 {
@@ -121,6 +136,7 @@ impl BusDevice for Ioapic {
let val = u32::from_ne_bytes(data_arr);
self.ioapic_write(val);
}
+ IOEOIR_OFF => self.end_of_interrupt(data[0]),
_ => {
warn!("IOAPIC: Bad write to {}", info);
}
@@ -129,28 +145,66 @@ impl BusDevice for Ioapic {
}
impl Ioapic {
- pub fn new(irq_socket: VmIrqRequestSocket) -> Result<Ioapic> {
- let mut state = IoapicState::default();
-
- for i in 0..NUM_IOAPIC_PINS {
- state.redirect_table[i].set_interrupt_mask(true);
- }
-
+ pub fn new(irq_tube: Tube, num_pins: usize) -> Result<Ioapic> {
+ let num_pins = num_pins.max(NUM_IOAPIC_PINS as usize);
+ let mut entry = IoapicRedirectionTableEntry::new();
+ entry.set_interrupt_mask(true);
Ok(Ioapic {
- state,
+ num_pins,
+ ioregsel: 0,
+ ioapicid: 0,
rtc_remote_irr: false,
- out_events: (0..NUM_IOAPIC_PINS).map(|_| None).collect(),
+ out_events: (0..num_pins).map(|_| None).collect(),
resample_events: Vec::new(),
- irq_socket,
+ redirect_table: (0..num_pins).map(|_| entry.clone()).collect(),
+ interrupt_level: (0..num_pins).map(|_| false).collect(),
+ irq_tube,
})
}
pub fn get_ioapic_state(&self) -> IoapicState {
- self.state
+ // Convert vector of first NUM_IOAPIC_PINS active interrupts into an u32 value.
+ let level_bitmap = self
+ .interrupt_level
+ .iter()
+ .take(NUM_IOAPIC_PINS)
+ .rev()
+ .fold(0, |acc, &l| acc * 2 + l as u32);
+ let mut state = IoapicState {
+ base_address: IOAPIC_BASE_ADDRESS,
+ ioregsel: self.ioregsel,
+ ioapicid: self.ioapicid,
+ current_interrupt_level_bitmap: level_bitmap,
+ ..Default::default()
+ };
+ for (dst, src) in state
+ .redirect_table
+ .iter_mut()
+ .zip(self.redirect_table.iter())
+ {
+ *dst = *src;
+ }
+ state
}
pub fn set_ioapic_state(&mut self, state: &IoapicState) {
- self.state = *state
+ self.ioregsel = state.ioregsel;
+ self.ioapicid = state.ioapicid & 0x0f00_0000;
+ for (src, dst) in state
+ .redirect_table
+ .iter()
+ .zip(self.redirect_table.iter_mut())
+ {
+ *dst = *src;
+ }
+ for (i, level) in self
+ .interrupt_level
+ .iter_mut()
+ .take(NUM_IOAPIC_PINS)
+ .enumerate()
+ {
+ *level = state.current_interrupt_level_bitmap & (1 << i) != 0;
+ }
}
pub fn register_resample_events(&mut self, resample_events: Vec<Vec<Event>>) {
@@ -160,14 +214,14 @@ impl Ioapic {
// The ioapic must be informed about EOIs in order to avoid sending multiple interrupts of the
// same type at the same time.
pub fn end_of_interrupt(&mut self, vector: u8) {
- if self.state.redirect_table[RTC_IRQ].get_vector() == vector && self.rtc_remote_irr {
+ if self.redirect_table[RTC_IRQ].get_vector() == vector && self.rtc_remote_irr {
// Specifically clear RTC IRQ field
self.rtc_remote_irr = false;
}
- for i in 0..NUM_IOAPIC_PINS {
- if self.state.redirect_table[i].get_vector() == vector
- && self.state.redirect_table[i].get_trigger_mode() == TriggerMode::Level
+ for i in 0..self.num_pins {
+ if self.redirect_table[i].get_vector() == vector
+ && self.redirect_table[i].get_trigger_mode() == TriggerMode::Level
{
if self
.resample_events
@@ -182,34 +236,32 @@ impl Ioapic {
resample_evt.write(1).unwrap();
}
}
- self.state.redirect_table[i].set_remote_irr(false);
+ self.redirect_table[i].set_remote_irr(false);
}
// There is an inherent race condition in hardware if the OS is finished processing an
// interrupt and a new interrupt is delivered between issuing an EOI and the EOI being
// completed. When that happens the ioapic is supposed to re-inject the interrupt.
- if self.state.current_interrupt_level_bitmap & (1 << i) != 0 {
+ if self.interrupt_level[i] {
self.service_irq(i, true);
}
}
}
pub fn service_irq(&mut self, irq: usize, level: bool) -> bool {
- let entry = &mut self.state.redirect_table[irq];
+ let entry = &mut self.redirect_table[irq];
// De-assert the interrupt.
if !level {
- self.state.current_interrupt_level_bitmap &= !(1 << irq);
+ self.interrupt_level[irq] = false;
return true;
}
// If it's an edge-triggered interrupt that's already high we ignore it.
- if entry.get_trigger_mode() == TriggerMode::Edge
- && self.state.current_interrupt_level_bitmap & (1 << irq) != 0
- {
+ if entry.get_trigger_mode() == TriggerMode::Edge && self.interrupt_level[irq] {
return false;
}
- self.state.current_interrupt_level_bitmap |= 1 << irq;
+ self.interrupt_level[irq] = true;
// Interrupts are masked, so don't inject.
if entry.get_interrupt_mask() {
@@ -242,22 +294,22 @@ impl Ioapic {
}
fn ioapic_write(&mut self, val: u32) {
- match self.state.ioregsel {
+ match self.ioregsel {
IOAPIC_REG_VERSION => { /* read-only register */ }
- IOAPIC_REG_ID => self.state.ioapicid = (val >> 24) & 0xf,
+ IOAPIC_REG_ID => self.ioapicid = val & 0x0f00_0000,
IOAPIC_REG_ARBITRATION_ID => { /* read-only register */ }
_ => {
- if self.state.ioregsel < IOWIN_OFF {
+ if self.ioregsel < IOWIN_OFF {
// Invalid write; ignore.
return;
}
- let (index, is_high_bits) = decode_irq_from_selector(self.state.ioregsel);
- if index >= hypervisor::NUM_IOAPIC_PINS {
+ let (index, is_high_bits) = decode_irq_from_selector(self.ioregsel);
+ if index >= self.num_pins {
// Invalid write; ignore.
return;
}
- let entry = &mut self.state.redirect_table[index];
+ let entry = &mut self.redirect_table[index];
if is_high_bits {
entry.set(32, 32, val.into());
} else {
@@ -278,16 +330,16 @@ impl Ioapic {
// is the fix for this.
}
- if self.state.redirect_table[index].get_trigger_mode() == TriggerMode::Level
- && self.state.current_interrupt_level_bitmap & (1 << index) != 0
- && !self.state.redirect_table[index].get_interrupt_mask()
+ if self.redirect_table[index].get_trigger_mode() == TriggerMode::Level
+ && self.interrupt_level[index]
+ && !self.redirect_table[index].get_interrupt_mask()
{
self.service_irq(index, true);
}
let mut address = MsiAddressMessage::new();
let mut data = MsiDataMessage::new();
- let entry = &self.state.redirect_table[index];
+ let entry = &self.redirect_table[index];
address.set_destination_mode(entry.get_dest_mode());
address.set_destination_id(entry.get_dest_id());
address.set_always_0xfee(0xfee);
@@ -329,21 +381,22 @@ impl Ioapic {
evt.gsi
} else {
let event = Event::new().map_err(IoapicError::CreateEvent)?;
- let request = VmIrqRequest::AllocateOneMsi {
- irqfd: MaybeOwnedDescriptor::Borrowed(event.as_raw_descriptor()),
- };
- self.irq_socket
+ let request = VmIrqRequest::AllocateOneMsi { irqfd: event };
+ self.irq_tube
.send(&request)
.map_err(IoapicError::AllocateOneMsiSend)?;
match self
- .irq_socket
+ .irq_tube
.recv()
.map_err(IoapicError::AllocateOneMsiRecv)?
{
VmIrqResponse::AllocateOneMsi { gsi, .. } => {
self.out_events[index] = Some(IrqEvent {
gsi,
- event,
+ event: match request {
+ VmIrqRequest::AllocateOneMsi { irqfd } => irqfd,
+ _ => unreachable!(),
+ },
resample_event: None,
});
gsi
@@ -361,32 +414,28 @@ impl Ioapic {
msi_address,
msi_data,
};
- self.irq_socket
+ self.irq_tube
.send(&request)
.map_err(IoapicError::AddMsiRouteSend)?;
- if let VmIrqResponse::Err(e) = self
- .irq_socket
- .recv()
- .map_err(IoapicError::AddMsiRouteRecv)?
- {
+ if let VmIrqResponse::Err(e) = self.irq_tube.recv().map_err(IoapicError::AddMsiRouteRecv)? {
return Err(IoapicError::AddMsiRoute(e));
}
Ok(())
}
fn ioapic_read(&mut self) -> u32 {
- match self.state.ioregsel {
- IOAPIC_REG_VERSION => IOAPIC_VERSION_ID,
- IOAPIC_REG_ID | IOAPIC_REG_ARBITRATION_ID => (self.state.ioapicid & 0xf) << 24,
+ match self.ioregsel {
+ IOAPIC_REG_VERSION => ((self.num_pins - 1) as u32) << 16 | IOAPIC_VERSION_ID,
+ IOAPIC_REG_ID | IOAPIC_REG_ARBITRATION_ID => self.ioapicid,
_ => {
- if self.state.ioregsel < IOWIN_OFF {
+ if self.ioregsel < IOWIN_OFF {
// Invalid read; ignore and return 0.
0
} else {
- let (index, is_high_bits) = decode_irq_from_selector(self.state.ioregsel);
- if index < NUM_IOAPIC_PINS {
+ let (index, is_high_bits) = decode_irq_from_selector(self.ioregsel);
+ if index < self.num_pins {
let offset = if is_high_bits { 32 } else { 0 };
- self.state.redirect_table[index].get(offset, 32) as u32
+ self.redirect_table[index].get(offset, 32) as u32
} else {
!0 // Invalid index - return all 1s
}
@@ -399,11 +448,11 @@ impl Ioapic {
#[derive(Debug)]
enum IoapicError {
AddMsiRoute(Error),
- AddMsiRouteRecv(MsgError),
- AddMsiRouteSend(MsgError),
+ AddMsiRouteRecv(TubeError),
+ AddMsiRouteSend(TubeError),
AllocateOneMsi(Error),
- AllocateOneMsiRecv(MsgError),
- AllocateOneMsiSend(MsgError),
+ AllocateOneMsiRecv(TubeError),
+ AllocateOneMsiSend(TubeError),
CreateEvent(Error),
}
@@ -428,14 +477,14 @@ impl Display for IoapicError {
#[cfg(test)]
mod tests {
use super::*;
- use hypervisor::{DeliveryMode, DeliveryStatus, DestinationMode, IoapicRedirectionTableEntry};
+ use hypervisor::{DeliveryMode, DeliveryStatus, DestinationMode};
const DEFAULT_VECTOR: u8 = 0x3a;
const DEFAULT_DESTINATION_ID: u8 = 0x5f;
fn new() -> Ioapic {
- let (_, device_socket) = msg_socket::pair::<VmIrqResponse, VmIrqRequest>().unwrap();
- Ioapic::new(device_socket).unwrap()
+ let (_, irq_tube) = Tube::pair().unwrap();
+ Ioapic::new(irq_tube, NUM_IOAPIC_PINS).unwrap()
}
fn ioapic_bus_address(offset: u8) -> BusAccessInfo {
@@ -739,7 +788,7 @@ mod tests {
fn remote_irr_read_only() {
let (mut ioapic, irq) = set_up(TriggerMode::Level);
- ioapic.state.redirect_table[irq].set_remote_irr(true);
+ ioapic.redirect_table[irq].set_remote_irr(true);
let mut entry = read_entry(&mut ioapic, irq);
entry.set_remote_irr(false);
@@ -752,7 +801,7 @@ mod tests {
fn delivery_status_read_only() {
let (mut ioapic, irq) = set_up(TriggerMode::Level);
- ioapic.state.redirect_table[irq].set_delivery_status(DeliveryStatus::Pending);
+ ioapic.redirect_table[irq].set_delivery_status(DeliveryStatus::Pending);
let mut entry = read_entry(&mut ioapic, irq);
entry.set_delivery_status(DeliveryStatus::Idle);
@@ -768,7 +817,7 @@ mod tests {
fn level_to_edge_transition_clears_remote_irr() {
let (mut ioapic, irq) = set_up(TriggerMode::Level);
- ioapic.state.redirect_table[irq].set_remote_irr(true);
+ ioapic.redirect_table[irq].set_remote_irr(true);
let mut entry = read_entry(&mut ioapic, irq);
entry.set_trigger_mode(TriggerMode::Edge);
@@ -781,7 +830,7 @@ mod tests {
fn masking_preserves_remote_irr() {
let (mut ioapic, irq) = set_up(TriggerMode::Level);
- ioapic.state.redirect_table[irq].set_remote_irr(true);
+ ioapic.redirect_table[irq].set_remote_irr(true);
set_mask(&mut ioapic, irq, true);
set_mask(&mut ioapic, irq, false);
diff --git a/devices/src/irqchip/kvm/x86_64.rs b/devices/src/irqchip/kvm/x86_64.rs
index 8b6d54627..613e1c4fd 100644
--- a/devices/src/irqchip/kvm/x86_64.rs
+++ b/devices/src/irqchip/kvm/x86_64.rs
@@ -18,8 +18,7 @@ use hypervisor::{
use kvm_sys::*;
use resources::SystemAllocator;
-use base::{error, Error, Event, Result};
-use vm_control::VmIrqRequestSocket;
+use base::{error, Error, Event, Result, Tube};
use crate::irqchip::{
Ioapic, IrqEvent, IrqEventIndex, Pic, VcpuRunState, IOAPIC_BASE_ADDRESS,
@@ -27,12 +26,12 @@ use crate::irqchip::{
};
use crate::{Bus, IrqChip, IrqChipCap, IrqChipX86_64, Pit, PitError};
-/// PIT channel 0 timer is connected to IRQ 0
+/// PIT tube 0 timer is connected to IRQ 0
const PIT_CHANNEL0_IRQ: u32 = 0;
/// Default x86 routing table. Pins 0-7 go to primary pic and ioapic, pins 8-15 go to secondary
/// pic and ioapic, and pins 16-23 go only to the ioapic.
-fn kvm_default_irq_routing_table() -> Vec<IrqRoute> {
+fn kvm_default_irq_routing_table(ioapic_pins: usize) -> Vec<IrqRoute> {
let mut routes: Vec<IrqRoute> = Vec::new();
for i in 0..8 {
@@ -43,7 +42,7 @@ fn kvm_default_irq_routing_table() -> Vec<IrqRoute> {
routes.push(IrqRoute::pic_irq_route(IrqSourceChip::PicSecondary, i));
routes.push(IrqRoute::ioapic_irq_route(i));
}
- for i in 16..NUM_IOAPIC_PINS as u32 {
+ for i in 16..ioapic_pins as u32 {
routes.push(IrqRoute::ioapic_irq_route(i));
}
@@ -68,7 +67,7 @@ impl KvmKernelIrqChip {
Ok(KvmKernelIrqChip {
vm,
vcpus: Arc::new(Mutex::new((0..num_vcpus).map(|_| None).collect())),
- routes: Arc::new(Mutex::new(kvm_default_irq_routing_table())),
+ routes: Arc::new(Mutex::new(kvm_default_irq_routing_table(NUM_IOAPIC_PINS))),
})
}
/// Attempt to create a shallow clone of this x86_64 KvmKernelIrqChip instance.
@@ -146,6 +145,7 @@ pub struct KvmSplitIrqChip {
pit: Arc<Mutex<Pit>>,
pic: Arc<Mutex<Pic>>,
ioapic: Arc<Mutex<Ioapic>>,
+ ioapic_pins: usize,
/// Vec of ioapic irq events that have been delayed because the ioapic was locked when
/// service_irq was called on the irqchip. This prevents deadlocks when a Vcpu thread has
/// locked the ioapic and the ioapic sends a AddMsiRoute signal to the main thread (which
@@ -155,9 +155,9 @@ pub struct KvmSplitIrqChip {
irq_events: Arc<Mutex<Vec<Option<IrqEvent>>>>,
}
-fn kvm_dummy_msi_routes() -> Vec<IrqRoute> {
+fn kvm_dummy_msi_routes(ioapic_pins: usize) -> Vec<IrqRoute> {
let mut routes: Vec<IrqRoute> = Vec::new();
- for i in 0..NUM_IOAPIC_PINS {
+ for i in 0..ioapic_pins {
routes.push(
// Add dummy MSI routes to replace the default IRQChip routes.
IrqRoute {
@@ -174,9 +174,14 @@ fn kvm_dummy_msi_routes() -> Vec<IrqRoute> {
impl KvmSplitIrqChip {
/// Construct a new KvmSplitIrqChip.
- pub fn new(vm: KvmVm, num_vcpus: usize, irq_socket: VmIrqRequestSocket) -> Result<Self> {
- vm.enable_split_irqchip()?;
-
+ pub fn new(
+ vm: KvmVm,
+ num_vcpus: usize,
+ irq_tube: Tube,
+ ioapic_pins: Option<usize>,
+ ) -> Result<Self> {
+ let ioapic_pins = ioapic_pins.unwrap_or(hypervisor::NUM_IOAPIC_PINS);
+ vm.enable_split_irqchip(ioapic_pins)?;
let pit_evt = Event::new()?;
let pit = Arc::new(Mutex::new(
Pit::new(pit_evt.try_clone()?, Arc::new(Mutex::new(Clock::new()))).map_err(
@@ -197,15 +202,16 @@ impl KvmSplitIrqChip {
routes: Arc::new(Mutex::new(Vec::new())),
pit,
pic: Arc::new(Mutex::new(Pic::new())),
- ioapic: Arc::new(Mutex::new(Ioapic::new(irq_socket)?)),
+ ioapic: Arc::new(Mutex::new(Ioapic::new(irq_tube, ioapic_pins)?)),
+ ioapic_pins,
delayed_ioapic_irq_events: Arc::new(Mutex::new(Vec::new())),
irq_events: Arc::new(Mutex::new(Default::default())),
};
// Setup standard x86 irq routes
- let mut routes = kvm_default_irq_routing_table();
- // Add dummy MSI routes for the first 24 GSIs
- routes.append(&mut kvm_dummy_msi_routes());
+ let mut routes = kvm_default_irq_routing_table(ioapic_pins);
+ // Add dummy MSI routes for the first ioapic_pins GSIs
+ routes.append(&mut kvm_dummy_msi_routes(ioapic_pins));
// Set the routes so they get sent to KVM
chip.set_irq_routes(&routes)?;
@@ -252,16 +258,15 @@ impl KvmSplitIrqChip {
/// Check if the specified vcpu has any pending interrupts. Returns None for no interrupts,
/// otherwise Some(u32) should be the injected interrupt vector. For KvmSplitIrqChip
/// this calls get_external_interrupt on the pic.
- fn get_external_interrupt(&self, vcpu_id: usize) -> Result<Option<u32>> {
+ fn get_external_interrupt(&self, vcpu_id: usize) -> Option<u32> {
// Pic interrupts for the split irqchip only go to vcpu 0
if vcpu_id != 0 {
- return Ok(None);
- }
- if let Some(vector) = self.pic.lock().get_external_interrupt() {
- Ok(Some(vector as u32))
- } else {
- Ok(None)
+ return None;
}
+ self.pic
+ .lock()
+ .get_external_interrupt()
+ .map(|vector| vector as u32)
}
}
@@ -313,7 +318,7 @@ impl IrqChip for KvmSplitIrqChip {
irq_event: &Event,
resample_event: Option<&Event>,
) -> Result<Option<IrqEventIndex>> {
- if irq < NUM_IOAPIC_PINS as u32 {
+ if irq < self.ioapic_pins as u32 {
let mut evt = IrqEvent {
gsi: irq,
event: irq_event.try_clone()?,
@@ -336,7 +341,7 @@ impl IrqChip for KvmSplitIrqChip {
/// Unregister an event for a particular GSI.
fn unregister_irq_event(&mut self, irq: u32, irq_event: &Event) -> Result<()> {
- if irq < NUM_IOAPIC_PINS as u32 {
+ if irq < self.ioapic_pins as u32 {
let mut irq_events = self.irq_events.lock();
for (index, evt) in irq_events.iter().enumerate() {
if let Some(evt) = evt {
@@ -470,8 +475,8 @@ impl IrqChip for KvmSplitIrqChip {
return Ok(());
}
- if let Some(vector) = self.get_external_interrupt(vcpu_id)? {
- vcpu.interrupt(vector as u32)?;
+ if let Some(vector) = self.get_external_interrupt(vcpu_id) {
+ vcpu.interrupt(vector)?;
}
// The second interrupt request should be handled immediately, so ask vCPU to exit as soon as
@@ -524,6 +529,7 @@ impl IrqChip for KvmSplitIrqChip {
pit: self.pit.clone(),
pic: self.pic.clone(),
ioapic: self.ioapic.clone(),
+ ioapic_pins: self.ioapic_pins,
delayed_ioapic_irq_events: self.delayed_ioapic_irq_events.clone(),
irq_events: self.irq_events.clone(),
})
@@ -558,13 +564,13 @@ impl IrqChip for KvmSplitIrqChip {
// At this point, all of our devices have been created and they have registered their
// irq events, so we can clone our resample events
let mut ioapic_resample_events: Vec<Vec<Event>> =
- (0..NUM_IOAPIC_PINS).map(|_| Vec::new()).collect();
+ (0..self.ioapic_pins).map(|_| Vec::new()).collect();
let mut pic_resample_events: Vec<Vec<Event>> =
- (0..NUM_IOAPIC_PINS).map(|_| Vec::new()).collect();
+ (0..self.ioapic_pins).map(|_| Vec::new()).collect();
for evt in self.irq_events.lock().iter() {
if let Some(evt) = evt {
- if (evt.gsi as usize) >= NUM_IOAPIC_PINS {
+ if (evt.gsi as usize) >= self.ioapic_pins {
continue;
}
if let Some(resample_evt) = &evt.resample_event {
@@ -583,9 +589,9 @@ impl IrqChip for KvmSplitIrqChip {
.lock()
.register_resample_events(pic_resample_events);
- // Make sure all future irq numbers are >= NUM_IOAPIC_PINS
+ // Make sure all future irq numbers are beyond IO-APIC range.
let mut irq_num = resources.allocate_irq().unwrap();
- while irq_num < NUM_IOAPIC_PINS as u32 {
+ while irq_num < self.ioapic_pins as u32 {
irq_num = resources.allocate_irq().unwrap();
}
@@ -701,7 +707,6 @@ mod tests {
use vm_memory::GuestMemory;
use hypervisor::{IoapicRedirectionTableEntry, PitRWMode, TriggerMode, Vm, VmX86_64};
- use vm_control::{VmIrqRequest, VmIrqResponse};
use super::super::super::tests::*;
use crate::IrqChip;
@@ -728,13 +733,13 @@ mod tests {
let mem = GuestMemory::new(&[]).unwrap();
let vm = KvmVm::new(&kvm, mem).expect("failed tso instantiate vm");
- let (_, device_socket) =
- msg_socket::pair::<VmIrqResponse, VmIrqRequest>().expect("failed to create irq socket");
+ let (_, device_tube) = Tube::pair().expect("failed to create irq tube");
let mut chip = KvmSplitIrqChip::new(
vm.try_clone().expect("failed to clone vm"),
1,
- device_socket,
+ device_tube,
+ None,
)
.expect("failed to instantiate KvmKernelIrqChip");
@@ -940,7 +945,7 @@ mod tests {
.expect("failed to get external interrupt"),
// Vector is 9 because the interrupt vector base address is 0x08 and this is irq
// line 1 and 8+1 = 9
- Some(0x9)
+ 0x9
);
assert_eq!(
@@ -984,7 +989,7 @@ mod tests {
assert_eq!(
chip.get_external_interrupt(0)
.expect("failed to get external interrupt"),
- Some(0)
+ 0,
);
// interrupt is not requested twice
diff --git a/devices/src/lib.rs b/devices/src/lib.rs
index 5299ef5a6..60b46934f 100644
--- a/devices/src/lib.rs
+++ b/devices/src/lib.rs
@@ -6,6 +6,10 @@
mod bus;
mod cmos;
+#[cfg(feature = "direct")]
+pub mod direct_io;
+#[cfg(feature = "direct")]
+pub mod direct_irq;
mod i8042;
pub mod irqchip;
mod pci;
@@ -27,8 +31,12 @@ pub mod virtio;
pub use self::acpi::ACPIPMResource;
pub use self::bat::{BatteryError, GoldfishBattery};
pub use self::bus::Error as BusError;
-pub use self::bus::{Bus, BusAccessInfo, BusDevice, BusRange, BusResumeDevice};
+pub use self::bus::{Bus, BusAccessInfo, BusDevice, BusDeviceSync, BusRange, BusResumeDevice};
pub use self::cmos::Cmos;
+#[cfg(feature = "direct")]
+pub use self::direct_io::DirectIo;
+#[cfg(feature = "direct")]
+pub use self::direct_irq::{DirectIrq, DirectIrqError};
pub use self::i8042::I8042Device;
pub use self::irqchip::*;
#[cfg(feature = "audio")]
diff --git a/devices/src/pci/ac97.rs b/devices/src/pci/ac97.rs
index b6517621d..c618b14d6 100644
--- a/devices/src/pci/ac97.rs
+++ b/devices/src/pci/ac97.rs
@@ -10,7 +10,7 @@ use std::str::FromStr;
use audio_streams::shm_streams::{NullShmStreamSource, ShmStreamSource};
use base::{error, Event, RawDescriptor};
-use libcras::{CrasClient, CrasClientType, CrasSocketType};
+use libcras::{CrasClient, CrasClientType, CrasSocketType, CrasSysError};
use resources::{Alloc, MmioType, SystemAllocator};
use vm_memory::GuestMemory;
@@ -22,9 +22,9 @@ use crate::pci::pci_configuration::{
};
use crate::pci::pci_device::{self, PciDevice, Result};
use crate::pci::{PciAddress, PciDeviceError, PciInterruptPin};
-#[cfg(not(target_os = "linux"))]
+#[cfg(not(any(target_os = "linux", target_os = "android")))]
use crate::virtio::snd::vios_backend::Error as VioSError;
-#[cfg(target_os = "linux")]
+#[cfg(any(target_os = "linux", target_os = "android"))]
use crate::virtio::snd::vios_backend::VioSShmStreamSource;
// Use 82801AA because it's what qemu does.
@@ -84,6 +84,17 @@ pub struct Ac97Parameters {
pub backend: Ac97Backend,
pub capture: bool,
pub vios_server_path: Option<PathBuf>,
+ client_type: Option<CrasClientType>,
+}
+
+impl Ac97Parameters {
+ /// Set CRAS client type by given client type string.
+ ///
+ /// `client_type` - The client type string.
+ pub fn set_client_type(&mut self, client_type: &str) -> std::result::Result<(), CrasSysError> {
+ self.client_type = Some(client_type.parse()?);
+ Ok(())
+ }
}
pub struct Ac97Dev {
@@ -115,6 +126,7 @@ impl Ac97Dev {
PciHeaderType::Device,
0x8086, // Subsystem Vendor ID
0x1, // Subsystem ID.
+ 0, // Revision ID.
);
Self {
@@ -137,10 +149,10 @@ impl Ac97Dev {
"Ac97Dev: create_cras_audio_device: {}. Fallback to null audio device",
e
);
- Self::create_null_audio_device(mem)
+ Ok(Self::create_null_audio_device(mem))
}),
Ac97Backend::VIOS => Self::create_vios_audio_device(mem, param),
- Ac97Backend::NULL => Self::create_null_audio_device(mem),
+ Ac97Backend::NULL => Ok(Self::create_null_audio_device(mem)),
}
}
@@ -158,7 +170,11 @@ impl Ac97Dev {
CrasClient::with_type(CrasSocketType::Unified)
.map_err(pci_device::Error::CreateCrasClientFailed)?,
);
- server.set_client_type(CrasClientType::CRAS_CLIENT_TYPE_CROSVM);
+ server.set_client_type(
+ params
+ .client_type
+ .unwrap_or(CrasClientType::CRAS_CLIENT_TYPE_CROSVM),
+ );
if params.capture {
server.enable_cras_capture();
}
@@ -168,7 +184,7 @@ impl Ac97Dev {
}
fn create_vios_audio_device(mem: GuestMemory, param: Ac97Parameters) -> Result<Self> {
- #[cfg(target_os = "linux")]
+ #[cfg(any(target_os = "linux", target_os = "android"))]
{
let server = Box::new(
// The presence of vios_server_path is checked during argument parsing
@@ -178,16 +194,15 @@ impl Ac97Dev {
let vios_audio = Self::new(mem, Ac97Backend::VIOS, server);
return Ok(vios_audio);
}
- #[cfg(not(target_os = "linux"))]
+ #[cfg(not(any(target_os = "linux", target_os = "android")))]
Err(pci_device::Error::CreateViosClientFailed(
VioSError::PlatformNotSupported,
))
}
- fn create_null_audio_device(mem: GuestMemory) -> Result<Self> {
+ fn create_null_audio_device(mem: GuestMemory) -> Self {
let server = Box::new(NullShmStreamSource::new());
- let null_audio = Self::new(mem, Ac97Backend::NULL, server);
- Ok(null_audio)
+ Self::new(mem, Ac97Backend::NULL, server)
}
fn read_mixer(&mut self, offset: u64, data: &mut [u8]) {
diff --git a/devices/src/pci/ac97_bus_master.rs b/devices/src/pci/ac97_bus_master.rs
index 321bf5cac..e121a6f7b 100644
--- a/devices/src/pci/ac97_bus_master.rs
+++ b/devices/src/pci/ac97_bus_master.rs
@@ -3,7 +3,6 @@
// found in the LICENSE file.
use std::collections::VecDeque;
-use std::convert::AsRef;
use std::convert::TryInto;
use std::fmt::{self, Display};
use std::sync::atomic::{AtomicBool, Ordering};
@@ -16,7 +15,8 @@ use audio_streams::{
BoxError, NoopStreamControl, SampleFormat, StreamControl, StreamDirection, StreamEffect,
};
use base::{
- self, error, set_rt_prio_limit, set_rt_round_robin, warn, AsRawDescriptor, Event, RawDescriptor,
+ self, error, set_rt_prio_limit, set_rt_round_robin, warn, AsRawDescriptors, Event,
+ RawDescriptor,
};
use sync::{Condvar, Mutex};
use vm_memory::{GuestAddress, GuestMemory};
@@ -70,22 +70,22 @@ impl Ac97BusMasterRegs {
}
}
- fn channel_count(&self, func: Ac97Function) -> usize {
- fn output_channel_count(glob_cnt: u32) -> usize {
+ fn tube_count(&self, func: Ac97Function) -> usize {
+ fn output_tube_count(glob_cnt: u32) -> usize {
let val = (glob_cnt & GLOB_CNT_PCM_246_MASK) >> 20;
match val {
0 => 2,
1 => 4,
2 => 6,
_ => {
- warn!("unknown channel_count: 0x{:x}", val);
+ warn!("unknown tube_count: 0x{:x}", val);
2
}
}
}
match func {
- Ac97Function::Output => output_channel_count(self.glob_cnt),
+ Ac97Function::Output => output_tube_count(self.glob_cnt),
_ => DEVICE_INPUT_CHANNEL_COUNT,
}
}
@@ -130,6 +130,8 @@ type GuestMemoryResult<T> = std::result::Result<T, GuestMemoryError>;
enum AudioError {
// Failed to create a new stream.
CreateStream(BoxError),
+ // Failure to get regions from guest memory.
+ GuestRegion(GuestMemoryError),
// Invalid buffer offset received from the audio server.
InvalidBufferOffset,
// Guest did not provide a buffer when needed.
@@ -150,6 +152,7 @@ impl Display for AudioError {
match self {
CreateStream(e) => write!(f, "Failed to create audio stream: {}.", e),
+ GuestRegion(e) => write!(f, "Failed to get guest memory region: {}.", e),
InvalidBufferOffset => write!(f, "Offset > max usize"),
NoBufferAvailable => write!(f, "No buffer was available from the Guest"),
ReadingGuestError(e) => write!(f, "Failed to read guest memory: {}.", e),
@@ -255,7 +258,7 @@ impl Ac97BusMaster {
/// Returns any file descriptors that need to be kept open when entering a jail.
pub fn keep_rds(&self) -> Option<Vec<RawDescriptor>> {
let mut rds = self.audio_server.keep_fds();
- rds.push(self.mem.as_raw_descriptor());
+ rds.append(&mut self.mem.as_raw_descriptors());
Some(rds)
}
@@ -341,7 +344,7 @@ impl Ac97BusMaster {
regs.po_regs.picb
} else {
// Estimate how many samples have been played since the last audio callback.
- let num_channels = regs.channel_count(Ac97Function::Output) as u64;
+ let num_channels = regs.tube_count(Ac97Function::Output) as u64;
let micros = regs.po_pointer_update_time.elapsed().subsec_micros();
// Round down to the next 10 millisecond boundary. The linux driver often
// assumes that two rapid reads from picb will return the same value.
@@ -562,7 +565,7 @@ impl Ac97BusMaster {
let locked_regs = self.regs.lock();
let sample_rate = self.current_sample_rate(func, mixer);
let buffer_samples = current_buffer_size(locked_regs.func_regs(func), &self.mem)?;
- let num_channels = locked_regs.channel_count(func);
+ let num_channels = locked_regs.tube_count(func);
let buffer_frames = buffer_samples / num_channels;
let mut pending_buffers = VecDeque::with_capacity(2);
@@ -588,7 +591,12 @@ impl Ac97BusMaster {
sample_rate,
buffer_frames,
&Self::stream_effects(func),
- self.mem.as_ref().inner(),
+ self.mem
+ .offset_region(starting_offsets[0])
+ .map_err(|e| {
+ AudioError::GuestRegion(GuestMemoryError::ReadingGuestBufferAddress(e))
+ })?
+ .inner(),
starting_offsets,
)
.map_err(AudioError::CreateStream)?;
@@ -712,7 +720,7 @@ fn next_guest_buffer(
// 0 h l n
// +++++++++......++++
(low > high && (low <= value || value <= high))
- };
+ }
// Check if
// * we're halted
@@ -729,7 +737,7 @@ fn next_guest_buffer(
let offset = get_buffer_offset(func_regs, mem, index)?
.try_into()
.map_err(|_| AudioError::InvalidBufferOffset)?;
- let frames = get_buffer_samples(func_regs, mem, index)? / regs.channel_count(func);
+ let frames = get_buffer_samples(func_regs, mem, index)? / regs.tube_count(func);
Ok(Some(GuestBuffer {
index,
@@ -1037,7 +1045,7 @@ mod test {
}
#[test]
- fn run_multi_channel_playback() {
+ fn run_multi_tube_playback() {
start_playback(2, 48000);
start_playback(4, 48000);
start_playback(6, 48000);
@@ -1087,7 +1095,7 @@ mod test {
bm.writeb(PO_LVI_15, LVI_MASK, &mixer);
assert_eq!(bm.readb(PO_CIV_14), 0);
- // Set channel count and sample rate.
+ // Set tube count and sample rate.
let mut cnt = bm.readl(GLOB_CNT_2C);
cnt &= !GLOB_CNT_PCM_246_MASK;
mixer.writew(MIXER_PCM_FRONT_DAC_RATE_2C, rate);
diff --git a/devices/src/pci/ac97_mixer.rs b/devices/src/pci/ac97_mixer.rs
index 7ab918d5e..bf1fcda54 100644
--- a/devices/src/pci/ac97_mixer.rs
+++ b/devices/src/pci/ac97_mixer.rs
@@ -119,7 +119,7 @@ impl Ac97Mixer {
/// Returns the front sample rate (reg 0x2c).
pub fn get_sample_rate(&self) -> u16 {
// MIXER_PCM_FRONT_DAC_RATE_2C, MIXER_PCM_SURR_DAC_RATE_2E, and MIXER_PCM_LFE_DAC_RATE_30
- // are updated to the same rate when playback with 2,4 and 6 channels.
+ // are updated to the same rate when playback with 2,4 and 6 tubes.
self.pcm_front_dac_rate
}
diff --git a/devices/src/pci/ac97_regs.rs b/devices/src/pci/ac97_regs.rs
index 26ac45a5c..afe06998f 100644
--- a/devices/src/pci/ac97_regs.rs
+++ b/devices/src/pci/ac97_regs.rs
@@ -58,7 +58,7 @@ pub const MIXER_EI_SDAC: u16 = 0x0080; // PCM Surround DAC is available.
pub const MIXER_EI_LDAC: u16 = 0x0100; // PCM LFE DAC is available.
// Basic capabilities for MIXER_RESET_00
-pub const BC_DEDICATED_MIC: u16 = 0x0001; /* Dedicated Mic PCM In Channel */
+pub const BC_DEDICATED_MIC: u16 = 0x0001; /* Dedicated Mic PCM In Tube */
// Bus Master regs from ICH spec:
// 00h PI_BDBAR PCM In Buffer Descriptor list Base Address Register
@@ -93,14 +93,14 @@ pub const GLOB_CNT_WARM_RESET: u32 = 0x0000_0004;
pub const GLOB_CNT_STABLE_BITS: u32 = 0x0000_007f; // Bits not affected by reset.
// PCM 4/6 Enable bits
-pub const GLOB_CNT_PCM_2: u32 = 0x0000_0000; // 2 channels
-pub const GLOB_CNT_PCM_4: u32 = 0x0010_0000; // 4 channels
-pub const GLOB_CNT_PCM_6: u32 = 0x0020_0000; // 6 channels
-pub const GLOB_CNT_PCM_246_MASK: u32 = GLOB_CNT_PCM_4 | GLOB_CNT_PCM_6; // channel mask
+pub const GLOB_CNT_PCM_2: u32 = 0x0000_0000; // 2 tubes
+pub const GLOB_CNT_PCM_4: u32 = 0x0010_0000; // 4 tubes
+pub const GLOB_CNT_PCM_6: u32 = 0x0020_0000; // 6 tubes
+pub const GLOB_CNT_PCM_246_MASK: u32 = GLOB_CNT_PCM_4 | GLOB_CNT_PCM_6; // tube mask
// Global status
pub const GLOB_STA_30: u64 = 0x30;
-// Primary codec ready set and turn on D20:21 to support 4 and 6 channels on PCM out.
+// Primary codec ready set and turn on D20:21 to support 4 and 6 tubes on PCM out.
pub const GLOB_STA_RESET_VAL: u32 = 0x0030_0100;
// glob_sta bits
diff --git a/devices/src/pci/msix.rs b/devices/src/pci/msix.rs
index 7240793c3..2852aefd6 100644
--- a/devices/src/pci/msix.rs
+++ b/devices/src/pci/msix.rs
@@ -3,11 +3,11 @@
// found in the LICENSE file.
use crate::pci::{PciCapability, PciCapabilityID};
-use base::{error, AsRawDescriptor, Error as SysError, Event, RawDescriptor};
-use msg_socket::{MsgError, MsgReceiver, MsgSender};
+use base::{error, AsRawDescriptor, Error as SysError, Event, RawDescriptor, Tube, TubeError};
+
use std::convert::TryInto;
use std::fmt::{self, Display};
-use vm_control::{MaybeOwnedDescriptor, VmIrqRequest, VmIrqRequestSocket, VmIrqResponse};
+use vm_control::{VmIrqRequest, VmIrqResponse};
use data_model::DataInit;
@@ -55,17 +55,17 @@ pub struct MsixConfig {
irq_vec: Vec<IrqfdGsi>,
masked: bool,
enabled: bool,
- msi_device_socket: VmIrqRequestSocket,
+ msi_device_socket: Tube,
msix_num: u16,
}
enum MsixError {
AddMsiRoute(SysError),
- AddMsiRouteRecv(MsgError),
- AddMsiRouteSend(MsgError),
+ AddMsiRouteRecv(TubeError),
+ AddMsiRouteSend(TubeError),
AllocateOneMsi(SysError),
- AllocateOneMsiRecv(MsgError),
- AllocateOneMsiSend(MsgError),
+ AllocateOneMsiRecv(TubeError),
+ AllocateOneMsiSend(TubeError),
}
impl Display for MsixError {
@@ -94,7 +94,7 @@ pub enum MsixStatus {
}
impl MsixConfig {
- pub fn new(msix_vectors: u16, vm_socket: VmIrqRequestSocket) -> Self {
+ pub fn new(msix_vectors: u16, vm_socket: Tube) -> Self {
assert!(msix_vectors <= MAX_MSIX_VECTORS_PER_DEVICE);
let mut table_entries: Vec<MsixTableEntry> = Vec::new();
@@ -235,10 +235,9 @@ impl MsixConfig {
self.irq_vec.clear();
for i in 0..self.msix_num {
let irqfd = Event::new().unwrap();
+ let request = VmIrqRequest::AllocateOneMsi { irqfd };
self.msi_device_socket
- .send(&VmIrqRequest::AllocateOneMsi {
- irqfd: MaybeOwnedDescriptor::Borrowed(irqfd.as_raw_descriptor()),
- })
+ .send(&request)
.map_err(MsixError::AllocateOneMsiSend)?;
let irq_num: u32;
match self
@@ -251,7 +250,10 @@ impl MsixConfig {
_ => unreachable!(),
}
self.irq_vec.push(IrqfdGsi {
- irqfd,
+ irqfd: match request {
+ VmIrqRequest::AllocateOneMsi { irqfd } => irqfd,
+ _ => unreachable!(),
+ },
gsi: irq_num,
});
@@ -498,7 +500,7 @@ impl MsixConfig {
/// Return the raw fd of the MSI device socket
pub fn get_msi_socket(&self) -> RawDescriptor {
- self.msi_device_socket.as_ref().as_raw_descriptor()
+ self.msi_device_socket.as_raw_descriptor()
}
/// Return irqfd of MSI-X Table entry
diff --git a/devices/src/pci/pci_configuration.rs b/devices/src/pci/pci_configuration.rs
index 0ae8f585d..8654b0171 100644
--- a/devices/src/pci/pci_configuration.rs
+++ b/devices/src/pci/pci_configuration.rs
@@ -24,7 +24,7 @@ const BAR_MEM_MIN_SIZE: u64 = 16;
const NUM_BAR_REGS: usize = 6;
const CAPABILITY_LIST_HEAD_OFFSET: usize = 0x34;
const FIRST_CAPABILITY_OFFSET: usize = 0x40;
-const CAPABILITY_MAX_OFFSET: usize = 192;
+const CAPABILITY_MAX_OFFSET: usize = 255;
const INTERRUPT_LINE_PIN_REG: usize = 15;
@@ -292,6 +292,7 @@ impl PciConfiguration {
header_type: PciHeaderType,
subsystem_vendor_id: u16,
subsystem_id: u16,
+ revision_id: u8,
) -> Self {
let mut registers = [0u32; NUM_CONFIGURATION_REGISTERS];
let mut writable_bits = [0u32; NUM_CONFIGURATION_REGISTERS];
@@ -305,7 +306,8 @@ impl PciConfiguration {
};
registers[2] = u32::from(class_code.get_register_value()) << 24
| u32::from(subclass.get_register_value()) << 16
- | u32::from(pi) << 8;
+ | u32::from(pi) << 8
+ | u32::from(revision_id);
writable_bits[3] = 0x0000_00ff; // Cacheline size (r/w)
match header_type {
PciHeaderType::Device => {
@@ -471,11 +473,17 @@ impl PciConfiguration {
}
let (mask, lower_bits) = match config.region_type {
- PciBarRegionType::Memory32BitRegion | PciBarRegionType::Memory64BitRegion => (
- BAR_MEM_ADDR_MASK,
- config.prefetchable as u32 | config.region_type as u32,
- ),
- PciBarRegionType::IORegion => (BAR_IO_ADDR_MASK, config.region_type as u32),
+ PciBarRegionType::Memory32BitRegion | PciBarRegionType::Memory64BitRegion => {
+ self.registers[COMMAND_REG] |= COMMAND_REG_MEMORY_SPACE_MASK;
+ (
+ BAR_MEM_ADDR_MASK,
+ config.prefetchable as u32 | config.region_type as u32,
+ )
+ }
+ PciBarRegionType::IORegion => {
+ self.registers[COMMAND_REG] |= COMMAND_REG_IO_SPACE_MASK;
+ (BAR_IO_ADDR_MASK, config.region_type as u32)
+ }
};
self.registers[bar_idx] = ((config.addr as u32) & mask) | lower_bits;
@@ -673,6 +681,7 @@ mod tests {
PciHeaderType::Device,
0xABCD,
0x2468,
+ 0,
);
// Add two capabilities with different contents.
@@ -734,6 +743,7 @@ mod tests {
PciHeaderType::Device,
0xABCD,
0x2468,
+ 0,
);
let class_reg = cfg.read_reg(2);
@@ -756,6 +766,7 @@ mod tests {
PciHeaderType::Device,
0xABCD,
0x2468,
+ 0,
);
// Attempt to overwrite vendor ID and device ID, which are read-only
@@ -775,6 +786,7 @@ mod tests {
PciHeaderType::Device,
0xABCD,
0x2468,
+ 0,
);
// No BAR 0 has been configured, so these should return None or 0 as appropriate.
@@ -796,6 +808,7 @@ mod tests {
PciHeaderType::Device,
0xABCD,
0x2468,
+ 0,
);
cfg.add_pci_bar(
@@ -842,6 +855,7 @@ mod tests {
PciHeaderType::Device,
0xABCD,
0x2468,
+ 0,
);
cfg.add_pci_bar(
@@ -887,6 +901,7 @@ mod tests {
PciHeaderType::Device,
0xABCD,
0x2468,
+ 0,
);
cfg.add_pci_bar(
@@ -929,6 +944,7 @@ mod tests {
PciHeaderType::Device,
0xABCD,
0x2468,
+ 0,
);
// bar_num 0-1: 64-bit memory
@@ -1050,6 +1066,7 @@ mod tests {
PciHeaderType::Device,
0xABCD,
0x2468,
+ 0,
);
// I/O BAR with size 2 (too small)
diff --git a/devices/src/pci/pci_device.rs b/devices/src/pci/pci_device.rs
index 25e37cdc3..e2a1b67e2 100644
--- a/devices/src/pci/pci_device.rs
+++ b/devices/src/pci/pci_device.rs
@@ -282,6 +282,7 @@ mod tests {
PciHeaderType::Device,
0x5678,
0xEF01,
+ 0,
),
};
diff --git a/devices/src/pci/pci_root.rs b/devices/src/pci/pci_root.rs
index d2f42cb10..4eb1b3407 100644
--- a/devices/src/pci/pci_root.rs
+++ b/devices/src/pci/pci_root.rs
@@ -145,6 +145,7 @@ impl PciRoot {
PciHeaderType::Device,
0,
0,
+ 0,
),
},
devices: BTreeMap::new(),
diff --git a/devices/src/pci/vfio_pci.rs b/devices/src/pci/vfio_pci.rs
index ae043d1bd..46ee5372b 100644
--- a/devices/src/pci/vfio_pci.rs
+++ b/devices/src/pci/vfio_pci.rs
@@ -6,17 +6,15 @@ use std::sync::Arc;
use std::u32;
use base::{
- error, AsRawDescriptor, Event, MappedRegion, MemoryMapping, MemoryMappingBuilder, RawDescriptor,
+ error, pagesize, AsRawDescriptor, Event, MappedRegion, MemoryMapping, MemoryMappingBuilder,
+ RawDescriptor, Tube,
};
use hypervisor::Datamatch;
-use msg_socket::{MsgReceiver, MsgSender};
+
use resources::{Alloc, MmioType, SystemAllocator};
use vfio_sys::*;
-use vm_control::{
- MaybeOwnedDescriptor, VmIrqRequest, VmIrqRequestSocket, VmIrqResponse,
- VmMemoryControlRequestSocket, VmMemoryRequest, VmMemoryResponse,
-};
+use vm_control::{VmIrqRequest, VmIrqResponse, VmMemoryRequest, VmMemoryResponse};
use crate::pci::msix::{
MsixConfig, BITS_PER_PBA_ENTRY, MSIX_PBA_ENTRIES_MODULO, MSIX_TABLE_ENTRIES_MODULO,
@@ -128,13 +126,13 @@ struct VfioMsiCap {
ctl: u16,
address: u64,
data: u16,
- vm_socket_irq: VmIrqRequestSocket,
+ vm_socket_irq: Tube,
irqfd: Option<Event>,
gsi: Option<u32>,
}
impl VfioMsiCap {
- fn new(config: &VfioPciConfig, msi_cap_start: u32, vm_socket_irq: VmIrqRequestSocket) -> Self {
+ fn new(config: &VfioPciConfig, msi_cap_start: u32, vm_socket_irq: Tube) -> Self {
let msi_ctl = config.read_config_word(msi_cap_start + PCI_MSI_FLAGS);
VfioMsiCap {
@@ -253,19 +251,27 @@ impl VfioMsiCap {
}
fn allocate_one_msi(&mut self) {
- if self.irqfd.is_none() {
- match Event::new() {
- Ok(fd) => self.irqfd = Some(fd),
+ let irqfd = match self.irqfd.take() {
+ Some(e) => e,
+ None => match Event::new() {
+ Ok(e) => e,
Err(e) => {
error!("failed to create event: {:?}", e);
return;
}
- };
- }
+ },
+ };
- if let Err(e) = self.vm_socket_irq.send(&VmIrqRequest::AllocateOneMsi {
- irqfd: MaybeOwnedDescriptor::Borrowed(self.irqfd.as_ref().unwrap().as_raw_descriptor()),
- }) {
+ let request = VmIrqRequest::AllocateOneMsi { irqfd };
+ let request_result = self.vm_socket_irq.send(&request);
+
+ // Stash the irqfd in self immediately because we used take above.
+ self.irqfd = match request {
+ VmIrqRequest::AllocateOneMsi { irqfd } => Some(irqfd),
+ _ => unreachable!(),
+ };
+
+ if let Err(e) = request_result {
error!("failed to send AllocateOneMsi request: {:?}", e);
return;
}
@@ -310,7 +316,7 @@ struct VfioMsixCap {
}
impl VfioMsixCap {
- fn new(config: &VfioPciConfig, msix_cap_start: u32, vm_socket_irq: VmIrqRequestSocket) -> Self {
+ fn new(config: &VfioPciConfig, msix_cap_start: u32, vm_socket_irq: Tube) -> Self {
let msix_ctl = config.read_config_word(msix_cap_start + PCI_MSIX_FLAGS);
let table_size = (msix_ctl & PCI_MSIX_FLAGS_QSIZE) + 1;
let table = config.read_config_dword(msix_cap_start + PCI_MSIX_TABLE);
@@ -441,7 +447,7 @@ pub struct VfioPciDevice {
msi_cap: Option<VfioMsiCap>,
msix_cap: Option<VfioMsixCap>,
irq_type: Option<VfioIrqType>,
- vm_socket_mem: VmMemoryControlRequestSocket,
+ vm_socket_mem: Tube,
device_data: Option<DeviceData>,
// scratch MemoryMapping to avoid unmap beform vm exit
@@ -452,9 +458,9 @@ impl VfioPciDevice {
/// Constructs a new Vfio Pci device for the give Vfio device
pub fn new(
device: VfioDevice,
- vfio_device_socket_msi: VmIrqRequestSocket,
- vfio_device_socket_msix: VmIrqRequestSocket,
- vfio_device_socket_mem: VmMemoryControlRequestSocket,
+ vfio_device_socket_msi: Tube,
+ vfio_device_socket_msix: Tube,
+ vfio_device_socket_mem: Tube,
) -> Self {
let dev = Arc::new(device);
let config = VfioPciConfig::new(Arc::clone(&dev));
@@ -543,22 +549,25 @@ impl VfioPciDevice {
if let Some(ref interrupt_evt) = self.interrupt_evt {
let mut fds = Vec::new();
fds.push(interrupt_evt);
- if let Err(e) = self.device.irq_enable(fds, VfioIrqType::Intx) {
+ if let Err(e) = self.device.irq_enable(fds, VFIO_PCI_INTX_IRQ_INDEX) {
error!("Intx enable failed: {}", e);
return;
}
if let Some(ref irq_resample_evt) = self.interrupt_resample_evt {
- if let Err(e) = self.device.irq_mask(VfioIrqType::Intx) {
+ if let Err(e) = self.device.irq_mask(VFIO_PCI_INTX_IRQ_INDEX) {
error!("Intx mask failed: {}", e);
self.disable_intx();
return;
}
- if let Err(e) = self.device.resample_virq_enable(irq_resample_evt) {
+ if let Err(e) = self
+ .device
+ .resample_virq_enable(irq_resample_evt, VFIO_PCI_INTX_IRQ_INDEX)
+ {
error!("resample enable failed: {}", e);
self.disable_intx();
return;
}
- if let Err(e) = self.device.irq_unmask(VfioIrqType::Intx) {
+ if let Err(e) = self.device.irq_unmask(VFIO_PCI_INTX_IRQ_INDEX) {
error!("Intx unmask failed: {}", e);
self.disable_intx();
return;
@@ -570,7 +579,7 @@ impl VfioPciDevice {
}
fn disable_intx(&mut self) {
- if let Err(e) = self.device.irq_disable(VfioIrqType::Intx) {
+ if let Err(e) = self.device.irq_disable(VFIO_PCI_INTX_IRQ_INDEX) {
error!("Intx disable failed: {}", e);
}
self.irq_type = None;
@@ -610,7 +619,7 @@ impl VfioPciDevice {
let mut fds = Vec::new();
fds.push(irqfd);
- if let Err(e) = self.device.irq_enable(fds, VfioIrqType::Msi) {
+ if let Err(e) = self.device.irq_enable(fds, VFIO_PCI_MSI_IRQ_INDEX) {
error!("failed to enable msi: {}", e);
self.enable_intx();
return;
@@ -620,7 +629,7 @@ impl VfioPciDevice {
}
fn disable_msi(&mut self) {
- if let Err(e) = self.device.irq_disable(VfioIrqType::Msi) {
+ if let Err(e) = self.device.irq_disable(VFIO_PCI_MSI_IRQ_INDEX) {
error!("failed to disable msi: {}", e);
return;
}
@@ -637,7 +646,7 @@ impl VfioPciDevice {
};
if let Some(descriptors) = irqfds {
- if let Err(e) = self.device.irq_enable(descriptors, VfioIrqType::Msix) {
+ if let Err(e) = self.device.irq_enable(descriptors, VFIO_PCI_MSIX_IRQ_INDEX) {
error!("failed to enable msix: {}", e);
self.enable_intx();
return;
@@ -651,7 +660,7 @@ impl VfioPciDevice {
}
fn disable_msix(&mut self) {
- if let Err(e) = self.device.irq_disable(VfioIrqType::Msix) {
+ if let Err(e) = self.device.irq_disable(VFIO_PCI_MSIX_IRQ_INDEX) {
error!("failed to disable msix: {}", e);
return;
}
@@ -681,10 +690,14 @@ impl VfioPciDevice {
let guest_map_start = bar_addr + mmap_offset;
let region_offset = self.device.get_region_offset(index);
let offset = region_offset + mmap_offset;
+ let descriptor = match self.device.device_file().try_clone() {
+ Ok(device_file) => device_file.into(),
+ Err(_) => break,
+ };
if self
.vm_socket_mem
.send(&VmMemoryRequest::RegisterMmapMemory {
- descriptor: MaybeOwnedDescriptor::Borrowed(self.device.as_raw_descriptor()),
+ descriptor,
size: mmap_size as usize,
offset,
gpa: guest_map_start,
@@ -694,7 +707,7 @@ impl VfioPciDevice {
break;
}
- let response = match self.vm_socket_mem.recv() {
+ let response: VmMemoryResponse = match self.vm_socket_mem.recv() {
Ok(res) => res,
Err(_) => break,
};
@@ -704,7 +717,7 @@ impl VfioPciDevice {
// device process doesn't has this mapping, but vfio_dma_map() need it
// in device process, so here map it again.
let mmap = match MemoryMappingBuilder::new(mmap_size as usize)
- .from_descriptor(self.device.as_ref())
+ .from_file(self.device.device_file())
.offset(offset)
.build()
{
@@ -712,10 +725,13 @@ impl VfioPciDevice {
Err(_e) => break,
};
let host = (&mmap).as_ptr() as u64;
+ let pgsz = pagesize() as u64;
+ let size = (mmap_size + pgsz - 1) / pgsz * pgsz;
// Safe because the given guest_map_start is valid guest bar address. and
// the host pointer is correct and valid guaranteed by MemoryMapping interface.
- match unsafe { self.device.vfio_dma_map(guest_map_start, mmap_size, host) }
- {
+ // The size will be extened to page size aligned if it is not which is also
+ // safe because VFIO actually maps the BAR with page size aligned size.
+ match unsafe { self.device.vfio_dma_map(guest_map_start, size, host) } {
Ok(_) => mem_map.push(mmap),
Err(e) => {
error!(
@@ -954,7 +970,7 @@ impl PciDevice for VfioPciDevice {
let mut config = self.config.read_config_dword(reg);
// Ignore IO bar
- if reg >= 0x10 && reg <= 0x24 {
+ if (0x10..=0x24).contains(&reg) {
for io_info in self.io_regions.iter() {
if io_info.bar_index * 4 + 0x10 == reg {
config = 0;
diff --git a/devices/src/proxy.rs b/devices/src/proxy.rs
index e044feb9b..5e22786d6 100644
--- a/devices/src/proxy.rs
+++ b/devices/src/proxy.rs
@@ -4,14 +4,14 @@
//! Runs hardware devices in child processes.
+use std::ffi::CString;
use std::fmt::{self, Display};
use std::time::Duration;
-use std::{self, io};
-use base::{error, net::UnixSeqpacket, AsRawDescriptor, RawDescriptor};
+use base::{error, AsRawDescriptor, RawDescriptor, Tube, TubeError};
use libc::{self, pid_t};
use minijail::{self, Minijail};
-use msg_socket::{MsgOnSocket, MsgReceiver, MsgSender, MsgSocket};
+use serde::{Deserialize, Serialize};
use crate::bus::ConfigWriteResult;
use crate::{BusAccessInfo, BusDevice};
@@ -20,7 +20,7 @@ use crate::{BusAccessInfo, BusDevice};
#[derive(Debug)]
pub enum Error {
ForkingJail(minijail::Error),
- Io(io::Error),
+ Tube(TubeError),
}
pub type Result<T> = std::result::Result<T, Error>;
@@ -30,14 +30,14 @@ impl Display for Error {
match self {
ForkingJail(e) => write!(f, "Failed to fork jail process: {}", e),
- Io(e) => write!(f, "IO error configuring proxy device {}.", e),
+ Tube(e) => write!(f, "Failed to configure tube: {}.", e),
}
}
}
const SOCKET_TIMEOUT_MS: u64 = 2000;
-#[derive(Debug, MsgOnSocket)]
+#[derive(Debug, Serialize, Deserialize)]
enum Command {
Read {
len: u32,
@@ -57,8 +57,7 @@ enum Command {
},
Shutdown,
}
-
-#[derive(MsgOnSocket)]
+#[derive(Debug, Serialize, Deserialize)]
enum CommandResult {
Ok,
ReadResult([u8; 8]),
@@ -69,12 +68,11 @@ enum CommandResult {
},
}
-fn child_proc<D: BusDevice>(sock: UnixSeqpacket, device: &mut D) {
+fn child_proc<D: BusDevice>(tube: Tube, device: &mut D) {
let mut running = true;
- let sock = MsgSocket::<CommandResult, Command>::new(sock);
while running {
- let cmd = match sock.recv() {
+ let cmd = match tube.recv() {
Ok(cmd) => cmd,
Err(err) => {
error!("child device process failed recv: {}", err);
@@ -86,7 +84,7 @@ fn child_proc<D: BusDevice>(sock: UnixSeqpacket, device: &mut D) {
Command::Read { len, info } => {
let mut buffer = [0u8; 8];
device.read(info, &mut buffer[0..len as usize]);
- sock.send(&CommandResult::ReadResult(buffer))
+ tube.send(&CommandResult::ReadResult(buffer))
}
Command::Write { len, info, data } => {
let len = len as usize;
@@ -96,7 +94,7 @@ fn child_proc<D: BusDevice>(sock: UnixSeqpacket, device: &mut D) {
}
Command::ReadConfig(idx) => {
let val = device.config_register_read(idx as usize);
- sock.send(&CommandResult::ReadConfigResult(val))
+ tube.send(&CommandResult::ReadConfigResult(val))
}
Command::WriteConfig {
reg_idx,
@@ -107,14 +105,14 @@ fn child_proc<D: BusDevice>(sock: UnixSeqpacket, device: &mut D) {
let len = len as usize;
let res =
device.config_register_write(reg_idx as usize, offset as u64, &data[0..len]);
- sock.send(&CommandResult::WriteConfigResult {
+ tube.send(&CommandResult::WriteConfigResult {
mem_bus_new_state: res.mem_bus_new_state,
io_bus_new_state: res.io_bus_new_state,
})
}
Command::Shutdown => {
running = false;
- sock.send(&CommandResult::Ok)
+ tube.send(&CommandResult::Ok)
}
};
if let Err(e) = res {
@@ -128,7 +126,7 @@ fn child_proc<D: BusDevice>(sock: UnixSeqpacket, device: &mut D) {
/// Because forks are very unfriendly to destructors and all memory mappings and file descriptors
/// are inherited, this should be used as early as possible in the main process.
pub struct ProxyDevice {
- sock: MsgSocket<Command, CommandResult>,
+ tube: Tube,
pid: pid_t,
debug_label: String,
}
@@ -149,15 +147,23 @@ impl ProxyDevice {
mut keep_rds: Vec<RawDescriptor>,
) -> Result<ProxyDevice> {
let debug_label = device.debug_label();
- let (child_sock, parent_sock) = UnixSeqpacket::pair().map_err(Error::Io)?;
+ let (child_tube, parent_tube) = Tube::pair().map_err(Error::Tube)?;
- keep_rds.push(child_sock.as_raw_descriptor());
+ keep_rds.push(child_tube.as_raw_descriptor());
// Forking here is safe as long as the program is still single threaded.
let pid = unsafe {
match jail.fork(Some(&keep_rds)).map_err(Error::ForkingJail)? {
0 => {
+ let max_len = 15; // pthread_setname_np() limit on Linux
+ let debug_label_trimmed =
+ &debug_label.as_bytes()[..std::cmp::min(max_len, debug_label.len())];
+ let thread_name = CString::new(debug_label_trimmed).unwrap();
+ // TODO(crbug.com/1199487): remove this once libc provides the wrapper for all
+ // targets
+ #[cfg(all(target_os = "linux", target_env = "gnu"))]
+ let _ = libc::pthread_setname_np(libc::pthread_self(), thread_name.as_ptr());
device.on_sandboxed();
- child_proc(child_sock, &mut device);
+ child_proc(child_tube, &mut device);
// We're explicitly not using std::process::exit here to avoid the cleanup of
// stdout/stderr globals. This can cause cascading panics and SIGILL if a worker
@@ -173,14 +179,14 @@ impl ProxyDevice {
}
};
- parent_sock
- .set_write_timeout(Some(Duration::from_millis(SOCKET_TIMEOUT_MS)))
- .map_err(Error::Io)?;
- parent_sock
- .set_read_timeout(Some(Duration::from_millis(SOCKET_TIMEOUT_MS)))
- .map_err(Error::Io)?;
+ parent_tube
+ .set_send_timeout(Some(Duration::from_millis(SOCKET_TIMEOUT_MS)))
+ .map_err(Error::Tube)?;
+ parent_tube
+ .set_recv_timeout(Some(Duration::from_millis(SOCKET_TIMEOUT_MS)))
+ .map_err(Error::Tube)?;
Ok(ProxyDevice {
- sock: MsgSocket::<Command, CommandResult>::new(parent_sock),
+ tube: parent_tube,
pid,
debug_label,
})
@@ -192,7 +198,7 @@ impl ProxyDevice {
/// Send a command that does not expect a response from the child device process.
fn send_no_result(&self, cmd: &Command) {
- let res = self.sock.send(cmd);
+ let res = self.tube.send(cmd);
if let Err(e) = res {
error!(
"failed write to child device process {}: {}",
@@ -204,7 +210,7 @@ impl ProxyDevice {
/// Send a command and read its response from the child device process.
fn sync_send(&self, cmd: &Command) -> Option<CommandResult> {
self.send_no_result(cmd);
- match self.sock.recv() {
+ match self.tube.recv() {
Err(e) => {
error!(
"failed to read result of {:?} from child device process {}: {}",
diff --git a/devices/src/usb/host_backend/error.rs b/devices/src/usb/host_backend/error.rs
index 3f1fed4e1..8d089233a 100644
--- a/devices/src/usb/host_backend/error.rs
+++ b/devices/src/usb/host_backend/error.rs
@@ -5,7 +5,8 @@
use crate::usb::xhci::scatter_gather_buffer::Error as BufferError;
use crate::usb::xhci::xhci_transfer::Error as XhciTransferError;
use crate::utils::Error as UtilsError;
-use msg_socket::MsgError;
+
+use base::TubeError;
use std::fmt::{self, Display};
use usb_util::Error as UsbUtilError;
@@ -20,11 +21,12 @@ pub enum Error {
SetInterfaceAltSetting(UsbUtilError),
ClearHalt(UsbUtilError),
CreateTransfer(UsbUtilError),
+ Reset(UsbUtilError),
GetEndpointType,
- CreateControlSock(std::io::Error),
- SetupControlSock(std::io::Error),
- ReadControlSock(MsgError),
- WriteControlSock(MsgError),
+ CreateControlTube(TubeError),
+ SetupControlTube(TubeError),
+ ReadControlTube(TubeError),
+ WriteControlTube(TubeError),
GetXhciTransferType(XhciTransferError),
TransferComplete(XhciTransferError),
ReadBuffer(BufferError),
@@ -51,11 +53,12 @@ impl Display for Error {
SetInterfaceAltSetting(e) => write!(f, "failed to set interface alt setting: {:?}", e),
ClearHalt(e) => write!(f, "failed to clear halt: {:?}", e),
CreateTransfer(e) => write!(f, "failed to create transfer: {:?}", e),
+ Reset(e) => write!(f, "failed to reset: {:?}", e),
GetEndpointType => write!(f, "failed to get endpoint type"),
- CreateControlSock(e) => write!(f, "failed to create contro sock: {}", e),
- SetupControlSock(e) => write!(f, "failed to setup control sock: {}", e),
- ReadControlSock(e) => write!(f, "failed to read control sock: {}", e),
- WriteControlSock(e) => write!(f, "failed to write control sock: {}", e),
+ CreateControlTube(e) => write!(f, "failed to create contro tube: {}", e),
+ SetupControlTube(e) => write!(f, "failed to setup control tube: {}", e),
+ ReadControlTube(e) => write!(f, "failed to read control tube: {}", e),
+ WriteControlTube(e) => write!(f, "failed to write control tube: {}", e),
GetXhciTransferType(e) => write!(f, "failed to get xhci transfer type: {}", e),
TransferComplete(e) => write!(f, "xhci transfer completed: {}", e),
ReadBuffer(e) => write!(f, "failed to read buffer: {}", e),
diff --git a/devices/src/usb/host_backend/host_backend_device_provider.rs b/devices/src/usb/host_backend/host_backend_device_provider.rs
index c33092a15..787bb8fe0 100644
--- a/devices/src/usb/host_backend/host_backend_device_provider.rs
+++ b/devices/src/usb/host_backend/host_backend_device_provider.rs
@@ -11,19 +11,14 @@ use crate::usb::xhci::usb_hub::UsbHub;
use crate::usb::xhci::xhci_backend_device_provider::XhciBackendDeviceProvider;
use crate::utils::AsyncJobQueue;
use crate::utils::{EventHandler, EventLoop, FailHandle};
-use base::net::UnixSeqpacket;
-use base::{
- error, AsRawDescriptor, FromRawDescriptor, IntoRawDescriptor, RawDescriptor, WatchingEvents,
-};
-use msg_socket::{MsgReceiver, MsgSender, MsgSocket};
+use base::{error, AsRawDescriptor, Descriptor, RawDescriptor, Tube, WatchingEvents};
use std::collections::HashMap;
use std::mem;
use std::time::Duration;
use sync::Mutex;
use usb_util::Device;
use vm_control::{
- MaybeOwnedDescriptor, UsbControlAttachedDevice, UsbControlCommand, UsbControlResult,
- UsbControlSocket, USB_CONTROL_MAX_PORTS,
+ UsbControlAttachedDevice, UsbControlCommand, UsbControlResult, USB_CONTROL_MAX_PORTS,
};
const SOCKET_TIMEOUT_MS: u64 = 2000;
@@ -32,31 +27,27 @@ const SOCKET_TIMEOUT_MS: u64 = 2000;
/// devices.
pub enum HostBackendDeviceProvider {
// The provider is created but not yet started.
- Created {
- sock: Mutex<MsgSocket<UsbControlResult, UsbControlCommand>>,
- },
+ Created { control_tube: Mutex<Tube> },
// The provider is started on an event loop.
- Started {
- inner: Arc<ProviderInner>,
- },
+ Started { inner: Arc<ProviderInner> },
// The provider has failed.
Failed,
}
impl HostBackendDeviceProvider {
- pub fn new() -> Result<(UsbControlSocket, HostBackendDeviceProvider)> {
- let (child_sock, control_sock) = UnixSeqpacket::pair().map_err(Error::CreateControlSock)?;
- control_sock
- .set_write_timeout(Some(Duration::from_millis(SOCKET_TIMEOUT_MS)))
- .map_err(Error::SetupControlSock)?;
- control_sock
- .set_read_timeout(Some(Duration::from_millis(SOCKET_TIMEOUT_MS)))
- .map_err(Error::SetupControlSock)?;
+ pub fn new() -> Result<(Tube, HostBackendDeviceProvider)> {
+ let (child_tube, control_tube) = Tube::pair().map_err(Error::CreateControlTube)?;
+ control_tube
+ .set_send_timeout(Some(Duration::from_millis(SOCKET_TIMEOUT_MS)))
+ .map_err(Error::SetupControlTube)?;
+ control_tube
+ .set_recv_timeout(Some(Duration::from_millis(SOCKET_TIMEOUT_MS)))
+ .map_err(Error::SetupControlTube)?;
let provider = HostBackendDeviceProvider::Created {
- sock: Mutex::new(MsgSocket::new(child_sock)),
+ control_tube: Mutex::new(child_tube),
};
- Ok((MsgSocket::new(control_sock), provider))
+ Ok((control_tube, provider))
}
fn start_helper(
@@ -66,20 +57,20 @@ impl HostBackendDeviceProvider {
hub: Arc<UsbHub>,
) -> Result<()> {
match mem::replace(self, HostBackendDeviceProvider::Failed) {
- HostBackendDeviceProvider::Created { sock } => {
+ HostBackendDeviceProvider::Created { control_tube } => {
let job_queue =
AsyncJobQueue::init(&event_loop).map_err(Error::StartAsyncJobQueue)?;
let inner = Arc::new(ProviderInner::new(
fail_handle,
job_queue,
event_loop.clone(),
- sock,
+ control_tube,
hub,
));
let handler: Arc<dyn EventHandler> = inner.clone();
event_loop
.add_event(
- &*inner.sock.lock(),
+ &*inner.control_tube.lock(),
WatchingEvents::empty().set_read(),
Arc::downgrade(&handler),
)
@@ -105,16 +96,15 @@ impl XhciBackendDeviceProvider for HostBackendDeviceProvider {
fail_handle: Arc<dyn FailHandle>,
event_loop: Arc<EventLoop>,
hub: Arc<UsbHub>,
- ) -> std::result::Result<(), ()> {
+ ) -> Result<()> {
self.start_helper(fail_handle, event_loop, hub)
- .map_err(|e| {
- error!("failed to start host backend device provider: {}", e);
- })
}
fn keep_rds(&self) -> Vec<RawDescriptor> {
match self {
- HostBackendDeviceProvider::Created { sock } => vec![sock.lock().as_raw_descriptor()],
+ HostBackendDeviceProvider::Created { control_tube } => {
+ vec![control_tube.lock().as_raw_descriptor()]
+ }
_ => {
error!(
"Trying to get keepfds when HostBackendDeviceProvider is not in created state"
@@ -130,7 +120,7 @@ pub struct ProviderInner {
fail_handle: Arc<dyn FailHandle>,
job_queue: Arc<AsyncJobQueue>,
event_loop: Arc<EventLoop>,
- sock: Mutex<MsgSocket<UsbControlResult, UsbControlCommand>>,
+ control_tube: Mutex<Tube>,
usb_hub: Arc<UsbHub>,
// Map of USB hub port number to per-device context.
@@ -147,14 +137,14 @@ impl ProviderInner {
fail_handle: Arc<dyn FailHandle>,
job_queue: Arc<AsyncJobQueue>,
event_loop: Arc<EventLoop>,
- sock: Mutex<MsgSocket<UsbControlResult, UsbControlCommand>>,
+ control_tube: Mutex<Tube>,
usb_hub: Arc<UsbHub>,
) -> ProviderInner {
ProviderInner {
fail_handle,
job_queue,
event_loop,
- sock,
+ control_tube,
usb_hub,
devices: Mutex::new(HashMap::new()),
}
@@ -162,18 +152,8 @@ impl ProviderInner {
/// Open a usbdevfs file to create a host USB device object.
/// `fd` should be an open file descriptor for a file in `/dev/bus/usb`.
- fn handle_attach_device(&self, fd: Option<MaybeOwnedDescriptor>) -> UsbControlResult {
- let usb_file = match fd {
- Some(MaybeOwnedDescriptor::Owned(file)) => file,
- _ => {
- error!("missing fd in UsbControlCommand::AttachDevice message");
- return UsbControlResult::FailedToOpenDevice;
- }
- };
-
- let raw_descriptor = usb_file.into_raw_descriptor();
- // Safe as it is valid to have multiple variables accessing the same fd.
- let device = match Device::new(unsafe { File::from_raw_descriptor(raw_descriptor) }) {
+ fn handle_attach_device(&self, usb_file: File) -> UsbControlResult {
+ let device = match Device::new(usb_file) {
Ok(d) => d,
Err(e) => {
error!("could not construct USB device from fd: {}", e);
@@ -181,6 +161,8 @@ impl ProviderInner {
}
};
+ let device_descriptor = Descriptor(device.as_raw_descriptor());
+
let arc_mutex_device = Arc::new(Mutex::new(device));
let event_handler: Arc<dyn EventHandler> = Arc::new(UsbUtilEventHandler {
@@ -188,7 +170,7 @@ impl ProviderInner {
});
if let Err(e) = self.event_loop.add_event(
- &MaybeOwnedDescriptor::Borrowed(raw_descriptor),
+ &device_descriptor,
WatchingEvents::empty().set_read().set_write(),
Arc::downgrade(&event_handler),
) {
@@ -233,12 +215,7 @@ impl ProviderInner {
let device = device_ctx.device.lock();
let fd = device.fd();
- if let Err(e) =
- self.event_loop
- .remove_event_for_fd(&MaybeOwnedDescriptor::Borrowed(
- fd.as_raw_descriptor(),
- ))
- {
+ if let Err(e) = self.event_loop.remove_event_for_fd(&*fd) {
error!(
"failed to remove poll change handler from event loop: {}",
e
@@ -276,16 +253,14 @@ impl ProviderInner {
}
fn on_event_helper(&self) -> Result<()> {
- let sock = self.sock.lock();
- let cmd = sock.recv().map_err(Error::ReadControlSock)?;
+ let tube = self.control_tube.lock();
+ let cmd = tube.recv().map_err(Error::ReadControlTube)?;
let result = match cmd {
- UsbControlCommand::AttachDevice { descriptor, .. } => {
- self.handle_attach_device(descriptor)
- }
+ UsbControlCommand::AttachDevice { file, .. } => self.handle_attach_device(file),
UsbControlCommand::DetachDevice { port } => self.handle_detach_device(port),
UsbControlCommand::ListDevice { ports } => self.handle_list_devices(ports),
};
- sock.send(&result).map_err(Error::WriteControlSock)?;
+ tube.send(&result).map_err(Error::WriteControlTube)?;
Ok(())
}
}
diff --git a/devices/src/usb/host_backend/host_device.rs b/devices/src/usb/host_backend/host_device.rs
index b3ff14b78..6a49ee5ea 100644
--- a/devices/src/usb/host_backend/host_device.rs
+++ b/devices/src/usb/host_backend/host_device.rs
@@ -468,10 +468,8 @@ impl XhciBackendDevice for HostDevice {
}
}
- fn submit_transfer(&mut self, transfer: XhciTransfer) -> std::result::Result<(), ()> {
- self.submit_transfer_helper(transfer).map_err(|e| {
- error!("failed to submit transfer: {}", e);
- })
+ fn submit_transfer(&mut self, transfer: XhciTransfer) -> Result<()> {
+ self.submit_transfer_helper(transfer)
}
fn set_address(&mut self, _address: UsbDeviceAddress) {
@@ -483,11 +481,8 @@ impl XhciBackendDevice for HostDevice {
);
}
- fn reset(&mut self) -> std::result::Result<(), ()> {
+ fn reset(&mut self) -> Result<()> {
usb_debug!("resetting host device");
- self.device
- .lock()
- .reset()
- .map_err(|e| error!("failed to reset device: {:?}", e))
+ self.device.lock().reset().map_err(Error::Reset)
}
}
diff --git a/devices/src/usb/xhci/device_slot.rs b/devices/src/usb/xhci/device_slot.rs
index 5e12327d1..e24776339 100644
--- a/devices/src/usb/xhci/device_slot.rs
+++ b/devices/src/usb/xhci/device_slot.rs
@@ -12,6 +12,7 @@ use super::xhci_abi::{
};
use super::xhci_regs::{valid_slot_id, MAX_PORTS, MAX_SLOTS};
use crate::register_space::Register;
+use crate::usb::host_backend::error::Error as HostBackendProviderError;
use crate::usb::xhci::ring_buffer_stop_cb::{fallible_closure, RingBufferStopCallback};
use crate::utils::{EventLoop, FailHandle};
use base::error;
@@ -37,6 +38,7 @@ pub enum Error {
BadInputContextAddr(GuestAddress),
BadDeviceContextAddr(GuestAddress),
CreateTransferController(TransferRingControllerError),
+ ResetPort(HostBackendProviderError),
}
type Result<T> = std::result::Result<T, Error>;
@@ -58,6 +60,7 @@ impl Display for Error {
BadInputContextAddr(addr) => write!(f, "bad input context address: {}", addr),
BadDeviceContextAddr(addr) => write!(f, "bad device context: {}", addr),
CreateTransferController(e) => write!(f, "failed to create transfer controller: {}", e),
+ ResetPort(e) => write!(f, "failed to reset port: {}", e),
}
}
}
@@ -127,10 +130,10 @@ impl DeviceSlots {
}
/// Reset the device connected to a specific port.
- pub fn reset_port(&self, port_id: u8) -> std::result::Result<(), ()> {
+ pub fn reset_port(&self, port_id: u8) -> Result<()> {
if let Some(port) = self.hub.get_port(port_id) {
if let Some(backend_device) = port.get_backend_device().as_mut() {
- backend_device.reset()?;
+ backend_device.reset().map_err(Error::ResetPort)?;
}
}
diff --git a/devices/src/usb/xhci/xhci.rs b/devices/src/usb/xhci/xhci.rs
index e2035980f..85f6de8f0 100644
--- a/devices/src/usb/xhci/xhci.rs
+++ b/devices/src/usb/xhci/xhci.rs
@@ -10,7 +10,10 @@ use super::ring_buffer_stop_cb::RingBufferStopCallback;
use super::usb_hub::UsbHub;
use super::xhci_backend_device_provider::XhciBackendDeviceProvider;
use super::xhci_regs::*;
-use crate::usb::host_backend::host_backend_device_provider::HostBackendDeviceProvider;
+use crate::usb::host_backend::{
+ error::Error as HostBackendProviderError,
+ host_backend_device_provider::HostBackendDeviceProvider,
+};
use crate::utils::{Error as UtilsError, EventLoop, FailHandle};
use base::{error, Event};
use std::fmt::{self, Display};
@@ -29,7 +32,7 @@ pub enum Error {
SetModeration(InterrupterError),
SetupEventRing(InterrupterError),
SetEventHandlerBusy(InterrupterError),
- StartProvider,
+ StartProvider(HostBackendProviderError),
RingDoorbell(DeviceSlotError),
CreateCommandRingController(CommandRingControllerError),
ResetPort,
@@ -50,7 +53,7 @@ impl Display for Error {
SetModeration(e) => write!(f, "failed to set interrupter moderation: {}", e),
SetupEventRing(e) => write!(f, "failed to setup event ring: {}", e),
SetEventHandlerBusy(e) => write!(f, "failed to set event handler busy: {}", e),
- StartProvider => write!(f, "failed to start backend provider"),
+ StartProvider(e) => write!(f, "failed to start backend provider: {}", e),
RingDoorbell(e) => write!(f, "failed to ring doorbell: {}", e),
CreateCommandRingController(e) => {
write!(f, "failed to create command ring controller: {}", e)
@@ -101,7 +104,7 @@ impl Xhci {
let mut device_provider = device_provider;
device_provider
.start(fail_handle.clone(), event_loop.clone(), hub.clone())
- .map_err(|_| Error::StartProvider)?;
+ .map_err(Error::StartProvider)?;
let device_slots = DeviceSlots::new(
fail_handle.clone(),
@@ -148,8 +151,7 @@ impl Xhci {
let xhci_weak = Arc::downgrade(xhci);
xhci.regs.crcr.set_write_cb(move |val: u64| {
let xhci = xhci_weak.upgrade().unwrap();
- let r = xhci.crcr_callback(val);
- xhci.handle_register_callback_result(r, 0)
+ xhci.crcr_callback(val)
});
for i in 0..xhci.regs.portsc.len() {
@@ -227,7 +229,7 @@ impl Xhci {
fn usbcmd_callback(&self, value: u32) -> Result<u32> {
if (value & USB_CMD_RESET) > 0 {
usb_debug!("xhci_controller: reset controller");
- self.reset()?;
+ self.reset();
return Ok(value & (!USB_CMD_RESET));
}
@@ -236,7 +238,7 @@ impl Xhci {
self.regs.usbsts.clear_bits(USB_STS_HALTED);
} else {
usb_debug!("xhci_controller: halt device");
- self.halt()?;
+ self.halt();
self.regs.crcr.clear_bits(CRCR_COMMAND_RING_RUNNING);
}
@@ -252,9 +254,9 @@ impl Xhci {
}
// Callback for crcr register write.
- fn crcr_callback(&self, value: u64) -> Result<u64> {
+ fn crcr_callback(&self, value: u64) -> u64 {
usb_debug!("xhci_controller: write to crcr {:x}", value);
- let value = if (self.regs.crcr.get_value() & CRCR_COMMAND_RING_RUNNING) == 0 {
+ if (self.regs.crcr.get_value() & CRCR_COMMAND_RING_RUNNING) == 0 {
self.command_ring_controller
.set_dequeue_pointer(GuestAddress(value & CRCR_COMMAND_RING_POINTER));
self.command_ring_controller
@@ -263,8 +265,7 @@ impl Xhci {
} else {
error!("Write to crcr while command ring is running");
self.regs.crcr.get_value()
- };
- Ok(value)
+ }
}
// Callback for portsc register write.
@@ -378,22 +379,20 @@ impl Xhci {
.map_err(Error::SetEventHandlerBusy)
}
- fn reset(&self) -> Result<()> {
+ fn reset(&self) {
self.regs.usbsts.set_bits(USB_STS_CONTROLLER_NOT_READY);
let usbsts = self.regs.usbsts.clone();
self.device_slots.stop_all_and_reset(move || {
usbsts.clear_bits(USB_STS_CONTROLLER_NOT_READY);
});
- Ok(())
}
- fn halt(&self) -> Result<()> {
+ fn halt(&self) {
let usbsts = self.regs.usbsts.clone();
self.device_slots
.stop_all(RingBufferStopCallback::new(move || {
usbsts.set_bits(USB_STS_HALTED);
}));
- Ok(())
}
}
diff --git a/devices/src/usb/xhci/xhci_backend_device.rs b/devices/src/usb/xhci/xhci_backend_device.rs
index a3d9e66c5..2b0dedb0f 100644
--- a/devices/src/usb/xhci/xhci_backend_device.rs
+++ b/devices/src/usb/xhci/xhci_backend_device.rs
@@ -3,6 +3,7 @@
// found in the LICENSE file.
use super::xhci_transfer::XhciTransfer;
+use crate::usb::host_backend::error::Result;
/// Address of this usb device, as in Set Address standard usb device request.
pub type UsbDeviceAddress = u32;
@@ -23,9 +24,9 @@ pub trait XhciBackendDevice: Send {
/// Get product id of this device.
fn get_pid(&self) -> u16;
/// Submit a xhci transfer to backend.
- fn submit_transfer(&mut self, transfer: XhciTransfer) -> std::result::Result<(), ()>;
+ fn submit_transfer(&mut self, transfer: XhciTransfer) -> Result<()>;
/// Set address of this backend.
fn set_address(&mut self, address: UsbDeviceAddress);
/// Reset the backend device.
- fn reset(&mut self) -> std::result::Result<(), ()>;
+ fn reset(&mut self) -> Result<()>;
}
diff --git a/devices/src/usb/xhci/xhci_backend_device_provider.rs b/devices/src/usb/xhci/xhci_backend_device_provider.rs
index aefcf5393..ca416cc02 100644
--- a/devices/src/usb/xhci/xhci_backend_device_provider.rs
+++ b/devices/src/usb/xhci/xhci_backend_device_provider.rs
@@ -2,6 +2,7 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
+use super::super::host_backend::error::Result;
use super::usb_hub::UsbHub;
use crate::utils::{EventLoop, FailHandle};
use base::RawDescriptor;
@@ -15,7 +16,7 @@ pub trait XhciBackendDeviceProvider: Send {
fail_handle: Arc<dyn FailHandle>,
event_loop: Arc<EventLoop>,
hub: Arc<UsbHub>,
- ) -> std::result::Result<(), ()>;
+ ) -> Result<()>;
/// Keep raw descriptors that should be kept open.
fn keep_rds(&self) -> Vec<RawDescriptor>;
diff --git a/devices/src/usb/xhci/xhci_controller.rs b/devices/src/usb/xhci/xhci_controller.rs
index 85c9d3240..0889b356a 100644
--- a/devices/src/usb/xhci/xhci_controller.rs
+++ b/devices/src/usb/xhci/xhci_controller.rs
@@ -111,6 +111,7 @@ impl XhciController {
PciHeaderType::Device,
0,
0,
+ 0,
);
XhciController {
config_regs,
diff --git a/devices/src/vfio.rs b/devices/src/vfio.rs
index fcc306532..9fda0d65d 100644
--- a/devices/src/vfio.rs
+++ b/devices/src/vfio.rs
@@ -174,7 +174,7 @@ impl VfioContainer {
// Add all guest memory regions into vfio container's iommu table,
// then vfio kernel driver could access guest memory from gfn
- guest_mem.with_regions(|_index, guest_addr, size, host_addr, _fd_offset| {
+ guest_mem.with_regions(|_index, guest_addr, size, host_addr, _mmap, _fd_offset| {
// Safe because the guest regions are guaranteed not to overlap
unsafe { self.vfio_dma_map(guest_addr.0, size as u64, host_addr as u64) }
})?;
@@ -391,21 +391,13 @@ impl VfioDevice {
/// Enable vfio device's irq and associate Irqfd Event with device.
/// When MSIx is enabled, multi vectors will be supported, so descriptors is vector and the vector
/// length is the num of MSIx vectors
- pub fn irq_enable(
- &self,
- descriptors: Vec<&Event>,
- irq_type: VfioIrqType,
- ) -> Result<(), VfioError> {
+ pub fn irq_enable(&self, descriptors: Vec<&Event>, index: u32) -> Result<(), VfioError> {
let count = descriptors.len();
let u32_size = mem::size_of::<u32>();
let mut irq_set = vec_with_array_field::<vfio_irq_set, u32>(count);
irq_set[0].argsz = (mem::size_of::<vfio_irq_set>() + count * u32_size) as u32;
irq_set[0].flags = VFIO_IRQ_SET_DATA_EVENTFD | VFIO_IRQ_SET_ACTION_TRIGGER;
- match irq_type {
- VfioIrqType::Intx => irq_set[0].index = VFIO_PCI_INTX_IRQ_INDEX,
- VfioIrqType::Msi => irq_set[0].index = VFIO_PCI_MSI_IRQ_INDEX,
- VfioIrqType::Msix => irq_set[0].index = VFIO_PCI_MSIX_IRQ_INDEX,
- }
+ irq_set[0].index = index;
irq_set[0].start = 0;
irq_set[0].count = count as u32;
@@ -421,7 +413,7 @@ impl VfioDevice {
}
// Safe as we are the owner of self and irq_set which are valid value
- let ret = unsafe { ioctl_with_ref(self, VFIO_DEVICE_SET_IRQS(), &irq_set[0]) };
+ let ret = unsafe { ioctl_with_ref(&self.dev, VFIO_DEVICE_SET_IRQS(), &irq_set[0]) };
if ret < 0 {
Err(VfioError::VfioIrqEnable(get_error()))
} else {
@@ -438,17 +430,17 @@ impl VfioDevice {
/// This function enable resample irqfd and let vfio kernel could get EOI notification.
///
/// descriptor: should be resample IrqFd.
- pub fn resample_virq_enable(&self, descriptor: &Event) -> Result<(), VfioError> {
+ pub fn resample_virq_enable(&self, descriptor: &Event, index: u32) -> Result<(), VfioError> {
let mut irq_set = vec_with_array_field::<vfio_irq_set, u32>(1);
irq_set[0].argsz = (mem::size_of::<vfio_irq_set>() + mem::size_of::<u32>()) as u32;
irq_set[0].flags = VFIO_IRQ_SET_DATA_EVENTFD | VFIO_IRQ_SET_ACTION_UNMASK;
- irq_set[0].index = VFIO_PCI_INTX_IRQ_INDEX;
+ irq_set[0].index = index;
irq_set[0].start = 0;
irq_set[0].count = 1;
{
- // irq_set.data could be none, bool or descriptor according to flags, so irq_set.data
- // is u8 default, here irq_set.data is descriptor as u32, so 4 default u8 are combined
+ // irq_set.data could be none, bool or descriptor according to flags, so irq_set.data is
+ // u8 default, here irq_set.data is descriptor as u32, so 4 default u8 are combined
// together as u32. It is safe as enough space is reserved through
// vec_with_array_field(u32)<1>.
let descriptors = unsafe { irq_set[0].data.as_mut_slice(4) };
@@ -456,7 +448,7 @@ impl VfioDevice {
}
// Safe as we are the owner of self and irq_set which are valid value
- let ret = unsafe { ioctl_with_ref(self, VFIO_DEVICE_SET_IRQS(), &irq_set[0]) };
+ let ret = unsafe { ioctl_with_ref(&self.dev, VFIO_DEVICE_SET_IRQS(), &irq_set[0]) };
if ret < 0 {
Err(VfioError::VfioIrqEnable(get_error()))
} else {
@@ -465,20 +457,16 @@ impl VfioDevice {
}
/// disable vfio device's irq and disconnect Irqfd Event with device
- pub fn irq_disable(&self, irq_type: VfioIrqType) -> Result<(), VfioError> {
+ pub fn irq_disable(&self, index: u32) -> Result<(), VfioError> {
let mut irq_set = vec_with_array_field::<vfio_irq_set, u32>(0);
irq_set[0].argsz = mem::size_of::<vfio_irq_set>() as u32;
irq_set[0].flags = VFIO_IRQ_SET_DATA_NONE | VFIO_IRQ_SET_ACTION_TRIGGER;
- match irq_type {
- VfioIrqType::Intx => irq_set[0].index = VFIO_PCI_INTX_IRQ_INDEX,
- VfioIrqType::Msi => irq_set[0].index = VFIO_PCI_MSI_IRQ_INDEX,
- VfioIrqType::Msix => irq_set[0].index = VFIO_PCI_MSIX_IRQ_INDEX,
- }
+ irq_set[0].index = index;
irq_set[0].start = 0;
irq_set[0].count = 0;
// Safe as we are the owner of self and irq_set which are valid value
- let ret = unsafe { ioctl_with_ref(self, VFIO_DEVICE_SET_IRQS(), &irq_set[0]) };
+ let ret = unsafe { ioctl_with_ref(&self.dev, VFIO_DEVICE_SET_IRQS(), &irq_set[0]) };
if ret < 0 {
Err(VfioError::VfioIrqDisable(get_error()))
} else {
@@ -487,20 +475,16 @@ impl VfioDevice {
}
/// Unmask vfio device irq
- pub fn irq_unmask(&self, irq_type: VfioIrqType) -> Result<(), VfioError> {
+ pub fn irq_unmask(&self, index: u32) -> Result<(), VfioError> {
let mut irq_set = vec_with_array_field::<vfio_irq_set, u32>(0);
irq_set[0].argsz = mem::size_of::<vfio_irq_set>() as u32;
irq_set[0].flags = VFIO_IRQ_SET_DATA_NONE | VFIO_IRQ_SET_ACTION_UNMASK;
- match irq_type {
- VfioIrqType::Intx => irq_set[0].index = VFIO_PCI_INTX_IRQ_INDEX,
- VfioIrqType::Msi => irq_set[0].index = VFIO_PCI_MSI_IRQ_INDEX,
- VfioIrqType::Msix => irq_set[0].index = VFIO_PCI_MSIX_IRQ_INDEX,
- }
+ irq_set[0].index = index;
irq_set[0].start = 0;
irq_set[0].count = 1;
// Safe as we are the owner of self and irq_set which are valid value
- let ret = unsafe { ioctl_with_ref(self, VFIO_DEVICE_SET_IRQS(), &irq_set[0]) };
+ let ret = unsafe { ioctl_with_ref(&self.dev, VFIO_DEVICE_SET_IRQS(), &irq_set[0]) };
if ret < 0 {
Err(VfioError::VfioIrqUnmask(get_error()))
} else {
@@ -509,20 +493,16 @@ impl VfioDevice {
}
/// Mask vfio device irq
- pub fn irq_mask(&self, irq_type: VfioIrqType) -> Result<(), VfioError> {
+ pub fn irq_mask(&self, index: u32) -> Result<(), VfioError> {
let mut irq_set = vec_with_array_field::<vfio_irq_set, u32>(0);
irq_set[0].argsz = mem::size_of::<vfio_irq_set>() as u32;
irq_set[0].flags = VFIO_IRQ_SET_DATA_NONE | VFIO_IRQ_SET_ACTION_MASK;
- match irq_type {
- VfioIrqType::Intx => irq_set[0].index = VFIO_PCI_INTX_IRQ_INDEX,
- VfioIrqType::Msi => irq_set[0].index = VFIO_PCI_MSI_IRQ_INDEX,
- VfioIrqType::Msix => irq_set[0].index = VFIO_PCI_MSIX_IRQ_INDEX,
- }
+ irq_set[0].index = index;
irq_set[0].start = 0;
irq_set[0].count = 1;
// Safe as we are the owner of self and irq_set which are valid value
- let ret = unsafe { ioctl_with_ref(self, VFIO_DEVICE_SET_IRQS(), &irq_set[0]) };
+ let ret = unsafe { ioctl_with_ref(&self.dev, VFIO_DEVICE_SET_IRQS(), &irq_set[0]) };
if ret < 0 {
Err(VfioError::VfioIrqMask(get_error()))
} else {
@@ -802,7 +782,7 @@ impl VfioDevice {
/// get vfio device's descriptors which are passed into minijail process
pub fn keep_rds(&self) -> Vec<RawDescriptor> {
let mut rds = Vec::new();
- rds.push(self.as_raw_descriptor());
+ rds.push(self.dev.as_raw_descriptor());
rds.push(self.group_descriptor);
rds.push(self.container.lock().as_raw_descriptor());
rds
@@ -822,10 +802,9 @@ impl VfioDevice {
pub fn vfio_dma_unmap(&self, iova: u64, size: u64) -> Result<(), VfioError> {
self.container.lock().vfio_dma_unmap(iova, size)
}
-}
-impl AsRawDescriptor for VfioDevice {
- fn as_raw_descriptor(&self) -> RawDescriptor {
- self.dev.as_raw_descriptor()
+ /// Gets the vfio device backing `File`.
+ pub fn device_file(&self) -> &File {
+ &self.dev
}
}
diff --git a/devices/src/virtio/balloon.rs b/devices/src/virtio/balloon.rs
index 05962e92f..a76765d47 100644
--- a/devices/src/virtio/balloon.rs
+++ b/devices/src/virtio/balloon.rs
@@ -12,18 +12,15 @@ use futures::{channel::mpsc, pin_mut, StreamExt};
use remain::sorted;
use thiserror::Error as ThisError;
-use base::{self, error, info, warn, AsRawDescriptor, Event, RawDescriptor};
+use base::{self, error, info, warn, AsRawDescriptor, AsyncTube, Event, RawDescriptor, Tube};
use cros_async::{select6, EventAsync, Executor};
use data_model::{DataInit, Le16, Le32, Le64};
-use msg_socket::MsgSender;
-use vm_control::{
- BalloonControlCommand, BalloonControlResponseSocket, BalloonControlResult, BalloonStats,
-};
+use vm_control::{BalloonControlCommand, BalloonControlResult, BalloonStats};
use vm_memory::{GuestAddress, GuestMemory};
use super::{
- copy_config, descriptor_utils, DescriptorChain, Interrupt, Queue, Reader, VirtioDevice,
- TYPE_BALLOON,
+ copy_config, descriptor_utils, DescriptorChain, Interrupt, Queue, Reader, SignalableInterrupt,
+ VirtioDevice, TYPE_BALLOON,
};
#[sorted]
@@ -31,10 +28,10 @@ use super::{
pub enum BalloonError {
/// Failed to create async message receiver.
#[error("failed to create async message receiver: {0}")]
- CreatingMessageReceiver(msg_socket::MsgError),
+ CreatingMessageReceiver(base::TubeError),
/// Failed to receive command message.
#[error("failed to receive command message: {0}")]
- ReceivingCommand(msg_socket::MsgError),
+ ReceivingCommand(base::TubeError),
/// Failed to write config event.
#[error("failed to write config event: {0}")]
WritingConfigEvent(base::Error),
@@ -194,7 +191,7 @@ async fn handle_stats_queue(
mut queue: Queue,
mut queue_event: EventAsync,
mut stats_rx: mpsc::Receiver<()>,
- command_socket: &BalloonControlResponseSocket,
+ command_tube: &Tube,
config: Arc<BalloonConfig>,
interrupt: Rc<RefCell<Interrupt>>,
) {
@@ -229,13 +226,13 @@ async fn handle_stats_queue(
balloon_actual: actual_pages << VIRTIO_BALLOON_PFN_SHIFT,
stats,
};
- if let Err(e) = command_socket.send(&result) {
+ if let Err(e) = command_tube.send(&result) {
error!("failed to send stats result: {}", e);
}
// Wait for a request to read the stats again.
if stats_rx.next().await.is_none() {
- error!("stats signal channel was closed");
+ error!("stats signal tube was closed");
break;
}
@@ -247,18 +244,14 @@ async fn handle_stats_queue(
// Async task that handles the command socket. The command socket handles messages from the host
// requesting that the guest balloon be adjusted or to report guest memory statistics.
-async fn handle_command_socket(
- ex: &Executor,
- command_socket: &BalloonControlResponseSocket,
+async fn handle_command_tube(
+ command_tube: &AsyncTube,
interrupt: Rc<RefCell<Interrupt>>,
config: Arc<BalloonConfig>,
mut stats_tx: mpsc::Sender<()>,
) -> Result<()> {
- let mut async_messages = command_socket
- .async_receiver(ex)
- .map_err(BalloonError::CreatingMessageReceiver)?;
loop {
- match async_messages.next().await {
+ match command_tube.next().await {
Ok(command) => match command {
BalloonControlCommand::Adjust { num_bytes } => {
let num_pages = (num_bytes >> VIRTIO_BALLOON_PFN_SHIFT) as usize;
@@ -283,14 +276,20 @@ async fn handle_command_socket(
// Async task that resamples the status of the interrupt when the guest sends a request by
// signalling the resample event associated with the interrupt.
async fn handle_irq_resample(ex: &Executor, interrupt: Rc<RefCell<Interrupt>>) {
- let resample_evt = interrupt
- .borrow_mut()
- .get_resample_evt()
- .try_clone()
- .unwrap();
- let resample_evt = EventAsync::new(resample_evt.0, ex).unwrap();
- while resample_evt.next_val().await.is_ok() {
- interrupt.borrow_mut().do_interrupt_resample();
+ let resample_evt = if let Some(resample_evt) = interrupt.borrow_mut().get_resample_evt() {
+ let resample_evt = resample_evt.try_clone().unwrap();
+ let resample_evt = EventAsync::new(resample_evt.0, ex).unwrap();
+ Some(resample_evt)
+ } else {
+ None
+ };
+ if let Some(resample_evt) = resample_evt {
+ while resample_evt.next_val().await.is_ok() {
+ interrupt.borrow_mut().do_interrupt_resample();
+ }
+ } else {
+ // no resample event, park the future.
+ let () = futures::future::pending().await;
}
}
@@ -306,95 +305,98 @@ async fn wait_kill(kill_evt: EventAsync) {
fn run_worker(
mut queue_evts: Vec<Event>,
mut queues: Vec<Queue>,
- command_socket: &BalloonControlResponseSocket,
+ command_tube: Tube,
interrupt: Interrupt,
kill_evt: Event,
mem: GuestMemory,
config: Arc<BalloonConfig>,
-) {
+) -> Tube {
// Wrap the interrupt in a `RefCell` so it can be shared between async functions.
let interrupt = Rc::new(RefCell::new(interrupt));
let ex = Executor::new().unwrap();
+ let command_tube = command_tube.into_async_tube(&ex).unwrap();
+
+ // We need a block to release all references to command_tube at the end before returning it.
+ {
+ // The first queue is used for inflate messages
+ let inflate_event = EventAsync::new(queue_evts.remove(0).0, &ex)
+ .expect("failed to set up the inflate event");
+ let inflate = handle_queue(
+ &mem,
+ queues.remove(0),
+ inflate_event,
+ interrupt.clone(),
+ |guest_address, len| {
+ if let Err(e) = mem.remove_range(guest_address, len) {
+ warn!("Marking pages unused failed: {}, addr={}", e, guest_address);
+ }
+ },
+ );
+ pin_mut!(inflate);
+
+ // The second queue is used for deflate messages
+ let deflate_event = EventAsync::new(queue_evts.remove(0).0, &ex)
+ .expect("failed to set up the deflate event");
+ let deflate = handle_queue(
+ &mem,
+ queues.remove(0),
+ deflate_event,
+ interrupt.clone(),
+ |_, _| {}, // Ignore these.
+ );
+ pin_mut!(deflate);
+
+ // The third queue is used for stats messages
+ let (stats_tx, stats_rx) = mpsc::channel::<()>(1);
+ let stats_event =
+ EventAsync::new(queue_evts.remove(0).0, &ex).expect("failed to set up the stats event");
+ let stats = handle_stats_queue(
+ &mem,
+ queues.remove(0),
+ stats_event,
+ stats_rx,
+ &command_tube,
+ config.clone(),
+ interrupt.clone(),
+ );
+ pin_mut!(stats);
- // The first queue is used for inflate messages
- let inflate_event =
- EventAsync::new(queue_evts.remove(0).0, &ex).expect("failed to set up the inflate event");
- let inflate = handle_queue(
- &mem,
- queues.remove(0),
- inflate_event,
- interrupt.clone(),
- |guest_address, len| {
- if let Err(e) = mem.remove_range(guest_address, len) {
- warn!("Marking pages unused failed: {}, addr={}", e, guest_address);
- }
- },
- );
- pin_mut!(inflate);
-
- // The second queue is used for deflate messages
- let deflate_event =
- EventAsync::new(queue_evts.remove(0).0, &ex).expect("failed to set up the deflate event");
- let deflate = handle_queue(
- &mem,
- queues.remove(0),
- deflate_event,
- interrupt.clone(),
- |_, _| {}, // Ignore these.
- );
- pin_mut!(deflate);
-
- // The third queue is used for stats messages
- let (stats_tx, stats_rx) = mpsc::channel::<()>(1);
- let stats_event =
- EventAsync::new(queue_evts.remove(0).0, &ex).expect("failed to set up the stats event");
- let stats = handle_stats_queue(
- &mem,
- queues.remove(0),
- stats_event,
- stats_rx,
- command_socket,
- config.clone(),
- interrupt.clone(),
- );
- pin_mut!(stats);
-
- // Future to handle command messages that resize the balloon.
- let command = handle_command_socket(&ex, command_socket, interrupt.clone(), config, stats_tx);
- pin_mut!(command);
-
- // Process any requests to resample the irq value.
- let resample = handle_irq_resample(&ex, interrupt.clone());
- pin_mut!(resample);
-
- // Exit if the kill event is triggered.
- let kill_evt = EventAsync::new(kill_evt.0, &ex).expect("failed to set up the kill event");
- let kill = wait_kill(kill_evt);
- pin_mut!(kill);
-
- if let Err(e) = ex.run_until(select6(inflate, deflate, stats, command, resample, kill)) {
- error!("error happened in executor: {}", e);
+ // Future to handle command messages that resize the balloon.
+ let command = handle_command_tube(&command_tube, interrupt.clone(), config, stats_tx);
+ pin_mut!(command);
+
+ // Process any requests to resample the irq value.
+ let resample = handle_irq_resample(&ex, interrupt.clone());
+ pin_mut!(resample);
+
+ // Exit if the kill event is triggered.
+ let kill_evt = EventAsync::new(kill_evt.0, &ex).expect("failed to set up the kill event");
+ let kill = wait_kill(kill_evt);
+ pin_mut!(kill);
+
+ if let Err(e) = ex.run_until(select6(inflate, deflate, stats, command, resample, kill)) {
+ error!("error happened in executor: {}", e);
+ }
}
+
+ command_tube.into()
}
/// Virtio device for memory balloon inflation/deflation.
pub struct Balloon {
- command_socket: Option<BalloonControlResponseSocket>,
+ command_tube: Option<Tube>,
config: Arc<BalloonConfig>,
features: u64,
kill_evt: Option<Event>,
- worker_thread: Option<thread::JoinHandle<BalloonControlResponseSocket>>,
+ worker_thread: Option<thread::JoinHandle<Tube>>,
}
impl Balloon {
/// Creates a new virtio balloon device.
- pub fn new(
- base_features: u64,
- command_socket: BalloonControlResponseSocket,
- ) -> Result<Balloon> {
+ pub fn new(base_features: u64, command_tube: Tube) -> Result<Balloon> {
Ok(Balloon {
- command_socket: Some(command_socket),
+ command_tube: Some(command_tube),
config: Arc::new(BalloonConfig {
num_pages: AtomicUsize::new(0),
actual_pages: AtomicUsize::new(0),
@@ -433,7 +435,7 @@ impl Drop for Balloon {
impl VirtioDevice for Balloon {
fn keep_rds(&self) -> Vec<RawDescriptor> {
- vec![self.command_socket.as_ref().unwrap().as_raw_descriptor()]
+ vec![self.command_tube.as_ref().unwrap().as_raw_descriptor()]
}
fn device_type(&self) -> u32 {
@@ -485,20 +487,19 @@ impl VirtioDevice for Balloon {
self.kill_evt = Some(self_kill_evt);
let config = self.config.clone();
- let command_socket = self.command_socket.take().unwrap();
+ let command_tube = self.command_tube.take().unwrap();
let worker_result = thread::Builder::new()
.name("virtio_balloon".to_string())
.spawn(move || {
run_worker(
queue_evts,
queues,
- &command_socket,
+ command_tube,
interrupt,
kill_evt,
mem,
config,
- );
- command_socket // Return the command socket so it can be re-used.
+ )
});
match worker_result {
@@ -525,8 +526,8 @@ impl VirtioDevice for Balloon {
error!("{}: failed to get back resources", self.debug_label());
return false;
}
- Ok(command_socket) => {
- self.command_socket = Some(command_socket);
+ Ok(command_tube) => {
+ self.command_tube = Some(command_tube);
return true;
}
}
diff --git a/devices/src/virtio/block.rs b/devices/src/virtio/block.rs
index eba3b0126..e9a202427 100644
--- a/devices/src/virtio/block.rs
+++ b/devices/src/virtio/block.rs
@@ -15,19 +15,19 @@ use std::u32;
use base::Error as SysError;
use base::Result as SysResult;
use base::{
- error, info, iov_max, warn, AsRawDescriptor, Event, PollToken, RawDescriptor, Timer,
+ error, info, iov_max, warn, AsRawDescriptor, Event, PollToken, RawDescriptor, Timer, Tube,
WaitContext,
};
use data_model::{DataInit, Le16, Le32, Le64};
use disk::DiskFile;
-use msg_socket::{MsgReceiver, MsgSender};
+
use sync::Mutex;
-use vm_control::{DiskControlCommand, DiskControlResponseSocket, DiskControlResult};
+use vm_control::{DiskControlCommand, DiskControlResult};
use vm_memory::GuestMemory;
use super::{
- copy_config, DescriptorChain, DescriptorError, Interrupt, Queue, Reader, VirtioDevice, Writer,
- TYPE_BLOCK,
+ copy_config, DescriptorChain, DescriptorError, Interrupt, Queue, Reader, SignalableInterrupt,
+ VirtioDevice, Writer, TYPE_BLOCK,
};
const QUEUE_SIZE: u16 = 256;
@@ -257,7 +257,7 @@ struct Worker {
read_only: bool,
sparse: bool,
id: Option<BlockId>,
- control_socket: Option<DiskControlResponseSocket>,
+ control_tube: Option<Tube>,
}
impl Worker {
@@ -397,12 +397,17 @@ impl Worker {
let wait_ctx: WaitContext<Token> = match WaitContext::build_with(&[
(&flush_timer, Token::FlushTimer),
(&queue_evt, Token::QueueAvailable),
- (self.interrupt.get_resample_evt(), Token::InterruptResample),
(&kill_evt, Token::Kill),
])
+ .and_then(|wc| {
+ if let Some(resample_evt) = self.interrupt.get_resample_evt() {
+ wc.add(resample_evt, Token::InterruptResample)?;
+ }
+ Ok(wc)
+ })
.and_then(|pc| {
- if let Some(control_socket) = self.control_socket.as_ref() {
- pc.add(control_socket, Token::ControlRequest)?
+ if let Some(control_tube) = self.control_tube.as_ref() {
+ pc.add(control_tube, Token::ControlRequest)?
}
Ok(pc)
}) {
@@ -443,14 +448,14 @@ impl Worker {
self.process_queue(0, &mut flush_timer, &mut flush_timer_armed);
}
Token::ControlRequest => {
- let control_socket = match self.control_socket.as_ref() {
+ let control_tube = match self.control_tube.as_ref() {
Some(cs) => cs,
None => {
error!("received control socket request with no control socket");
break 'wait;
}
};
- let req = match control_socket.recv() {
+ let req = match control_tube.recv() {
Ok(req) => req,
Err(e) => {
error!("control socket failed recv: {}", e);
@@ -468,8 +473,8 @@ impl Worker {
}
};
- // We already know there is Some control_socket used to recv a request.
- if let Err(e) = self.control_socket.as_ref().unwrap().send(&resp) {
+ // We already know there is Some control_tube used to recv a request.
+ if let Err(e) = self.control_tube.as_ref().unwrap().send(&resp) {
error!("control socket failed send: {}", e);
break 'wait;
}
@@ -499,7 +504,7 @@ pub struct Block {
seg_max: u32,
block_size: u32,
id: Option<BlockId>,
- control_socket: Option<DiskControlResponseSocket>,
+ control_tube: Option<Tube>,
}
fn build_config_space(disk_size: u64, seg_max: u32, block_size: u32) -> virtio_blk_config {
@@ -527,7 +532,7 @@ impl Block {
sparse: bool,
block_size: u32,
id: Option<BlockId>,
- control_socket: Option<DiskControlResponseSocket>,
+ control_tube: Option<Tube>,
) -> SysResult<Block> {
if block_size % SECTOR_SIZE as u32 != 0 {
error!(
@@ -576,7 +581,7 @@ impl Block {
seg_max,
block_size,
id,
- control_socket,
+ control_tube,
})
}
@@ -751,8 +756,8 @@ impl VirtioDevice for Block {
keep_rds.extend(disk_image.as_raw_descriptors());
}
- if let Some(control_socket) = &self.control_socket {
- keep_rds.push(control_socket.as_raw_descriptor());
+ if let Some(control_tube) = &self.control_tube {
+ keep_rds.push(control_tube.as_raw_descriptor());
}
keep_rds
@@ -803,7 +808,7 @@ impl VirtioDevice for Block {
let disk_size = self.disk_size.clone();
let id = self.id.take();
if let Some(disk_image) = self.disk_image.take() {
- let control_socket = self.control_socket.take();
+ let control_tube = self.control_tube.take();
let worker_result =
thread::Builder::new()
.name("virtio_blk".to_string())
@@ -817,7 +822,7 @@ impl VirtioDevice for Block {
read_only,
sparse,
id,
- control_socket,
+ control_tube,
};
worker.run(queue_evts.remove(0), kill_evt);
worker
@@ -851,7 +856,7 @@ impl VirtioDevice for Block {
}
Ok(worker) => {
self.disk_image = Some(worker.disk_image);
- self.control_socket = worker.control_socket;
+ self.control_tube = worker.control_tube;
return true;
}
}
diff --git a/devices/src/virtio/block_async.rs b/devices/src/virtio/block_async.rs
index 0ea2e0265..dc04227c7 100644
--- a/devices/src/virtio/block_async.rs
+++ b/devices/src/virtio/block_async.rs
@@ -15,23 +15,26 @@ use std::u32;
use futures::pin_mut;
use futures::stream::{FuturesUnordered, StreamExt};
-use libchromeos::sync::Mutex as AsyncMutex;
use remain::sorted;
use thiserror::Error as ThisError;
use base::Error as SysError;
use base::Result as SysResult;
-use base::{error, info, iov_max, warn, AsRawDescriptor, Event, RawDescriptor, Timer};
-use cros_async::{select5, AsyncError, EventAsync, Executor, SelectResult, TimerAsync};
+use base::{
+ error, info, iov_max, warn, AsRawDescriptor, AsyncTube, Event, RawDescriptor, Timer, Tube,
+ TubeError,
+};
+use cros_async::{
+ select5, sync::Mutex as AsyncMutex, AsyncError, EventAsync, Executor, SelectResult, TimerAsync,
+};
use data_model::{DataInit, Le16, Le32, Le64};
use disk::{AsyncDisk, ToAsyncDisk};
-use msg_socket::{MsgError, MsgSender};
-use vm_control::{DiskControlCommand, DiskControlResponseSocket, DiskControlResult};
+use vm_control::{DiskControlCommand, DiskControlResult};
use vm_memory::GuestMemory;
use super::{
- copy_config, DescriptorChain, DescriptorError, Interrupt, Queue, Reader, VirtioDevice, Writer,
- TYPE_BLOCK,
+ copy_config, DescriptorChain, DescriptorError, Interrupt, Queue, Reader, SignalableInterrupt,
+ VirtioDevice, Writer, TYPE_BLOCK,
};
const QUEUE_SIZE: u16 = 256;
@@ -48,9 +51,17 @@ const MAX_WRITE_ZEROES_SEG: u32 = 32;
// but this should probably be based on cluster size for qcow.
const DISCARD_SECTOR_ALIGNMENT: u32 = 128;
+const ID_LEN: usize = 20;
+
+/// Virtio block device identifier.
+/// This is an ASCII string terminated by a \0, unless all 20 bytes are used,
+/// in which case the \0 terminator is omitted.
+pub type BlockId = [u8; ID_LEN];
+
const VIRTIO_BLK_T_IN: u32 = 0;
const VIRTIO_BLK_T_OUT: u32 = 1;
const VIRTIO_BLK_T_FLUSH: u32 = 4;
+const VIRTIO_BLK_T_GET_ID: u32 = 8;
const VIRTIO_BLK_T_DISCARD: u32 = 11;
const VIRTIO_BLK_T_WRITE_ZEROES: u32 = 13;
@@ -91,7 +102,7 @@ unsafe impl DataInit for virtio_blk_topology {}
#[derive(Copy, Clone, Debug, Default)]
#[repr(C, packed)]
-struct virtio_blk_config {
+pub(crate) struct virtio_blk_config {
capacity: Le64,
size_max: Le32,
seg_max: Le32,
@@ -100,7 +111,7 @@ struct virtio_blk_config {
topology: virtio_blk_topology,
writeback: u8,
unused0: u8,
- num_queues: Le16,
+ pub num_queues: Le16,
max_discard_sectors: Le32,
max_discard_seg: Le32,
discard_sector_alignment: Le32,
@@ -140,8 +151,8 @@ unsafe impl DataInit for virtio_blk_discard_write_zeroes {}
#[sorted]
#[derive(ThisError, Debug)]
enum ExecuteError {
- #[error("couldn't create a message receiver: {0}")]
- CreatingMessageReceiver(MsgError),
+ #[error("failed to copy ID string: {0}")]
+ CopyId(io::Error),
#[error("virtio descriptor error: {0}")]
Descriptor(DescriptorError),
#[error("failed to perform discard or write zeroes; sector={sector} num_sectors={num_sectors} flags={flags}; {ioerr:?}")]
@@ -167,10 +178,10 @@ enum ExecuteError {
},
#[error("read only; request_type={request_type}")]
ReadOnly { request_type: u32 },
- #[error("failed to read command message: {0}")]
- ReceivingCommand(MsgError),
+ #[error("failed to recieve command message: {0}")]
+ ReceivingCommand(TubeError),
#[error("failed to send command response: {0}")]
- SendingResponse(MsgError),
+ SendingResponse(TubeError),
#[error("couldn't reset the timer: {0}")]
TimerReset(base::Error),
#[error("unsupported ({0})")]
@@ -188,7 +199,7 @@ enum ExecuteError {
impl ExecuteError {
fn status(&self) -> u8 {
match self {
- ExecuteError::CreatingMessageReceiver(_) => VIRTIO_BLK_S_IOERR,
+ ExecuteError::CopyId(_) => VIRTIO_BLK_S_IOERR,
ExecuteError::Descriptor(_) => VIRTIO_BLK_S_IOERR,
ExecuteError::DiscardWriteZeroes { .. } => VIRTIO_BLK_S_IOERR,
ExecuteError::Flush(_) => VIRTIO_BLK_S_IOERR,
@@ -227,6 +238,7 @@ struct DiskState {
disk_size: Arc<AtomicU64>,
read_only: bool,
sparse: bool,
+ id: Option<BlockId>,
}
async fn process_one_request(
@@ -294,7 +306,7 @@ async fn process_one_request_task(
let mut queue = queue.borrow_mut();
queue.add_used(&mem, descriptor_index, len as u32);
- queue.trigger_interrupt(&mem, &interrupt.borrow());
+ queue.trigger_interrupt(&mem, &*interrupt.borrow());
queue.update_int_required(&mem);
}
@@ -335,19 +347,28 @@ async fn handle_irq_resample(
ex: &Executor,
interrupt: Rc<RefCell<Interrupt>>,
) -> result::Result<(), OtherError> {
- let resample_evt = interrupt
- .borrow_mut()
- .get_resample_evt()
- .try_clone()
- .map_err(OtherError::CloneResampleEvent)?;
- let resample_evt =
- EventAsync::new(resample_evt.0, ex).map_err(OtherError::AsyncResampleCreate)?;
- loop {
- let _ = resample_evt
- .next_val()
- .await
- .map_err(OtherError::ReadResampleEvent)?;
- interrupt.borrow_mut().do_interrupt_resample();
+ let resample_evt = if let Some(resample_evt) = interrupt.borrow().get_resample_evt() {
+ let resample_evt = resample_evt
+ .try_clone()
+ .map_err(OtherError::CloneResampleEvent)?;
+ let resample_evt =
+ EventAsync::new(resample_evt.0, ex).map_err(OtherError::AsyncResampleCreate)?;
+ Some(resample_evt)
+ } else {
+ None
+ };
+ if let Some(resample_evt) = resample_evt {
+ loop {
+ let _ = resample_evt
+ .next_val()
+ .await
+ .map_err(OtherError::ReadResampleEvent)?;
+ interrupt.borrow().do_interrupt_resample();
+ }
+ } else {
+ // no resample event, park the future.
+ let () = futures::future::pending().await;
+ Ok(())
}
}
@@ -357,24 +378,20 @@ async fn wait_kill(kill_evt: EventAsync) {
let _ = kill_evt.next_val().await;
}
-async fn handle_command_socket(
- ex: &Executor,
- command_socket: &Option<DiskControlResponseSocket>,
+async fn handle_command_tube(
+ command_tube: &Option<AsyncTube>,
interrupt: Rc<RefCell<Interrupt>>,
disk_state: Rc<AsyncMutex<DiskState>>,
) -> Result<(), ExecuteError> {
- let command_socket = match command_socket {
+ let command_tube = match command_tube {
Some(c) => c,
None => {
let () = futures::future::pending().await;
return Ok(());
}
};
- let mut async_messages = command_socket
- .async_receiver(ex)
- .map_err(ExecuteError::CreatingMessageReceiver)?;
loop {
- match async_messages.next().await {
+ match command_tube.next().await {
Ok(command) => {
let resp = match command {
DiskControlCommand::Resize { new_size } => {
@@ -382,7 +399,7 @@ async fn handle_command_socket(
}
};
- command_socket
+ command_tube
.send(&resp)
.map_err(ExecuteError::SendingResponse)?;
if let DiskControlResult::Ok = resp {
@@ -462,7 +479,7 @@ fn run_worker(
queues: Vec<Queue>,
mem: GuestMemory,
disk_state: &Rc<AsyncMutex<DiskState>>,
- control_socket: &Option<DiskControlResponseSocket>,
+ control_tube: &Option<AsyncTube>,
queue_evts: Vec<Event>,
kill_evt: Event,
) -> Result<(), String> {
@@ -514,7 +531,7 @@ fn run_worker(
pin_mut!(disk_flush);
// Handles control requests.
- let control = handle_command_socket(&ex, control_socket, interrupt.clone(), disk_state.clone());
+ let control = handle_command_tube(control_tube, interrupt.clone(), disk_state.clone());
pin_mut!(control);
// Process any requests to resample the irq value.
@@ -546,8 +563,7 @@ fn run_worker(
/// Virtio device for exposing block level read/write operations on a host file.
pub struct BlockAsync {
kill_evt: Option<Event>,
- worker_thread:
- Option<thread::JoinHandle<(Box<dyn ToAsyncDisk>, Option<DiskControlResponseSocket>)>>,
+ worker_thread: Option<thread::JoinHandle<(Box<dyn ToAsyncDisk>, Option<Tube>)>>,
disk_image: Option<Box<dyn ToAsyncDisk>>,
disk_size: Arc<AtomicU64>,
avail_features: u64,
@@ -555,7 +571,8 @@ pub struct BlockAsync {
sparse: bool,
seg_max: u32,
block_size: u32,
- control_socket: Option<DiskControlResponseSocket>,
+ id: Option<BlockId>,
+ control_tube: Option<Tube>,
}
fn build_config_space(disk_size: u64, seg_max: u32, block_size: u32) -> virtio_blk_config {
@@ -583,7 +600,8 @@ impl BlockAsync {
read_only: bool,
sparse: bool,
block_size: u32,
- control_socket: Option<DiskControlResponseSocket>,
+ id: Option<BlockId>,
+ control_tube: Option<Tube>,
) -> SysResult<BlockAsync> {
if block_size % SECTOR_SIZE as u32 != 0 {
error!(
@@ -632,7 +650,8 @@ impl BlockAsync {
sparse,
seg_max,
block_size,
- control_socket,
+ id,
+ control_tube,
})
}
@@ -655,7 +674,7 @@ impl BlockAsync {
let req_type = req_header.req_type.to_native();
let sector = req_header.sector.to_native();
- if disk_state.read_only && req_type != VIRTIO_BLK_T_IN {
+ if disk_state.read_only && req_type != VIRTIO_BLK_T_IN && req_type != VIRTIO_BLK_T_GET_ID {
return Err(ExecuteError::ReadOnly {
request_type: req_type,
});
@@ -790,6 +809,13 @@ impl BlockAsync {
.await
.map_err(ExecuteError::Flush)?;
}
+ VIRTIO_BLK_T_GET_ID => {
+ if let Some(id) = disk_state.id {
+ writer.write_all(&id).map_err(ExecuteError::CopyId)?;
+ } else {
+ return Err(ExecuteError::Unsupported(req_type));
+ }
+ }
t => return Err(ExecuteError::Unsupported(t)),
};
Ok(())
@@ -817,8 +843,8 @@ impl VirtioDevice for BlockAsync {
keep_rds.extend(disk_image.as_raw_descriptors());
}
- if let Some(control_socket) = &self.control_socket {
- keep_rds.push(control_socket.as_raw_descriptor());
+ if let Some(control_tube) = &self.control_tube {
+ keep_rds.push(control_tube.as_raw_descriptor());
}
keep_rds
@@ -863,13 +889,16 @@ impl VirtioDevice for BlockAsync {
let read_only = self.read_only;
let sparse = self.sparse;
let disk_size = self.disk_size.clone();
+ let id = self.id.take();
if let Some(disk_image) = self.disk_image.take() {
- let control_socket = self.control_socket.take();
+ let control_tube = self.control_tube.take();
let worker_result =
thread::Builder::new()
.name("virtio_blk".to_string())
.spawn(move || {
let ex = Executor::new().expect("Failed to create an executor");
+ let async_control = control_tube
+ .map(|c| c.into_async_tube(&ex).expect("failed to create async tube"));
let async_image = match disk_image.to_async_disk(&ex) {
Ok(d) => d,
Err(e) => panic!("Failed to create async disk {}", e),
@@ -879,6 +908,7 @@ impl VirtioDevice for BlockAsync {
disk_size,
read_only,
sparse,
+ id,
}));
if let Err(err_string) = run_worker(
ex,
@@ -886,7 +916,7 @@ impl VirtioDevice for BlockAsync {
queues,
mem,
&disk_state,
- &control_socket,
+ &async_control,
queue_evts,
kill_evt,
) {
@@ -897,7 +927,10 @@ impl VirtioDevice for BlockAsync {
Ok(d) => d.into_inner(),
Err(_) => panic!("too many refs to the disk"),
};
- (disk_state.disk_image.into_inner(), control_socket)
+ (
+ disk_state.disk_image.into_inner(),
+ async_control.map(|c| c.into()),
+ )
});
match worker_result {
@@ -926,9 +959,9 @@ impl VirtioDevice for BlockAsync {
error!("{}: failed to get back resources", self.debug_label());
return false;
}
- Ok((disk_image, control_socket)) => {
+ Ok((disk_image, control_tube)) => {
self.disk_image = Some(disk_image);
- self.control_socket = control_socket;
+ self.control_tube = control_tube;
return true;
}
}
@@ -962,7 +995,7 @@ mod tests {
f.set_len(0x1000).unwrap();
let features = base_features(ProtectionType::Unprotected);
- let b = BlockAsync::new(features, Box::new(f), true, false, 512, None).unwrap();
+ let b = BlockAsync::new(features, Box::new(f), true, false, 512, None, None).unwrap();
let mut num_sectors = [0u8; 4];
b.read_config(0, &mut num_sectors);
// size is 0x1000, so num_sectors is 8 (4096/512).
@@ -982,7 +1015,7 @@ mod tests {
f.set_len(0x1000).unwrap();
let features = base_features(ProtectionType::Unprotected);
- let b = BlockAsync::new(features, Box::new(f), true, false, 4096, None).unwrap();
+ let b = BlockAsync::new(features, Box::new(f), true, false, 4096, None, None).unwrap();
let mut blk_size = [0u8; 4];
b.read_config(20, &mut blk_size);
// blk_size should be 4096 (0x1000).
@@ -999,7 +1032,7 @@ mod tests {
{
let f = File::create(&path).unwrap();
let features = base_features(ProtectionType::Unprotected);
- let b = BlockAsync::new(features, Box::new(f), false, true, 512, None).unwrap();
+ let b = BlockAsync::new(features, Box::new(f), false, true, 512, None, None).unwrap();
// writable device should set VIRTIO_BLK_F_FLUSH + VIRTIO_BLK_F_DISCARD
// + VIRTIO_BLK_F_WRITE_ZEROES + VIRTIO_F_VERSION_1 + VIRTIO_BLK_F_BLK_SIZE
// + VIRTIO_BLK_F_SEG_MAX + VIRTIO_BLK_F_MQ
@@ -1010,7 +1043,7 @@ mod tests {
{
let f = File::create(&path).unwrap();
let features = base_features(ProtectionType::Unprotected);
- let b = BlockAsync::new(features, Box::new(f), false, false, 512, None).unwrap();
+ let b = BlockAsync::new(features, Box::new(f), false, false, 512, None, None).unwrap();
// read-only device should set VIRTIO_BLK_F_FLUSH and VIRTIO_BLK_F_RO
// + VIRTIO_F_VERSION_1 + VIRTIO_BLK_F_BLK_SIZE + VIRTIO_BLK_F_SEG_MAX
// + VIRTIO_BLK_F_MQ
@@ -1021,7 +1054,7 @@ mod tests {
{
let f = File::create(&path).unwrap();
let features = base_features(ProtectionType::Unprotected);
- let b = BlockAsync::new(features, Box::new(f), true, true, 512, None).unwrap();
+ let b = BlockAsync::new(features, Box::new(f), true, true, 512, None, None).unwrap();
// read-only device should set VIRTIO_BLK_F_FLUSH and VIRTIO_BLK_F_RO
// + VIRTIO_F_VERSION_1 + VIRTIO_BLK_F_BLK_SIZE + VIRTIO_BLK_F_SEG_MAX
// + VIRTIO_BLK_F_MQ
@@ -1086,6 +1119,7 @@ mod tests {
disk_size: Arc::new(AtomicU64::new(disk_size)),
read_only: false,
sparse: true,
+ id: None,
}));
let fut = process_one_request(avail_desc, disk_state, flush_timer, flush_timer_armed, &mem);
@@ -1154,6 +1188,7 @@ mod tests {
disk_size: Arc::new(AtomicU64::new(disk_size)),
read_only: false,
sparse: true,
+ id: None,
}));
let fut = process_one_request(avail_desc, disk_state, flush_timer, flush_timer_armed, &mem);
@@ -1166,4 +1201,79 @@ mod tests {
let status = mem.read_obj_from_addr::<u8>(status_offset).unwrap();
assert_eq!(status, VIRTIO_BLK_S_IOERR);
}
+
+ #[test]
+ fn get_id() {
+ let ex = Executor::new().expect("creating an executor failed");
+
+ let tempdir = TempDir::new().unwrap();
+ let mut path = tempdir.path().to_owned();
+ path.push("disk_image");
+ let f = OpenOptions::new()
+ .read(true)
+ .write(true)
+ .create(true)
+ .open(&path)
+ .unwrap();
+ let disk_size = 0x1000;
+ f.set_len(disk_size).unwrap();
+
+ let mem = GuestMemory::new(&[(GuestAddress(0u64), 4 * 1024 * 1024)])
+ .expect("Creating guest memory failed.");
+
+ let req_hdr = virtio_blk_req_header {
+ req_type: Le32::from(VIRTIO_BLK_T_GET_ID),
+ reserved: Le32::from(0),
+ sector: Le64::from(0),
+ };
+ mem.write_obj_at_addr(req_hdr, GuestAddress(0x1000))
+ .expect("writing req failed");
+
+ let avail_desc = create_descriptor_chain(
+ &mem,
+ GuestAddress(0x100), // Place descriptor chain at 0x100.
+ GuestAddress(0x1000), // Describe buffer at 0x1000.
+ vec![
+ // Request header
+ (DescriptorType::Readable, size_of_val(&req_hdr) as u32),
+ // I/O buffer (20 bytes for serial)
+ (DescriptorType::Writable, 20),
+ // Request status
+ (DescriptorType::Writable, 1),
+ ],
+ 0,
+ )
+ .expect("create_descriptor_chain failed");
+
+ let af = SingleFileDisk::new(f, &ex).expect("Failed to create SFD");
+ let timer = Timer::new().expect("Failed to create a timer");
+ let flush_timer = Rc::new(RefCell::new(
+ TimerAsync::new(timer.0, &ex).expect("Failed to create an async timer"),
+ ));
+ let flush_timer_armed = Rc::new(RefCell::new(false));
+
+ let id = b"a20-byteserialnumber";
+
+ let disk_state = Rc::new(AsyncMutex::new(DiskState {
+ disk_image: Box::new(af),
+ disk_size: Arc::new(AtomicU64::new(disk_size)),
+ read_only: false,
+ sparse: true,
+ id: Some(*id),
+ }));
+
+ let fut = process_one_request(avail_desc, disk_state, flush_timer, flush_timer_armed, &mem);
+
+ ex.run_until(fut)
+ .expect("running executor failed")
+ .expect("execute failed");
+
+ let status_offset = GuestAddress((0x1000 + size_of_val(&req_hdr) + 512) as u64);
+ let status = mem.read_obj_from_addr::<u8>(status_offset).unwrap();
+ assert_eq!(status, VIRTIO_BLK_S_OK);
+
+ let id_offset = GuestAddress(0x1000 + size_of_val(&req_hdr) as u64);
+ let returned_id = mem.read_obj_from_addr::<[u8; 20]>(id_offset).unwrap();
+ assert_eq!(returned_id, *id);
+ }
}
diff --git a/devices/src/virtio/console.rs b/devices/src/virtio/console.rs
index 5eb7cbcf4..9aa789a5d 100644
--- a/devices/src/virtio/console.rs
+++ b/devices/src/virtio/console.rs
@@ -11,7 +11,8 @@ use data_model::{DataInit, Le16, Le32};
use vm_memory::GuestMemory;
use super::{
- base_features, copy_config, Interrupt, Queue, Reader, VirtioDevice, Writer, TYPE_CONSOLE,
+ base_features, copy_config, Interrupt, Queue, Reader, SignalableInterrupt, VirtioDevice,
+ Writer, TYPE_CONSOLE,
};
use crate::{ProtectionType, SerialDevice};
@@ -239,7 +240,6 @@ impl Worker {
(&transmit_evt, Token::TransmitQueueAvailable),
(&receive_evt, Token::ReceiveQueueAvailable),
(&in_avail_evt, Token::InputAvailable),
- (self.interrupt.get_resample_evt(), Token::InterruptResample),
(&kill_evt, Token::Kill),
]) {
Ok(pc) => pc,
@@ -248,6 +248,15 @@ impl Worker {
return;
}
};
+ if let Some(resample_evt) = self.interrupt.get_resample_evt() {
+ if wait_ctx
+ .add(resample_evt, Token::InterruptResample)
+ .is_err()
+ {
+ error!("failed adding resample event to WaitContext.");
+ return;
+ }
+ }
let mut output: Box<dyn io::Write> = match self.output.take() {
Some(o) => o,
diff --git a/devices/src/virtio/descriptor_utils.rs b/devices/src/virtio/descriptor_utils.rs
index a8be0e779..fc839b610 100644
--- a/devices/src/virtio/descriptor_utils.rs
+++ b/devices/src/virtio/descriptor_utils.rs
@@ -18,6 +18,7 @@ use base::{FileReadWriteAtVolatile, FileReadWriteVolatile};
use cros_async::MemRegion;
use data_model::{DataInit, Le16, Le32, Le64, VolatileMemoryError, VolatileSlice};
use disk::AsyncDisk;
+use smallvec::SmallVec;
use vm_memory::{GuestAddress, GuestMemory};
use super::DescriptorChain;
@@ -56,7 +57,7 @@ impl std::error::Error for Error {}
#[derive(Clone)]
struct DescriptorChainRegions {
- regions: Vec<MemRegion>,
+ regions: SmallVec<[MemRegion; 16]>,
current: usize,
bytes_consumed: usize,
}
@@ -87,7 +88,7 @@ impl DescriptorChainRegions {
/// `GuestMemory`. Calling this function does not consume any bytes from the `DescriptorChain`.
/// Instead callers should use the `consume` method to advance the `DescriptorChain`. Multiple
/// calls to `get` with no intervening calls to `consume` will return the same data.
- fn get_remaining<'mem>(&self, mem: &'mem GuestMemory) -> Vec<VolatileSlice<'mem>> {
+ fn get_remaining<'mem>(&self, mem: &'mem GuestMemory) -> SmallVec<[VolatileSlice<'mem>; 16]> {
self.get_remaining_regions()
.iter()
.filter_map(|region| {
@@ -134,7 +135,7 @@ impl DescriptorChainRegions {
&self,
mem: &'mem GuestMemory,
count: usize,
- ) -> Vec<VolatileSlice<'mem>> {
+ ) -> SmallVec<[VolatileSlice<'mem>; 16]> {
self.get_remaining_regions_with_count(count)
.iter()
.filter_map(|region| {
@@ -259,7 +260,7 @@ impl Reader {
len: desc.len.try_into().expect("u32 doesn't fit in usize"),
})
})
- .collect::<Result<Vec<MemRegion>>>()?;
+ .collect::<Result<SmallVec<[MemRegion; 16]>>>()?;
Ok(Reader {
mem,
regions: DescriptorChainRegions {
@@ -442,7 +443,7 @@ impl Reader {
/// Returns a `&[VolatileSlice]` that represents all the remaining data in this `Reader`.
/// Calling this method does not actually consume any data from the `Reader` and callers should
/// call `consume` to advance the `Reader`.
- pub fn get_remaining(&self) -> Vec<VolatileSlice> {
+ pub fn get_remaining(&self) -> SmallVec<[VolatileSlice; 16]> {
self.regions.get_remaining(&self.mem)
}
@@ -527,7 +528,7 @@ impl Writer {
len: desc.len.try_into().expect("u32 doesn't fit in usize"),
})
})
- .collect::<Result<Vec<MemRegion>>>()?;
+ .collect::<Result<SmallVec<[MemRegion; 16]>>>()?;
Ok(Writer {
mem,
regions: DescriptorChainRegions {
diff --git a/devices/src/virtio/fs/caps.rs b/devices/src/virtio/fs/caps.rs
new file mode 100644
index 000000000..f4964c6a6
--- /dev/null
+++ b/devices/src/virtio/fs/caps.rs
@@ -0,0 +1,163 @@
+// Copyright 2021 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+use std::ffi::c_void;
+use std::io;
+use std::os::raw::c_int;
+
+#[allow(non_camel_case_types)]
+type cap_t = *mut c_void;
+
+#[allow(non_camel_case_types)]
+pub type cap_value_t = u32;
+
+#[allow(non_camel_case_types)]
+type cap_flag_t = u32;
+
+#[allow(non_camel_case_types)]
+type cap_flag_value_t = i32;
+
+#[link(name = "cap")]
+extern "C" {
+ fn cap_free(ptr: *mut c_void) -> c_int;
+
+ fn cap_set_flag(
+ c: cap_t,
+ f: cap_flag_t,
+ ncap: c_int,
+ caps: *const cap_value_t,
+ val: cap_flag_value_t,
+ ) -> c_int;
+
+ fn cap_get_proc() -> cap_t;
+ fn cap_set_proc(cap: cap_t) -> c_int;
+}
+
+#[repr(u32)]
+pub enum Capability {
+ Chown = 0,
+ DacOverride = 1,
+ DacReadSearch = 2,
+ Fowner = 3,
+ Fsetid = 4,
+ Kill = 5,
+ Setgid = 6,
+ Setuid = 7,
+ Setpcap = 8,
+ LinuxImmutable = 9,
+ NetBindService = 10,
+ NetBroadcast = 11,
+ NetAdmin = 12,
+ NetRaw = 13,
+ IpcLock = 14,
+ IpcOwner = 15,
+ SysModule = 16,
+ SysRawio = 17,
+ SysChroot = 18,
+ SysPtrace = 19,
+ SysPacct = 20,
+ SysAdmin = 21,
+ SysBoot = 22,
+ SysNice = 23,
+ SysResource = 24,
+ SysTime = 25,
+ SysTtyConfig = 26,
+ Mknod = 27,
+ Lease = 28,
+ AuditWrite = 29,
+ AuditControl = 30,
+ Setfcap = 31,
+ MacOverride = 32,
+ MacAdmin = 33,
+ Syslog = 34,
+ WakeAlarm = 35,
+ BlockSuspend = 36,
+ AuditRead = 37,
+ Last,
+}
+
+impl From<Capability> for cap_value_t {
+ fn from(c: Capability) -> cap_value_t {
+ c as cap_value_t
+ }
+}
+
+#[repr(u32)]
+pub enum Set {
+ Effective = 0,
+ Permitted = 1,
+ Inheritable = 2,
+}
+
+impl From<Set> for cap_flag_t {
+ fn from(s: Set) -> cap_flag_t {
+ s as cap_flag_t
+ }
+}
+
+#[repr(i32)]
+pub enum Value {
+ Clear = 0,
+ Set = 1,
+}
+
+impl From<Value> for cap_flag_value_t {
+ fn from(v: Value) -> cap_flag_value_t {
+ v as cap_flag_value_t
+ }
+}
+
+pub struct Caps(cap_t);
+
+impl Caps {
+ /// Get the capabilities for the current thread.
+ pub fn for_current_thread() -> io::Result<Caps> {
+ // Safe because this doesn't modify any memory and we check the return value.
+ let caps = unsafe { cap_get_proc() };
+ if caps.is_null() {
+ Err(io::Error::last_os_error())
+ } else {
+ Ok(Caps(caps))
+ }
+ }
+
+ /// Update the capabilities described by `self` by setting or clearing `caps` in `set`.
+ pub fn update(&mut self, caps: &[Capability], set: Set, value: Value) -> io::Result<()> {
+ // Safe because this only modifies the memory pointed to by `self.0` and we check the return
+ // value.
+ let ret = unsafe {
+ cap_set_flag(
+ self.0,
+ set.into(),
+ caps.len() as c_int,
+ // It's safe to cast this pointer because `Capability` is #[repr(u32)]
+ caps.as_ptr() as *const cap_value_t,
+ value.into(),
+ )
+ };
+
+ if ret == 0 {
+ Ok(())
+ } else {
+ Err(io::Error::last_os_error())
+ }
+ }
+
+ /// Apply the capabilities described by `self` to the current thread.
+ pub fn apply(&self) -> io::Result<()> {
+ if unsafe { cap_set_proc(self.0) } == 0 {
+ Ok(())
+ } else {
+ Err(io::Error::last_os_error())
+ }
+ }
+}
+
+impl Drop for Caps {
+ fn drop(&mut self) {
+ unsafe {
+ cap_free(self.0);
+ }
+ }
+}
diff --git a/devices/src/virtio/fs/mod.rs b/devices/src/virtio/fs/mod.rs
index 7ec90cbf4..bc8c28f01 100644
--- a/devices/src/virtio/fs/mod.rs
+++ b/devices/src/virtio/fs/mod.rs
@@ -8,11 +8,10 @@ use std::mem;
use std::sync::{Arc, Mutex};
use std::thread;
-use base::{error, warn, AsRawDescriptor, Error as SysError, Event, RawDescriptor};
+use base::{error, warn, AsRawDescriptor, Error as SysError, Event, RawDescriptor, Tube};
use data_model::{DataInit, Le32};
-use msg_socket::{MsgReceiver, MsgSender};
use resources::Alloc;
-use vm_control::{FsMappingRequest, FsMappingRequestSocket, VmResponse};
+use vm_control::{FsMappingRequest, VmResponse};
use vm_memory::GuestMemory;
use crate::pci::{
@@ -23,6 +22,7 @@ use crate::virtio::{
VirtioPciShmCap, TYPE_FS,
};
+mod caps;
mod multikey;
pub mod passthrough;
mod read_dir;
@@ -33,7 +33,7 @@ use passthrough::PassthroughFs;
use worker::Worker;
// The fs device does not have a fixed number of queues.
-const QUEUE_SIZE: u16 = 1024;
+pub const QUEUE_SIZE: u16 = 1024;
const FS_BAR_NUM: u8 = 4;
const FS_BAR_OFFSET: u64 = 0;
@@ -48,15 +48,15 @@ pub const FS_MAX_TAG_LEN: usize = 36;
/// kernel/include/uapi/linux/virtio_fs.h
#[repr(C, packed)]
#[derive(Clone, Copy)]
-struct Config {
+pub(crate) struct virtio_fs_config {
/// Filesystem name (UTF-8, not NUL-terminated, padded with NULs)
- tag: [u8; FS_MAX_TAG_LEN],
+ pub tag: [u8; FS_MAX_TAG_LEN],
/// Number of request queues
- num_queues: Le32,
+ pub num_request_queues: Le32,
}
// Safe because all members are plain old data and any value is valid.
-unsafe impl DataInit for Config {}
+unsafe impl DataInit for virtio_fs_config {}
/// Errors that may occur during the creation or operation of an Fs device.
#[derive(Debug)]
@@ -81,6 +81,10 @@ pub enum Error {
InvalidDescriptorChain(DescriptorError),
/// Error happened in FUSE.
FuseError(fuse::Error),
+ /// Failed to get the securebits for the worker thread.
+ GetSecurebits(io::Error),
+ /// Failed to set the securebits for the worker thread.
+ SetSecurebits(io::Error),
}
impl ::std::error::Error for Error {}
@@ -109,6 +113,12 @@ impl fmt::Display for Error {
SignalUsedQueue(err) => write!(f, "failed to signal used queue: {}", err),
InvalidDescriptorChain(err) => write!(f, "DescriptorChain is invalid: {}", err),
FuseError(err) => write!(f, "fuse error: {}", err),
+ GetSecurebits(err) => {
+ write!(f, "failed to get securebits for the worker thread: {}", err)
+ }
+ SetSecurebits(err) => {
+ write!(f, "failed to set securebits for the worker thread: {}", err)
+ }
}
}
}
@@ -116,13 +126,13 @@ impl fmt::Display for Error {
pub type Result<T> = ::std::result::Result<T, Error>;
pub struct Fs {
- cfg: Config,
+ cfg: virtio_fs_config,
fs: Option<PassthroughFs>,
queue_sizes: Box<[u16]>,
avail_features: u64,
acked_features: u64,
pci_bar: Option<Alloc>,
- socket: Option<FsMappingRequestSocket>,
+ tube: Option<Tube>,
workers: Vec<(Event, thread::JoinHandle<Result<()>>)>,
}
@@ -132,7 +142,7 @@ impl Fs {
tag: &str,
num_workers: usize,
fs_cfg: passthrough::Config,
- socket: FsMappingRequestSocket,
+ tube: Tube,
) -> Result<Fs> {
if tag.len() > FS_MAX_TAG_LEN {
return Err(Error::TagTooLong(tag.len()));
@@ -141,9 +151,9 @@ impl Fs {
let mut cfg_tag = [0u8; FS_MAX_TAG_LEN];
cfg_tag[..tag.len()].copy_from_slice(tag.as_bytes());
- let cfg = Config {
+ let cfg = virtio_fs_config {
tag: cfg_tag,
- num_queues: Le32::from(num_workers as u32),
+ num_request_queues: Le32::from(num_workers as u32),
};
let fs = PassthroughFs::new(fs_cfg).map_err(Error::CreateFs)?;
@@ -158,7 +168,7 @@ impl Fs {
avail_features: base_features,
acked_features: 0,
pci_bar: None,
- socket: Some(socket),
+ tube: Some(tube),
workers: Vec::with_capacity(num_workers + 1),
})
}
@@ -190,7 +200,7 @@ impl VirtioDevice for Fs {
.as_ref()
.map(PassthroughFs::keep_rds)
.unwrap_or_else(Vec::new);
- if let Some(rd) = self.socket.as_ref().map(|s| s.as_raw_descriptor()) {
+ if let Some(rd) = self.tube.as_ref().map(|s| s.as_raw_descriptor()) {
fds.push(rd);
}
@@ -240,7 +250,7 @@ impl VirtioDevice for Fs {
let server = Arc::new(Server::new(fs));
let irq = Arc::new(interrupt);
- let socket = self.socket.take().expect("missing mapping socket");
+ let socket = self.tube.take().expect("missing mapping socket");
let mut slot = 0;
// Set up shared memory for DAX.
diff --git a/devices/src/virtio/fs/passthrough.rs b/devices/src/virtio/fs/passthrough.rs
index 71415e101..dfeee627a 100644
--- a/devices/src/virtio/fs/passthrough.rs
+++ b/devices/src/virtio/fs/passthrough.rs
@@ -6,12 +6,11 @@ use std::borrow::Cow;
use std::cmp;
use std::collections::btree_map;
use std::collections::BTreeMap;
-use std::ffi::{c_void, CStr, CString};
+use std::ffi::{CStr, CString};
use std::fs::File;
use std::io;
use std::mem::{self, size_of, MaybeUninit};
use std::os::raw::{c_int, c_long};
-use std::ptr;
use std::str::FromStr;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::Arc;
@@ -27,10 +26,11 @@ use fuse::filesystem::{
IoctlReply, ListxattrReply, OpenOptions, RemoveMappingOne, SetattrValid, ZeroCopyReader,
ZeroCopyWriter, ROOT_ID,
};
+use fuse::sys::WRITE_KILL_PRIV;
use fuse::Mapper;
-use rand_ish::SimpleRng;
use sync::Mutex;
+use crate::virtio::fs::caps::{Capability, Caps, Set as CapSet, Value as CapValue};
use crate::virtio::fs::multikey::MultikeyBTreeMap;
use crate::virtio::fs::read_dir::ReadDir;
@@ -254,6 +254,62 @@ fn set_creds(
ScopedGid::new(gid, oldgid).and_then(|gid| Ok((ScopedUid::new(uid, olduid)?, gid)))
}
+struct ScopedUmask<'a> {
+ old: libc::mode_t,
+ mask: libc::mode_t,
+ _factory: &'a mut Umask,
+}
+
+impl<'a> Drop for ScopedUmask<'a> {
+ fn drop(&mut self) {
+ // Safe because this doesn't modify any memory and always succeeds.
+ let previous = unsafe { libc::umask(self.old) };
+ debug_assert_eq!(
+ previous, self.mask,
+ "umask changed while holding ScopedUmask"
+ );
+ }
+}
+
+struct Umask;
+
+impl Umask {
+ fn set(&mut self, mask: libc::mode_t) -> ScopedUmask {
+ ScopedUmask {
+ // Safe because this doesn't modify any memory and always succeeds.
+ old: unsafe { libc::umask(mask) },
+ mask,
+ _factory: self,
+ }
+ }
+}
+
+struct ScopedFsetid(Caps);
+impl Drop for ScopedFsetid {
+ fn drop(&mut self) {
+ if let Err(e) = raise_cap_fsetid(&mut self.0) {
+ error!(
+ "Failed to restore CAP_FSETID: {}. Some operations may be broken.",
+ e
+ )
+ }
+ }
+}
+
+fn raise_cap_fsetid(c: &mut Caps) -> io::Result<()> {
+ c.update(&[Capability::Fsetid], CapSet::Effective, CapValue::Set)?;
+ c.apply()
+}
+
+// Drops CAP_FSETID from the effective set for the current thread and returns an RAII guard that
+// adds the capability back when it is dropped.
+fn drop_cap_fsetid() -> io::Result<ScopedFsetid> {
+ let mut caps = Caps::for_current_thread()?;
+ caps.update(&[Capability::Fsetid], CapSet::Effective, CapValue::Clear)?;
+ caps.apply()?;
+ Ok(ScopedFsetid(caps))
+}
+
fn ebadf() -> io::Error {
io::Error::from_raw_os_error(libc::EBADF)
}
@@ -442,6 +498,10 @@ pub struct PassthroughFs {
// process-wide CWD, we cannot allow more than one thread to do it at the same time.
chdir_mutex: Mutex<()>,
+ // Used when creating files / directories / nodes. Since the umask is process-wide, we can only
+ // allow one thread at a time to change it.
+ umask: Mutex<Umask>,
+
cfg: Config,
}
@@ -479,6 +539,7 @@ impl PassthroughFs {
zero_message_opendir: AtomicBool::new(false),
chdir_mutex: Mutex::new(()),
+ umask: Mutex::new(Umask),
cfg,
})
}
@@ -570,7 +631,7 @@ impl PassthroughFs {
}
// Creates a new entry for `f` or increases the refcount of the existing entry for `f`.
- fn add_entry(&self, f: File, st: libc::stat64, open_flags: libc::c_int) -> io::Result<Entry> {
+ fn add_entry(&self, f: File, st: libc::stat64, open_flags: libc::c_int) -> Entry {
let altkey = InodeAltKey {
ino: st.st_ino,
dev: st.st_dev,
@@ -603,13 +664,13 @@ impl PassthroughFs {
inode
};
- Ok(Entry {
+ Entry {
inode,
generation: 0,
attr: st,
attr_timeout: self.cfg.attr_timeout,
entry_timeout: self.cfg.entry_timeout,
- })
+ }
}
// Performs an ascii case insensitive lookup.
@@ -651,7 +712,7 @@ impl PassthroughFs {
// Safe because we just opened this fd.
let f = unsafe { File::from_raw_descriptor(fd) };
- self.add_entry(f, st, flags)
+ Ok(self.add_entry(f, st, flags))
}
fn do_open(&self, inode: Inode, flags: u32) -> io::Result<(Option<Handle>, OpenOptions)> {
@@ -684,68 +745,6 @@ impl PassthroughFs {
Ok((Some(handle), opts))
}
- fn do_tmpfile(
- &self,
- ctx: &Context,
- dir: &InodeData,
- flags: u32,
- mut mode: u32,
- umask: u32,
- ) -> io::Result<(File, libc::c_int)> {
- // We don't want to use `O_EXCL` with `O_TMPFILE` as it has a different meaning when used in
- // that combination.
- let mut tmpflags = (flags as i32 | libc::O_TMPFILE | libc::O_CLOEXEC | libc::O_NOFOLLOW)
- & !(libc::O_EXCL | libc::O_CREAT);
-
- // O_TMPFILE requires that we use O_RDWR or O_WRONLY.
- if flags as i32 & libc::O_ACCMODE == libc::O_RDONLY {
- tmpflags &= !libc::O_ACCMODE;
- tmpflags |= libc::O_RDWR;
- }
-
- // The presence of a default posix acl xattr in the parent directory completely changes the
- // meaning of the mode parameter so only apply the umask if it doesn't have one.
- if !self.has_default_posix_acl(&dir)? {
- mode &= !umask;
- }
-
- // Safe because this is a valid c string.
- let current_dir = unsafe { CStr::from_bytes_with_nul_unchecked(b".\0") };
-
- // Safe because this doesn't modify any memory and we check the return value.
- let fd = unsafe {
- libc::openat(
- dir.as_raw_descriptor(),
- current_dir.as_ptr(),
- tmpflags,
- mode,
- )
- };
- if fd < 0 {
- return Err(io::Error::last_os_error());
- }
-
- // Safe because we just opened this fd.
- let tmpfile = unsafe { File::from_raw_descriptor(fd) };
-
- // We need to respect the setgid bit in the parent directory if it is set.
- let st = stat(dir)?;
- let gid = if st.st_mode & libc::S_ISGID != 0 {
- st.st_gid
- } else {
- ctx.gid
- };
-
- // Now set the uid and gid for the file. Safe because this doesn't modify any memory and we
- // check the return value.
- let ret = unsafe { libc::fchown(tmpfile.as_raw_descriptor(), ctx.uid, gid) };
- if ret < 0 {
- return Err(io::Error::last_os_error());
- }
-
- Ok((tmpfile, tmpflags))
- }
-
fn do_release(&self, inode: Inode, handle: Handle) -> io::Result<()> {
let mut handles = self.handles.lock();
@@ -868,21 +867,6 @@ impl PassthroughFs {
}
}
- // Checks whether `inode` has a default posix acl xattr.
- fn has_default_posix_acl(&self, inode: &InodeData) -> io::Result<bool> {
- // Safe because this is a valid c string with no interior nul-bytes.
- let acl = unsafe { CStr::from_bytes_with_nul_unchecked(b"system.posix_acl_default\0") };
-
- if let Err(e) = self.do_getxattr(inode, acl, &mut []) {
- match e.raw_os_error() {
- Some(libc::ENODATA) | Some(libc::EOPNOTSUPP) => Ok(false),
- _ => Err(e),
- }
- } else {
- Ok(true)
- }
- }
-
fn get_encryption_policy_ex<R: io::Read>(
&self,
inode: Inode,
@@ -1015,8 +999,8 @@ fn forget_one(
// Synchronizes with the acquire load in `do_lookup`.
if data
.refcount
- .compare_and_swap(refcount, new_count, Ordering::Release)
- == refcount
+ .compare_exchange_weak(refcount, new_count, Ordering::Release, Ordering::Relaxed)
+ .is_ok()
{
if new_count == 0 {
// We just removed the last refcount for this inode. There's no need for an
@@ -1062,133 +1046,6 @@ fn strip_xattr_prefix(buf: &mut Vec<u8>) {
}
}
-// Like mkdtemp but also takes a mode parameter rather than always using 0o700. This is needed
-// because if the parent has a default posix acl set then the meaning of the mode parameter in the
-// mkdir call completely changes: the actual mode is inherited from the default acls set in the
-// parent and the mode is treated like a umask (the real umask is ignored in this case).
-// Additionally, this only happens when the inode is first created and not on subsequent fchmod
-// calls so we really need to use the requested mode from the very beginning and not the default
-// 0o700 mode that mkdtemp uses.
-fn create_temp_dir<D: AsRawDescriptor>(parent: &D, mode: libc::mode_t) -> io::Result<CString> {
- const MAX_ATTEMPTS: usize = 64;
- let mut seed = 0u64.to_ne_bytes();
- // Safe because this will only modify `seed` and we check the return value.
- let ret = unsafe {
- libc::syscall(
- libc::SYS_getrandom,
- seed.as_mut_ptr() as *mut c_void,
- seed.len(),
- 0,
- )
- };
- if ret < 0 {
- return Err(io::Error::last_os_error());
- }
-
- let mut rng = SimpleRng::new(u64::from_ne_bytes(seed));
-
- // Set an upper bound so that we don't end up spinning here forever.
- for _ in 0..MAX_ATTEMPTS {
- let mut name = String::from(".");
- name.push_str(&rng.str(6));
- let name = CString::new(name).expect("SimpleRng produced string with nul-bytes");
-
- // Safe because this doesn't modify any memory and we check the return value.
- let ret = unsafe { libc::mkdirat(parent.as_raw_descriptor(), name.as_ptr(), mode) };
- if ret == 0 {
- return Ok(name);
- }
-
- let e = io::Error::last_os_error();
- if let Some(libc::EEXIST) = e.raw_os_error() {
- continue;
- } else {
- return Err(e);
- }
- }
-
- Err(io::Error::from_raw_os_error(libc::EAGAIN))
-}
-
-// A temporary directory that is automatically deleted when dropped unless `into_inner()` is called.
-// This isn't a general-purpose temporary directory and is only intended to be used to ensure that
-// there are no leaks when initializing a newly created directory with the correct metadata (see the
-// implementation of `mkdir()` below). The directory is removed via a call to `unlinkat` so callers
-// are not allowed to actually populate this temporary directory with any entries (or else deleting
-// the directory will fail).
-struct TempDir<'a, D: AsRawDescriptor> {
- parent: &'a D,
- name: CString,
- file: File,
-}
-
-impl<'a, D: AsRawDescriptor> TempDir<'a, D> {
- // Creates a new temporary directory in `parent` with a randomly generated name. `parent` must
- // be a directory.
- fn new(parent: &'a D, mode: libc::mode_t) -> io::Result<Self> {
- let name = create_temp_dir(parent, mode)?;
-
- // Safe because this doesn't modify any memory and we check the return value.
- let raw_descriptor = unsafe {
- libc::openat(
- parent.as_raw_descriptor(),
- name.as_ptr(),
- libc::O_DIRECTORY | libc::O_CLOEXEC,
- )
- };
- if raw_descriptor < 0 {
- return Err(io::Error::last_os_error());
- }
-
- Ok(TempDir {
- parent,
- name,
- // Safe because we just opened this descriptor.
- file: unsafe { File::from_raw_descriptor(raw_descriptor) },
- })
- }
-
- fn basename(&self) -> &CStr {
- &self.name
- }
-
- // Consumes the `TempDir`, returning the inner `File` without deleting the temporary
- // directory.
- fn into_inner(self) -> (CString, File) {
- // Safe because this is a valid pointer and we are going to call `mem::forget` on `self` so
- // we will not be aliasing memory.
- let _parent = unsafe { ptr::read(&self.parent) };
- let name = unsafe { ptr::read(&self.name) };
- let file = unsafe { ptr::read(&self.file) };
- mem::forget(self);
-
- (name, file)
- }
-}
-
-impl<'a, D: AsRawDescriptor> AsRawDescriptor for TempDir<'a, D> {
- fn as_raw_descriptor(&self) -> RawDescriptor {
- self.file.as_raw_descriptor()
- }
-}
-
-impl<'a, D: AsRawDescriptor> Drop for TempDir<'a, D> {
- fn drop(&mut self) {
- // Safe because this doesn't modify any memory and we check the return value.
- let ret = unsafe {
- libc::unlinkat(
- self.parent.as_raw_descriptor(),
- self.name.as_ptr(),
- libc::AT_REMOVEDIR,
- )
- };
- if ret < 0 {
- println!("Failed to remove tempdir: {}", io::Error::last_os_error());
- error!("Failed to remove tempdir: {}", io::Error::last_os_error());
- }
- }
-}
-
impl FileSystem for PassthroughFs {
type Inode = Inode;
type Handle = Handle;
@@ -1331,65 +1188,24 @@ impl FileSystem for PassthroughFs {
ctx: Context,
parent: Inode,
name: &CStr,
- mut mode: u32,
+ mode: u32,
umask: u32,
) -> io::Result<Entry> {
- // This method has the same issues as `create()`: namely that the kernel may have allowed a
- // process to make a directory due to one of its supplementary groups but that information
- // is not forwarded to us. However, there is no `O_TMPDIR` equivalent for directories so
- // instead we create a "hidden" directory with a randomly generated name in the parent
- // directory, modify the uid/gid and mode to the proper values, and then rename it to the
- // requested name. This ensures that even in the case of a power loss the directory is not
- // visible in the filesystem with the requested name but incorrect metadata. The only thing
- // left would be a empty hidden directory with a random name.
let data = self.find_inode(parent)?;
- // The presence of a default posix acl xattr in the parent directory completely changes the
- // meaning of the mode parameter so only apply the umask if it doesn't have one.
- if !self.has_default_posix_acl(&data)? {
- mode &= !umask;
- }
-
- let tmpdir = TempDir::new(&*data, mode)?;
-
- // We need to respect the setgid bit in the parent directory if it is set.
- let st = stat(&data.file.lock().0)?;
- let gid = if st.st_mode & libc::S_ISGID != 0 {
- st.st_gid
- } else {
- ctx.gid
- };
-
- // Set the uid and gid for the directory. Safe because this doesn't modify any memory and we
- // check the return value.
- let ret = unsafe { libc::fchown(tmpdir.as_raw_descriptor(), ctx.uid, gid) };
- if ret < 0 {
- return Err(io::Error::last_os_error());
- }
+ let (_uid, _gid) = set_creds(ctx.uid, ctx.gid)?;
+ let res = {
+ let mut um = self.umask.lock();
+ let _scoped_umask = um.set(umask);
- // Now rename it into place. Safe because this doesn't modify any memory and we check the
- // return value. TODO: Switch to libc::renameat2 once
- // https://github.com/rust-lang/libc/pull/1508 lands and we have glibc 2.28.
- let ret = unsafe {
- libc::syscall(
- libc::SYS_renameat2,
- data.as_raw_descriptor(),
- tmpdir.basename().as_ptr(),
- data.as_raw_descriptor(),
- name.as_ptr(),
- libc::RENAME_NOREPLACE,
- )
+ // Safe because this doesn't modify any memory and we check the return value.
+ unsafe { libc::mkdirat(data.as_raw_descriptor(), name.as_ptr(), mode) }
};
- if ret < 0 {
- return Err(io::Error::last_os_error());
+ if res == 0 {
+ self.do_lookup(&data, name)
+ } else {
+ Err(io::Error::last_os_error())
}
-
- // Now that we've moved the directory make sure we don't try to delete the now non-existent
- // `tmpdir`.
- let (_, dir) = tmpdir.into_inner();
-
- let st = stat(&dir)?;
- self.add_entry(dir, st, libc::O_DIRECTORY | libc::O_CLOEXEC)
}
fn rmdir(&self, _ctx: Context, parent: Inode, name: &CStr) -> io::Result<()> {
@@ -1458,10 +1274,36 @@ impl FileSystem for PassthroughFs {
) -> io::Result<Entry> {
let data = self.find_inode(parent)?;
- let (tmpfile, flags) = self.do_tmpfile(&ctx, &data, 0, mode, umask)?;
+ let (_uid, _gid) = set_creds(ctx.uid, ctx.gid)?;
+
+ let tmpflags = libc::O_RDWR | libc::O_TMPFILE | libc::O_CLOEXEC | libc::O_NOFOLLOW;
+
+ // Safe because this is a valid c string.
+ let current_dir = unsafe { CStr::from_bytes_with_nul_unchecked(b".\0") };
+
+ let fd = {
+ let mut um = self.umask.lock();
+ let _scoped_umask = um.set(umask);
+
+ // Safe because this doesn't modify any memory and we check the return value.
+ unsafe {
+ libc::openat(
+ data.as_raw_descriptor(),
+ current_dir.as_ptr(),
+ tmpflags,
+ mode,
+ )
+ }
+ };
+ if fd < 0 {
+ return Err(io::Error::last_os_error());
+ }
+
+ // Safe because we just opened this fd.
+ let tmpfile = unsafe { File::from_raw_descriptor(fd) };
let st = stat(&tmpfile)?;
- self.add_entry(tmpfile, st, flags)
+ Ok(self.add_entry(tmpfile, st, tmpflags))
}
fn create(
@@ -1473,40 +1315,31 @@ impl FileSystem for PassthroughFs {
flags: u32,
umask: u32,
) -> io::Result<(Entry, Option<Handle>, OpenOptions)> {
- // The `Context` may not contain all the information we need to create the file here. For
- // example, a process may be part of several groups, one of which gives it permission to
- // create a file in `parent`, but is not the gid of the process. This information is not
- // forwarded to the server so we don't know when this is happening. Instead, we just rely on
- // the access checks in the kernel driver: if we received this request then the kernel has
- // determined that the process is allowed to create the file and we shouldn't reject it now
- // based on acls.
- //
- // To ensure that the file is created atomically with the proper uid/gid we use `O_TMPFILE`
- // + `linkat` as described in the `open(2)` manpage.
let data = self.find_inode(parent)?;
- let (tmpfile, tmpflags) = self.do_tmpfile(&ctx, &data, flags, mode, umask)?;
+ let (_uid, _gid) = set_creds(ctx.uid, ctx.gid)?;
- let proc_path = CString::new(format!("self/fd/{}", tmpfile.as_raw_descriptor()))
- .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
+ let create_flags =
+ (flags as i32 | libc::O_CREAT | libc::O_CLOEXEC | libc::O_NOFOLLOW) & !libc::O_DIRECT;
- // Finally link it into the file system tree so that it's visible to other processes. Safe
- // because this doesn't modify any memory and we check the return value.
- let ret = unsafe {
- libc::linkat(
- self.proc.as_raw_descriptor(),
- proc_path.as_ptr(),
- data.as_raw_descriptor(),
- name.as_ptr(),
- libc::AT_SYMLINK_FOLLOW,
- )
+ let fd = {
+ let mut um = self.umask.lock();
+ let _scoped_umask = um.set(umask);
+
+ // Safe because this doesn't modify any memory and we check the return value. We don't
+ // really check `flags` because if the kernel can't handle poorly specified flags then
+ // we have much bigger problems.
+ unsafe { libc::openat(data.as_raw_descriptor(), name.as_ptr(), create_flags, mode) }
};
- if ret < 0 {
+ if fd < 0 {
return Err(io::Error::last_os_error());
}
- let st = stat(&tmpfile)?;
- let entry = self.add_entry(tmpfile, st, tmpflags)?;
+ // Safe because we just opened this fd.
+ let file = unsafe { File::from_raw_descriptor(fd) };
+
+ let st = stat(&file)?;
+ let entry = self.add_entry(file, st, create_flags);
let (handle, opts) = if self.zero_message_open.load(Ordering::Relaxed) {
(None, OpenOptions::KEEP_CACHE)
@@ -1570,7 +1403,7 @@ impl FileSystem for PassthroughFs {
fn write<R: io::Read + ZeroCopyReader>(
&self,
- ctx: Context,
+ _ctx: Context,
inode: Inode,
handle: Handle,
mut r: R,
@@ -1578,11 +1411,15 @@ impl FileSystem for PassthroughFs {
offset: u64,
_lock_owner: Option<u64>,
_delayed_write: bool,
- _flags: u32,
+ flags: u32,
) -> io::Result<usize> {
- // We need to change credentials during a write so that the kernel will remove setuid or
- // setgid bits from the file if it was written to by someone other than the owner.
- let (_uid, _gid) = set_creds(ctx.uid, ctx.gid)?;
+ // When the WRITE_KILL_PRIV flag is set, drop CAP_FSETID so that the kernel will
+ // automatically clear the setuid and setgid bits for us.
+ let _fsetid = if flags & WRITE_KILL_PRIV != 0 {
+ Some(drop_cap_fsetid()?)
+ } else {
+ None
+ };
if self.zero_message_open.load(Ordering::Relaxed) {
let data = self.find_inode(inode)?;
@@ -1788,28 +1625,27 @@ impl FileSystem for PassthroughFs {
ctx: Context,
parent: Inode,
name: &CStr,
- mut mode: u32,
+ mode: u32,
rdev: u32,
umask: u32,
) -> io::Result<Entry> {
- let (_uid, _gid) = set_creds(ctx.uid, ctx.gid)?;
-
let data = self.find_inode(parent)?;
- // The presence of a default posix acl xattr in the parent directory completely changes the
- // meaning of the mode parameter so only apply the umask if it doesn't have one.
- if !self.has_default_posix_acl(&data)? {
- mode &= !umask;
- }
+ let (_uid, _gid) = set_creds(ctx.uid, ctx.gid)?;
- // Safe because this doesn't modify any memory and we check the return value.
- let res = unsafe {
- libc::mknodat(
- data.as_raw_descriptor(),
- name.as_ptr(),
- mode as libc::mode_t,
- rdev as libc::dev_t,
- )
+ let res = {
+ let mut um = self.umask.lock();
+ let _scoped_umask = um.set(umask);
+
+ // Safe because this doesn't modify any memory and we check the return value.
+ unsafe {
+ libc::mknodat(
+ data.as_raw_descriptor(),
+ name.as_ptr(),
+ mode as libc::mode_t,
+ rdev as libc::dev_t,
+ )
+ }
};
if res < 0 {
@@ -1912,7 +1748,7 @@ impl FileSystem for PassthroughFs {
// behavior by doing the same thing (dup-ing the fd and then immediately closing it). Safe
// because this doesn't modify any memory and we check the return values.
unsafe {
- let newfd = libc::dup(data.as_raw_descriptor());
+ let newfd = libc::fcntl(data.as_raw_descriptor(), libc::F_DUPFD_CLOEXEC, 0);
if newfd < 0 {
return Err(io::Error::last_os_error());
@@ -2380,72 +2216,6 @@ impl FileSystem for PassthroughFs {
mod tests {
use super::*;
- use std::env;
- use std::os::unix::ffi::OsStringExt;
-
- #[test]
- fn create_temp_dir() {
- let testdir = CString::new(env::temp_dir().into_os_string().into_vec())
- .expect("env::temp_dir() is not a valid c-string");
- let fd = unsafe {
- libc::openat(
- libc::AT_FDCWD,
- testdir.as_ptr(),
- libc::O_PATH | libc::O_CLOEXEC,
- )
- };
- assert!(fd >= 0, "Failed to open env::temp_dir()");
- let parent = unsafe { File::from_raw_descriptor(fd) };
- let t = TempDir::new(&parent, 0o755).expect("Failed to create temporary directory");
-
- let basename = t.basename().to_string_lossy();
- let path = env::temp_dir().join(&*basename);
- assert!(path.exists());
- assert!(path.is_dir());
- }
-
- #[test]
- fn remove_temp_dir() {
- let testdir = CString::new(env::temp_dir().into_os_string().into_vec())
- .expect("env::temp_dir() is not a valid c-string");
- let fd = unsafe {
- libc::openat(
- libc::AT_FDCWD,
- testdir.as_ptr(),
- libc::O_PATH | libc::O_CLOEXEC,
- )
- };
- assert!(fd >= 0, "Failed to open env::temp_dir()");
- let parent = unsafe { File::from_raw_descriptor(fd) };
- let t = TempDir::new(&parent, 0o755).expect("Failed to create temporary directory");
-
- let basename = t.basename().to_string_lossy();
- let path = env::temp_dir().join(&*basename);
- mem::drop(t);
- assert!(!path.exists());
- }
-
- #[test]
- fn temp_dir_into_inner() {
- let testdir = CString::new(env::temp_dir().into_os_string().into_vec())
- .expect("env::temp_dir() is not a valid c-string");
- let fd = unsafe {
- libc::openat(
- libc::AT_FDCWD,
- testdir.as_ptr(),
- libc::O_PATH | libc::O_CLOEXEC,
- )
- };
- assert!(fd >= 0, "Failed to open env::temp_dir()");
- let parent = unsafe { File::from_raw_descriptor(fd) };
- let t = TempDir::new(&parent, 0o755).expect("Failed to create temporary directory");
-
- let (basename_cstr, _) = t.into_inner();
- let basename = basename_cstr.to_string_lossy();
- let path = env::temp_dir().join(&*basename);
- assert!(path.exists());
- }
-
#[test]
fn rewrite_xattr_names() {
let cfg = Config {
diff --git a/devices/src/virtio/fs/worker.rs b/devices/src/virtio/fs/worker.rs
index ded93b2c7..65a47a9c6 100644
--- a/devices/src/virtio/fs/worker.rs
+++ b/devices/src/virtio/fs/worker.rs
@@ -2,20 +2,19 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
-use std::convert::TryInto;
+use std::convert::{TryFrom, TryInto};
use std::fs::File;
use std::io;
use std::os::unix::io::AsRawFd;
use std::sync::{Arc, Mutex};
-use base::{error, Event, PollToken, WaitContext};
+use base::{error, Event, PollToken, SafeDescriptor, Tube, WaitContext};
use fuse::filesystem::{FileSystem, ZeroCopyReader, ZeroCopyWriter};
-use msg_socket::{MsgReceiver, MsgSender};
-use vm_control::{FsMappingRequest, FsMappingRequestSocket, MaybeOwnedDescriptor, VmResponse};
+use vm_control::{FsMappingRequest, VmResponse};
use vm_memory::GuestMemory;
use crate::virtio::fs::{Error, Result};
-use crate::virtio::{Interrupt, Queue, Reader, Writer};
+use crate::virtio::{Interrupt, Queue, Reader, SignalableInterrupt, Writer};
impl fuse::Reader for Reader {}
@@ -46,27 +45,27 @@ impl ZeroCopyWriter for Writer {
}
struct Mapper {
- socket: Arc<Mutex<FsMappingRequestSocket>>,
+ tube: Arc<Mutex<Tube>>,
slot: u32,
}
impl Mapper {
- fn new(socket: Arc<Mutex<FsMappingRequestSocket>>, slot: u32) -> Self {
- Self { socket, slot }
+ fn new(tube: Arc<Mutex<Tube>>, slot: u32) -> Self {
+ Self { tube, slot }
}
fn process_request(&self, request: &FsMappingRequest) -> io::Result<()> {
- let socket = self.socket.lock().map_err(|e| {
- error!("failed to lock socket: {}", e);
+ let tube = self.tube.lock().map_err(|e| {
+ error!("failed to lock tube: {}", e);
io::Error::from_raw_os_error(libc::EINVAL)
})?;
- socket.send(request).map_err(|e| {
+ tube.send(request).map_err(|e| {
error!("failed to send request {:?}: {}", request, e);
io::Error::from_raw_os_error(libc::EINVAL)
})?;
- match socket.recv() {
+ match tube.recv() {
Ok(VmResponse::Ok) => Ok(()),
Ok(VmResponse::Err(e)) => Err(e.into()),
r => {
@@ -91,9 +90,11 @@ impl fuse::Mapper for Mapper {
io::Error::from_raw_os_error(libc::EINVAL)
})?;
+ let fd = SafeDescriptor::try_from(fd)?;
+
let request = FsMappingRequest::CreateMemoryMapping {
slot: self.slot,
- fd: MaybeOwnedDescriptor::Borrowed(fd.as_raw_fd()),
+ fd,
size,
file_offset,
prot,
@@ -128,7 +129,7 @@ pub struct Worker<F: FileSystem + Sync> {
queue: Queue,
server: Arc<fuse::Server<F>>,
irq: Arc<Interrupt>,
- socket: Arc<Mutex<FsMappingRequestSocket>>,
+ tube: Arc<Mutex<Tube>>,
slot: u32,
}
@@ -138,7 +139,7 @@ impl<F: FileSystem + Sync> Worker<F> {
queue: Queue,
server: Arc<fuse::Server<F>>,
irq: Arc<Interrupt>,
- socket: Arc<Mutex<FsMappingRequestSocket>>,
+ tube: Arc<Mutex<Tube>>,
slot: u32,
) -> Worker<F> {
Worker {
@@ -146,7 +147,7 @@ impl<F: FileSystem + Sync> Worker<F> {
queue,
server,
irq,
- socket,
+ tube,
slot,
}
}
@@ -154,7 +155,7 @@ impl<F: FileSystem + Sync> Worker<F> {
fn process_queue(&mut self) -> Result<()> {
let mut needs_interrupt = false;
- let mapper = Mapper::new(Arc::clone(&self.socket), self.slot);
+ let mapper = Mapper::new(Arc::clone(&self.tube), self.slot);
while let Some(avail_desc) = self.queue.pop(&self.mem) {
let reader = Reader::new(self.mem.clone(), avail_desc.clone())
.map_err(Error::InvalidDescriptorChain)?;
@@ -182,6 +183,29 @@ impl<F: FileSystem + Sync> Worker<F> {
kill_evt: Event,
watch_resample_event: bool,
) -> Result<()> {
+ // We need to set the no setuid fixup secure bit so that we don't drop capabilities when
+ // changing the thread uid/gid. Without this, creating new entries can fail in some corner
+ // cases.
+ const SECBIT_NO_SETUID_FIXUP: i32 = 1 << 2;
+
+ // TODO(crbug.com/1199487): Remove this once libc provides the wrapper for all targets.
+ #[cfg(target_os = "linux")]
+ {
+ // Safe because this doesn't modify any memory and we check the return value.
+ let mut securebits = unsafe { libc::prctl(libc::PR_GET_SECUREBITS) };
+ if securebits < 0 {
+ return Err(Error::GetSecurebits(io::Error::last_os_error()));
+ }
+
+ securebits |= SECBIT_NO_SETUID_FIXUP;
+
+ // Safe because this doesn't modify any memory and we check the return value.
+ let ret = unsafe { libc::prctl(libc::PR_SET_SECUREBITS, securebits) };
+ if ret < 0 {
+ return Err(Error::SetSecurebits(io::Error::last_os_error()));
+ }
+ }
+
#[derive(PollToken)]
enum Token {
// A request is ready on the queue.
@@ -197,9 +221,11 @@ impl<F: FileSystem + Sync> Worker<F> {
.map_err(Error::CreateWaitContext)?;
if watch_resample_event {
- wait_ctx
- .add(self.irq.get_resample_evt(), Token::InterruptResample)
- .map_err(Error::CreateWaitContext)?;
+ if let Some(resample_evt) = self.irq.get_resample_evt() {
+ wait_ctx
+ .add(resample_evt, Token::InterruptResample)
+ .map_err(Error::CreateWaitContext)?;
+ }
}
loop {
diff --git a/devices/src/virtio/gpu/mod.rs b/devices/src/virtio/gpu/mod.rs
index 9646612e8..1d77e4ff0 100644
--- a/devices/src/virtio/gpu/mod.rs
+++ b/devices/src/virtio/gpu/mod.rs
@@ -3,6 +3,8 @@
// found in the LICENSE file.
mod protocol;
+mod udmabuf;
+mod udmabuf_bindings;
mod virtio_gpu;
use std::cell::RefCell;
@@ -19,22 +21,15 @@ use std::thread;
use std::time::Duration;
use base::{
- debug, error, warn, AsRawDescriptor, Event, ExternalMapping, PollToken, RawDescriptor,
- WaitContext,
+ debug, error, warn, AsRawDescriptor, AsRawDescriptors, Event, ExternalMapping, PollToken,
+ RawDescriptor, Tube, WaitContext,
};
use data_model::*;
pub use gpu_display::EventDevice;
use gpu_display::*;
-use rutabaga_gfx::{
- DrmFormat, GfxstreamFlags, ResourceCreate3D, ResourceCreateBlob, RutabagaBuilder,
- RutabagaChannel, RutabagaComponentType, RutabagaFenceData, Transfer3D, VirglRendererFlags,
- RUTABAGA_CHANNEL_TYPE_CAMERA, RUTABAGA_CHANNEL_TYPE_WAYLAND, RUTABAGA_PIPE_BIND_RENDER_TARGET,
- RUTABAGA_PIPE_TEXTURE_2D,
-};
-
-use msg_socket::{MsgReceiver, MsgSender};
+use rutabaga_gfx::*;
use resources::Alloc;
@@ -42,8 +37,8 @@ use sync::Mutex;
use vm_memory::{GuestAddress, GuestMemory};
use super::{
- copy_config, resource_bridge::*, DescriptorChain, Interrupt, Queue, Reader, VirtioDevice,
- Writer, TYPE_GPU,
+ copy_config, resource_bridge::*, DescriptorChain, Interrupt, Queue, Reader,
+ SignalableInterrupt, VirtioDevice, Writer, TYPE_GPU,
};
use super::{PciCapabilityType, VirtioPciShmCap};
@@ -55,15 +50,13 @@ use crate::pci::{
PciAddress, PciBarConfiguration, PciBarPrefetchable, PciBarRegionType, PciCapability,
};
-use vm_control::VmMemoryControlRequestSocket;
-
pub const DEFAULT_DISPLAY_WIDTH: u32 = 1280;
pub const DEFAULT_DISPLAY_HEIGHT: u32 = 1024;
#[derive(Copy, Clone, Debug, PartialEq)]
pub enum GpuMode {
Mode2D,
- Mode3D,
+ ModeVirglRenderer,
ModeGfxstream,
}
@@ -77,7 +70,8 @@ pub struct GpuParameters {
pub renderer_use_surfaceless: bool,
pub gfxstream_use_guest_angle: bool,
pub gfxstream_use_syncfd: bool,
- pub gfxstream_support_vulkan: bool,
+ pub use_vulkan: bool,
+ pub udmabuf: bool,
pub mode: GpuMode,
pub cache_path: Option<String>,
pub cache_size: Option<String>,
@@ -103,10 +97,11 @@ impl Default for GpuParameters {
renderer_use_surfaceless: true,
gfxstream_use_guest_angle: false,
gfxstream_use_syncfd: true,
- gfxstream_support_vulkan: true,
- mode: GpuMode::Mode3D,
+ use_vulkan: false,
+ mode: GpuMode::ModeVirglRenderer,
cache_path: None,
cache_size: None,
+ udmabuf: false,
}
}
}
@@ -127,10 +122,11 @@ fn build(
display_height: u32,
rutabaga_builder: RutabagaBuilder,
event_devices: Vec<EventDevice>,
- gpu_device_socket: VmMemoryControlRequestSocket,
+ gpu_device_tube: Tube,
pci_bar: Alloc,
map_request: Arc<Mutex<Option<ExternalMapping>>>,
external_blob: bool,
+ udmabuf: bool,
) -> Option<VirtioGpu> {
let mut display_opt = None;
for display in possible_displays {
@@ -157,10 +153,11 @@ fn build(
display_height,
rutabaga_builder,
event_devices,
- gpu_device_socket,
+ gpu_device_tube,
pci_bar,
map_request,
external_blob,
+ udmabuf,
)
}
@@ -228,7 +225,7 @@ impl Frontend {
self.virtio_gpu.process_display()
}
- fn process_resource_bridge(&mut self, resource_bridge: &ResourceResponseSocket) {
+ fn process_resource_bridge(&mut self, resource_bridge: &Tube) {
let response = match resource_bridge.recv() {
Ok(ResourceRequest::GetBuffer { id }) => self.virtio_gpu.export_resource(id),
Ok(ResourceRequest::GetFence { seqno }) => {
@@ -685,7 +682,7 @@ struct Worker {
ctrl_evt: Event,
cursor_queue: Queue,
cursor_evt: Event,
- resource_bridges: Vec<ResourceResponseSocket>,
+ resource_bridges: Vec<Tube>,
kill_evt: Event,
state: Frontend,
}
@@ -706,7 +703,6 @@ impl Worker {
(&self.ctrl_evt, Token::CtrlQueue),
(&self.cursor_evt, Token::CursorQueue),
(&*self.state.display().borrow(), Token::Display),
- (self.interrupt.get_resample_evt(), Token::InterruptResample),
(&self.kill_evt, Token::Kill),
]) {
Ok(pc) => pc,
@@ -715,6 +711,15 @@ impl Worker {
return;
}
};
+ if let Some(resample_evt) = self.interrupt.get_resample_evt() {
+ if wait_ctx
+ .add(resample_evt, Token::InterruptResample)
+ .is_err()
+ {
+ error!("failed creating WaitContext");
+ return;
+ }
+ }
for (index, bridge) in self.resource_bridges.iter().enumerate() {
if let Err(e) = wait_ctx.add(bridge, Token::ResourceBridge { index }) {
@@ -864,8 +869,8 @@ impl DisplayBackend {
pub struct Gpu {
exit_evt: Event,
- gpu_device_socket: Option<VmMemoryControlRequestSocket>,
- resource_bridges: Vec<ResourceResponseSocket>,
+ gpu_device_tube: Option<Tube>,
+ resource_bridges: Vec<Tube>,
event_devices: Vec<EventDevice>,
kill_evt: Option<Event>,
config_event: bool,
@@ -880,14 +885,16 @@ pub struct Gpu {
external_blob: bool,
rutabaga_component: RutabagaComponentType,
base_features: u64,
+ mem: GuestMemory,
+ udmabuf: bool,
}
impl Gpu {
pub fn new(
exit_evt: Event,
- gpu_device_socket: Option<VmMemoryControlRequestSocket>,
+ gpu_device_tube: Option<Tube>,
num_scanouts: NonZeroU8,
- resource_bridges: Vec<ResourceResponseSocket>,
+ resource_bridges: Vec<Tube>,
display_backends: Vec<DisplayBackend>,
gpu_parameters: &GpuParameters,
event_devices: Vec<EventDevice>,
@@ -895,13 +902,15 @@ impl Gpu {
external_blob: bool,
base_features: u64,
channels: BTreeMap<String, PathBuf>,
+ mem: GuestMemory,
) -> Gpu {
let virglrenderer_flags = VirglRendererFlags::new()
.use_egl(gpu_parameters.renderer_use_egl)
.use_gles(gpu_parameters.renderer_use_gles)
.use_glx(gpu_parameters.renderer_use_glx)
.use_surfaceless(gpu_parameters.renderer_use_surfaceless)
- .use_external_blob(external_blob);
+ .use_external_blob(external_blob)
+ .use_venus(gpu_parameters.use_vulkan);
let gfxstream_flags = GfxstreamFlags::new()
.use_egl(gpu_parameters.renderer_use_egl)
.use_gles(gpu_parameters.renderer_use_gles)
@@ -909,7 +918,7 @@ impl Gpu {
.use_surfaceless(gpu_parameters.renderer_use_surfaceless)
.use_guest_angle(gpu_parameters.gfxstream_use_guest_angle)
.use_syncfd(gpu_parameters.gfxstream_use_syncfd)
- .support_vulkan(gpu_parameters.gfxstream_support_vulkan);
+ .use_vulkan(gpu_parameters.use_vulkan);
let mut rutabaga_channels: Vec<RutabagaChannel> = Vec::new();
for (channel_name, path) in &channels {
@@ -929,7 +938,7 @@ impl Gpu {
let rutabaga_channels_opt = Some(rutabaga_channels);
let component = match gpu_parameters.mode {
GpuMode::Mode2D => RutabagaComponentType::Rutabaga2D,
- GpuMode::Mode3D => RutabagaComponentType::VirglRenderer,
+ GpuMode::ModeVirglRenderer => RutabagaComponentType::VirglRenderer,
GpuMode::ModeGfxstream => RutabagaComponentType::Gfxstream,
};
@@ -942,7 +951,7 @@ impl Gpu {
Gpu {
exit_evt,
- gpu_device_socket,
+ gpu_device_tube,
num_scanouts,
resource_bridges,
event_devices,
@@ -958,6 +967,8 @@ impl Gpu {
external_blob,
rutabaga_component: component,
base_features,
+ mem,
+ udmabuf: gpu_parameters.udmabuf,
}
}
@@ -970,8 +981,10 @@ impl Gpu {
let num_capsets = match self.rutabaga_component {
RutabagaComponentType::Rutabaga2D => 0,
_ => {
+ let mut num_capsets = 0;
+
// Cross-domain (like virtio_wl with llvmpipe) is always available.
- let mut num_capsets = 1;
+ num_capsets += 1;
// Three capsets for virgl_renderer
#[cfg(feature = "virgl_renderer")]
@@ -1021,14 +1034,19 @@ impl VirtioDevice for Gpu {
keep_rds.push(libc::STDERR_FILENO);
}
- if let Some(ref gpu_device_socket) = self.gpu_device_socket {
- keep_rds.push(gpu_device_socket.as_raw_descriptor());
+ if self.udmabuf {
+ keep_rds.append(&mut self.mem.as_raw_descriptors());
+ }
+
+ if let Some(ref gpu_device_tube) = self.gpu_device_tube {
+ keep_rds.push(gpu_device_tube.as_raw_descriptor());
}
keep_rds.push(self.exit_evt.as_raw_descriptor());
for bridge in &self.resource_bridges {
keep_rds.push(bridge.as_raw_descriptor());
}
+
keep_rds
}
@@ -1044,10 +1062,19 @@ impl VirtioDevice for Gpu {
let rutabaga_features = match self.rutabaga_component {
RutabagaComponentType::Rutabaga2D => 0,
_ => {
- 1 << VIRTIO_GPU_F_VIRGL
+ let mut features_3d = 0;
+
+ features_3d |= 1 << VIRTIO_GPU_F_VIRGL
| 1 << VIRTIO_GPU_F_RESOURCE_UUID
| 1 << VIRTIO_GPU_F_RESOURCE_BLOB
| 1 << VIRTIO_GPU_F_CONTEXT_INIT
+ | 1 << VIRTIO_GPU_F_RESOURCE_SYNC;
+
+ if self.udmabuf {
+ features_3d |= 1 << VIRTIO_GPU_F_CREATE_GUEST_HANDLE;
+ }
+
+ features_3d
}
};
@@ -1110,8 +1137,9 @@ impl VirtioDevice for Gpu {
let event_devices = self.event_devices.split_off(0);
let map_request = Arc::clone(&self.map_request);
let external_blob = self.external_blob;
- if let (Some(gpu_device_socket), Some(pci_bar), Some(rutabaga_builder)) = (
- self.gpu_device_socket.take(),
+ let udmabuf = self.udmabuf;
+ if let (Some(gpu_device_tube), Some(pci_bar), Some(rutabaga_builder)) = (
+ self.gpu_device_tube.take(),
self.pci_bar.take(),
self.rutabaga_builder.take(),
) {
@@ -1125,10 +1153,11 @@ impl VirtioDevice for Gpu {
display_height,
rutabaga_builder,
event_devices,
- gpu_device_socket,
+ gpu_device_tube,
pci_bar,
map_request,
external_blob,
+ udmabuf,
) {
Some(backend) => backend,
None => return,
diff --git a/devices/src/virtio/gpu/protocol.rs b/devices/src/virtio/gpu/protocol.rs
index 77ab43a64..c2f9b3eb8 100644
--- a/devices/src/virtio/gpu/protocol.rs
+++ b/devices/src/virtio/gpu/protocol.rs
@@ -16,18 +16,21 @@ use std::str::from_utf8;
use super::super::DescriptorError;
use super::{Reader, Writer};
use base::Error as SysError;
-use base::ExternalMappingError;
+use base::{ExternalMappingError, TubeError};
use data_model::{DataInit, Le32, Le64};
use gpu_display::GpuDisplayError;
-use msg_socket::MsgError;
use rutabaga_gfx::RutabagaError;
+use crate::virtio::gpu::udmabuf::UdmabufError;
+
pub const VIRTIO_GPU_F_VIRGL: u32 = 0;
pub const VIRTIO_GPU_F_EDID: u32 = 1;
pub const VIRTIO_GPU_F_RESOURCE_UUID: u32 = 2;
pub const VIRTIO_GPU_F_RESOURCE_BLOB: u32 = 3;
/* The following capabilities are not upstreamed. */
pub const VIRTIO_GPU_F_CONTEXT_INIT: u32 = 4;
+pub const VIRTIO_GPU_F_CREATE_GUEST_HANDLE: u32 = 5;
+pub const VIRTIO_GPU_F_RESOURCE_SYNC: u32 = 6;
pub const VIRTIO_GPU_UNDEFINED: u32 = 0x0;
@@ -88,6 +91,8 @@ pub const VIRTIO_GPU_BLOB_MEM_HOST3D_GUEST: u32 = 0x0003;
pub const VIRTIO_GPU_BLOB_FLAG_USE_MAPPABLE: u32 = 0x0001;
pub const VIRTIO_GPU_BLOB_FLAG_USE_SHAREABLE: u32 = 0x0002;
pub const VIRTIO_GPU_BLOB_FLAG_USE_CROSS_DEVICE: u32 = 0x0004;
+/* Create a OS-specific handle from guest memory (not upstreamed). */
+pub const VIRTIO_GPU_BLOB_FLAG_CREATE_GUEST_HANDLE: u32 = 0x0008;
pub const VIRTIO_GPU_SHM_ID_NONE: u8 = 0x0000;
pub const VIRTIO_GPU_SHM_ID_HOST_VISIBLE: u8 = 0x0001;
@@ -800,7 +805,7 @@ pub enum GpuResponse {
map_info: u32,
},
ErrUnspec,
- ErrMsg(MsgError),
+ ErrMsg(TubeError),
ErrSys(SysError),
ErrRutabaga(RutabagaError),
ErrDisplay(GpuDisplayError),
@@ -813,10 +818,11 @@ pub enum GpuResponse {
ErrInvalidResourceId,
ErrInvalidContextId,
ErrInvalidParameter,
+ ErrUdmabuf(UdmabufError),
}
-impl From<MsgError> for GpuResponse {
- fn from(e: MsgError) -> GpuResponse {
+impl From<TubeError> for GpuResponse {
+ fn from(e: TubeError) -> GpuResponse {
GpuResponse::ErrMsg(e)
}
}
@@ -839,6 +845,12 @@ impl From<ExternalMappingError> for GpuResponse {
}
}
+impl From<UdmabufError> for GpuResponse {
+ fn from(e: UdmabufError) -> GpuResponse {
+ GpuResponse::ErrUdmabuf(e)
+ }
+}
+
impl Display for GpuResponse {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
use self::GpuResponse::*;
@@ -848,6 +860,7 @@ impl Display for GpuResponse {
ErrRutabaga(e) => write!(f, "renderer error: {}", e),
ErrDisplay(e) => write!(f, "display error: {}", e),
ErrScanout { num_scanouts } => write!(f, "non-zero scanout: {}", num_scanouts),
+ ErrUdmabuf(e) => write!(f, "udmabuf error: {}", e),
_ => Ok(()),
}
}
@@ -1024,6 +1037,7 @@ impl GpuResponse {
GpuResponse::ErrRutabaga(_) => VIRTIO_GPU_RESP_ERR_UNSPEC,
GpuResponse::ErrDisplay(_) => VIRTIO_GPU_RESP_ERR_UNSPEC,
GpuResponse::ErrMapping(_) => VIRTIO_GPU_RESP_ERR_UNSPEC,
+ GpuResponse::ErrUdmabuf(_) => VIRTIO_GPU_RESP_ERR_UNSPEC,
GpuResponse::ErrScanout { num_scanouts: _ } => VIRTIO_GPU_RESP_ERR_UNSPEC,
GpuResponse::ErrOutOfMemory => VIRTIO_GPU_RESP_ERR_OUT_OF_MEMORY,
GpuResponse::ErrInvalidScanoutId => VIRTIO_GPU_RESP_ERR_INVALID_SCANOUT_ID,
diff --git a/devices/src/virtio/gpu/udmabuf.rs b/devices/src/virtio/gpu/udmabuf.rs
new file mode 100644
index 000000000..812080bcc
--- /dev/null
+++ b/devices/src/virtio/gpu/udmabuf.rs
@@ -0,0 +1,266 @@
+// Copyright 2021 The Chromium OS Authors. All rights reservsize.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+#![allow(dead_code)]
+
+use std::fs::{File, OpenOptions};
+use std::os::raw::c_uint;
+
+use std::path::Path;
+use std::{fmt, io};
+
+use base::{
+ ioctl_iow_nr, ioctl_with_ptr, pagesize, AsRawDescriptor, FromRawDescriptor, MappedRegion,
+ SafeDescriptor,
+};
+
+use data_model::{FlexibleArray, FlexibleArrayWrapper};
+
+use rutabaga_gfx::{RutabagaHandle, RUTABAGA_MEM_HANDLE_TYPE_DMABUF};
+
+use super::udmabuf_bindings::*;
+
+use vm_memory::{GuestAddress, GuestMemory, GuestMemoryError};
+
+const UDMABUF_IOCTL_BASE: c_uint = 0x75;
+
+ioctl_iow_nr!(UDMABUF_CREATE, UDMABUF_IOCTL_BASE, 0x42, udmabuf_create);
+ioctl_iow_nr!(
+ UDMABUF_CREATE_LIST,
+ UDMABUF_IOCTL_BASE,
+ 0x43,
+ udmabuf_create_list
+);
+
+// It's possible to make the flexible array trait implementation a macro one day...
+impl FlexibleArray<udmabuf_create_item> for udmabuf_create_list {
+ fn set_len(&mut self, len: usize) {
+ self.count = len as u32;
+ }
+
+ fn get_len(&self) -> usize {
+ self.count as usize
+ }
+
+ fn get_slice(&self, len: usize) -> &[udmabuf_create_item] {
+ unsafe { self.list.as_slice(len) }
+ }
+
+ fn get_mut_slice(&mut self, len: usize) -> &mut [udmabuf_create_item] {
+ unsafe { self.list.as_mut_slice(len) }
+ }
+}
+
+type UdmabufCreateList = FlexibleArrayWrapper<udmabuf_create_list, udmabuf_create_item>;
+
+#[derive(Debug)]
+pub enum UdmabufError {
+ DriverOpenFailed(io::Error),
+ NotPageAligned,
+ InvalidOffset(GuestMemoryError),
+ DmabufCreationFail(io::Error),
+}
+
+impl fmt::Display for UdmabufError {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ use self::UdmabufError::*;
+ match self {
+ DriverOpenFailed(e) => write!(f, "failed to open udmabuf driver: {:?}", e),
+ NotPageAligned => write!(f, "All guest addresses must aligned to 4KiB"),
+ InvalidOffset(e) => write!(f, "failed to get region offset: {:?}", e),
+ DmabufCreationFail(e) => write!(f, "failed to create buffer: {:?}", e),
+ }
+ }
+}
+
+/// The result of an operation in this file.
+pub type UdmabufResult<T> = std::result::Result<T, UdmabufError>;
+
+// Returns absolute offset within the memory corresponding to a particular guest address.
+// This offset is not relative to a particular mapping.
+
+// # Examples
+//
+// # fn test_memory_offsets() {
+// # let start_addr1 = GuestAddress(0x100)
+// # let start_addr2 = GuestAddress(0x1100);
+// # let mem = GuestMemory::new(&vec![(start_addr1, 0x1000),(start_addr2, 0x1000)])?;
+// # assert_eq!(memory_offset(&mem, GuestAddress(0x1100), 0x1000).unwrap(),0x1000);
+// #}
+fn memory_offset(mem: &GuestMemory, guest_addr: GuestAddress, len: u64) -> UdmabufResult<u64> {
+ mem.do_in_region(guest_addr, move |mapping, map_offset, memfd_offset| {
+ let map_offset = map_offset as u64;
+ if map_offset
+ .checked_add(len)
+ .map_or(true, |a| a > mapping.size() as u64)
+ {
+ return Err(GuestMemoryError::InvalidGuestAddress(guest_addr));
+ }
+
+ return Ok(memfd_offset + map_offset);
+ })
+ .map_err(UdmabufError::InvalidOffset)
+}
+
+/// A convenience wrapper for the Linux kernel's udmabuf driver.
+///
+/// udmabuf is a kernel driver that turns memfd pages into dmabufs. It can be used for
+/// zero-copy buffer sharing between the guest and host when guest memory is backed by
+/// memfd pages.
+pub struct UdmabufDriver {
+ driver_fd: File,
+}
+
+impl UdmabufDriver {
+ /// Opens the udmabuf device on success.
+ pub fn new() -> UdmabufResult<UdmabufDriver> {
+ const UDMABUF_PATH: &str = "/dev/udmabuf";
+ let path = Path::new(UDMABUF_PATH);
+ let fd = OpenOptions::new()
+ .read(true)
+ .write(true)
+ .open(path)
+ .map_err(UdmabufError::DriverOpenFailed)?;
+
+ Ok(UdmabufDriver { driver_fd: fd })
+ }
+
+ /// Creates a dma-buf fd for the given scatter-gather list of guest memory pages (`iovecs`).
+ pub fn create_udmabuf(
+ &self,
+ mem: &GuestMemory,
+ iovecs: &[(GuestAddress, usize)],
+ ) -> UdmabufResult<RutabagaHandle> {
+ let pgsize = pagesize();
+
+ let mut list = UdmabufCreateList::new(iovecs.len() as usize);
+ let mut items = list.mut_entries_slice();
+ for (i, &(addr, len)) in iovecs.iter().enumerate() {
+ let offset = memory_offset(mem, addr, len as u64)?;
+
+ if offset as usize % pgsize != 0 || len % pgsize != 0 {
+ return Err(UdmabufError::NotPageAligned);
+ }
+
+ // `unwrap` can't panic if `memory_offset obove succeeds.
+ items[i].memfd = mem.shm_region(addr).unwrap().as_raw_descriptor() as u32;
+ items[i].__pad = 0;
+ items[i].offset = offset;
+ items[i].size = len as u64;
+ }
+
+ // Safe because we always allocate enough space for `udmabuf_create_list`.
+ let fd = unsafe {
+ let create_list = list.as_mut_ptr();
+ (*create_list).flags = UDMABUF_FLAGS_CLOEXEC;
+ ioctl_with_ptr(&self.driver_fd, UDMABUF_CREATE_LIST(), create_list)
+ };
+
+ if fd < 0 {
+ return Err(UdmabufError::DmabufCreationFail(io::Error::last_os_error()));
+ }
+
+ // Safe because we validated the file exists.
+ let os_handle = unsafe { SafeDescriptor::from_raw_descriptor(fd) };
+ Ok(RutabagaHandle {
+ os_handle,
+ handle_type: RUTABAGA_MEM_HANDLE_TYPE_DMABUF,
+ })
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use base::kernel_has_memfd;
+ use vm_memory::GuestAddress;
+
+ #[test]
+ fn test_memory_offsets() {
+ if !kernel_has_memfd() {
+ return;
+ }
+
+ let start_addr1 = GuestAddress(0x100);
+ let start_addr2 = GuestAddress(0x1100);
+ let start_addr3 = GuestAddress(0x2100);
+
+ let mem = GuestMemory::new(&vec![
+ (start_addr1, 0x1000),
+ (start_addr2, 0x1000),
+ (start_addr3, 0x1000),
+ ])
+ .unwrap();
+
+ assert_eq!(memory_offset(&mem, GuestAddress(0x300), 1).unwrap(), 0x200);
+ assert_eq!(
+ memory_offset(&mem, GuestAddress(0x1200), 1).unwrap(),
+ 0x1100
+ );
+ assert_eq!(
+ memory_offset(&mem, GuestAddress(0x1100), 0x1000).unwrap(),
+ 0x1000
+ );
+ assert!(memory_offset(&mem, GuestAddress(0x1100), 0x1001).is_err());
+ }
+
+ #[test]
+ fn test_udmabuf_create() {
+ if !kernel_has_memfd() {
+ return;
+ }
+
+ let driver_result = UdmabufDriver::new();
+
+ // Most kernels will not have udmabuf support.
+ if driver_result.is_err() {
+ return;
+ }
+
+ let driver = driver_result.unwrap();
+
+ let start_addr1 = GuestAddress(0x100);
+ let start_addr2 = GuestAddress(0x1100);
+ let start_addr3 = GuestAddress(0x2100);
+
+ let sg_list = vec![
+ (start_addr1, 0x1000),
+ (start_addr2, 0x1000),
+ (start_addr3, 0x1000),
+ ];
+
+ let mem = GuestMemory::new(&sg_list[..]).unwrap();
+
+ let mut udmabuf_create_list = vec![
+ (start_addr3, 0x1000 as usize),
+ (start_addr2, 0x1000 as usize),
+ (start_addr1, 0x1000 as usize),
+ (GuestAddress(0x4000), 0x1000 as usize),
+ ];
+
+ let result = driver.create_udmabuf(&mem, &udmabuf_create_list[..]);
+ assert_eq!(result.is_err(), true);
+
+ udmabuf_create_list.pop();
+
+ let rutabaga_handle1 = driver
+ .create_udmabuf(&mem, &udmabuf_create_list[..])
+ .unwrap();
+ assert_eq!(
+ rutabaga_handle1.handle_type,
+ RUTABAGA_MEM_HANDLE_TYPE_DMABUF
+ );
+
+ udmabuf_create_list.pop();
+
+ // Multiple udmabufs with same memory backing is allowed.
+ let rutabaga_handle2 = driver
+ .create_udmabuf(&mem, &udmabuf_create_list[..])
+ .unwrap();
+ assert_eq!(
+ rutabaga_handle2.handle_type,
+ RUTABAGA_MEM_HANDLE_TYPE_DMABUF
+ );
+ }
+}
diff --git a/devices/src/virtio/gpu/udmabuf_bindings.rs b/devices/src/virtio/gpu/udmabuf_bindings.rs
new file mode 100644
index 000000000..dcd89a6f5
--- /dev/null
+++ b/devices/src/virtio/gpu/udmabuf_bindings.rs
@@ -0,0 +1,74 @@
+// Copyright 2021 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+/* automatically generated by rust-bindgen, though the exact command remains
+ * lost to history. Should be easy to duplicate, if needed.
+ */
+
+#![allow(dead_code)]
+#![allow(non_camel_case_types)]
+
+#[repr(C)]
+#[derive(Default)]
+pub struct __IncompleteArrayField<T>(::std::marker::PhantomData<T>, [T; 0]);
+impl<T> __IncompleteArrayField<T> {
+ #[inline]
+ pub fn new() -> Self {
+ __IncompleteArrayField(::std::marker::PhantomData, [])
+ }
+ #[inline]
+ pub unsafe fn as_ptr(&self) -> *const T {
+ ::std::mem::transmute(self)
+ }
+ #[inline]
+ pub unsafe fn as_mut_ptr(&mut self) -> *mut T {
+ ::std::mem::transmute(self)
+ }
+ #[inline]
+ pub unsafe fn as_slice(&self, len: usize) -> &[T] {
+ ::std::slice::from_raw_parts(self.as_ptr(), len)
+ }
+ #[inline]
+ pub unsafe fn as_mut_slice(&mut self, len: usize) -> &mut [T] {
+ ::std::slice::from_raw_parts_mut(self.as_mut_ptr(), len)
+ }
+}
+impl<T> ::std::fmt::Debug for __IncompleteArrayField<T> {
+ fn fmt(&self, fmt: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
+ fmt.write_str("__IncompleteArrayField")
+ }
+}
+impl<T> ::std::clone::Clone for __IncompleteArrayField<T> {
+ #[inline]
+ fn clone(&self) -> Self {
+ Self::new()
+ }
+}
+pub const UDMABUF_FLAGS_CLOEXEC: u32 = 1;
+pub type __u32 = ::std::os::raw::c_uint;
+pub type __u64 = ::std::os::raw::c_ulonglong;
+#[repr(C)]
+#[derive(Debug, Default, Copy, Clone)]
+pub struct udmabuf_create {
+ pub memfd: __u32,
+ pub flags: __u32,
+ pub offset: __u64,
+ pub size: __u64,
+}
+#[repr(C)]
+#[derive(Debug, Default, Copy, Clone)]
+pub struct udmabuf_create_item {
+ pub memfd: __u32,
+ pub __pad: __u32,
+ pub offset: __u64,
+ pub size: __u64,
+}
+#[repr(C)]
+#[repr(align(8))]
+#[derive(Debug, Default)]
+pub struct udmabuf_create_list {
+ pub flags: __u32,
+ pub count: __u32,
+ pub list: __IncompleteArrayField<udmabuf_create_item>,
+}
diff --git a/devices/src/virtio/gpu/virtio_gpu.rs b/devices/src/virtio/gpu/virtio_gpu.rs
index e1e29064d..00086903a 100644
--- a/devices/src/virtio/gpu/virtio_gpu.rs
+++ b/devices/src/virtio/gpu/virtio_gpu.rs
@@ -10,7 +10,7 @@ use std::result::Result;
use std::sync::Arc;
use crate::virtio::resource_bridge::{BufferInfo, PlaneInfo, ResourceInfo, ResourceResponse};
-use base::{error, AsRawDescriptor, ExternalMapping};
+use base::{error, AsRawDescriptor, ExternalMapping, Tube};
use data_model::VolatileSlice;
@@ -20,21 +20,22 @@ use rutabaga_gfx::{
RutabagaIovec, Transfer3D,
};
-use msg_socket::{MsgReceiver, MsgSender};
-
use libc::c_void;
use resources::Alloc;
-use super::protocol::{GpuResponse::*, GpuResponsePlaneInfo, VirtioGpuResult};
+use super::protocol::{
+ GpuResponse::{self, *},
+ GpuResponsePlaneInfo, VirtioGpuResult, VIRTIO_GPU_BLOB_FLAG_CREATE_GUEST_HANDLE,
+ VIRTIO_GPU_BLOB_MEM_HOST3D,
+};
+use super::udmabuf::UdmabufDriver;
use super::VirtioScanoutBlobData;
use sync::Mutex;
use vm_memory::{GuestAddress, GuestMemory};
-use vm_control::{
- MaybeOwnedDescriptor, MemSlot, VmMemoryControlRequestSocket, VmMemoryRequest, VmMemoryResponse,
-};
+use vm_control::{MemSlot, VmMemoryRequest, VmMemoryResponse};
struct VirtioGpuResource {
resource_id: u32,
@@ -78,12 +79,13 @@ pub struct VirtioGpu {
cursor_surface_id: Option<u32>,
// Maps event devices to scanout number.
event_devices: Map<u32, u32>,
- gpu_device_socket: VmMemoryControlRequestSocket,
+ gpu_device_tube: Tube,
pci_bar: Alloc,
map_request: Arc<Mutex<Option<ExternalMapping>>>,
rutabaga: Rutabaga,
resources: Map<u32, VirtioGpuResource>,
external_blob: bool,
+ udmabuf_driver: Option<UdmabufDriver>,
}
fn sglist_to_rutabaga_iovecs(
@@ -116,15 +118,26 @@ impl VirtioGpu {
display_height: u32,
rutabaga_builder: RutabagaBuilder,
event_devices: Vec<EventDevice>,
- gpu_device_socket: VmMemoryControlRequestSocket,
+ gpu_device_tube: Tube,
pci_bar: Alloc,
map_request: Arc<Mutex<Option<ExternalMapping>>>,
external_blob: bool,
+ udmabuf: bool,
) -> Option<VirtioGpu> {
let rutabaga = rutabaga_builder
.build()
.map_err(|e| error!("failed to build rutabaga {}", e))
.ok()?;
+
+ let mut udmabuf_driver = None;
+ if udmabuf {
+ udmabuf_driver = Some(
+ UdmabufDriver::new()
+ .map_err(|e| error!("failed to initialize udmabuf: {}", e))
+ .ok()?,
+ );
+ }
+
let mut virtio_gpu = VirtioGpu {
display: Rc::new(RefCell::new(display)),
display_width,
@@ -134,12 +147,13 @@ impl VirtioGpu {
scanout_surface_id: None,
cursor_resource_id: None,
cursor_surface_id: None,
- gpu_device_socket,
+ gpu_device_tube,
pci_bar,
map_request,
rutabaga,
resources: Default::default(),
external_blob,
+ udmabuf_driver,
};
for event_device in event_devices {
@@ -424,7 +438,7 @@ impl VirtioGpu {
/// If supported, export the resource with the given `resource_id` to a file.
pub fn export_resource(&mut self, resource_id: u32) -> ResourceResponse {
let file = match self.rutabaga.export_blob(resource_id) {
- Ok(handle) => handle.os_handle,
+ Ok(handle) => handle.os_handle.into(),
Err(_) => return ResourceResponse::Invalid,
};
@@ -453,6 +467,7 @@ impl VirtioGpu {
stride: q.strides[3],
},
],
+ modifier: q.modifier,
}))
}
@@ -460,7 +475,7 @@ impl VirtioGpu {
pub fn export_fence(&self, fence_id: u32) -> ResourceResponse {
match self.rutabaga.export_fence(fence_id) {
Ok(handle) => ResourceResponse::Resource(ResourceInfo::Fence {
- file: handle.os_handle,
+ file: handle.os_handle.into(),
}),
Err(_) => ResourceResponse::Invalid,
}
@@ -517,7 +532,7 @@ impl VirtioGpu {
// Rely on rutabaga to check for duplicate resource ids.
self.resources.insert(resource_id, resource);
- self.result_from_query(resource_id)
+ Ok(self.result_from_query(resource_id))
}
/// Attaches backing memory to the given resource, represented by a `Vec` of `(address, size)`
@@ -588,19 +603,32 @@ impl VirtioGpu {
vecs: Vec<(GuestAddress, usize)>,
mem: &GuestMemory,
) -> VirtioGpuResult {
- let rutabaga_iovecs = sglist_to_rutabaga_iovecs(&vecs[..], mem).map_err(|_| ErrUnspec)?;
+ let mut rutabaga_handle = None;
+ let mut rutabaga_iovecs = None;
+
+ if resource_create_blob.blob_flags & VIRTIO_GPU_BLOB_FLAG_CREATE_GUEST_HANDLE != 0 {
+ rutabaga_handle = match self.udmabuf_driver {
+ Some(ref driver) => Some(driver.create_udmabuf(mem, &vecs[..])?),
+ None => return Err(ErrUnspec),
+ }
+ } else if resource_create_blob.blob_mem != VIRTIO_GPU_BLOB_MEM_HOST3D {
+ rutabaga_iovecs =
+ Some(sglist_to_rutabaga_iovecs(&vecs[..], mem).map_err(|_| ErrUnspec)?);
+ }
+
self.rutabaga.resource_create_blob(
ctx_id,
resource_id,
resource_create_blob,
rutabaga_iovecs,
+ rutabaga_handle,
)?;
let resource = VirtioGpuResource::new(resource_id, 0, 0, resource_create_blob.size);
// Rely on rutabaga to check for duplicate resource ids.
self.resources.insert(resource_id, resource);
- self.result_from_query(resource_id)
+ Ok(self.result_from_query(resource_id))
}
/// Uses the hypervisor to map the rutabaga blob resource.
@@ -611,15 +639,28 @@ impl VirtioGpu {
.ok_or(ErrInvalidResourceId)?;
let map_info = self.rutabaga.map_info(resource_id).map_err(|_| ErrUnspec)?;
+ let vulkan_info_opt = self.rutabaga.vulkan_info(resource_id).ok();
+
let export = self.rutabaga.export_blob(resource_id);
let request = match export {
- Ok(ref export) => VmMemoryRequest::RegisterFdAtPciBarOffset(
- self.pci_bar,
- MaybeOwnedDescriptor::Borrowed(export.os_handle.as_raw_descriptor()),
- resource.size as usize,
- offset,
- ),
+ Ok(export) => match vulkan_info_opt {
+ Some(vulkan_info) => VmMemoryRequest::RegisterVulkanMemoryAtPciBarOffset {
+ alloc: self.pci_bar,
+ descriptor: export.os_handle,
+ handle_type: export.handle_type,
+ memory_idx: vulkan_info.memory_idx,
+ physical_device_idx: vulkan_info.physical_device_idx,
+ offset,
+ size: resource.size,
+ },
+ None => VmMemoryRequest::RegisterFdAtPciBarOffset(
+ self.pci_bar,
+ export.os_handle,
+ resource.size as usize,
+ offset,
+ ),
+ },
Err(_) => {
if self.external_blob {
return Err(ErrUnspec);
@@ -638,8 +679,8 @@ impl VirtioGpu {
}
};
- self.gpu_device_socket.send(&request)?;
- let response = self.gpu_device_socket.recv()?;
+ self.gpu_device_tube.send(&request)?;
+ let response = self.gpu_device_tube.recv()?;
match response {
VmMemoryResponse::RegisterMemory { pfn: _, slot } => {
@@ -660,8 +701,8 @@ impl VirtioGpu {
let slot = resource.slot.ok_or(ErrUnspec)?;
let request = VmMemoryRequest::UnregisterMemory(slot);
- self.gpu_device_socket.send(&request)?;
- let response = self.gpu_device_socket.recv()?;
+ self.gpu_device_tube.send(&request)?;
+ let response = self.gpu_device_tube.recv()?;
match response {
VmMemoryResponse::Ok => {
@@ -704,7 +745,7 @@ impl VirtioGpu {
}
// Non-public function -- no doc comment needed!
- fn result_from_query(&mut self, resource_id: u32) -> VirtioGpuResult {
+ fn result_from_query(&mut self, resource_id: u32) -> GpuResponse {
match self.rutabaga.query(resource_id) {
Ok(query) => {
let mut plane_info = Vec::with_capacity(4);
@@ -715,12 +756,12 @@ impl VirtioGpu {
});
}
let format_modifier = query.modifier;
- Ok(OkResourcePlaneInfo {
+ OkResourcePlaneInfo {
format_modifier,
plane_info,
- })
+ }
}
- Err(_) => Ok(OkNoData),
+ Err(_) => OkNoData,
}
}
}
diff --git a/devices/src/virtio/input/event_source.rs b/devices/src/virtio/input/event_source.rs
index 650b19398..e17812298 100644
--- a/devices/src/virtio/input/event_source.rs
+++ b/devices/src/virtio/input/event_source.rs
@@ -242,7 +242,7 @@ mod tests {
use std::cmp::min;
use std::io::{Read, Write};
- use data_model::{DataInit, Le16, Le32};
+ use data_model::{DataInit, Le16, SLe32};
use linux_input_sys::InputEventDecoder;
use crate::virtio::input::event_source::{input_event, virtio_input_event, EventSourceImpl};
@@ -317,7 +317,11 @@ mod tests {
timestamp_fields: [0, 0],
type_: 3 * (idx as u16) + 1,
code: 3 * (idx as u16) + 2,
- value: 3 * (idx as u32) + 3,
+ value: if idx % 2 == 0 {
+ 3 * (idx as i32) + 3
+ } else {
+ -3 * (idx as i32) - 3
+ },
});
}
ret
@@ -326,7 +330,7 @@ mod tests {
fn assert_events_match(e1: &virtio_input_event, e2: &input_event) {
assert_eq!(e1.type_, Le16::from(e2.type_), "type should match");
assert_eq!(e1.code, Le16::from(e2.code), "code should match");
- assert_eq!(e1.value, Le32::from(e2.value), "value should match");
+ assert_eq!(e1.value, SLe32::from(e2.value), "value should match");
}
#[test]
diff --git a/devices/src/virtio/input/mod.rs b/devices/src/virtio/input/mod.rs
index 3848431cf..30cd724b2 100644
--- a/devices/src/virtio/input/mod.rs
+++ b/devices/src/virtio/input/mod.rs
@@ -16,8 +16,8 @@ use vm_memory::GuestMemory;
use self::event_source::{EvdevEventSource, EventSource, SocketEventSource};
use super::{
- copy_config, DescriptorChain, DescriptorError, Interrupt, Queue, Reader, VirtioDevice, Writer,
- TYPE_INPUT,
+ copy_config, DescriptorChain, DescriptorError, Interrupt, Queue, Reader, SignalableInterrupt,
+ VirtioDevice, Writer, TYPE_INPUT,
};
use linux_input_sys::{virtio_input_event, InputEventDecoder};
use std::collections::BTreeMap;
@@ -433,7 +433,7 @@ impl<T: EventSource> Worker<T> {
Ok(count) => count,
Err(e) => {
error!("Input: failed to read events from virtqueue: {}", e);
- break;
+ return Err(e);
}
};
@@ -463,7 +463,6 @@ impl<T: EventSource> Worker<T> {
(&event_queue_evt, Token::EventQAvailable),
(&status_queue_evt, Token::StatusQAvailable),
(&self.event_source, Token::InputEventsAvailable),
- (self.interrupt.get_resample_evt(), Token::InterruptResample),
(&kill_evt, Token::Kill),
]) {
Ok(wait_ctx) => wait_ctx,
@@ -472,6 +471,15 @@ impl<T: EventSource> Worker<T> {
return;
}
};
+ if let Some(resample_evt) = self.interrupt.get_resample_evt() {
+ if wait_ctx
+ .add(resample_evt, Token::InterruptResample)
+ .is_err()
+ {
+ error!("failed adding resample event to WaitContext.");
+ return;
+ }
+ }
'wait: loop {
let wait_events = match wait_ctx.wait() {
diff --git a/devices/src/virtio/interrupt.rs b/devices/src/virtio/interrupt.rs
index 2ddff08f6..5b52dc505 100644
--- a/devices/src/virtio/interrupt.rs
+++ b/devices/src/virtio/interrupt.rs
@@ -9,31 +9,35 @@ use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use sync::Mutex;
+pub trait SignalableInterrupt {
+ /// Writes to the irqfd to VMM to deliver virtual interrupt to the guest.
+ fn signal(&self, vector: u16, interrupt_status_mask: u32);
+
+ /// Notify the driver that buffers have been placed in the used queue.
+ fn signal_used_queue(&self, vector: u16) {
+ self.signal(vector, INTERRUPT_STATUS_USED_RING)
+ }
+
+ /// Notify the driver that the device configuration has changed.
+ fn signal_config_changed(&self);
+
+ /// Get the event to signal resampling is needed if it exists.
+ fn get_resample_evt(&self) -> Option<&Event>;
+
+ /// Reads the status and writes to the interrupt event. Doesn't read the resample event, it
+ /// assumes the resample has been requested.
+ fn do_interrupt_resample(&self);
+}
+
pub struct Interrupt {
interrupt_status: Arc<AtomicUsize>,
interrupt_evt: Event,
interrupt_resample_evt: Event,
- pub msix_config: Option<Arc<Mutex<MsixConfig>>>,
+ msix_config: Option<Arc<Mutex<MsixConfig>>>,
config_msix_vector: u16,
}
-impl Interrupt {
- pub fn new(
- interrupt_status: Arc<AtomicUsize>,
- interrupt_evt: Event,
- interrupt_resample_evt: Event,
- msix_config: Option<Arc<Mutex<MsixConfig>>>,
- config_msix_vector: u16,
- ) -> Interrupt {
- Interrupt {
- interrupt_status,
- interrupt_evt,
- interrupt_resample_evt,
- msix_config,
- config_msix_vector,
- }
- }
-
+impl SignalableInterrupt for Interrupt {
/// Virtqueue Interrupts From The Device
///
/// If MSI-X is enabled in this device, MSI-X interrupt is preferred.
@@ -62,33 +66,46 @@ impl Interrupt {
}
}
- /// Notify the driver that buffers have been placed in the used queue.
- pub fn signal_used_queue(&self, vector: u16) {
- self.signal(vector, INTERRUPT_STATUS_USED_RING)
- }
-
- /// Notify the driver that the device configuration has changed.
- pub fn signal_config_changed(&self) {
+ fn signal_config_changed(&self) {
self.signal(self.config_msix_vector, INTERRUPT_STATUS_CONFIG_CHANGED)
}
- /// Handle interrupt resampling event, reading the value from the event and doing the resample.
- pub fn interrupt_resample(&self) {
- let _ = self.interrupt_resample_evt.read();
- self.do_interrupt_resample();
+ fn get_resample_evt(&self) -> Option<&Event> {
+ Some(&self.interrupt_resample_evt)
}
- /// Read the status and write to the interrupt event. Don't read the resample event, assume the
- /// resample has been requested.
- pub fn do_interrupt_resample(&self) {
+ fn do_interrupt_resample(&self) {
if self.interrupt_status.load(Ordering::SeqCst) != 0 {
self.interrupt_evt.write(1).unwrap();
}
}
+}
+
+impl Interrupt {
+ pub fn new(
+ interrupt_status: Arc<AtomicUsize>,
+ interrupt_evt: Event,
+ interrupt_resample_evt: Event,
+ msix_config: Option<Arc<Mutex<MsixConfig>>>,
+ config_msix_vector: u16,
+ ) -> Interrupt {
+ Interrupt {
+ interrupt_status,
+ interrupt_evt,
+ interrupt_resample_evt,
+ msix_config,
+ config_msix_vector,
+ }
+ }
+
+ /// Handle interrupt resampling event, reading the value from the event and doing the resample.
+ pub fn interrupt_resample(&self) {
+ let _ = self.interrupt_resample_evt.read();
+ self.do_interrupt_resample();
+ }
- /// Return the reference of interrupt_resample_evt
- /// To keep the interface clean, this member is private.
- pub fn get_resample_evt(&self) -> &Event {
- &self.interrupt_resample_evt
+ /// Get a reference to the msix configuration
+ pub fn get_msix_config(&self) -> &Option<Arc<Mutex<MsixConfig>>> {
+ &self.msix_config
}
}
diff --git a/devices/src/virtio/mod.rs b/devices/src/virtio/mod.rs
index a859e13bb..26966815b 100644
--- a/devices/src/virtio/mod.rs
+++ b/devices/src/virtio/mod.rs
@@ -92,7 +92,7 @@ const MAX_VIRTIO_DEVICE_ID: u32 = 63;
const TYPE_WL: u32 = MAX_VIRTIO_DEVICE_ID;
const TYPE_TPM: u32 = MAX_VIRTIO_DEVICE_ID - 1;
-const VIRTIO_F_VERSION_1: u32 = 32;
+pub const VIRTIO_F_VERSION_1: u32 = 32;
const VIRTIO_F_ACCESS_PLATFORM: u32 = 33;
const INTERRUPT_STATUS_USED_RING: u32 = 0x1;
diff --git a/devices/src/virtio/net.rs b/devices/src/virtio/net.rs
index 1c1c225cc..d1f35dca5 100644
--- a/devices/src/virtio/net.rs
+++ b/devices/src/virtio/net.rs
@@ -23,7 +23,8 @@ use virtio_sys::virtio_net::{
use vm_memory::GuestMemory;
use super::{
- copy_config, DescriptorError, Interrupt, Queue, Reader, VirtioDevice, Writer, TYPE_NET,
+ copy_config, DescriptorError, Interrupt, Queue, Reader, SignalableInterrupt, VirtioDevice,
+ Writer, TYPE_NET,
};
const QUEUE_SIZE: u16 = 256;
@@ -134,7 +135,7 @@ fn virtio_features_to_tap_offload(features: u64) -> c_uint {
#[derive(Debug, Clone, Copy, Default)]
#[repr(C)]
-struct VirtioNetConfig {
+pub(crate) struct VirtioNetConfig {
mac: [u8; 6],
status: Le16,
max_vq_pairs: Le16,
@@ -351,9 +352,11 @@ where
.add(ctrl_evt, Token::CtrlQueue)
.map_err(NetError::CreateWaitContext)?;
// Let CtrlQueue's thread handle InterruptResample also.
- wait_ctx
- .add(self.interrupt.get_resample_evt(), Token::InterruptResample)
- .map_err(NetError::CreateWaitContext)?;
+ if let Some(resample_evt) = self.interrupt.get_resample_evt() {
+ wait_ctx
+ .add(resample_evt, Token::InterruptResample)
+ .map_err(NetError::CreateWaitContext)?;
+ }
}
let mut tap_polling_enabled = true;
diff --git a/devices/src/virtio/p9.rs b/devices/src/virtio/p9.rs
index a082afa06..670645d88 100644
--- a/devices/src/virtio/p9.rs
+++ b/devices/src/virtio/p9.rs
@@ -12,7 +12,8 @@ use base::{error, warn, Error as SysError, Event, PollToken, RawDescriptor, Wait
use vm_memory::GuestMemory;
use super::{
- copy_config, DescriptorError, Interrupt, Queue, Reader, VirtioDevice, Writer, TYPE_9P,
+ copy_config, DescriptorError, Interrupt, Queue, Reader, SignalableInterrupt, VirtioDevice,
+ Writer, TYPE_9P,
};
const QUEUE_SIZE: u16 = 128;
@@ -115,12 +116,14 @@ impl Worker {
Kill,
}
- let wait_ctx: WaitContext<Token> = WaitContext::build_with(&[
- (&queue_evt, Token::QueueReady),
- (self.interrupt.get_resample_evt(), Token::InterruptResample),
- (&kill_evt, Token::Kill),
- ])
- .map_err(P9Error::CreateWaitContext)?;
+ let wait_ctx: WaitContext<Token> =
+ WaitContext::build_with(&[(&queue_evt, Token::QueueReady), (&kill_evt, Token::Kill)])
+ .map_err(P9Error::CreateWaitContext)?;
+ if let Some(resample_evt) = self.interrupt.get_resample_evt() {
+ wait_ctx
+ .add(resample_evt, Token::InterruptResample)
+ .map_err(P9Error::CreateWaitContext)?;
+ }
loop {
let events = wait_ctx.wait().map_err(P9Error::WaitError)?;
diff --git a/devices/src/virtio/pmem.rs b/devices/src/virtio/pmem.rs
index 6412268d7..9a3ca39cb 100644
--- a/devices/src/virtio/pmem.rs
+++ b/devices/src/virtio/pmem.rs
@@ -7,19 +7,15 @@ use std::fs::File;
use std::io;
use std::thread;
-use base::{error, AsRawDescriptor, Event, PollToken, RawDescriptor, WaitContext};
+use base::{error, AsRawDescriptor, Event, PollToken, RawDescriptor, Tube, WaitContext};
use base::{Error as SysError, Result as SysResult};
-use vm_memory::{GuestAddress, GuestMemory};
-
use data_model::{DataInit, Le32, Le64};
-
-use msg_socket::{MsgReceiver, MsgSender};
-
-use vm_control::{MemSlot, VmMsyncRequest, VmMsyncRequestSocket, VmMsyncResponse};
+use vm_control::{MemSlot, VmMsyncRequest, VmMsyncResponse};
+use vm_memory::{GuestAddress, GuestMemory};
use super::{
- copy_config, DescriptorChain, DescriptorError, Interrupt, Queue, Reader, VirtioDevice, Writer,
- TYPE_PMEM,
+ copy_config, DescriptorChain, DescriptorError, Interrupt, Queue, Reader, SignalableInterrupt,
+ VirtioDevice, Writer, TYPE_PMEM,
};
const QUEUE_SIZE: u16 = 256;
@@ -87,7 +83,7 @@ struct Worker {
interrupt: Interrupt,
queue: Queue,
memory: GuestMemory,
- pmem_device_socket: VmMsyncRequestSocket,
+ pmem_device_tube: Tube,
mapping_arena_slot: MemSlot,
mapping_size: usize,
}
@@ -102,12 +98,12 @@ impl Worker {
size: self.mapping_size,
};
- if let Err(e) = self.pmem_device_socket.send(&request) {
+ if let Err(e) = self.pmem_device_tube.send(&request) {
error!("failed to send request: {}", e);
return VIRTIO_PMEM_RESP_TYPE_EIO;
}
- match self.pmem_device_socket.recv() {
+ match self.pmem_device_tube.recv() {
Ok(response) => match response {
VmMsyncResponse::Ok => VIRTIO_PMEM_RESP_TYPE_OK,
VmMsyncResponse::Err(e) => {
@@ -177,7 +173,6 @@ impl Worker {
let wait_ctx: WaitContext<Token> = match WaitContext::build_with(&[
(&queue_evt, Token::QueueAvailable),
- (self.interrupt.get_resample_evt(), Token::InterruptResample),
(&kill_evt, Token::Kill),
]) {
Ok(pc) => pc,
@@ -186,6 +181,15 @@ impl Worker {
return;
}
};
+ if let Some(resample_evt) = self.interrupt.get_resample_evt() {
+ if wait_ctx
+ .add(resample_evt, Token::InterruptResample)
+ .is_err()
+ {
+ error!("failed adding resample event to WaitContext.");
+ return;
+ }
+ }
'wait: loop {
let events = match wait_ctx.wait() {
@@ -227,7 +231,7 @@ pub struct Pmem {
mapping_address: GuestAddress,
mapping_arena_slot: MemSlot,
mapping_size: u64,
- pmem_device_socket: Option<VmMsyncRequestSocket>,
+ pmem_device_tube: Option<Tube>,
}
impl Pmem {
@@ -237,7 +241,7 @@ impl Pmem {
mapping_address: GuestAddress,
mapping_arena_slot: MemSlot,
mapping_size: u64,
- pmem_device_socket: Option<VmMsyncRequestSocket>,
+ pmem_device_tube: Option<Tube>,
) -> SysResult<Pmem> {
if mapping_size > usize::max_value() as u64 {
return Err(SysError::new(libc::EOVERFLOW));
@@ -251,7 +255,7 @@ impl Pmem {
mapping_address,
mapping_arena_slot,
mapping_size,
- pmem_device_socket,
+ pmem_device_tube,
})
}
}
@@ -276,8 +280,8 @@ impl VirtioDevice for Pmem {
keep_rds.push(disk_image.as_raw_descriptor());
}
- if let Some(ref pmem_device_socket) = self.pmem_device_socket {
- keep_rds.push(pmem_device_socket.as_raw_descriptor());
+ if let Some(ref pmem_device_tube) = self.pmem_device_tube {
+ keep_rds.push(pmem_device_tube.as_raw_descriptor());
}
keep_rds
}
@@ -320,7 +324,7 @@ impl VirtioDevice for Pmem {
// We checked that this fits in a usize in `Pmem::new`.
let mapping_size = self.mapping_size as usize;
- if let Some(pmem_device_socket) = self.pmem_device_socket.take() {
+ if let Some(pmem_device_tube) = self.pmem_device_tube.take() {
let (self_kill_event, kill_event) =
match Event::new().and_then(|e| Ok((e.try_clone()?, e))) {
Ok(v) => v,
@@ -338,7 +342,7 @@ impl VirtioDevice for Pmem {
interrupt,
memory,
queue,
- pmem_device_socket,
+ pmem_device_tube,
mapping_arena_slot,
mapping_size,
};
diff --git a/devices/src/virtio/queue.rs b/devices/src/virtio/queue.rs
index 3f48ad64a..537abaeee 100644
--- a/devices/src/virtio/queue.rs
+++ b/devices/src/virtio/queue.rs
@@ -11,7 +11,7 @@ use cros_async::{AsyncError, EventAsync};
use virtio_sys::virtio_ring::VIRTIO_RING_F_EVENT_IDX;
use vm_memory::{GuestAddress, GuestMemory};
-use super::{Interrupt, VIRTIO_MSI_NO_VECTOR};
+use super::{SignalableInterrupt, VIRTIO_MSI_NO_VECTOR};
const VIRTQ_DESC_F_NEXT: u16 = 0x1;
const VIRTQ_DESC_F_WRITE: u16 = 0x2;
@@ -536,7 +536,11 @@ impl Queue {
/// inject interrupt into guest on this queue
/// return true: interrupt is injected into guest for this queue
/// false: interrupt isn't injected
- pub fn trigger_interrupt(&mut self, mem: &GuestMemory, interrupt: &Interrupt) -> bool {
+ pub fn trigger_interrupt(
+ &mut self,
+ mem: &GuestMemory,
+ interrupt: &dyn SignalableInterrupt,
+ ) -> bool {
if self.available_interrupt_enabled(mem) {
self.last_used = self.next_used;
interrupt.signal_used_queue(self.vector);
@@ -554,6 +558,7 @@ impl Queue {
#[cfg(test)]
mod tests {
+ use super::super::Interrupt;
use super::*;
use base::Event;
use data_model::{DataInit, Le16, Le32, Le64};
diff --git a/devices/src/virtio/resource_bridge.rs b/devices/src/virtio/resource_bridge.rs
index ec2c2b590..a89843764 100644
--- a/devices/src/virtio/resource_bridge.rs
+++ b/devices/src/virtio/resource_bridge.rs
@@ -8,53 +8,51 @@
use std::fmt;
use std::fs::File;
-use base::RawDescriptor;
-use msg_on_socket_derive::MsgOnSocket;
-use msg_socket::{MsgError, MsgReceiver, MsgSender, MsgSocket};
+use serde::{Deserialize, Serialize};
-#[derive(MsgOnSocket, Debug)]
+use base::{with_as_descriptor, Tube, TubeError};
+
+#[derive(Debug, Serialize, Deserialize)]
pub enum ResourceRequest {
GetBuffer { id: u32 },
GetFence { seqno: u64 },
}
-#[derive(MsgOnSocket, Clone, Copy, Default)]
+#[derive(Serialize, Deserialize, Clone, Copy, Default)]
pub struct PlaneInfo {
pub offset: u32,
pub stride: u32,
}
-#[derive(MsgOnSocket)]
+#[derive(Serialize, Deserialize)]
pub struct BufferInfo {
+ #[serde(with = "with_as_descriptor")]
pub file: File,
pub planes: [PlaneInfo; RESOURE_PLANE_NUM],
+ pub modifier: u64,
}
pub const RESOURE_PLANE_NUM: usize = 4;
-#[derive(MsgOnSocket)]
+#[derive(Serialize, Deserialize)]
pub enum ResourceInfo {
Buffer(BufferInfo),
- Fence { file: File },
+ Fence {
+ #[serde(with = "with_as_descriptor")]
+ file: File,
+ },
}
-#[derive(MsgOnSocket)]
+#[derive(Serialize, Deserialize)]
pub enum ResourceResponse {
Resource(ResourceInfo),
Invalid,
}
-pub type ResourceRequestSocket = MsgSocket<ResourceRequest, ResourceResponse>;
-pub type ResourceResponseSocket = MsgSocket<ResourceResponse, ResourceRequest>;
-
-pub fn pair() -> std::io::Result<(ResourceRequestSocket, ResourceResponseSocket)> {
- msg_socket::pair()
-}
-
#[derive(Debug)]
pub enum ResourceBridgeError {
InvalidResource(ResourceRequest),
- SendFailure(ResourceRequest, MsgError),
- RecieveFailure(ResourceRequest, MsgError),
+ SendFailure(ResourceRequest, TubeError),
+ RecieveFailure(ResourceRequest, TubeError),
}
impl fmt::Display for ResourceRequest {
@@ -89,14 +87,14 @@ impl fmt::Display for ResourceBridgeError {
impl std::error::Error for ResourceBridgeError {}
pub fn get_resource_info(
- sock: &ResourceRequestSocket,
+ tube: &Tube,
request: ResourceRequest,
) -> std::result::Result<ResourceInfo, ResourceBridgeError> {
- if let Err(e) = sock.send(&request) {
+ if let Err(e) = tube.send(&request) {
return Err(ResourceBridgeError::SendFailure(request, e));
}
- match sock.recv() {
+ match tube.recv() {
Ok(ResourceResponse::Resource(info)) => Ok(info),
Ok(ResourceResponse::Invalid) => Err(ResourceBridgeError::InvalidResource(request)),
Err(e) => Err(ResourceBridgeError::RecieveFailure(request, e)),
diff --git a/devices/src/virtio/rng.rs b/devices/src/virtio/rng.rs
index 466ffd247..65d671e6a 100644
--- a/devices/src/virtio/rng.rs
+++ b/devices/src/virtio/rng.rs
@@ -10,7 +10,7 @@ use std::thread;
use base::{error, warn, AsRawDescriptor, Event, PollToken, RawDescriptor, WaitContext};
use vm_memory::GuestMemory;
-use super::{Interrupt, Queue, VirtioDevice, Writer, TYPE_RNG};
+use super::{Interrupt, Queue, SignalableInterrupt, VirtioDevice, Writer, TYPE_RNG};
const QUEUE_SIZE: u16 = 256;
const QUEUE_SIZES: &[u16] = &[QUEUE_SIZE];
@@ -75,7 +75,6 @@ impl Worker {
let wait_ctx: WaitContext<Token> = match WaitContext::build_with(&[
(&queue_evt, Token::QueueAvailable),
- (self.interrupt.get_resample_evt(), Token::InterruptResample),
(&kill_evt, Token::Kill),
]) {
Ok(pc) => pc,
@@ -84,6 +83,15 @@ impl Worker {
return;
}
};
+ if let Some(resample_evt) = self.interrupt.get_resample_evt() {
+ if wait_ctx
+ .add(resample_evt, Token::InterruptResample)
+ .is_err()
+ {
+ error!("failed adding resample event to WaitContext.");
+ return;
+ }
+ }
'wait: loop {
let events = match wait_ctx.wait() {
diff --git a/devices/src/virtio/snd/constants.rs b/devices/src/virtio/snd/constants.rs
index 38b06c496..28a73a083 100644
--- a/devices/src/virtio/snd/constants.rs
+++ b/devices/src/virtio/snd/constants.rs
@@ -13,14 +13,17 @@ pub const STREAM_STOP: u32 = 0x0100 + 5;
pub const CHANNEL_MAP_INFO: u32 = 0x0200;
-pub const VIRTIO_SND_S_OK: u32 = 0x8000;
-pub const VIRTIO_SND_S_BAD_MSG: u32 = 0x8001;
-pub const VIRTIO_SND_S_NOT_SUPP: u32 = 0x8002;
-pub const VIRTIO_SND_S_IO_ERR: u32 = 0x8003;
-
pub const VIRTIO_SND_D_OUTPUT: u8 = 0;
pub const VIRTIO_SND_D_INPUT: u8 = 1;
+/* supported PCM stream features */
+pub const VIRTIO_SND_PCM_F_SHMEM_HOST: u8 = 0;
+pub const VIRTIO_SND_PCM_F_SHMEM_GUEST: u8 = 1;
+pub const VIRTIO_SND_PCM_F_MSG_POLLING: u8 = 2;
+pub const VIRTIO_SND_PCM_F_EVT_SHMEM_PERIODS: u8 = 3;
+pub const VIRTIO_SND_PCM_F_EVT_XRUNS: u8 = 4;
+
+/* supported PCM sample formats */
pub const VIRTIO_SND_PCM_FMT_IMA_ADPCM: u8 = 0;
pub const VIRTIO_SND_PCM_FMT_MU_LAW: u8 = 1;
pub const VIRTIO_SND_PCM_FMT_A_LAW: u8 = 2;
@@ -61,3 +64,76 @@ pub const VIRTIO_SND_PCM_RATE_96000: u8 = 10;
pub const VIRTIO_SND_PCM_RATE_176400: u8 = 11;
pub const VIRTIO_SND_PCM_RATE_192000: u8 = 12;
pub const VIRTIO_SND_PCM_RATE_384000: u8 = 13;
+
+// From https://github.com/oasis-tcs/virtio-spec/blob/master/virtio-sound.tex
+/* jack control request types */
+pub const VIRTIO_SND_R_JACK_INFO: u32 = 1;
+pub const VIRTIO_SND_R_JACK_REMAP: u32 = 2;
+
+/* PCM control request types */
+pub const VIRTIO_SND_R_PCM_INFO: u32 = 0x0100;
+pub const VIRTIO_SND_R_PCM_SET_PARAMS: u32 = 0x0101;
+pub const VIRTIO_SND_R_PCM_PREPARE: u32 = 0x0102;
+pub const VIRTIO_SND_R_PCM_RELEASE: u32 = 0x0103;
+pub const VIRTIO_SND_R_PCM_START: u32 = 0x0104;
+pub const VIRTIO_SND_R_PCM_STOP: u32 = 0x0105;
+
+/* channel map control request types */
+pub const VIRTIO_SND_R_CHMAP_INFO: u32 = 0x0200;
+
+/* jack event types */
+pub const VIRTIO_SND_EVT_JACK_CONNECTED: u32 = 0x1000;
+pub const VIRTIO_SND_EVT_JACK_DISCONNECTED: u32 = 0x1001;
+
+/* PCM event types */
+pub const VIRTIO_SND_EVT_PCM_PERIOD_ELAPSED: u32 = 0x1100;
+pub const VIRTIO_SND_EVT_PCM_XRUN: u32 = 0x1101;
+
+/* common status codes */
+pub const VIRTIO_SND_S_OK: u32 = 0x8000;
+pub const VIRTIO_SND_S_BAD_MSG: u32 = 0x8001;
+pub const VIRTIO_SND_S_NOT_SUPP: u32 = 0x8002;
+pub const VIRTIO_SND_S_IO_ERR: u32 = 0x8003;
+
+pub const VIRTIO_SND_JACK_F_REMAP: u32 = 0;
+
+/* standard channel position definition */
+pub const VIRTIO_SND_CHMAP_NONE: u32 = 0; /* undefined */
+pub const VIRTIO_SND_CHMAP_NA: u32 = 1; /* silent */
+pub const VIRTIO_SND_CHMAP_MONO: u32 = 2; /* mono stream */
+pub const VIRTIO_SND_CHMAP_FL: u32 = 3; /* front left */
+pub const VIRTIO_SND_CHMAP_FR: u32 = 4; /* front right */
+pub const VIRTIO_SND_CHMAP_RL: u32 = 5; /* rear left */
+pub const VIRTIO_SND_CHMAP_RR: u32 = 6; /* rear right */
+pub const VIRTIO_SND_CHMAP_FC: u32 = 7; /* front center */
+pub const VIRTIO_SND_CHMAP_LFE: u32 = 8; /* low frequency (LFE) */
+pub const VIRTIO_SND_CHMAP_SL: u32 = 9; /* side left */
+pub const VIRTIO_SND_CHMAP_SR: u32 = 10; /* side right */
+pub const VIRTIO_SND_CHMAP_RC: u32 = 11; /* rear center */
+pub const VIRTIO_SND_CHMAP_FLC: u32 = 12; /* front left center */
+pub const VIRTIO_SND_CHMAP_FRC: u32 = 13; /* front right center */
+pub const VIRTIO_SND_CHMAP_RLC: u32 = 14; /* rear left center */
+pub const VIRTIO_SND_CHMAP_RRC: u32 = 15; /* rear right center */
+pub const VIRTIO_SND_CHMAP_FLW: u32 = 16; /* front left wide */
+pub const VIRTIO_SND_CHMAP_FRW: u32 = 17; /* front right wide */
+pub const VIRTIO_SND_CHMAP_FLH: u32 = 18; /* front left high */
+pub const VIRTIO_SND_CHMAP_FCH: u32 = 19; /* front center high */
+pub const VIRTIO_SND_CHMAP_FRH: u32 = 20; /* front right high */
+pub const VIRTIO_SND_CHMAP_TC: u32 = 21; /* top center */
+pub const VIRTIO_SND_CHMAP_TFL: u32 = 22; /* top front left */
+pub const VIRTIO_SND_CHMAP_TFR: u32 = 23; /* top front right */
+pub const VIRTIO_SND_CHMAP_TFC: u32 = 24; /* top front center */
+pub const VIRTIO_SND_CHMAP_TRL: u32 = 25; /* top rear left */
+pub const VIRTIO_SND_CHMAP_TRR: u32 = 26; /* top rear right */
+pub const VIRTIO_SND_CHMAP_TRC: u32 = 27; /* top rear center */
+pub const VIRTIO_SND_CHMAP_TFLC: u32 = 28; /* top front left center */
+pub const VIRTIO_SND_CHMAP_TFRC: u32 = 29; /* top front right center */
+pub const VIRTIO_SND_CHMAP_TSL: u32 = 34; /* top side left */
+pub const VIRTIO_SND_CHMAP_TSR: u32 = 35; /* top side right */
+pub const VIRTIO_SND_CHMAP_LLFE: u32 = 36; /* left LFE */
+pub const VIRTIO_SND_CHMAP_RLFE: u32 = 37; /* right LFE */
+pub const VIRTIO_SND_CHMAP_BC: u32 = 38; /* bottom center */
+pub const VIRTIO_SND_CHMAP_BLC: u32 = 39; /* bottom left center */
+pub const VIRTIO_SND_CHMAP_BRC: u32 = 40; /* bottom right center */
+
+pub const VIRTIO_SND_CHMAP_MAX_SIZE: usize = 18;
diff --git a/devices/src/virtio/snd/layout.rs b/devices/src/virtio/snd/layout.rs
index c78a9d7a4..8881af0b4 100644
--- a/devices/src/virtio/snd/layout.rs
+++ b/devices/src/virtio/snd/layout.rs
@@ -2,9 +2,20 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
+use crate::virtio::snd::constants::VIRTIO_SND_CHMAP_MAX_SIZE;
use data_model::{DataInit, Le32, Le64};
#[derive(Copy, Clone, Default)]
+#[repr(C, packed)]
+pub struct virtio_snd_config {
+ pub jacks: Le32,
+ pub streams: Le32,
+ pub chmaps: Le32,
+}
+// Safe because it only has data and has no implicit padding.
+unsafe impl DataInit for virtio_snd_config {}
+
+#[derive(Copy, Clone, Default)]
#[repr(C)]
pub struct virtio_snd_hdr {
pub code: Le32,
@@ -12,6 +23,24 @@ pub struct virtio_snd_hdr {
// Safe because it only has data and has no implicit padding.
unsafe impl DataInit for virtio_snd_hdr {}
+#[derive(Copy, Clone, Default)]
+#[repr(C)]
+pub struct virtio_snd_jack_hdr {
+ pub hdr: virtio_snd_hdr,
+ pub jack_id: Le32,
+}
+// Safe because it only has data and has no implicit padding.
+unsafe impl DataInit for virtio_snd_jack_hdr {}
+
+#[derive(Copy, Clone, Default)]
+#[repr(C)]
+pub struct virtio_snd_event {
+ pub hdr: virtio_snd_hdr,
+ pub data: Le32,
+}
+// Safe because it only has data and has no implicit padding.
+unsafe impl DataInit for virtio_snd_event {}
+
#[derive(Copy, Clone)]
#[repr(C)]
pub struct virtio_snd_query_info {
@@ -35,9 +64,9 @@ unsafe impl DataInit for virtio_snd_info {}
#[repr(C)]
pub struct virtio_snd_pcm_info {
pub hdr: virtio_snd_info,
- pub features: Le32,
- pub formats: Le64,
- pub rates: Le64,
+ pub features: Le32, /* 1 << VIRTIO_SND_PCM_F_XXX */
+ pub formats: Le64, /* 1 << VIRTIO_SND_PCM_FMT_XXX */
+ pub rates: Le64, /* 1 << VIRTIO_SND_PCM_RATE_XXX */
pub direction: u8,
pub channels_min: u8,
pub channels_max: u8,
@@ -62,7 +91,7 @@ pub struct virtio_snd_pcm_set_params {
pub hdr: virtio_snd_pcm_hdr,
pub buffer_bytes: Le32,
pub period_bytes: Le32,
- pub features: Le32,
+ pub features: Le32, /* 1 << VIRTIO_SND_PCM_F_XXX */
pub channels: u8,
pub format: u8,
pub rate: u8,
@@ -79,7 +108,7 @@ pub struct virtio_snd_pcm_xfer {
// Safe because it only has data and has no implicit padding.
unsafe impl DataInit for virtio_snd_pcm_xfer {}
-#[derive(Copy, Clone)]
+#[derive(Copy, Clone, Default)]
#[repr(C)]
pub struct virtio_snd_pcm_status {
pub status: Le32,
@@ -87,3 +116,37 @@ pub struct virtio_snd_pcm_status {
}
// Safe because it only has data and has no implicit padding.
unsafe impl DataInit for virtio_snd_pcm_status {}
+
+#[derive(Copy, Clone)]
+#[repr(C)]
+pub struct virtio_snd_jack_info {
+ pub hdr: virtio_snd_info,
+ pub features: Le32, /* 1 << VIRTIO_SND_JACK_F_XXX */
+ pub hda_reg_defconf: Le32,
+ pub hda_reg_caps: Le32,
+ pub connected: u8,
+ pub padding: [u8; 7],
+}
+// Safe because it only has data and has no implicit padding.
+unsafe impl DataInit for virtio_snd_jack_info {}
+
+#[derive(Copy, Clone)]
+#[repr(C)]
+pub struct virtio_snd_jack_remap {
+ pub hdr: virtio_snd_jack_hdr, /* .code = VIRTIO_SND_R_JACK_REMAP */
+ pub association: Le32,
+ pub sequence: Le32,
+}
+// Safe because it only has data and has no implicit padding.
+unsafe impl DataInit for virtio_snd_jack_remap {}
+
+#[derive(Copy, Clone)]
+#[repr(C)]
+pub struct virtio_snd_chmap_info {
+ pub hdr: virtio_snd_info,
+ pub direction: u8,
+ pub channels: u8,
+ pub positions: [u8; VIRTIO_SND_CHMAP_MAX_SIZE],
+}
+// Safe because it only has data and has no implicit padding.
+unsafe impl DataInit for virtio_snd_chmap_info {}
diff --git a/devices/src/virtio/snd/vios_backend/mod.rs b/devices/src/virtio/snd/vios_backend/mod.rs
index 416fe7725..602ec4dcd 100644
--- a/devices/src/virtio/snd/vios_backend/mod.rs
+++ b/devices/src/virtio/snd/vios_backend/mod.rs
@@ -5,7 +5,7 @@
mod shm_streams;
mod shm_vios;
-#[cfg(target_os = "linux")]
+#[cfg(any(target_os = "linux", target_os = "android"))]
pub use self::shm_streams::*;
pub use self::shm_vios::*;
diff --git a/devices/src/virtio/snd/vios_backend/shm_streams.rs b/devices/src/virtio/snd/vios_backend/shm_streams.rs
index c83ae86ef..5b5e0f641 100644
--- a/devices/src/virtio/snd/vios_backend/shm_streams.rs
+++ b/devices/src/virtio/snd/vios_backend/shm_streams.rs
@@ -22,7 +22,6 @@ use std::path::Path;
use std::sync::Arc;
use std::time::{Duration, Instant};
-use sync::Mutex;
use sys_util::{Error as SysError, SharedMemory as SysSharedMemory};
use super::shm_vios::{Error, Result};
@@ -33,17 +32,14 @@ type GenericResult<T> = std::result::Result<T, BoxError>;
/// Adapter that provides the ShmStreamSource trait around the VioS backend.
pub struct VioSShmStreamSource {
- // Reference counting is needed because the streams also need a reference to the client to push
- // buffers and release the stream on drop and it has to implement Send because ShmStreamSource
- // requires it, so the only possibility is Arc.
- vios_client: Arc<Mutex<VioSClient>>,
+ vios_client: Arc<VioSClient>,
}
impl VioSShmStreamSource {
/// Creates a new stream source given the path to the audio server's socket.
pub fn new<P: AsRef<Path>>(server: P) -> Result<VioSShmStreamSource> {
Ok(Self {
- vios_client: Arc::new(Mutex::new(VioSClient::try_new(server)?)),
+ vios_client: Arc::new(VioSClient::try_new(server)?),
})
}
}
@@ -93,39 +89,41 @@ impl ShmStreamSource for VioSShmStreamSource {
client_shm: &SysSharedMemory,
_buffer_offsets: [u64; 2],
) -> GenericResult<Box<dyn ShmStream>> {
- let mut vios_client = self.vios_client.lock();
- match direction {
- StreamDirection::Playback => {
- match vios_client.get_unused_stream_id(VIRTIO_SND_D_OUTPUT) {
- Some(stream_id) => {
- vios_client.prepare_stream(stream_id)?;
- let frame_size = num_channels * format.sample_bytes();
- let period_bytes = (frame_size * buffer_size) as u32;
- let params = VioSStreamParams {
- buffer_bytes: 2 * period_bytes,
- period_bytes,
- features: 0u32,
- channels: num_channels as u8,
- format: from_sample_format(format),
- rate: virtio_frame_rate(frame_rate)?,
- };
- vios_client.set_stream_parameters(stream_id, params)?;
- vios_client.start_stream(stream_id)?;
- VioSndShmStream::new(
- buffer_size,
- num_channels,
- format,
- frame_rate,
- stream_id,
- self.vios_client.clone(),
- client_shm,
- )
- }
- None => Err(Box::new(Error::NoStreamsAvailable)),
- }
- }
- StreamDirection::Capture => panic!("Capture not yet supported"),
- }
+ self.vios_client.ensure_bg_thread_started()?;
+ let virtio_dir = match direction {
+ StreamDirection::Playback => VIRTIO_SND_D_OUTPUT,
+ StreamDirection::Capture => VIRTIO_SND_D_INPUT,
+ };
+ let frame_size = num_channels * format.sample_bytes();
+ let period_bytes = (frame_size * buffer_size) as u32;
+ let stream_id = self
+ .vios_client
+ .get_unused_stream_id(virtio_dir)
+ .ok_or(Box::new(Error::NoStreamsAvailable))?;
+ // Create the stream object before any errors can be returned to guarantee the stream will
+ // be released in all cases
+ let stream_box = VioSndShmStream::new(
+ buffer_size,
+ num_channels,
+ format,
+ frame_rate,
+ stream_id,
+ direction,
+ self.vios_client.clone(),
+ client_shm,
+ );
+ self.vios_client.prepare_stream(stream_id)?;
+ let params = VioSStreamParams {
+ buffer_bytes: 2 * period_bytes,
+ period_bytes,
+ features: 0u32,
+ channels: num_channels as u8,
+ format: from_sample_format(format),
+ rate: virtio_frame_rate(frame_rate)?,
+ };
+ self.vios_client.set_stream_parameters(stream_id, params)?;
+ self.vios_client.start_stream(stream_id)?;
+ stream_box
}
/// Get a list of file descriptors used by the implementation.
@@ -134,7 +132,7 @@ impl ShmStreamSource for VioSShmStreamSource {
/// This list helps users of the ShmStreamSource enter Linux jails without
/// closing needed file descriptors.
fn keep_fds(&self) -> Vec<RawFd> {
- self.vios_client.lock().keep_fds()
+ self.vios_client.keep_fds()
}
}
@@ -148,7 +146,8 @@ pub struct VioSndShmStream {
next_frame: Duration,
start_time: Instant,
stream_id: u32,
- vios_client: Arc<Mutex<VioSClient>>,
+ direction: StreamDirection,
+ vios_client: Arc<VioSClient>,
client_shm: SharedMemory,
}
@@ -160,21 +159,22 @@ impl VioSndShmStream {
format: SampleFormat,
frame_rate: u32,
stream_id: u32,
- vios_client: Arc<Mutex<VioSClient>>,
+ direction: StreamDirection,
+ vios_client: Arc<VioSClient>,
client_shm: &SysSharedMemory,
) -> GenericResult<Box<dyn ShmStream>> {
let interval = Duration::from_millis(buffer_size as u64 * 1000 / frame_rate as u64);
let dup_fd = unsafe {
- // Safe because dup doesn't affect memory and client_shm should wrap a known valid file
- // descriptor
- libc::dup(client_shm.as_raw_fd())
+ // Safe because fcntl doesn't affect memory and client_shm should wrap a known valid
+ // file descriptor.
+ libc::fcntl(client_shm.as_raw_fd(), libc::F_DUPFD_CLOEXEC, 0)
};
if dup_fd < 0 {
return Err(Box::new(Error::DupError(SysError::last())));
}
let file = unsafe {
- // safe because we checked the result of libc::dup()
+ // safe because we checked the result of libc::fcntl()
File::from_raw_fd(dup_fd)
};
let client_shm_clone =
@@ -189,6 +189,7 @@ impl VioSndShmStream {
next_frame: interval,
start_time: Instant::now(),
stream_id,
+ direction,
vios_client,
client_shm: client_shm_clone,
}))
@@ -211,10 +212,10 @@ impl ShmStream for VioSndShmStream {
/// Waits until the next time a frame should be sent to the server. The server may release the
/// previous buffer much sooner than it needs the next one, so this function may sleep to wait
/// for the right time.
- fn wait_for_next_action_with_timeout<'b>(
- &'b mut self,
+ fn wait_for_next_action_with_timeout(
+ &mut self,
timeout: Duration,
- ) -> GenericResult<Option<ServerRequest<'b>>> {
+ ) -> GenericResult<Option<ServerRequest>> {
let elapsed = self.start_time.elapsed();
if elapsed < self.next_frame {
if timeout < self.next_frame - elapsed {
@@ -231,12 +232,24 @@ impl ShmStream for VioSndShmStream {
impl BufferSet for VioSndShmStream {
fn callback(&mut self, offset: usize, frames: usize) -> GenericResult<()> {
- self.vios_client.lock().inject_audio_data(
- self.stream_id,
- &mut self.client_shm,
- offset,
- frames * self.frame_size,
- )?;
+ match self.direction {
+ StreamDirection::Playback => {
+ self.vios_client.inject_audio_data(
+ self.stream_id,
+ &mut self.client_shm,
+ offset,
+ frames * self.frame_size,
+ )?;
+ }
+ StreamDirection::Capture => {
+ self.vios_client.request_audio_data(
+ self.stream_id,
+ &mut self.client_shm,
+ offset,
+ frames * self.frame_size,
+ )?;
+ }
+ }
Ok(())
}
@@ -247,11 +260,11 @@ impl BufferSet for VioSndShmStream {
impl Drop for VioSndShmStream {
fn drop(&mut self) {
- let mut client = self.vios_client.lock();
let stream_id = self.stream_id;
- if let Err(e) = client
+ if let Err(e) = self
+ .vios_client
.stop_stream(stream_id)
- .and_then(|_| client.release_stream(stream_id))
+ .and_then(|_| self.vios_client.release_stream(stream_id))
{
error!("Failed to stop and release stream {}: {}", stream_id, e);
}
diff --git a/devices/src/virtio/snd/vios_backend/shm_vios.rs b/devices/src/virtio/snd/vios_backend/shm_vios.rs
index d6372e68c..02ffa1ce5 100644
--- a/devices/src/virtio/snd/vios_backend/shm_vios.rs
+++ b/devices/src/virtio/snd/vios_backend/shm_vios.rs
@@ -6,15 +6,22 @@ use crate::virtio::snd::constants::*;
use crate::virtio::snd::layout::*;
use base::{
- net::UnixSeqpacket, Error as BaseError, FromRawDescriptor, IntoRawDescriptor, MemoryMapping,
- MemoryMappingBuilder, MmapError, SafeDescriptor, ScmSocket, SharedMemory,
+ error, net::UnixSeqpacket, Error as BaseError, Event, FromRawDescriptor, IntoRawDescriptor,
+ MemoryMapping, MemoryMappingBuilder, MmapError, PollToken, SafeDescriptor, ScmSocket,
+ SharedMemory, WaitContext,
};
use data_model::{DataInit, VolatileMemory, VolatileMemoryError};
+use std::collections::HashMap;
use std::fs::File;
use std::io::{Error as IOError, ErrorKind as IOErrorKind, Seek, SeekFrom};
use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
use std::path::Path;
+use std::sync::mpsc::{channel, Receiver, RecvError, Sender};
+use std::sync::Arc;
+use std::thread::JoinHandle;
+
+use sync::Mutex;
use thiserror::Error as ThisError;
@@ -58,6 +65,20 @@ pub enum Error {
PlatformNotSupported,
#[error("Command failed with status {0}")]
CommandFailed(u32),
+ #[error("IO buffer operation failed: status = {0}")]
+ IOBufferError(u32),
+ #[error("Failed to duplicate UnixSeqpacket: {0}")]
+ UnixSeqpacketDupError(IOError),
+ #[error("Sender was dropped without sending buffer status, the recv thread may have exited")]
+ BufferStatusSenderLost(RecvError),
+ #[error("Failed to create Recv event: {0}")]
+ EventCreateError(BaseError),
+ #[error("Failed to dup Recv event: {0}")]
+ EventDupError(BaseError),
+ #[error("Failed to create Recv thread's WaitContext: {0}")]
+ WaitContextCreateError(BaseError),
+ #[error("Error waiting for events")]
+ WaitError(BaseError),
}
#[derive(ThisError, Debug)]
@@ -75,14 +96,21 @@ pub enum ProtocolErrorKind {
/// The client for the VioS backend
///
/// Uses a protocol equivalent to virtio-snd over a shared memory file and a unix socket for
-/// notifications.
+/// notifications. It's thread safe, it can be encapsulated in an Arc smart pointer and shared
+/// between threads.
pub struct VioSClient {
config: VioSConfig,
- streams: Vec<VioSStreamInfo>,
- control_socket: UnixSeqpacket,
- event_socket: UnixSeqpacket,
- tx: IoBufferQueue,
- rx: IoBufferQueue,
+ // These mutexes should almost never be held simultaneously. If at some point they have to the
+ // locking order should match the order in which they are declared here.
+ streams: Mutex<Vec<VioSStreamInfo>>,
+ control_socket: Mutex<UnixSeqpacket>,
+ event_socket: Mutex<UnixSeqpacket>,
+ tx: Mutex<IoBufferQueue>,
+ rx: Mutex<IoBufferQueue>,
+ rx_subscribers: Arc<Mutex<HashMap<usize, Sender<(u32, usize)>>>>,
+ recv_running: Arc<Mutex<bool>>,
+ recv_event: Mutex<Event>,
+ recv_thread: Mutex<Option<JoinHandle<Result<()>>>>,
}
impl VioSClient {
@@ -93,14 +121,14 @@ impl VioSClient {
let mut config: VioSConfig = Default::default();
let mut fds: Vec<RawFd> = Vec::new();
const NUM_FDS: usize = 5;
- fds.resize(NUM_FDS, 0 as RawFd);
+ fds.resize(NUM_FDS, 0);
let (recv_size, fd_count) = client_socket
.recv_with_fds(config.as_mut_slice(), &mut fds)
.map_err(|e| Error::ServerError(e))?;
// Resize the vector to the actual number of file descriptors received and wrap them in
// SafeDescriptors to prevent leaks
- fds.resize(fd_count, -1 as RawFd);
+ fds.resize(fd_count, -1);
let mut safe_fds: Vec<SafeDescriptor> = fds
.into_iter()
.map(|fd| unsafe {
@@ -153,30 +181,61 @@ impl VioSClient {
));
}
- let control_socket = client_socket;
+ let rx_subscribers: Arc<Mutex<HashMap<usize, Sender<(u32, usize)>>>> =
+ Arc::new(Mutex::new(HashMap::new()));
+ let recv_running = Arc::new(Mutex::new(true));
+ let recv_event = Event::new().map_err(|e| Error::EventCreateError(e))?;
let mut client = VioSClient {
config,
- streams: Vec::new(),
- control_socket,
- event_socket,
- tx: IoBufferQueue::new(tx_socket, tx_shm_file)?,
- rx: IoBufferQueue::new(rx_socket, rx_shm_file)?,
+ streams: Mutex::new(Vec::new()),
+ control_socket: Mutex::new(client_socket),
+ event_socket: Mutex::new(event_socket),
+ tx: Mutex::new(IoBufferQueue::new(tx_socket, tx_shm_file)?),
+ rx: Mutex::new(IoBufferQueue::new(rx_socket, rx_shm_file)?),
+ rx_subscribers,
+ recv_running,
+ recv_event: Mutex::new(recv_event),
+ recv_thread: Mutex::new(None),
};
client.request_and_cache_streams_info()?;
-
Ok(client)
}
- /// Get a description of the available sound streams.
- pub fn streams(&self) -> &Vec<VioSStreamInfo> {
- &self.streams
+ pub fn ensure_bg_thread_started(&self) -> Result<()> {
+ if self.recv_thread.lock().is_some() {
+ return Ok(());
+ }
+ let event_socket = self
+ .recv_event
+ .lock()
+ .try_clone()
+ .map_err(|e| Error::EventDupError(e))?;
+ let rx_socket = self
+ .rx
+ .lock()
+ .socket
+ .try_clone()
+ .map_err(|e| Error::UnixSeqpacketDupError(e))?;
+ let mut opt = self.recv_thread.lock();
+ // The lock on recv_thread was released above to avoid holding more than one lock at a time
+ // while duplicating the fds. So we have to check again the condition.
+ if opt.is_none() {
+ *opt = Some(spawn_recv_thread(
+ self.rx_subscribers.clone(),
+ event_socket,
+ self.recv_running.clone(),
+ rx_socket,
+ ));
+ }
+ Ok(())
}
/// Gets an unused stream id of the specified direction. `direction` must be one of
/// VIRTIO_SND_D_INPUT OR VIRTIO_SND_D_OUTPUT.
pub fn get_unused_stream_id(&self, direction: u8) -> Option<u32> {
self.streams
+ .lock()
.iter()
.filter(|s| s.state == StreamState::Available && s.direction == direction as u8)
.map(|s| s.id)
@@ -184,21 +243,20 @@ impl VioSClient {
}
/// Configures a stream with the given parameters.
- pub fn set_stream_parameters(
- &mut self,
- stream_id: u32,
- params: VioSStreamParams,
- ) -> Result<()> {
- self.validate_stream_id(stream_id, &[StreamState::Available, StreamState::Acquired])?;
+ pub fn set_stream_parameters(&self, stream_id: u32, params: VioSStreamParams) -> Result<()> {
+ self.validate_stream_id(
+ stream_id,
+ &[StreamState::Available, StreamState::Acquired],
+ None,
+ )?;
let raw_params: virtio_snd_pcm_set_params = (stream_id, params).into();
- seq_socket_send(&self.control_socket, raw_params)?;
- self.recv_cmd_status()?;
- self.streams[stream_id as usize].state = StreamState::Acquired;
+ self.send_cmd(raw_params)?;
+ self.streams.lock()[stream_id as usize].state = StreamState::Acquired;
Ok(())
}
/// Send the PREPARE_STREAM command to the server.
- pub fn prepare_stream(&mut self, stream_id: u32) -> Result<()> {
+ pub fn prepare_stream(&self, stream_id: u32) -> Result<()> {
self.common_stream_op(
stream_id,
&[StreamState::Available, StreamState::Acquired],
@@ -208,7 +266,7 @@ impl VioSClient {
}
/// Send the RELEASE_STREAM command to the server.
- pub fn release_stream(&mut self, stream_id: u32) -> Result<()> {
+ pub fn release_stream(&self, stream_id: u32) -> Result<()> {
self.common_stream_op(
stream_id,
&[StreamState::Acquired],
@@ -218,7 +276,7 @@ impl VioSClient {
}
/// Send the START_STREAM command to the server.
- pub fn start_stream(&mut self, stream_id: u32) -> Result<()> {
+ pub fn start_stream(&self, stream_id: u32) -> Result<()> {
self.common_stream_op(
stream_id,
&[StreamState::Acquired],
@@ -228,7 +286,7 @@ impl VioSClient {
}
/// Send the STOP_STREAM command to the server.
- pub fn stop_stream(&mut self, stream_id: u32) -> Result<()> {
+ pub fn stop_stream(&self, stream_id: u32) -> Result<()> {
self.common_stream_op(
stream_id,
&[StreamState::Active],
@@ -239,77 +297,121 @@ impl VioSClient {
/// Send audio frames to the server. The audio data is taken from a shared memory resource.
pub fn inject_audio_data(
- &mut self,
+ &self,
stream_id: u32,
buffer: &mut SharedMemory,
src_offset: usize,
size: usize,
) -> Result<()> {
- self.validate_stream_id(stream_id, &[StreamState::Active])?;
- if self.streams[stream_id as usize].direction != VIRTIO_SND_D_OUTPUT {
- return Err(Error::WrongDirection(
- self.streams[stream_id as usize].direction,
- ));
- }
- let dst_offset = self.tx.push_buffer(buffer, src_offset, size)?;
+ self.validate_stream_id(stream_id, &[StreamState::Active], Some(VIRTIO_SND_D_OUTPUT))?;
+ let mut tx_lock = self.tx.lock();
+ let tx = &mut *tx_lock;
+ let dst_offset = tx.push_buffer(buffer, src_offset, size)?;
let msg = IoTransferMsg::new(stream_id, dst_offset, size);
- seq_socket_send(&self.tx.socket, msg)
+ seq_socket_send(&tx.socket, msg)
+ }
+
+ pub fn request_audio_data(
+ &self,
+ stream_id: u32,
+ buffer: &mut SharedMemory,
+ dst_offset: usize,
+ size: usize,
+ ) -> Result<usize> {
+ self.validate_stream_id(stream_id, &[StreamState::Active], Some(VIRTIO_SND_D_INPUT))?;
+ let (src_offset, status_promise) = {
+ let mut rx_lock = self.rx.lock();
+ let rx = &mut *rx_lock;
+ let src_offset = rx.allocate_buffer(size)?;
+ // Register to receive the status before sending the buffer to the server
+ let (sender, receiver): (Sender<(u32, usize)>, Receiver<(u32, usize)>) = channel();
+ // It's OK to acquire rx_subscriber's lock after rx_lock
+ self.rx_subscribers.lock().insert(src_offset, sender);
+ let msg = IoTransferMsg::new(stream_id, src_offset, size);
+ seq_socket_send(&rx.socket, msg)?;
+ (src_offset, receiver)
+ };
+ // Make sure no mutexes are held while awaiting for the buffer to be written to
+ let recv_size = await_status(status_promise)?;
+ {
+ let mut rx_lock = self.rx.lock();
+ rx_lock
+ .pop_buffer(buffer, dst_offset, recv_size, src_offset)
+ .map(|()| recv_size)
+ }
}
/// Get a list of file descriptors used by the implementation.
pub fn keep_fds(&self) -> Vec<RawFd> {
+ let control_fd = self.control_socket.lock().as_raw_fd();
+ let event_fd = self.event_socket.lock().as_raw_fd();
+ let (tx_socket_fd, tx_shm_fd) = {
+ let lock = self.tx.lock();
+ (lock.socket.as_raw_fd(), lock.file.as_raw_fd())
+ };
+ let (rx_socket_fd, rx_shm_fd) = {
+ let lock = self.rx.lock();
+ (lock.socket.as_raw_fd(), lock.file.as_raw_fd())
+ };
vec![
- self.control_socket.as_raw_fd(),
- self.event_socket.as_raw_fd(),
- self.tx.socket.as_raw_fd(),
- self.rx.socket.as_raw_fd(),
- self.tx.file.as_raw_fd(),
- self.rx.file.as_raw_fd(),
+ control_fd,
+ event_fd,
+ tx_socket_fd,
+ tx_shm_fd,
+ rx_socket_fd,
+ rx_shm_fd,
]
}
- fn validate_stream_id(&self, stream_id: u32, permitted_states: &[StreamState]) -> Result<()> {
- if stream_id >= self.streams.len() as u32 {
+ fn send_cmd<T: DataInit>(&self, data: T) -> Result<()> {
+ let mut control_socket_lock = self.control_socket.lock();
+ seq_socket_send(&mut *control_socket_lock, data)?;
+ recv_cmd_status(&mut *control_socket_lock)
+ }
+
+ fn validate_stream_id(
+ &self,
+ stream_id: u32,
+ permitted_states: &[StreamState],
+ direction: Option<u8>,
+ ) -> Result<()> {
+ let streams_lock = self.streams.lock();
+ let stream_idx = stream_id as usize;
+ if stream_idx >= streams_lock.len() {
return Err(Error::InvalidStreamId(stream_id));
}
- if !permitted_states.contains(&self.streams[stream_id as usize].state) {
- return Err(Error::UnexpectedState(
- self.streams[stream_id as usize].state,
- ));
+ if !permitted_states.contains(&streams_lock[stream_idx].state) {
+ return Err(Error::UnexpectedState(streams_lock[stream_idx].state));
+ }
+ match direction {
+ None => Ok(()),
+ Some(d) => {
+ if d == streams_lock[stream_idx].direction {
+ Ok(())
+ } else {
+ Err(Error::WrongDirection(streams_lock[stream_idx].direction))
+ }
+ }
}
- Ok(())
}
fn common_stream_op(
- &mut self,
+ &self,
stream_id: u32,
expected_states: &[StreamState],
new_state: StreamState,
op: u32,
) -> Result<()> {
- self.validate_stream_id(stream_id, expected_states)?;
+ self.validate_stream_id(stream_id, expected_states, None)?;
let msg = virtio_snd_pcm_hdr {
hdr: virtio_snd_hdr { code: op.into() },
stream_id: stream_id.into(),
};
- seq_socket_send(&self.control_socket, msg)?;
- self.recv_cmd_status()?;
- self.streams[stream_id as usize].state = new_state;
+ self.send_cmd(msg)?;
+ self.streams.lock()[stream_id as usize].state = new_state;
Ok(())
}
- fn recv_cmd_status(&mut self) -> Result<()> {
- let mut status: virtio_snd_hdr = Default::default();
- self.control_socket
- .recv(status.as_mut_slice())
- .map_err(|e| Error::ServerIOError(e))?;
- if status.code.to_native() == VIRTIO_SND_S_OK {
- Ok(())
- } else {
- Err(Error::CommandFailed(status.code.to_native()))
- }
- }
-
fn request_and_cache_streams_info(&mut self) -> Result<()> {
let num_streams = self.config.streams as usize;
let info_size = std::mem::size_of::<virtio_snd_pcm_info>();
@@ -321,10 +423,9 @@ impl VioSClient {
count: (num_streams as u32).into(),
size: (std::mem::size_of::<virtio_snd_query_info>() as u32).into(),
};
- seq_socket_send(&self.control_socket, req)?;
- self.recv_cmd_status()?;
- let info_vec = self
- .control_socket
+ self.send_cmd(req)?;
+ let control_socket_lock = self.control_socket.lock();
+ let info_vec = control_socket_lock
.recv_as_vec()
.map_err(|e| Error::ServerIOError(e))?;
if info_vec.len() != num_streams * info_size {
@@ -332,19 +433,123 @@ impl VioSClient {
ProtocolErrorKind::UnexpectedMessageSize(num_streams * info_size, info_vec.len()),
));
}
- self.streams = info_vec
- .chunks(info_size)
- .enumerate()
- .map(|(id, info_buffer)| {
- // unwrap is safe because we checked the size of the vector above
- let virtio_stream_info = virtio_snd_pcm_info::from_slice(&info_buffer).unwrap();
- VioSStreamInfo::new(id as u32, &virtio_stream_info)
- })
- .collect();
+ self.streams = Mutex::new(
+ info_vec
+ .chunks(info_size)
+ .enumerate()
+ .map(|(id, info_buffer)| {
+ // unwrap is safe because we checked the size of the vector
+ let virtio_stream_info = virtio_snd_pcm_info::from_slice(&info_buffer).unwrap();
+ VioSStreamInfo::new(id as u32, &virtio_stream_info)
+ })
+ .collect(),
+ );
Ok(())
}
}
+impl Drop for VioSClient {
+ fn drop(&mut self) {
+ // Stop the recv thread
+ *self.recv_running.lock() = false;
+ if let Err(e) = self.recv_event.lock().write(1u64) {
+ error!("Failed to notify recv thread: {:?}", e);
+ }
+ if let Some(handle) = self.recv_thread.lock().take() {
+ match handle.join() {
+ Ok(r) => {
+ if let Err(e) = r {
+ error!("Error detected on Recv Thread: {}", e);
+ }
+ }
+ Err(e) => error!("Recv thread panicked: {:?}", e),
+ };
+ }
+ }
+}
+
+#[derive(PollToken)]
+enum Token {
+ Notification,
+ RxBufferMsg,
+}
+
+fn spawn_recv_thread(
+ rx_subscribers: Arc<Mutex<HashMap<usize, Sender<(u32, usize)>>>>,
+ event: Event,
+ running: Arc<Mutex<bool>>,
+ rx_socket: UnixSeqpacket,
+) -> JoinHandle<Result<()>> {
+ std::thread::spawn(move || {
+ let wait_ctx: WaitContext<Token> = WaitContext::build_with(&[
+ (&rx_socket, Token::RxBufferMsg),
+ (&event, Token::Notification),
+ ])
+ .map_err(|e| Error::WaitContextCreateError(e))?;
+ while *running.lock() {
+ let events = wait_ctx.wait().map_err(|e| Error::WaitError(e))?;
+ for evt in events {
+ match evt.token {
+ Token::RxBufferMsg => {
+ let mut msg: IoStatusMsg = Default::default();
+ let size = rx_socket
+ .recv(msg.as_mut_slice())
+ .map_err(|e| Error::ServerIOError(e))?;
+ if size != std::mem::size_of::<IoStatusMsg>() {
+ return Err(Error::ProtocolError(
+ ProtocolErrorKind::UnexpectedMessageSize(
+ std::mem::size_of::<IoStatusMsg>(),
+ size,
+ ),
+ ));
+ }
+ let mut status = msg.status.status.into();
+ if status == u32::MAX {
+ // Anyone waiting for this would continue to wait for as long as status is
+ // u32::MAX
+ status -= 1;
+ }
+ let offset = msg.buffer_offset as usize;
+ let consumed_len = msg.consumed_len as usize;
+ // Acquire and immediately release the mutex protecting the hashmap
+ let promise_opt = rx_subscribers.lock().remove(&offset);
+ match promise_opt {
+ None => error!(
+ "Received an unexpected buffer status message: {}. This is a BUG!!",
+ offset
+ ),
+ Some(sender) => {
+ if let Err(e) = sender.send((status, consumed_len)) {
+ error!("Failed to notify waiting thread: {:?}", e);
+ }
+ }
+ }
+ }
+ Token::Notification => {
+ // Just consume the notification and check for termination on the next
+ // iteration
+ if let Err(e) = event.read() {
+ error!("Failed to consume notification from recv thread: {:?}", e);
+ }
+ }
+ }
+ }
+ }
+ Ok(())
+ })
+}
+
+fn await_status(promise: Receiver<(u32, usize)>) -> Result<usize> {
+ let (status, consumed_len) = promise
+ .recv()
+ .map_err(|e| Error::BufferStatusSenderLost(e))?;
+ if status == VIRTIO_SND_S_OK {
+ Ok(consumed_len)
+ } else {
+ Err(Error::IOBufferError(status))
+ }
+}
+
struct IoBufferQueue {
socket: UnixSeqpacket,
file: File,
@@ -360,7 +565,7 @@ impl IoBufferQueue {
.map_err(|e| Error::FileSizeError(e))? as usize;
let mmap = MemoryMappingBuilder::new(size)
- .from_descriptor(&file)
+ .from_file(&file)
.build()
.map_err(|e| Error::ServerMmapError(e))?;
@@ -373,17 +578,22 @@ impl IoBufferQueue {
})
}
- fn push_buffer(&mut self, src: &mut SharedMemory, offset: usize, size: usize) -> Result<usize> {
+ fn allocate_buffer(&mut self, size: usize) -> Result<usize> {
if size > self.size {
return Err(Error::OutOfSpace);
}
- let shm_offset = if size > self.size - self.next {
+ let offset = if size > self.size - self.next {
// Can't fit the new buffer at the end of the area, so put it at the beginning
0
} else {
self.next
};
+ self.next = offset + size;
+ Ok(offset)
+ }
+ fn push_buffer(&mut self, src: &mut SharedMemory, offset: usize, size: usize) -> Result<usize> {
+ let shm_offset = self.allocate_buffer(size)?;
let (src_mmap, mmap_offset) = mmap_buffer(src, offset, size)?;
let src_slice = src_mmap
.get_slice(mmap_offset, size)
@@ -393,9 +603,27 @@ impl IoBufferQueue {
.get_slice(shm_offset, size)
.map_err(|e| Error::VolatileMemoryError(e))?;
src_slice.copy_to_volatile_slice(dst_slice);
- self.next = shm_offset + size;
Ok(shm_offset)
}
+
+ fn pop_buffer(
+ &mut self,
+ dst: &mut SharedMemory,
+ dst_offset: usize,
+ size: usize,
+ src_offset: usize,
+ ) -> Result<()> {
+ let (dst_mmap, mmap_offset) = mmap_buffer(dst, dst_offset, size)?;
+ let dst_slice = dst_mmap
+ .get_slice(mmap_offset, size)
+ .map_err(|e| Error::VolatileMemoryError(e))?;
+ let src_slice = self
+ .mmap
+ .get_slice(src_offset, size)
+ .map_err(|e| Error::VolatileMemoryError(e))?;
+ src_slice.copy_to_volatile_slice(dst_slice);
+ Ok(())
+ }
}
/// Description of a stream made available by the server.
@@ -479,13 +707,25 @@ fn mmap_buffer(
let mmap = MemoryMappingBuilder::new(extended_size)
.offset(aligned_offset as u64)
- .from_descriptor(src)
+ .from_shared_memory(src)
.build()
.map_err(|e| Error::GuestMmapError(e))?;
Ok((mmap, offset_from_mapping_start))
}
+fn recv_cmd_status(control_socket: &mut UnixSeqpacket) -> Result<()> {
+ let mut status: virtio_snd_hdr = Default::default();
+ control_socket
+ .recv(status.as_mut_slice())
+ .map_err(|e| Error::ServerIOError(e))?;
+ if status.code.to_native() == VIRTIO_SND_S_OK {
+ Ok(())
+ } else {
+ Err(Error::CommandFailed(status.code.to_native()))
+ }
+}
+
fn seq_socket_send<T: DataInit>(socket: &UnixSeqpacket, data: T) -> Result<()> {
loop {
let send_res = socket.send(data.as_slice());
@@ -538,7 +778,7 @@ impl IoTransferMsg {
}
#[repr(C)]
-#[derive(Copy, Clone)]
+#[derive(Copy, Clone, Default)]
struct IoStatusMsg {
status: virtio_snd_pcm_status,
buffer_offset: u32,
diff --git a/devices/src/virtio/tpm.rs b/devices/src/virtio/tpm.rs
index ee68ea1ea..e9e07c90a 100644
--- a/devices/src/virtio/tpm.rs
+++ b/devices/src/virtio/tpm.rs
@@ -14,7 +14,8 @@ use base::{error, Event, PollToken, RawDescriptor, WaitContext};
use vm_memory::GuestMemory;
use super::{
- DescriptorChain, DescriptorError, Interrupt, Queue, Reader, VirtioDevice, Writer, TYPE_TPM,
+ DescriptorChain, DescriptorError, Interrupt, Queue, Reader, SignalableInterrupt, VirtioDevice,
+ Writer, TYPE_TPM,
};
// A single queue of size 2. The guest kernel driver will enqueue a single
@@ -112,9 +113,14 @@ impl Worker {
let wait_ctx = match WaitContext::build_with(&[
(&self.queue_evt, Token::QueueAvailable),
- (self.interrupt.get_resample_evt(), Token::InterruptResample),
(&self.kill_evt, Token::Kill),
- ]) {
+ ])
+ .and_then(|wc| {
+ if let Some(resample_evt) = self.interrupt.get_resample_evt() {
+ wc.add(resample_evt, Token::InterruptResample)?;
+ }
+ Ok(wc)
+ }) {
Ok(pc) => pc,
Err(e) => {
error!("vtpm failed creating WaitContext: {}", e);
diff --git a/devices/src/virtio/vhost/control_socket.rs b/devices/src/virtio/vhost/control_socket.rs
index fc1d0f302..18553885d 100644
--- a/devices/src/virtio/vhost/control_socket.rs
+++ b/devices/src/virtio/vhost/control_socket.rs
@@ -2,10 +2,11 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
-use base::{Error as SysError, RawDescriptor};
-use msg_socket::{MsgOnSocket, MsgSocket};
+use serde::{Deserialize, Serialize};
-#[derive(MsgOnSocket, Debug)]
+use base::Error as SysError;
+
+#[derive(Serialize, Deserialize, Debug)]
pub enum VhostDevRequest {
/// Mask or unmask all the MSI entries for a Virtio Vhost device.
MsixChanged,
@@ -13,24 +14,8 @@ pub enum VhostDevRequest {
MsixEntryChanged(usize),
}
-#[derive(MsgOnSocket, Debug)]
+#[derive(Serialize, Deserialize, Debug)]
pub enum VhostDevResponse {
Ok,
Err(SysError),
}
-
-pub type VhostDevRequestSocket = MsgSocket<VhostDevRequest, VhostDevResponse>;
-pub type VhostDevResponseSocket = MsgSocket<VhostDevResponse, VhostDevRequest>;
-
-/// Create control socket pair. This pair is used to communicate with the
-/// virtio device process.
-/// Mainly between the virtio and activate thread.
-pub fn create_control_sockets() -> (
- Option<VhostDevRequestSocket>,
- Option<VhostDevResponseSocket>,
-) {
- match msg_socket::pair::<VhostDevRequest, VhostDevResponse>() {
- Ok((request, response)) => (Some(request), Some(response)),
- _ => (None, None),
- }
-}
diff --git a/devices/src/virtio/vhost/mod.rs b/devices/src/virtio/vhost/mod.rs
index e1cd06766..14bd8cad8 100644
--- a/devices/src/virtio/vhost/mod.rs
+++ b/devices/src/virtio/vhost/mod.rs
@@ -6,13 +6,14 @@
use std::fmt::{self, Display};
-use base::Error as SysError;
+use base::{Error as SysError, TubeError};
use net_util::Error as TapError;
use remain::sorted;
use vhost::Error as VhostError;
mod control_socket;
mod net;
+pub mod user;
mod vsock;
mod worker;
@@ -27,6 +28,8 @@ pub enum Error {
CloneKillEvent(SysError),
/// Creating kill event failed.
CreateKillEvent(SysError),
+ /// Creating tube failed.
+ CreateTube(TubeError),
/// Creating wait context failed.
CreateWaitContext(SysError),
/// Enabling tap interface failed.
@@ -88,6 +91,7 @@ impl Display for Error {
match self {
CloneKillEvent(e) => write!(f, "failed to clone kill event: {}", e),
CreateKillEvent(e) => write!(f, "failed to create kill event: {}", e),
+ CreateTube(e) => write!(f, "failed to create tube: {}", e),
CreateWaitContext(e) => write!(f, "failed to create poll context: {}", e),
TapEnable(e) => write!(f, "failed to enable tap interface: {}", e),
TapOpen(e) => write!(f, "failed to open tap device: {}", e),
diff --git a/devices/src/virtio/vhost/net.rs b/devices/src/virtio/vhost/net.rs
index 41dc6da4b..d68febe19 100644
--- a/devices/src/virtio/vhost/net.rs
+++ b/devices/src/virtio/vhost/net.rs
@@ -4,11 +4,12 @@
use std::mem;
use std::net::Ipv4Addr;
+use std::path::PathBuf;
use std::thread;
use net_util::{MacAddress, TapT};
-use base::{error, warn, AsRawDescriptor, Event, RawDescriptor};
+use base::{error, warn, AsRawDescriptor, Event, RawDescriptor, Tube};
use vhost::NetT as VhostNetT;
use virtio_sys::virtio_net;
use vm_memory::GuestMemory;
@@ -18,7 +19,6 @@ use super::worker::Worker;
use super::{Error, Result};
use crate::pci::MsixStatus;
use crate::virtio::{Interrupt, Queue, VirtioDevice, TYPE_NET};
-use msg_socket::{MsgReceiver, MsgSender};
const QUEUE_SIZE: u16 = 256;
const NUM_QUEUES: usize = 2;
@@ -33,8 +33,8 @@ pub struct Net<T: TapT, U: VhostNetT<T>> {
vhost_interrupt: Option<Vec<Event>>,
avail_features: u64,
acked_features: u64,
- request_socket: Option<VhostDevRequestSocket>,
- response_socket: Option<VhostDevResponseSocket>,
+ request_tube: Tube,
+ response_tube: Option<Tube>,
}
impl<T, U> Net<T, U>
@@ -45,6 +45,7 @@ where
/// Create a new virtio network device with the given IP address and
/// netmask.
pub fn new(
+ vhost_net_device_path: &PathBuf,
base_features: u64,
ip_addr: Ipv4Addr,
netmask: Ipv4Addr,
@@ -71,7 +72,7 @@ where
.map_err(Error::TapSetVnetHdrSize)?;
tap.enable().map_err(Error::TapEnable)?;
- let vhost_net_handle = U::new(mem).map_err(Error::VhostOpen)?;
+ let vhost_net_handle = U::new(vhost_net_device_path, mem).map_err(Error::VhostOpen)?;
let avail_features = base_features
| 1 << virtio_net::VIRTIO_NET_F_GUEST_CSUM
@@ -90,7 +91,7 @@ where
vhost_interrupt.push(Event::new().map_err(Error::VhostIrqCreate)?);
}
- let (request_socket, response_socket) = create_control_sockets();
+ let (request_tube, response_tube) = Tube::pair().map_err(Error::CreateTube)?;
Ok(Net {
workers_kill_evt: Some(kill_evt.try_clone().map_err(Error::CloneKillEvent)?),
@@ -101,8 +102,8 @@ where
vhost_interrupt: Some(vhost_interrupt),
avail_features,
acked_features: 0u64,
- request_socket,
- response_socket,
+ request_tube,
+ response_tube: Some(response_tube),
})
}
}
@@ -152,12 +153,10 @@ where
}
keep_rds.push(self.kill_evt.as_raw_descriptor());
- if let Some(request_socket) = &self.request_socket {
- keep_rds.push(request_socket.as_raw_descriptor());
- }
+ keep_rds.push(self.request_tube.as_raw_descriptor());
- if let Some(response_socket) = &self.response_socket {
- keep_rds.push(response_socket.as_raw_descriptor());
+ if let Some(response_tube) = &self.response_tube {
+ keep_rds.push(response_tube.as_raw_descriptor());
}
keep_rds
@@ -206,8 +205,8 @@ where
if let Some(vhost_interrupt) = self.vhost_interrupt.take() {
if let Some(kill_evt) = self.workers_kill_evt.take() {
let acked_features = self.acked_features;
- let socket = if self.response_socket.is_some() {
- self.response_socket.take()
+ let socket = if self.response_tube.is_some() {
+ self.response_tube.take()
} else {
None
};
@@ -275,44 +274,50 @@ where
}
fn control_notify(&self, behavior: MsixStatus) {
- if self.worker_thread.is_none() || self.request_socket.is_none() {
+ if self.worker_thread.is_none() {
return;
}
- if let Some(socket) = &self.request_socket {
- match behavior {
- MsixStatus::EntryChanged(index) => {
- if let Err(e) = socket.send(&VhostDevRequest::MsixEntryChanged(index)) {
- error!(
- "{} failed to send VhostMsixEntryChanged request for entry {}: {:?}",
- self.debug_label(),
- index,
- e
- );
- return;
- }
- if let Err(e) = socket.recv() {
- error!("{} failed to receive VhostMsixEntryChanged response for entry {}: {:?}", self.debug_label(), index, e);
- }
+ match behavior {
+ MsixStatus::EntryChanged(index) => {
+ if let Err(e) = self
+ .request_tube
+ .send(&VhostDevRequest::MsixEntryChanged(index))
+ {
+ error!(
+ "{} failed to send VhostMsixEntryChanged request for entry {}: {:?}",
+ self.debug_label(),
+ index,
+ e
+ );
+ return;
}
- MsixStatus::Changed => {
- if let Err(e) = socket.send(&VhostDevRequest::MsixChanged) {
- error!(
- "{} failed to send VhostMsixChanged request: {:?}",
- self.debug_label(),
- e
- );
- return;
- }
- if let Err(e) = socket.recv() {
- error!(
- "{} failed to receive VhostMsixChanged response {:?}",
- self.debug_label(),
- e
- );
- }
+ if let Err(e) = self.request_tube.recv::<VhostDevResponse>() {
+ error!(
+ "{} failed to receive VhostMsixEntryChanged response for entry {}: {:?}",
+ self.debug_label(),
+ index,
+ e
+ );
+ }
+ }
+ MsixStatus::Changed => {
+ if let Err(e) = self.request_tube.send(&VhostDevRequest::MsixChanged) {
+ error!(
+ "{} failed to send VhostMsixChanged request: {:?}",
+ self.debug_label(),
+ e
+ );
+ return;
+ }
+ if let Err(e) = self.request_tube.recv::<VhostDevResponse>() {
+ error!(
+ "{} failed to receive VhostMsixChanged response {:?}",
+ self.debug_label(),
+ e
+ );
}
- _ => {}
}
+ _ => {}
}
}
@@ -334,7 +339,7 @@ where
self.tap = Some(tap);
self.vhost_interrupt = Some(worker.vhost_interrupt);
self.workers_kill_evt = Some(worker.kill_evt);
- self.response_socket = worker.response_socket;
+ self.response_tube = worker.response_tube;
return true;
}
}
@@ -366,6 +371,7 @@ pub mod tests {
let guest_memory = create_guest_memory().unwrap();
let features = base_features(ProtectionType::Unprotected);
Net::<FakeTap, FakeNet<FakeTap>>::new(
+ &PathBuf::from(""),
features,
Ipv4Addr::new(127, 0, 0, 1),
Ipv4Addr::new(255, 255, 255, 0),
diff --git a/devices/src/virtio/vhost/user/block.rs b/devices/src/virtio/vhost/user/block.rs
new file mode 100644
index 000000000..f3b6abfbe
--- /dev/null
+++ b/devices/src/virtio/vhost/user/block.rs
@@ -0,0 +1,180 @@
+// Copyright 2021 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+use std::cell::RefCell;
+use std::os::unix::net::UnixStream;
+use std::path::Path;
+use std::thread;
+use std::u32;
+
+use base::{error, Event, RawDescriptor};
+use cros_async::Executor;
+use virtio_sys::virtio_ring::VIRTIO_RING_F_EVENT_IDX;
+use vm_memory::GuestMemory;
+use vmm_vhost::vhost_user::message::{VhostUserProtocolFeatures, VhostUserVirtioFeatures};
+use vmm_vhost::vhost_user::Master;
+
+use crate::virtio::vhost::user::handler::VhostUserHandler;
+use crate::virtio::vhost::user::worker::Worker;
+use crate::virtio::vhost::user::{Error, Result};
+use crate::virtio::{virtio_blk_config, Interrupt, Queue, VirtioDevice, TYPE_BLOCK};
+
+const VIRTIO_BLK_F_SEG_MAX: u32 = 2;
+const VIRTIO_BLK_F_RO: u32 = 5;
+const VIRTIO_BLK_F_BLK_SIZE: u32 = 6;
+const VIRTIO_BLK_F_FLUSH: u32 = 9;
+const VIRTIO_BLK_F_DISCARD: u32 = 13;
+const VIRTIO_BLK_F_WRITE_ZEROES: u32 = 14;
+
+const QUEUE_SIZE: u16 = 256;
+
+pub struct Block {
+ kill_evt: Option<Event>,
+ worker_thread: Option<thread::JoinHandle<Worker>>,
+ handler: RefCell<VhostUserHandler>,
+ queue_sizes: Vec<u16>,
+}
+
+impl Block {
+ pub fn new<P: AsRef<Path>>(base_features: u64, socket_path: P) -> Result<Block> {
+ let socket = UnixStream::connect(&socket_path).map_err(Error::SocketConnect)?;
+ // TODO(b/181753022): Support multiple queues.
+ let vhost_user_blk = Master::from_stream(socket, 1 /* queues_num */);
+
+ let allow_features = 1u64 << crate::virtio::VIRTIO_F_VERSION_1
+ | 1 << VIRTIO_BLK_F_SEG_MAX
+ | 1 << VIRTIO_BLK_F_RO
+ | 1 << VIRTIO_BLK_F_BLK_SIZE
+ | 1 << VIRTIO_BLK_F_FLUSH
+ | 1 << VIRTIO_BLK_F_DISCARD
+ | 1 << VIRTIO_BLK_F_WRITE_ZEROES
+ | 1 << VIRTIO_RING_F_EVENT_IDX
+ | VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits();
+ let init_features = base_features | VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits();
+ let allow_protocol_features = VhostUserProtocolFeatures::CONFIG;
+
+ let mut handler = VhostUserHandler::new(
+ vhost_user_blk,
+ allow_features,
+ init_features,
+ allow_protocol_features,
+ )?;
+ let queue_sizes = handler.queue_sizes(QUEUE_SIZE, 1)?;
+
+ Ok(Block {
+ kill_evt: None,
+ worker_thread: None,
+ handler: RefCell::new(handler),
+ queue_sizes,
+ })
+ }
+}
+
+impl Drop for Block {
+ fn drop(&mut self) {
+ if let Some(kill_evt) = self.kill_evt.take() {
+ // Ignore the result because there is nothing we can do about it.
+ let _ = kill_evt.write(1);
+ }
+
+ if let Some(worker_thread) = self.worker_thread.take() {
+ let _ = worker_thread.join();
+ }
+ }
+}
+
+impl VirtioDevice for Block {
+ fn keep_rds(&self) -> Vec<RawDescriptor> {
+ Vec::new()
+ }
+
+ fn features(&self) -> u64 {
+ self.handler.borrow().avail_features
+ }
+
+ fn ack_features(&mut self, features: u64) {
+ if let Err(e) = self.handler.borrow_mut().ack_features(features) {
+ error!("failed to enable features 0x{:x}: {}", features, e);
+ }
+ }
+
+ fn device_type(&self) -> u32 {
+ TYPE_BLOCK
+ }
+
+ fn queue_max_sizes(&self) -> &[u16] {
+ self.queue_sizes.as_slice()
+ }
+
+ fn read_config(&self, offset: u64, data: &mut [u8]) {
+ if let Err(e) = self
+ .handler
+ .borrow_mut()
+ .read_config::<virtio_blk_config>(offset, data)
+ {
+ error!("failed to read config: {}", e);
+ }
+ }
+
+ fn activate(
+ &mut self,
+ mem: GuestMemory,
+ interrupt: Interrupt,
+ queues: Vec<Queue>,
+ queue_evts: Vec<Event>,
+ ) {
+ if let Err(e) = self
+ .handler
+ .borrow_mut()
+ .activate(&mem, &interrupt, &queues, &queue_evts)
+ {
+ error!("failed to activate queues: {}", e);
+ return;
+ }
+
+ let (self_kill_evt, kill_evt) = match Event::new().and_then(|e| Ok((e.try_clone()?, e))) {
+ Ok(v) => v,
+ Err(e) => {
+ error!("failed creating kill Event pair: {}", e);
+ return;
+ }
+ };
+ self.kill_evt = Some(self_kill_evt);
+
+ let worker_result = thread::Builder::new()
+ .name("vhost_user_virtio_blk".to_string())
+ .spawn(move || {
+ let ex = Executor::new().expect("failed to create an executor");
+ let mut worker = Worker {
+ queues,
+ mem,
+ kill_evt,
+ };
+
+ if let Err(e) = worker.run(&ex, interrupt) {
+ error!("failed to start a worker: {}", e);
+ }
+ worker
+ });
+
+ match worker_result {
+ Err(e) => {
+ error!("failed to spawn vhost-user virtio_blk worker: {}", e);
+ return;
+ }
+ Ok(join_handle) => {
+ self.worker_thread = Some(join_handle);
+ }
+ }
+ }
+
+ fn reset(&mut self) -> bool {
+ if let Err(e) = self.handler.borrow_mut().reset(self.queue_sizes.len()) {
+ error!("Failed to reset block device: {}", e);
+ false
+ } else {
+ true
+ }
+ }
+}
diff --git a/devices/src/virtio/vhost/user/fs.rs b/devices/src/virtio/vhost/user/fs.rs
new file mode 100644
index 000000000..2e212b97c
--- /dev/null
+++ b/devices/src/virtio/vhost/user/fs.rs
@@ -0,0 +1,192 @@
+// Copyright 2021 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+use std::cell::RefCell;
+use std::os::unix::net::UnixStream;
+use std::path::Path;
+use std::thread;
+
+use base::{error, Event, RawDescriptor};
+use cros_async::Executor;
+use data_model::{DataInit, Le32};
+use vm_memory::GuestMemory;
+use vmm_vhost::vhost_user::message::{VhostUserProtocolFeatures, VhostUserVirtioFeatures};
+use vmm_vhost::vhost_user::{Error as VhostUserError, Master};
+use vmm_vhost::Error as VhostError;
+
+use crate::virtio::fs::{virtio_fs_config, FS_MAX_TAG_LEN, QUEUE_SIZE};
+use crate::virtio::vhost::user::handler::VhostUserHandler;
+use crate::virtio::vhost::user::worker::Worker;
+use crate::virtio::vhost::user::{Error, Result};
+use crate::virtio::{copy_config, TYPE_FS};
+use crate::virtio::{Interrupt, Queue, VirtioDevice};
+
+pub struct Fs {
+ cfg: virtio_fs_config,
+ kill_evt: Option<Event>,
+ worker_thread: Option<thread::JoinHandle<Worker>>,
+ handler: RefCell<VhostUserHandler>,
+ queue_sizes: Vec<u16>,
+}
+impl Fs {
+ pub fn new<P: AsRef<Path>>(base_features: u64, socket_path: P, tag: &str) -> Result<Fs> {
+ if tag.len() > FS_MAX_TAG_LEN {
+ return Err(Error::TagTooLong {
+ len: tag.len(),
+ max: FS_MAX_TAG_LEN,
+ });
+ }
+
+ // The spec requires a minimum of 2 queues: one worker queue and one high priority queue
+ let default_queue_size = 2;
+
+ let mut cfg_tag = [0u8; FS_MAX_TAG_LEN];
+ cfg_tag[..tag.len()].copy_from_slice(tag.as_bytes());
+
+ let cfg = virtio_fs_config {
+ tag: cfg_tag,
+ // Only count the worker queues, exclude the high prio queue
+ num_request_queues: Le32::from(default_queue_size - 1),
+ };
+
+ let socket = UnixStream::connect(&socket_path).map_err(Error::SocketConnect)?;
+ let master = Master::from_stream(socket, default_queue_size as u64);
+
+ let allow_features = 1u64 << crate::virtio::VIRTIO_F_VERSION_1
+ | VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits();
+ let init_features = base_features | VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits();
+ let allow_protocol_features =
+ VhostUserProtocolFeatures::MQ | VhostUserProtocolFeatures::CONFIG;
+
+ let mut handler = VhostUserHandler::new(
+ master,
+ allow_features,
+ init_features,
+ allow_protocol_features,
+ )?;
+ let queue_sizes = handler.queue_sizes(QUEUE_SIZE, default_queue_size as usize)?;
+
+ Ok(Fs {
+ cfg,
+ kill_evt: None,
+ worker_thread: None,
+ handler: RefCell::new(handler),
+ queue_sizes,
+ })
+ }
+}
+
+impl VirtioDevice for Fs {
+ fn keep_rds(&self) -> Vec<RawDescriptor> {
+ Vec::new()
+ }
+
+ fn device_type(&self) -> u32 {
+ TYPE_FS
+ }
+
+ fn queue_max_sizes(&self) -> &[u16] {
+ &self.queue_sizes
+ }
+
+ fn features(&self) -> u64 {
+ self.handler.borrow().avail_features
+ }
+
+ fn ack_features(&mut self, features: u64) {
+ if let Err(e) = self.handler.borrow_mut().ack_features(features) {
+ error!("failed to enable features 0x{:x}: {}", features, e);
+ }
+ }
+
+ fn read_config(&self, offset: u64, data: &mut [u8]) {
+ match self
+ .handler
+ .borrow_mut()
+ .read_config::<virtio_fs_config>(offset, data)
+ {
+ Ok(()) => {}
+ // copy local config when VhostUserProtocolFeatures::CONFIG is not supported by the
+ // device
+ Err(Error::GetConfig(VhostError::VhostUserProtocol(
+ VhostUserError::InvalidOperation,
+ ))) => copy_config(data, 0, self.cfg.as_slice(), offset),
+ Err(e) => error!("Failed to fetch device config: {}", e),
+ }
+ }
+
+ fn activate(
+ &mut self,
+ mem: GuestMemory,
+ interrupt: Interrupt,
+ queues: Vec<Queue>,
+ queue_evts: Vec<Event>,
+ ) {
+ if let Err(e) = self
+ .handler
+ .borrow_mut()
+ .activate(&mem, &interrupt, &queues, &queue_evts)
+ {
+ error!("failed to activate queues: {}", e);
+ return;
+ }
+
+ let (self_kill_evt, kill_evt) = match Event::new().and_then(|e| Ok((e.try_clone()?, e))) {
+ Ok(v) => v,
+ Err(e) => {
+ error!("failed creating kill Event pair: {}", e);
+ return;
+ }
+ };
+ self.kill_evt = Some(self_kill_evt);
+
+ let worker_result = thread::Builder::new()
+ .name("vhost_user_virtio_fs".to_string())
+ .spawn(move || {
+ let ex = Executor::new().expect("failed to create an executor");
+ let mut worker = Worker {
+ queues,
+ mem,
+ kill_evt,
+ };
+
+ if let Err(e) = worker.run(&ex, interrupt) {
+ error!("failed to start a worker: {}", e);
+ }
+ worker
+ });
+
+ match worker_result {
+ Err(e) => {
+ error!("failed to spawn vhost-user virtio_fs worker: {}", e);
+ return;
+ }
+ Ok(join_handle) => {
+ self.worker_thread = Some(join_handle);
+ }
+ }
+ }
+
+ fn reset(&mut self) -> bool {
+ if let Err(e) = self.handler.borrow_mut().reset(self.queue_sizes.len()) {
+ error!("Failed to reset fs device: {}", e);
+ false
+ } else {
+ true
+ }
+ }
+}
+
+impl Drop for Fs {
+ fn drop(&mut self) {
+ if let Some(kill_evt) = self.kill_evt.take() {
+ // Ignore the result because there is nothing we can do about it.
+ let _ = kill_evt.write(1);
+ }
+
+ if let Some(worker_thread) = self.worker_thread.take() {
+ let _ = worker_thread.join();
+ }
+ }
+}
diff --git a/devices/src/virtio/vhost/user/handler.rs b/devices/src/virtio/vhost/user/handler.rs
new file mode 100644
index 000000000..f3797ae8e
--- /dev/null
+++ b/devices/src/virtio/vhost/user/handler.rs
@@ -0,0 +1,228 @@
+// Copyright 2021 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+use std::io::Write;
+
+use base::{AsRawDescriptor, Event};
+use vm_memory::GuestMemory;
+use vmm_vhost::vhost_user::message::{
+ VhostUserConfigFlags, VhostUserProtocolFeatures, VhostUserVirtioFeatures,
+ VHOST_USER_CONFIG_OFFSET,
+};
+use vmm_vhost::vhost_user::{Master, VhostUserMaster};
+use vmm_vhost::{VhostBackend, VhostUserMemoryRegionInfo, VringConfigData};
+
+use crate::virtio::vhost::user::{Error, Result};
+use crate::virtio::{Interrupt, Queue};
+
+fn set_features(vu: &mut Master, avail_features: u64, ack_features: u64) -> Result<u64> {
+ let features = avail_features & ack_features;
+ vu.set_features(features).map_err(Error::SetFeatures)?;
+ Ok(features)
+}
+
+pub struct VhostUserHandler {
+ vu: Master,
+ pub avail_features: u64,
+ acked_features: u64,
+ protocol_features: VhostUserProtocolFeatures,
+}
+
+impl VhostUserHandler {
+ /// Creates a `VhostUserHandler` instance with features and protocol features initialized.
+ pub fn new(
+ mut vu: Master,
+ allow_features: u64,
+ init_features: u64,
+ allow_protocol_features: VhostUserProtocolFeatures,
+ ) -> Result<Self> {
+ vu.set_owner().map_err(Error::SetOwner)?;
+
+ let avail_features = allow_features & vu.get_features().map_err(Error::GetFeatures)?;
+ let acked_features = set_features(&mut vu, avail_features, init_features)?;
+
+ let mut protocol_features = VhostUserProtocolFeatures::empty();
+ if acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() != 0 {
+ let avail_protocol_features = vu
+ .get_protocol_features()
+ .map_err(Error::GetProtocolFeatures)?;
+ protocol_features = allow_protocol_features & avail_protocol_features;
+ vu.set_protocol_features(protocol_features)
+ .map_err(Error::SetProtocolFeatures)?;
+ }
+
+ Ok(VhostUserHandler {
+ vu,
+ avail_features,
+ acked_features,
+ protocol_features,
+ })
+ }
+
+ /// Returns a vector of sizes of each queue.
+ pub fn queue_sizes(&mut self, queue_size: u16, default_queues_num: usize) -> Result<Vec<u16>> {
+ let queues_num = if self
+ .protocol_features
+ .contains(VhostUserProtocolFeatures::MQ)
+ {
+ self.vu.get_queue_num().map_err(Error::GetQueueNum)? as usize
+ } else {
+ default_queues_num
+ };
+ Ok(vec![queue_size; queues_num])
+ }
+
+ /// Enables a set of features.
+ pub fn ack_features(&mut self, ack_features: u64) -> Result<()> {
+ let features = set_features(
+ &mut self.vu,
+ self.avail_features,
+ self.acked_features | ack_features,
+ )?;
+ self.acked_features = features;
+ Ok(())
+ }
+
+ /// Gets the device configuration space at `offset` and writes it into `data`.
+ pub fn read_config<T>(&mut self, offset: u64, mut data: &mut [u8]) -> Result<()> {
+ let config_len = std::mem::size_of::<T>() as u64;
+ let data_len = data.len() as u64;
+ offset
+ .checked_add(data_len)
+ .and_then(|l| if l <= config_len { Some(()) } else { None })
+ .ok_or(Error::InvalidConfigOffset {
+ data_len,
+ offset,
+ config_len,
+ })?;
+
+ let buf = vec![0u8; config_len as usize];
+ let (_, config) = self
+ .vu
+ .get_config(
+ VHOST_USER_CONFIG_OFFSET,
+ config_len as u32,
+ VhostUserConfigFlags::WRITABLE,
+ &buf,
+ )
+ .map_err(Error::GetConfig)?;
+
+ data.write_all(
+ &config[offset as usize..std::cmp::min(data_len + offset, config_len) as usize],
+ )
+ .map_err(Error::CopyConfig)
+ }
+
+ fn set_mem_table(&mut self, mem: &GuestMemory) -> Result<()> {
+ let mut regions: Vec<VhostUserMemoryRegionInfo> = Vec::new();
+ mem.with_regions::<_, ()>(
+ |_idx, guest_phys_addr, memory_size, userspace_addr, mmap, mmap_offset| {
+ let region = VhostUserMemoryRegionInfo {
+ guest_phys_addr: guest_phys_addr.0,
+ memory_size: memory_size as u64,
+ userspace_addr: userspace_addr as u64,
+ mmap_offset,
+ mmap_handle: mmap.as_raw_descriptor(),
+ };
+ regions.push(region);
+ Ok(())
+ },
+ )
+ .unwrap(); // never fail
+
+ self.vu
+ .set_mem_table(regions.as_slice())
+ .map_err(Error::SetMemTable)?;
+
+ Ok(())
+ }
+
+ fn activate_vring(
+ &mut self,
+ mem: &GuestMemory,
+ queue_index: usize,
+ queue: &Queue,
+ queue_evt: &Event,
+ interrupt: &Interrupt,
+ ) -> Result<()> {
+ self.vu
+ .set_vring_num(queue_index, queue.actual_size())
+ .map_err(Error::SetVringNum)?;
+
+ let config_data = VringConfigData {
+ queue_max_size: queue.max_size,
+ queue_size: queue.actual_size(),
+ flags: 0u32,
+ desc_table_addr: mem
+ .get_host_address(queue.desc_table)
+ .map_err(Error::GetHostAddress)? as u64,
+ used_ring_addr: mem
+ .get_host_address(queue.used_ring)
+ .map_err(Error::GetHostAddress)? as u64,
+ avail_ring_addr: mem
+ .get_host_address(queue.avail_ring)
+ .map_err(Error::GetHostAddress)? as u64,
+ log_addr: None,
+ };
+ self.vu
+ .set_vring_addr(queue_index, &config_data)
+ .map_err(Error::SetVringAddr)?;
+
+ self.vu
+ .set_vring_base(queue_index, 0)
+ .map_err(Error::SetVringBase)?;
+
+ let msix_config_opt = interrupt
+ .get_msix_config()
+ .as_ref()
+ .ok_or(Error::MsixConfigUnavailable)?;
+ let msix_config = msix_config_opt.lock();
+ let irqfd = msix_config
+ .get_irqfd(queue.vector as usize)
+ .ok_or(Error::MsixIrqfdUnavailable)?;
+ self.vu
+ .set_vring_call(queue_index, &irqfd.0)
+ .map_err(Error::SetVringCall)?;
+
+ self.vu
+ .set_vring_kick(queue_index, &queue_evt.0)
+ .map_err(Error::SetVringKick)?;
+ self.vu
+ .set_vring_enable(queue_index, true)
+ .map_err(Error::SetVringEnable)?;
+
+ Ok(())
+ }
+
+ /// Activates vrings.
+ pub fn activate(
+ &mut self,
+ mem: &GuestMemory,
+ interrupt: &Interrupt,
+ queues: &[Queue],
+ queue_evts: &[Event],
+ ) -> Result<()> {
+ self.set_mem_table(&mem)?;
+
+ for (queue_index, queue) in queues.iter().enumerate() {
+ let queue_evt = &queue_evts[queue_index];
+ self.activate_vring(&mem, queue_index, queue, queue_evt, &interrupt)?;
+ }
+
+ Ok(())
+ }
+
+ /// Deactivates all vrings.
+ pub fn reset(&mut self, queues_num: usize) -> Result<()> {
+ for queue_index in 0..queues_num {
+ self.vu
+ .set_vring_enable(queue_index, false)
+ .map_err(Error::SetVringEnable)?;
+ self.vu
+ .get_vring_base(queue_index)
+ .map_err(Error::GetVringBase)?;
+ }
+ Ok(())
+ }
+}
diff --git a/devices/src/virtio/vhost/user/mod.rs b/devices/src/virtio/vhost/user/mod.rs
new file mode 100644
index 000000000..b408e1c18
--- /dev/null
+++ b/devices/src/virtio/vhost/user/mod.rs
@@ -0,0 +1,101 @@
+// Copyright 2021 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+mod block;
+mod fs;
+mod handler;
+mod net;
+mod worker;
+
+pub use self::block::*;
+pub use self::fs::*;
+pub use self::net::*;
+
+use remain::sorted;
+use thiserror::Error as ThisError;
+use vm_memory::GuestMemoryError;
+use vmm_vhost::Error as VhostError;
+
+#[sorted]
+#[derive(ThisError, Debug)]
+pub enum Error {
+ /// Failed to copy config to a buffer.
+ #[error("failed to copy config to a buffer: {0}")]
+ CopyConfig(std::io::Error),
+ /// Failed to create `base::Event`.
+ #[error("failed to create Event: {0}")]
+ CreateEvent(base::Error),
+ /// Failed to get config.
+ #[error("failed to get config: {0}")]
+ GetConfig(VhostError),
+ /// Failed to get features.
+ #[error("failed to get features: {0}")]
+ GetFeatures(VhostError),
+ /// Failed to get host address.
+ #[error("failed to get host address: {0}")]
+ GetHostAddress(GuestMemoryError),
+ /// Failed to get protocol features.
+ #[error("failed to get protocol features: {0}")]
+ GetProtocolFeatures(VhostError),
+ /// Failed to get number of queues.
+ #[error("failed to get number of queues: {0}")]
+ GetQueueNum(VhostError),
+ /// Failed to get vring base offset.
+ #[error("failed to get vring base offset: {0}")]
+ GetVringBase(VhostError),
+ /// Invalid config offset is given.
+ #[error("invalid config offset is given: {data_len} + {offset} > {config_len}")]
+ InvalidConfigOffset {
+ data_len: u64,
+ offset: u64,
+ config_len: u64,
+ },
+ /// MSI-X config is unavailable.
+ #[error("MSI-X config is unavailable")]
+ MsixConfigUnavailable,
+ /// MSI-X irqfd is unavailable.
+ #[error("MSI-X irqfd is unavailable")]
+ MsixIrqfdUnavailable,
+ /// Failed to reset owner.
+ #[error("failed to reset owner: {0}")]
+ ResetOwner(VhostError),
+ /// Failed to set features.
+ #[error("failed to set features: {0}")]
+ SetFeatures(VhostError),
+ /// Failed to set memory map regions.
+ #[error("failed to set memory map regions: {0}")]
+ SetMemTable(VhostError),
+ /// Failed to set owner.
+ #[error("failed to set owner: {0}")]
+ SetOwner(VhostError),
+ /// Failed to set protocol features.
+ #[error("failed to set protocol features: {0}")]
+ SetProtocolFeatures(VhostError),
+ /// Failed to set vring address.
+ #[error("failed to set vring address: {0}")]
+ SetVringAddr(VhostError),
+ /// Failed to set vring base offset.
+ #[error("failed to set vring base offset: {0}")]
+ SetVringBase(VhostError),
+ /// Failed to set eventfd to signal used vring buffers.
+ #[error("failed to set eventfd to signal used vring buffers: {0}")]
+ SetVringCall(VhostError),
+ /// Failed to enable or disable vring.
+ #[error("failed to enable or disable vring: {0}")]
+ SetVringEnable(VhostError),
+ /// Failed to set eventfd for adding buffers to vring.
+ #[error("failed to set eventfd for adding buffers to vring: {0}")]
+ SetVringKick(VhostError),
+ /// Failed to set the size of the queue.
+ #[error("failed to set the size of the queue: {0}")]
+ SetVringNum(VhostError),
+ /// Failed to connect socket.
+ #[error("failed to connect socket: {0}")]
+ SocketConnect(std::io::Error),
+ /// The tag for the Fs device was too long to fit in the config space.
+ #[error("tag is too long: {len} > {max}")]
+ TagTooLong { len: usize, max: usize },
+}
+
+pub type Result<T> = std::result::Result<T, Error>;
diff --git a/devices/src/virtio/vhost/user/net.rs b/devices/src/virtio/vhost/user/net.rs
new file mode 100644
index 000000000..fe3deccf4
--- /dev/null
+++ b/devices/src/virtio/vhost/user/net.rs
@@ -0,0 +1,185 @@
+// Copyright 2021 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+use std::cell::RefCell;
+use std::os::unix::net::UnixStream;
+use std::path::Path;
+use std::thread;
+use std::u32;
+
+use base::{error, Event, RawDescriptor};
+use cros_async::Executor;
+use virtio_sys::virtio_net;
+use virtio_sys::virtio_ring::VIRTIO_RING_F_EVENT_IDX;
+use vm_memory::GuestMemory;
+use vmm_vhost::vhost_user::message::{VhostUserProtocolFeatures, VhostUserVirtioFeatures};
+use vmm_vhost::vhost_user::Master;
+
+use crate::virtio::vhost::user::handler::VhostUserHandler;
+use crate::virtio::vhost::user::worker::Worker;
+use crate::virtio::vhost::user::Error;
+use crate::virtio::{Interrupt, Queue, VirtioDevice, VirtioNetConfig, TYPE_NET};
+
+type Result<T> = std::result::Result<T, Error>;
+
+const QUEUE_SIZE: u16 = 256;
+
+pub struct Net {
+ kill_evt: Option<Event>,
+ worker_thread: Option<thread::JoinHandle<Worker>>,
+ handler: RefCell<VhostUserHandler>,
+ queue_sizes: Vec<u16>,
+}
+
+impl Net {
+ pub fn new<P: AsRef<Path>>(base_features: u64, socket_path: P) -> Result<Net> {
+ let socket = UnixStream::connect(&socket_path).map_err(Error::SocketConnect)?;
+ let vhost_user_net = Master::from_stream(socket, 16 /* # of queues */);
+
+ // TODO(b/182430355): Support VIRTIO_NET_F_CTRL_VQ and VIRTIO_NET_F_CTRL_GUEST_OFFLOADS.
+ let allow_features = 1 << crate::virtio::VIRTIO_F_VERSION_1
+ | 1 << virtio_net::VIRTIO_NET_F_CSUM
+ | 1 << virtio_net::VIRTIO_NET_F_GUEST_CSUM
+ | 1 << virtio_net::VIRTIO_NET_F_GUEST_TSO4
+ | 1 << virtio_net::VIRTIO_NET_F_GUEST_UFO
+ | 1 << virtio_net::VIRTIO_NET_F_HOST_TSO4
+ | 1 << virtio_net::VIRTIO_NET_F_HOST_UFO
+ | 1 << virtio_net::VIRTIO_NET_F_MQ
+ | 1 << VIRTIO_RING_F_EVENT_IDX
+ | VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits();
+ let init_features = base_features | VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits();
+ let allow_protocol_features =
+ VhostUserProtocolFeatures::MQ | VhostUserProtocolFeatures::CONFIG;
+
+ let mut handler = VhostUserHandler::new(
+ vhost_user_net,
+ allow_features,
+ init_features,
+ allow_protocol_features,
+ )?;
+ let queue_sizes = handler.queue_sizes(QUEUE_SIZE, 2 /* 1 rx + 1 tx */)?;
+
+ Ok(Net {
+ kill_evt: None,
+ worker_thread: None,
+ handler: RefCell::new(handler),
+ queue_sizes,
+ })
+ }
+}
+
+impl Drop for Net {
+ fn drop(&mut self) {
+ if let Some(kill_evt) = self.kill_evt.take() {
+ // Ignore the result because there is nothing we can do about it.
+ let _ = kill_evt.write(1);
+ }
+
+ if let Some(worker_thread) = self.worker_thread.take() {
+ let _ = worker_thread.join();
+ }
+ }
+}
+
+impl VirtioDevice for Net {
+ fn keep_rds(&self) -> Vec<RawDescriptor> {
+ Vec::new()
+ }
+
+ fn features(&self) -> u64 {
+ self.handler.borrow().avail_features
+ }
+
+ fn ack_features(&mut self, features: u64) {
+ if let Err(e) = self.handler.borrow_mut().ack_features(features) {
+ error!("failed to enable features 0x{:x}: {}", features, e);
+ }
+ }
+
+ fn device_type(&self) -> u32 {
+ TYPE_NET
+ }
+
+ fn queue_max_sizes(&self) -> &[u16] {
+ self.queue_sizes.as_slice()
+ }
+
+ fn read_config(&self, offset: u64, data: &mut [u8]) {
+ if let Err(e) = self
+ .handler
+ .borrow_mut()
+ .read_config::<VirtioNetConfig>(offset, data)
+ {
+ error!("failed to read config: {}", e);
+ }
+ }
+
+ fn activate(
+ &mut self,
+ mem: GuestMemory,
+ interrupt: Interrupt,
+ queues: Vec<Queue>,
+ queue_evts: Vec<Event>,
+ ) {
+ // TODO(b/182430355): Remove this check once ctrlq is supported.
+ if queues.len() % 2 != 0 {
+ error!(
+ "The number of queues must be an even number but {}",
+ queues.len()
+ );
+ }
+
+ if let Err(e) = self
+ .handler
+ .borrow_mut()
+ .activate(&mem, &interrupt, &queues, &queue_evts)
+ {
+ error!("failed to activate queues: {}", e);
+ return;
+ }
+
+ let (self_kill_evt, kill_evt) = match Event::new().and_then(|e| Ok((e.try_clone()?, e))) {
+ Ok(v) => v,
+ Err(e) => {
+ error!("failed creating kill Event pair: {}", e);
+ return;
+ }
+ };
+ self.kill_evt = Some(self_kill_evt);
+
+ let worker_result = thread::Builder::new()
+ .name("vhost_user_virtio_net".to_string())
+ .spawn(move || {
+ let ex = Executor::new().expect("failed to create an executor");
+ let mut worker = Worker {
+ queues,
+ mem,
+ kill_evt,
+ };
+ if let Err(e) = worker.run(&ex, interrupt) {
+ error!("failed to start a worker: {}", e);
+ }
+ worker
+ });
+
+ match worker_result {
+ Err(e) => {
+ error!("failed to spawn virtio_net worker: {}", e);
+ return;
+ }
+ Ok(join_handle) => {
+ self.worker_thread = Some(join_handle);
+ }
+ }
+ }
+
+ fn reset(&mut self) -> bool {
+ if let Err(e) = self.handler.borrow_mut().reset(self.queue_sizes.len()) {
+ error!("Failed to reset net device: {}", e);
+ false
+ } else {
+ true
+ }
+ }
+}
diff --git a/devices/src/virtio/vhost/user/worker.rs b/devices/src/virtio/vhost/user/worker.rs
new file mode 100644
index 000000000..0015a158c
--- /dev/null
+++ b/devices/src/virtio/vhost/user/worker.rs
@@ -0,0 +1,82 @@
+// Copyright 2021 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+use base::{error, Event};
+use cros_async::{select2, AsyncError, EventAsync, Executor, SelectResult};
+use futures::pin_mut;
+use thiserror::Error as ThisError;
+use vm_memory::GuestMemory;
+
+use crate::virtio::interrupt::SignalableInterrupt;
+use crate::virtio::{Interrupt, Queue};
+
+#[derive(ThisError, Debug)]
+enum Error {
+ /// Failed to read the resample event.
+ #[error("failed to read the resample event: {0}")]
+ ReadResampleEvent(AsyncError),
+}
+
+pub struct Worker {
+ pub queues: Vec<Queue>,
+ pub mem: GuestMemory,
+ pub kill_evt: Event,
+}
+
+impl Worker {
+ // Processes any requests to resample the irq value.
+ async fn handle_irq_resample(
+ resample_evt: EventAsync,
+ interrupt: Interrupt,
+ ) -> Result<(), Error> {
+ loop {
+ let _ = resample_evt
+ .next_val()
+ .await
+ .map_err(Error::ReadResampleEvent)?;
+ interrupt.do_interrupt_resample();
+ }
+ }
+
+ // Waits until the kill event is triggered.
+ async fn wait_kill(kill_evt: EventAsync) {
+ // Once this event is readable, exit. Exiting this future will cause the main loop to
+ // break and the device process to exit.
+ let _ = kill_evt.next_val().await;
+ }
+
+ // Runs asynchronous tasks.
+ pub fn run(&mut self, ex: &Executor, interrupt: Interrupt) -> Result<(), String> {
+ let resample_evt = interrupt
+ .get_resample_evt()
+ .expect("resample event required")
+ .try_clone()
+ .expect("failed to clone resample event");
+ let async_resample_evt =
+ EventAsync::new(resample_evt.0, ex).expect("failed to create async resample event");
+ let resample = Self::handle_irq_resample(async_resample_evt, interrupt);
+ pin_mut!(resample);
+
+ let kill_evt = EventAsync::new(
+ self.kill_evt
+ .try_clone()
+ .expect("failed to clone kill_evt")
+ .0,
+ &ex,
+ )
+ .expect("failed to create async kill event fd");
+ let kill = Self::wait_kill(kill_evt);
+ pin_mut!(kill);
+
+ match ex.run_until(select2(resample, kill)) {
+ Ok((resample_res, _)) => {
+ if let SelectResult::Finished(Err(e)) = resample_res {
+ return Err(format!("failed to resample a irq value: {:?}", e));
+ }
+ Ok(())
+ }
+ Err(e) => Err(e.to_string()),
+ }
+ }
+}
diff --git a/devices/src/virtio/vhost/vsock.rs b/devices/src/virtio/vhost/vsock.rs
index 1885131cb..719aea48b 100644
--- a/devices/src/virtio/vhost/vsock.rs
+++ b/devices/src/virtio/vhost/vsock.rs
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
-use std::thread;
+use std::{path::PathBuf, thread};
use data_model::{DataInit, Le64};
@@ -31,9 +31,15 @@ pub struct Vsock {
impl Vsock {
/// Create a new virtio-vsock device with the given VM cid.
- pub fn new(base_features: u64, cid: u64, mem: &GuestMemory) -> Result<Vsock> {
+ pub fn new(
+ vhost_vsock_device_path: &PathBuf,
+ base_features: u64,
+ cid: u64,
+ mem: &GuestMemory,
+ ) -> Result<Vsock> {
let kill_evt = Event::new().map_err(Error::CreateKillEvent)?;
- let handle = VhostVsockHandle::new(mem).map_err(Error::VhostOpen)?;
+ let handle =
+ VhostVsockHandle::new(vhost_vsock_device_path, mem).map_err(Error::VhostOpen)?;
let avail_features = base_features
| 1 << virtio_sys::vhost::VIRTIO_F_NOTIFY_ON_EMPTY
diff --git a/devices/src/virtio/vhost/worker.rs b/devices/src/virtio/vhost/worker.rs
index 04c933a81..9d7822ebf 100644
--- a/devices/src/virtio/vhost/worker.rs
+++ b/devices/src/virtio/vhost/worker.rs
@@ -4,14 +4,13 @@
use std::os::raw::c_ulonglong;
-use base::{error, Error as SysError, Event, PollToken, WaitContext};
+use base::{error, Error as SysError, Event, PollToken, Tube, WaitContext};
use vhost::Vhost;
-use super::control_socket::{VhostDevRequest, VhostDevResponse, VhostDevResponseSocket};
+use super::control_socket::{VhostDevRequest, VhostDevResponse};
use super::{Error, Result};
-use crate::virtio::{Interrupt, Queue};
+use crate::virtio::{Interrupt, Queue, SignalableInterrupt};
use libc::EIO;
-use msg_socket::{MsgReceiver, MsgSender};
/// Worker that takes care of running the vhost device.
pub struct Worker<T: Vhost> {
@@ -21,7 +20,7 @@ pub struct Worker<T: Vhost> {
pub vhost_interrupt: Vec<Event>,
acked_features: u64,
pub kill_evt: Event,
- pub response_socket: Option<VhostDevResponseSocket>,
+ pub response_tube: Option<Tube>,
}
impl<T: Vhost> Worker<T> {
@@ -32,7 +31,7 @@ impl<T: Vhost> Worker<T> {
interrupt: Interrupt,
acked_features: u64,
kill_evt: Event,
- response_socket: Option<VhostDevResponseSocket>,
+ response_tube: Option<Tube>,
) -> Worker<T> {
Worker {
interrupt,
@@ -41,7 +40,7 @@ impl<T: Vhost> Worker<T> {
vhost_interrupt,
acked_features,
kill_evt,
- response_socket,
+ response_tube,
}
}
@@ -106,22 +105,25 @@ impl<T: Vhost> Worker<T> {
ControlNotify,
}
- let wait_ctx: WaitContext<Token> = WaitContext::build_with(&[
- (self.interrupt.get_resample_evt(), Token::InterruptResample),
- (&self.kill_evt, Token::Kill),
- ])
- .map_err(Error::CreateWaitContext)?;
+ let wait_ctx: WaitContext<Token> =
+ WaitContext::build_with(&[(&self.kill_evt, Token::Kill)])
+ .map_err(Error::CreateWaitContext)?;
for (index, vhost_int) in self.vhost_interrupt.iter().enumerate() {
wait_ctx
.add(vhost_int, Token::VhostIrqi { index })
.map_err(Error::CreateWaitContext)?;
}
- if let Some(socket) = &self.response_socket {
+ if let Some(socket) = &self.response_tube {
wait_ctx
.add(socket, Token::ControlNotify)
.map_err(Error::CreateWaitContext)?;
}
+ if let Some(resample_evt) = self.interrupt.get_resample_evt() {
+ wait_ctx
+ .add(resample_evt, Token::InterruptResample)
+ .map_err(Error::CreateWaitContext)?;
+ }
'wait: loop {
let events = wait_ctx.wait().map_err(Error::WaitError)?;
@@ -142,7 +144,7 @@ impl<T: Vhost> Worker<T> {
break 'wait;
}
Token::ControlNotify => {
- if let Some(socket) = &self.response_socket {
+ if let Some(socket) = &self.response_tube {
match socket.recv() {
Ok(VhostDevRequest::MsixEntryChanged(index)) => {
let mut qindex = 0;
@@ -199,8 +201,8 @@ impl<T: Vhost> Worker<T> {
// No response_socket means it doesn't have any control related
// with the msix. Due to this, cannot use the direct irq fd but
// should fall back to indirect irq fd.
- if self.response_socket.is_some() {
- if let Some(msix_config) = &self.interrupt.msix_config {
+ if self.response_tube.is_some() {
+ if let Some(msix_config) = self.interrupt.get_msix_config() {
let msix_config = msix_config.lock();
let msix_masked = msix_config.masked();
if msix_masked {
@@ -228,7 +230,7 @@ impl<T: Vhost> Worker<T> {
}
fn set_vring_calls(&self) -> Result<()> {
- if let Some(msix_config) = &self.interrupt.msix_config {
+ if let Some(msix_config) = self.interrupt.get_msix_config() {
let msix_config = msix_config.lock();
if msix_config.masked() {
for (queue_index, _) in self.queues.iter().enumerate() {
diff --git a/devices/src/virtio/video/decoder/backend/mod.rs b/devices/src/virtio/video/decoder/backend/mod.rs
index a9c1411a1..0bd6b6bf0 100644
--- a/devices/src/virtio/video/decoder/backend/mod.rs
+++ b/devices/src/virtio/video/decoder/backend/mod.rs
@@ -82,6 +82,7 @@ pub trait DecoderSession {
format: Format,
output_buffer: RawDescriptor,
planes: &[FramePlane],
+ modifier: u64,
) -> VideoResult<()>;
/// Ask the device to reuse an output buffer previously passed to
diff --git a/devices/src/virtio/video/decoder/backend/vda.rs b/devices/src/virtio/video/decoder/backend/vda.rs
index 716347cd6..99ce7cbdf 100644
--- a/devices/src/virtio/video/decoder/backend/vda.rs
+++ b/devices/src/virtio/video/decoder/backend/vda.rs
@@ -155,6 +155,7 @@ impl<'a> DecoderSession for LibvdaSession<'a> {
format: Format,
output_buffer: RawDescriptor,
planes: &[FramePlane],
+ modifier: u64,
) -> VideoResult<()> {
let vda_planes: Vec<libvda::FramePlane> = planes.into_iter().map(Into::into).collect();
Ok(self.session.use_output_buffer(
@@ -162,6 +163,7 @@ impl<'a> DecoderSession for LibvdaSession<'a> {
libvda::PixelFormat::try_from(format)?,
output_buffer,
&vda_planes,
+ modifier,
)?)
}
diff --git a/devices/src/virtio/video/decoder/mod.rs b/devices/src/virtio/video/decoder/mod.rs
index 7314583e9..9545bb9cb 100644
--- a/devices/src/virtio/video/decoder/mod.rs
+++ b/devices/src/virtio/video/decoder/mod.rs
@@ -9,11 +9,9 @@ use std::collections::{BTreeMap, BTreeSet, VecDeque};
use std::convert::TryInto;
use backend::*;
-use base::{error, IntoRawDescriptor, WaitContext};
+use base::{error, IntoRawDescriptor, Tube, WaitContext};
-use crate::virtio::resource_bridge::{
- self, BufferInfo, ResourceInfo, ResourceRequest, ResourceRequestSocket,
-};
+use crate::virtio::resource_bridge::{self, BufferInfo, ResourceInfo, ResourceRequest};
use crate::virtio::video::async_cmd_desc_map::AsyncCmdDescMap;
use crate::virtio::video::command::{QueueType, VideoCmd};
use crate::virtio::video::control::{CtrlType, CtrlVal, QueryCtrlType};
@@ -179,9 +177,6 @@ struct Context<S: DecoderSession> {
in_res: InputResources,
out_res: OutputResources,
- // Set the flag if we need to clear output resource when the output queue is cleared next time.
- is_clear_out_res_needed: bool,
-
// Set the flag when we ask the decoder reset, and unset when the reset is done.
is_resetting: bool,
@@ -204,7 +199,6 @@ impl<S: DecoderSession> Context<S> {
out_params: Default::default(),
in_res: Default::default(),
out_res: Default::default(),
- is_clear_out_res_needed: false,
is_resetting: false,
pending_ready_pictures: Default::default(),
session: None,
@@ -262,7 +256,7 @@ impl<S: DecoderSession> Context<S> {
fn get_resource_info(
&self,
queue_type: QueueType,
- res_bridge: &ResourceRequestSocket,
+ res_bridge: &Tube,
resource_id: u32,
) -> VideoResult<BufferInfo> {
let res_id_to_res_handle = match queue_type {
@@ -340,12 +334,6 @@ impl<S: DecoderSession> Context<S> {
// No need to set `frame_rate`, as it's only for the encoder.
..Default::default()
};
-
- // That eos_resource_id has value means there are previous output resources.
- // Clear the output resources when the output queue is cleared next time.
- if self.out_res.eos_resource_id.is_some() {
- self.is_clear_out_res_needed = true;
- }
}
fn handle_notify_end_of_bitstream_buffer(&mut self, bitstream_id: i32) -> Option<ResourceId> {
@@ -540,7 +528,7 @@ impl<'a, D: DecoderBackend> Decoder<D> {
fn queue_input_resource(
&mut self,
- resource_bridge: &ResourceRequestSocket,
+ resource_bridge: &Tube,
stream_id: StreamId,
resource_id: ResourceId,
timestamp: u64,
@@ -604,7 +592,7 @@ impl<'a, D: DecoderBackend> Decoder<D> {
fn queue_output_resource(
&mut self,
- resource_bridge: &ResourceRequestSocket,
+ resource_bridge: &Tube,
stream_id: StreamId,
resource_id: ResourceId,
) -> VideoResult<VideoCmdResponseType> {
@@ -661,7 +649,13 @@ impl<'a, D: DecoderBackend> Decoder<D> {
// Take ownership of this file by `into_raw_descriptor()` as this
// file will be closed by libvda.
let fd = resource_info.file.into_raw_descriptor();
- session.use_output_buffer(buffer_id as i32, Format::NV12, fd, &planes)
+ session.use_output_buffer(
+ buffer_id as i32,
+ Format::NV12,
+ fd,
+ &planes,
+ resource_info.modifier,
+ )
}
}?;
Ok(VideoCmdResponseType::Async(AsyncCmdTag::Queue {
@@ -807,11 +801,7 @@ impl<'a, D: DecoderBackend> Decoder<D> {
}))
}
QueueType::Output => {
- if std::mem::replace(&mut ctx.is_clear_out_res_needed, false) {
- ctx.out_res = Default::default();
- } else {
- ctx.out_res.queued_res_ids.clear();
- }
+ ctx.out_res.queued_res_ids.clear();
Ok(VideoCmdResponseType::Sync(CmdResponse::NoData))
}
}
@@ -823,7 +813,7 @@ impl<D: DecoderBackend> Device for Decoder<D> {
&mut self,
cmd: VideoCmd,
wait_ctx: &WaitContext<Token>,
- resource_bridge: &ResourceRequestSocket,
+ resource_bridge: &Tube,
) -> (
VideoCmdResponseType,
Option<(u32, Vec<VideoEvtResponseType>)>,
diff --git a/devices/src/virtio/video/device.rs b/devices/src/virtio/video/device.rs
index e5f0a0f6e..d56700cba 100644
--- a/devices/src/virtio/video/device.rs
+++ b/devices/src/virtio/video/device.rs
@@ -4,9 +4,8 @@
//! Definition of the trait `Device` that each backend video device must implement.
-use base::{PollToken, WaitContext};
+use base::{PollToken, Tube, WaitContext};
-use crate::virtio::resource_bridge::ResourceRequestSocket;
use crate::virtio::video::async_cmd_desc_map::AsyncCmdDescMap;
use crate::virtio::video::command::{QueueType, VideoCmd};
use crate::virtio::video::error::*;
@@ -102,7 +101,7 @@ pub trait Device {
&mut self,
cmd: VideoCmd,
wait_ctx: &WaitContext<Token>,
- resource_bridge: &ResourceRequestSocket,
+ resource_bridge: &Tube,
) -> (
VideoCmdResponseType,
Option<(u32, Vec<VideoEvtResponseType>)>,
diff --git a/devices/src/virtio/video/encoder/mod.rs b/devices/src/virtio/video/encoder/mod.rs
index 34eec5d4b..7d32d378c 100644
--- a/devices/src/virtio/video/encoder/mod.rs
+++ b/devices/src/virtio/video/encoder/mod.rs
@@ -11,12 +11,10 @@ mod libvda_encoder;
pub use encoder::EncoderError;
pub use libvda_encoder::LibvdaEncoder;
-use base::{error, warn, WaitContext};
+use base::{error, warn, Tube, WaitContext};
use std::collections::{BTreeMap, BTreeSet};
-use crate::virtio::resource_bridge::{
- self, BufferInfo, ResourceInfo, ResourceRequest, ResourceRequestSocket,
-};
+use crate::virtio::resource_bridge::{self, BufferInfo, ResourceInfo, ResourceRequest};
use crate::virtio::video::async_cmd_desc_map::AsyncCmdDescMap;
use crate::virtio::video::command::{QueueType, VideoCmd};
use crate::virtio::video::control::*;
@@ -481,6 +479,7 @@ impl<T: EncoderSession> Stream<T> {
}
}
+ #[allow(clippy::unnecessary_wraps)]
fn notify_error(&self, error: EncoderError) -> Option<Vec<VideoEvtResponseType>> {
error!(
"Received encoder error event for stream {}: {}",
@@ -499,7 +498,7 @@ pub struct EncoderDevice<T: Encoder> {
streams: BTreeMap<u32, Stream<T::Session>>,
}
-fn get_resource_info(res_bridge: &ResourceRequestSocket, uuid: u128) -> VideoResult<BufferInfo> {
+fn get_resource_info(res_bridge: &Tube, uuid: u128) -> VideoResult<BufferInfo> {
match resource_bridge::get_resource_info(
res_bridge,
ResourceRequest::GetBuffer { id: uuid as u32 },
@@ -519,6 +518,7 @@ impl<T: Encoder> EncoderDevice<T> {
})
}
+ #[allow(clippy::unnecessary_wraps)]
fn query_capabilities(&self, queue_type: QueueType) -> VideoResult<VideoCmdResponseType> {
let descs = match queue_type {
QueueType::Input => self.cros_capabilities.input_format_descs.clone(),
@@ -597,7 +597,7 @@ impl<T: Encoder> EncoderDevice<T> {
fn resource_create(
&mut self,
wait_ctx: &WaitContext<Token>,
- resource_bridge: &ResourceRequestSocket,
+ resource_bridge: &Tube,
stream_id: u32,
queue_type: QueueType,
resource_id: u32,
@@ -672,7 +672,7 @@ impl<T: Encoder> EncoderDevice<T> {
fn resource_queue(
&mut self,
- resource_bridge: &ResourceRequestSocket,
+ resource_bridge: &Tube,
stream_id: u32,
queue_type: QueueType,
resource_id: u32,
@@ -1200,7 +1200,7 @@ impl<T: Encoder> Device for EncoderDevice<T> {
&mut self,
req: VideoCmd,
wait_ctx: &WaitContext<Token>,
- resource_bridge: &ResourceRequestSocket,
+ resource_bridge: &Tube,
) -> (
VideoCmdResponseType,
Option<(u32, Vec<VideoEvtResponseType>)>,
diff --git a/devices/src/virtio/video/format.rs b/devices/src/virtio/video/format.rs
index a7d61cebf..e52695b93 100644
--- a/devices/src/virtio/video/format.rs
+++ b/devices/src/virtio/video/format.rs
@@ -227,7 +227,7 @@ impl Response for FormatDesc {
plane_align: Le32::from(0),
num_frames: Le32::from(self.frame_formats.len() as u32),
})?;
- self.frame_formats.iter().map(|ff| ff.write(w)).collect()
+ self.frame_formats.iter().try_for_each(|ff| ff.write(w))
}
}
diff --git a/devices/src/virtio/video/mod.rs b/devices/src/virtio/video/mod.rs
index 4a64ec1c3..b67c7adf2 100644
--- a/devices/src/virtio/video/mod.rs
+++ b/devices/src/virtio/video/mod.rs
@@ -10,11 +10,10 @@
use std::fmt::{self, Display};
use std::thread;
-use base::{error, AsRawDescriptor, Error as SysError, Event, RawDescriptor};
+use base::{error, AsRawDescriptor, Error as SysError, Event, RawDescriptor, Tube};
use data_model::{DataInit, Le32};
use vm_memory::GuestMemory;
-use crate::virtio::resource_bridge::ResourceRequestSocket;
use crate::virtio::virtio_device::VirtioDevice;
use crate::virtio::{self, copy_config, DescriptorError, Interrupt};
@@ -92,7 +91,7 @@ pub enum VideoDeviceType {
pub struct VideoDevice {
device_type: VideoDeviceType,
kill_evt: Option<Event>,
- resource_bridge: Option<ResourceRequestSocket>,
+ resource_bridge: Option<Tube>,
base_features: u64,
}
@@ -100,7 +99,7 @@ impl VideoDevice {
pub fn new(
base_features: u64,
device_type: VideoDeviceType,
- resource_bridge: Option<ResourceRequestSocket>,
+ resource_bridge: Option<Tube>,
) -> VideoDevice {
VideoDevice {
device_type,
diff --git a/devices/src/virtio/video/protocol.rs b/devices/src/virtio/video/protocol.rs
index fb574a4f0..ec03f8c12 100644
--- a/devices/src/virtio/video/protocol.rs
+++ b/devices/src/virtio/video/protocol.rs
@@ -4,7 +4,7 @@
//! This file was generated by the following commands and modified manually.
//!
-//! ```
+//! ```shell
//! $ bindgen virtio_video.h \
//! --whitelist-type "virtio_video.*" \
//! --whitelist-var "VIRTIO_VIDEO_.*" \
diff --git a/devices/src/virtio/video/response.rs b/devices/src/virtio/video/response.rs
index bbc71ddf6..a32d3fc1e 100644
--- a/devices/src/virtio/video/response.rs
+++ b/devices/src/virtio/video/response.rs
@@ -104,7 +104,7 @@ impl Response for CmdResponse {
num_descs: Le32::from(descs.len() as u32),
..Default::default()
})?;
- descs.iter().map(|d| d.write(w)).collect()
+ descs.iter().try_for_each(|d| d.write(w))
}
ResourceQueue {
timestamp,
diff --git a/devices/src/virtio/video/worker.rs b/devices/src/virtio/video/worker.rs
index 15da5971e..f7a3a3214 100644
--- a/devices/src/virtio/video/worker.rs
+++ b/devices/src/virtio/video/worker.rs
@@ -6,11 +6,10 @@
use std::collections::VecDeque;
-use base::{error, info, Event, WaitContext};
+use base::{error, info, Event, Tube, WaitContext};
use vm_memory::GuestMemory;
use crate::virtio::queue::{DescriptorChain, Queue};
-use crate::virtio::resource_bridge::ResourceRequestSocket;
use crate::virtio::video::async_cmd_desc_map::AsyncCmdDescMap;
use crate::virtio::video::command::{QueueType, VideoCmd};
use crate::virtio::video::device::{
@@ -19,7 +18,7 @@ use crate::virtio::video::device::{
use crate::virtio::video::event::{self, EvtType, VideoEvt};
use crate::virtio::video::response::{self, Response};
use crate::virtio::video::{Error, Result};
-use crate::virtio::{Interrupt, Reader, Writer};
+use crate::virtio::{Interrupt, Reader, SignalableInterrupt, Writer};
pub struct Worker {
pub interrupt: Interrupt,
@@ -27,7 +26,7 @@ pub struct Worker {
pub cmd_evt: Event,
pub event_evt: Event,
pub kill_evt: Event,
- pub resource_bridge: ResourceRequestSocket,
+ pub resource_bridge: Tube,
}
/// Pair of a descriptor chain and a response to be written.
@@ -276,8 +275,13 @@ impl Worker {
(&self.cmd_evt, Token::CmdQueue),
(&self.event_evt, Token::EventQueue),
(&self.kill_evt, Token::Kill),
- (self.interrupt.get_resample_evt(), Token::InterruptResample),
])
+ .and_then(|wc| {
+ if let Some(resample_evt) = self.interrupt.get_resample_evt() {
+ wc.add(resample_evt, Token::InterruptResample)?;
+ }
+ Ok(wc)
+ })
.map_err(Error::WaitContextCreationFailed)?;
// Stores descriptors in which responses for asynchronous commands will be written.
diff --git a/devices/src/virtio/virtio_pci_device.rs b/devices/src/virtio/virtio_pci_device.rs
index e1d592991..b0f63d1d1 100644
--- a/devices/src/virtio/virtio_pci_device.rs
+++ b/devices/src/virtio/virtio_pci_device.rs
@@ -6,7 +6,7 @@ use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use sync::Mutex;
-use base::{warn, AsRawDescriptor, Event, RawDescriptor, Result};
+use base::{warn, AsRawDescriptor, Event, RawDescriptor, Result, Tube};
use data_model::{DataInit, Le32};
use hypervisor::Datamatch;
use libc::ERANGE;
@@ -19,7 +19,6 @@ use crate::pci::{
PciClassCode, PciConfiguration, PciDevice, PciDeviceError, PciDisplaySubclass, PciHeaderType,
PciInterruptPin, PciSubclass,
};
-use vm_control::VmIrqRequestSocket;
use self::virtio_pci_common_config::VirtioPciCommonConfig;
@@ -193,6 +192,7 @@ const NOTIFY_OFF_MULTIPLIER: u32 = 4; // A dword per notification address.
const VIRTIO_PCI_VENDOR_ID: u16 = 0x1af4;
const VIRTIO_PCI_DEVICE_ID_BASE: u16 = 0x1040; // Add to device type to get device ID.
+const VIRTIO_PCI_REVISION_ID: u8 = 1;
/// Implements the
/// [PCI](http://docs.oasis-open.org/virtio/virtio/v1.0/cs04/virtio-v1.0-cs04.html#x1-650001)
@@ -221,7 +221,7 @@ impl VirtioPciDevice {
pub fn new(
mem: GuestMemory,
device: Box<dyn VirtioDevice>,
- msi_device_socket: VmIrqRequestSocket,
+ msi_device_tube: Tube,
) -> Result<Self> {
let mut queue_evts = Vec::new();
for _ in device.queue_max_sizes() {
@@ -250,7 +250,7 @@ impl VirtioPciDevice {
// One MSI-X vector per queue plus one for configuration changes.
let msix_num = u16::try_from(num_queues + 1).map_err(|_| base::Error::new(ERANGE))?;
- let msix_config = Arc::new(Mutex::new(MsixConfig::new(msix_num, msi_device_socket)));
+ let msix_config = Arc::new(Mutex::new(MsixConfig::new(msix_num, msi_device_tube)));
let config_regs = PciConfiguration::new(
VIRTIO_PCI_VENDOR_ID,
@@ -261,6 +261,7 @@ impl VirtioPciDevice {
PciHeaderType::Device,
VIRTIO_PCI_VENDOR_ID,
pci_device_id,
+ VIRTIO_PCI_REVISION_ID,
);
Ok(VirtioPciDevice {
@@ -389,7 +390,7 @@ impl VirtioPciDevice {
impl PciDevice for VirtioPciDevice {
fn debug_label(&self) -> String {
- format!("virtio-pci ({})", self.device.debug_label())
+ format!("pci{}", self.device.debug_label())
}
fn allocate_address(
diff --git a/devices/src/virtio/wl.rs b/devices/src/virtio/wl.rs
index 8359b442d..887ca2d19 100644
--- a/devices/src/virtio/wl.rs
+++ b/devices/src/virtio/wl.rs
@@ -28,7 +28,6 @@
//! the virtio queue, and routing messages in and out of `WlState`. Possible events include the kill
//! event, available descriptors on the `in` or `out` queue, and incoming data on any vfd's socket.
-use std::cell::RefCell;
use std::collections::btree_map::Entry;
use std::collections::{BTreeMap as Map, BTreeSet as Set, VecDeque};
use std::convert::From;
@@ -49,29 +48,29 @@ use std::time::Duration;
#[cfg(feature = "minigbm")]
use libc::{EBADF, EINVAL};
-use data_model::VolatileMemoryError;
use data_model::*;
use base::{
error, pipe, round_up_to_page_size, warn, AsRawDescriptor, Error, Event, FileFlags,
FromRawDescriptor, PollToken, RawDescriptor, Result, ScmSocket, SharedMemory, SharedMemoryUnix,
- WaitContext,
+ Tube, TubeError, WaitContext,
};
#[cfg(feature = "minigbm")]
use base::{ioctl_iow_nr, ioctl_with_ref};
#[cfg(feature = "gpu")]
use base::{IntoRawDescriptor, SafeDescriptor};
-use msg_socket::{MsgError, MsgReceiver, MsgSender};
use vm_memory::{GuestMemory, GuestMemoryError};
#[cfg(feature = "minigbm")]
use vm_control::GpuMemoryDesc;
-use super::resource_bridge::*;
-use super::{DescriptorChain, Interrupt, Queue, Reader, VirtioDevice, Writer, TYPE_WL};
-use vm_control::{
- MaybeOwnedDescriptor, MemSlot, VmMemoryControlRequestSocket, VmMemoryRequest, VmMemoryResponse,
+use super::resource_bridge::{
+ get_resource_info, BufferInfo, ResourceBridgeError, ResourceInfo, ResourceRequest,
};
+use super::{
+ DescriptorChain, Interrupt, Queue, Reader, SignalableInterrupt, VirtioDevice, Writer, TYPE_WL,
+};
+use vm_control::{MemSlot, VmMemoryRequest, VmMemoryResponse};
const VIRTWL_SEND_MAX_ALLOCS: usize = 28;
const VIRTIO_WL_CMD_VFD_NEW: u32 = 256;
@@ -266,7 +265,7 @@ enum WlError {
NewPipe(Error),
SocketConnect(io::Error),
SocketNonBlock(io::Error),
- VmControl(MsgError),
+ VmControl(TubeError),
VmBadResponse,
CheckedOffset,
ParseDesc(io::Error),
@@ -333,21 +332,28 @@ impl From<VolatileMemoryError> for WlError {
#[derive(Clone)]
struct VmRequester {
- inner: Rc<RefCell<VmMemoryControlRequestSocket>>,
+ inner: Rc<Tube>,
}
impl VmRequester {
- fn new(vm_socket: VmMemoryControlRequestSocket) -> VmRequester {
+ fn new(vm_socket: Tube) -> VmRequester {
VmRequester {
- inner: Rc::new(RefCell::new(vm_socket)),
+ inner: Rc::new(vm_socket),
}
}
- fn request(&self, request: VmMemoryRequest) -> WlResult<VmMemoryResponse> {
- let mut inner = self.inner.borrow_mut();
- let vm_socket = &mut *inner;
- vm_socket.send(&request).map_err(WlError::VmControl)?;
- vm_socket.recv().map_err(WlError::VmControl)
+ fn request(&self, request: &VmMemoryRequest) -> WlResult<VmMemoryResponse> {
+ self.inner.send(&request).map_err(WlError::VmControl)?;
+ self.inner.recv().map_err(WlError::VmControl)
+ }
+
+ fn register_memory(&self, shm: SharedMemory) -> WlResult<(SharedMemory, VmMemoryResponse)> {
+ let request = VmMemoryRequest::RegisterMemory(shm);
+ let response = self.request(&request)?;
+ match request {
+ VmMemoryRequest::RegisterMemory(shm) => Ok((shm, response)),
+ _ => unreachable!(),
+ }
}
}
@@ -554,7 +560,7 @@ impl<'a> WlResp<'a> {
#[derive(Default)]
struct WlVfd {
socket: Option<UnixStream>,
- guest_shared_memory: Option<(u64 /* size */, SharedMemory)>,
+ guest_shared_memory: Option<SharedMemory>,
remote_pipe: Option<File>,
local_pipe: Option<(u32 /* flags */, File)>,
slot: Option<(MemSlot, u64 /* pfn */, VmRequester)>,
@@ -594,14 +600,16 @@ impl WlVfd {
let vfd_shm =
SharedMemory::named("virtwl_alloc", size_page_aligned).map_err(WlError::NewAlloc)?;
- let register_response = vm.request(VmMemoryRequest::RegisterMemory(
- MaybeOwnedDescriptor::Borrowed(vfd_shm.as_raw_descriptor()),
- vfd_shm.size() as usize,
- ))?;
+ let register_request = VmMemoryRequest::RegisterMemory(vfd_shm);
+ let register_response = vm.request(&register_request)?;
match register_response {
VmMemoryResponse::RegisterMemory { pfn, slot } => {
let mut vfd = WlVfd::default();
- vfd.guest_shared_memory = Some((vfd_shm.size(), vfd_shm));
+ let vfd_shm = match register_request {
+ VmMemoryRequest::RegisterMemory(shm) => shm,
+ _ => unreachable!(),
+ };
+ vfd.guest_shared_memory = Some(vfd_shm);
vfd.slot = Some((slot, pfn, vm));
Ok(vfd)
}
@@ -617,22 +625,22 @@ impl WlVfd {
format: u32,
) -> WlResult<(WlVfd, GpuMemoryDesc)> {
let allocate_and_register_gpu_memory_response =
- vm.request(VmMemoryRequest::AllocateAndRegisterGpuMemory {
+ vm.request(&VmMemoryRequest::AllocateAndRegisterGpuMemory {
width,
height,
format,
})?;
match allocate_and_register_gpu_memory_response {
VmMemoryResponse::AllocateAndRegisterGpuMemory {
- descriptor: MaybeOwnedDescriptor::Owned(file),
+ descriptor,
pfn,
slot,
desc,
} => {
let mut vfd = WlVfd::default();
let vfd_shm =
- SharedMemory::from_safe_descriptor(file).map_err(WlError::NewAlloc)?;
- vfd.guest_shared_memory = Some((vfd_shm.size(), vfd_shm));
+ SharedMemory::from_safe_descriptor(descriptor).map_err(WlError::NewAlloc)?;
+ vfd.guest_shared_memory = Some(vfd_shm);
vfd.slot = Some((slot, pfn, vm));
vfd.is_dmabuf = true;
Ok((vfd, desc))
@@ -648,7 +656,7 @@ impl WlVfd {
}
match &self.guest_shared_memory {
- Some((_, descriptor)) => {
+ Some(descriptor) => {
let sync = dma_buf_sync {
flags: flags as u64,
};
@@ -686,21 +694,14 @@ impl WlVfd {
// for how big the shared memory chunk to map into guest memory is. If seeking to the end
// fails, we assume it's a socket or pipe with read/write semantics.
match descriptor.seek(SeekFrom::End(0)) {
- Ok(fd_size) => {
- let size = round_up_to_page_size(fd_size as usize) as u64;
- let register_response = vm.request(VmMemoryRequest::RegisterMemory(
- MaybeOwnedDescriptor::Borrowed(descriptor.as_raw_descriptor()),
- size as usize,
- ))?;
+ Ok(_) => {
+ let shm = SharedMemory::from_file(descriptor).map_err(WlError::FromSharedMemory)?;
+ let (shm, register_response) = vm.register_memory(shm)?;
match register_response {
VmMemoryResponse::RegisterMemory { pfn, slot } => {
let mut vfd = WlVfd::default();
- vfd.guest_shared_memory = Some((
- size,
- SharedMemory::from_file(descriptor)
- .map_err(WlError::FromSharedMemory)?,
- ));
+ vfd.guest_shared_memory = Some(shm);
vfd.slot = Some((slot, pfn, vm));
Ok(vfd)
}
@@ -748,16 +749,16 @@ impl WlVfd {
// Size in bytes of the shared memory VFD.
fn size(&self) -> Option<u64> {
- self.guest_shared_memory.as_ref().map(|&(size, _)| size)
+ self.guest_shared_memory.as_ref().map(|shm| shm.size())
}
// The FD that gets sent if this VFD is sent over a socket.
fn send_descriptor(&self) -> Option<RawDescriptor> {
self.guest_shared_memory
.as_ref()
- .map(|(_, shm)| shm.as_raw_descriptor())
- .or_else(|| self.socket.as_ref().map(|s| s.as_raw_descriptor()))
- .or_else(|| self.remote_pipe.as_ref().map(|p| p.as_raw_descriptor()))
+ .map(|shm| shm.as_raw_descriptor())
+ .or(self.socket.as_ref().map(|s| s.as_raw_descriptor()))
+ .or(self.remote_pipe.as_ref().map(|p| p.as_raw_descriptor()))
}
// The FD that is used for polling for events on this VFD.
@@ -842,7 +843,7 @@ impl WlVfd {
fn close(&mut self) -> WlResult<()> {
if let Some((slot, _, vm)) = self.slot.take() {
- vm.request(VmMemoryRequest::UnregisterMemory(slot))?;
+ vm.request(&VmMemoryRequest::UnregisterMemory(slot))?;
}
self.socket = None;
self.remote_pipe = None;
@@ -867,7 +868,7 @@ enum WlRecv {
struct WlState {
wayland_paths: Map<String, PathBuf>,
vm: VmRequester,
- resource_bridge: Option<ResourceRequestSocket>,
+ resource_bridge: Option<Tube>,
use_transition_flags: bool,
wait_ctx: WaitContext<u32>,
vfds: Map<u32, WlVfd>,
@@ -884,14 +885,14 @@ struct WlState {
impl WlState {
fn new(
wayland_paths: Map<String, PathBuf>,
- vm_socket: VmMemoryControlRequestSocket,
+ vm_tube: Tube,
use_transition_flags: bool,
use_send_vfd_v2: bool,
- resource_bridge: Option<ResourceRequestSocket>,
+ resource_bridge: Option<Tube>,
) -> WlState {
WlState {
wayland_paths,
- vm: VmRequester::new(vm_socket),
+ vm: VmRequester::new(vm_tube),
resource_bridge,
wait_ctx: WaitContext::new().expect("failed to create WaitContext"),
use_transition_flags,
@@ -1105,7 +1106,7 @@ impl WlState {
fn get_info(&mut self, request: ResourceRequest) -> Option<File> {
let sock = self.resource_bridge.as_ref().unwrap();
match get_resource_info(sock, request) {
- Ok(ResourceInfo::Buffer(BufferInfo { file, planes: _ })) => Some(file),
+ Ok(ResourceInfo::Buffer(BufferInfo { file, .. })) => Some(file),
Ok(ResourceInfo::Fence { file }) => Some(file),
Err(ResourceBridgeError::InvalidResource(req)) => {
warn!("attempt to send non-existent gpu resource {}", req);
@@ -1476,10 +1477,10 @@ impl Worker {
in_queue: Queue,
out_queue: Queue,
wayland_paths: Map<String, PathBuf>,
- vm_socket: VmMemoryControlRequestSocket,
+ vm_tube: Tube,
use_transition_flags: bool,
use_send_vfd_v2: bool,
- resource_bridge: Option<ResourceRequestSocket>,
+ resource_bridge: Option<Tube>,
) -> Worker {
Worker {
interrupt,
@@ -1488,7 +1489,7 @@ impl Worker {
out_queue,
state: WlState::new(
wayland_paths,
- vm_socket,
+ vm_tube,
use_transition_flags,
use_send_vfd_v2,
resource_bridge,
@@ -1515,7 +1516,6 @@ impl Worker {
(&out_queue_evt, Token::OutQueue),
(&kill_evt, Token::Kill),
(&self.state.wait_ctx, Token::State),
- (self.interrupt.get_resample_evt(), Token::InterruptResample),
]) {
Ok(pc) => pc,
Err(e) => {
@@ -1523,6 +1523,15 @@ impl Worker {
return;
}
};
+ if let Some(resample_evt) = self.interrupt.get_resample_evt() {
+ if wait_ctx
+ .add(resample_evt, Token::InterruptResample)
+ .is_err()
+ {
+ error!("failed adding resample event to WaitContext.");
+ return;
+ }
+ }
'wait: loop {
let mut signal_used_in = false;
@@ -1660,8 +1669,8 @@ pub struct Wl {
kill_evt: Option<Event>,
worker_thread: Option<thread::JoinHandle<()>>,
wayland_paths: Map<String, PathBuf>,
- vm_socket: Option<VmMemoryControlRequestSocket>,
- resource_bridge: Option<ResourceRequestSocket>,
+ vm_socket: Option<Tube>,
+ resource_bridge: Option<Tube>,
use_transition_flags: bool,
use_send_vfd_v2: bool,
base_features: u64,
@@ -1671,14 +1680,14 @@ impl Wl {
pub fn new(
base_features: u64,
wayland_paths: Map<String, PathBuf>,
- vm_socket: VmMemoryControlRequestSocket,
- resource_bridge: Option<ResourceRequestSocket>,
+ vm_tube: Tube,
+ resource_bridge: Option<Tube>,
) -> Result<Wl> {
Ok(Wl {
kill_evt: None,
worker_thread: None,
wayland_paths,
- vm_socket: Some(vm_socket),
+ vm_socket: Some(vm_tube),
resource_bridge,
use_transition_flags: false,
use_send_vfd_v2: false,
diff --git a/disk/src/composite.rs b/disk/src/composite.rs
index a23eca0ca..efa5e1de8 100644
--- a/disk/src/composite.rs
+++ b/disk/src/composite.rs
@@ -194,7 +194,7 @@ impl CompositeDiskFile {
}
}
- fn disk_at_offset<'a>(&'a mut self, offset: u64) -> io::Result<&'a mut ComponentDiskPart> {
+ fn disk_at_offset(&mut self, offset: u64) -> io::Result<&mut ComponentDiskPart> {
self.component_disks
.iter_mut()
.find(|disk| disk.range().contains(&offset))
@@ -342,13 +342,14 @@ impl AsRawDescriptors for CompositeDiskFile {
#[cfg(test)]
mod tests {
use super::*;
- use base::{AsRawDescriptor, SharedMemory};
+ use base::AsRawDescriptor;
use data_model::VolatileMemory;
+ use tempfile::tempfile;
#[test]
fn block_duplicate_offset_disks() {
- let file1: File = SharedMemory::new(None).unwrap().into();
- let file2: File = SharedMemory::new(None).unwrap().into();
+ let file1 = tempfile().unwrap();
+ let file2 = tempfile().unwrap();
let disk_part1 = ComponentDiskPart {
file: Box::new(file1),
offset: 0,
@@ -364,8 +365,8 @@ mod tests {
#[test]
fn get_len() {
- let file1: File = SharedMemory::new(None).unwrap().into();
- let file2: File = SharedMemory::new(None).unwrap().into();
+ let file1 = tempfile().unwrap();
+ let file2 = tempfile().unwrap();
let disk_part1 = ComponentDiskPart {
file: Box::new(file1),
offset: 0,
@@ -383,7 +384,7 @@ mod tests {
#[test]
fn single_file_passthrough() {
- let file: File = SharedMemory::new(None).unwrap().into();
+ let file = tempfile().unwrap();
let disk_part = ComponentDiskPart {
file: Box::new(file),
offset: 0,
@@ -405,9 +406,9 @@ mod tests {
#[test]
fn triple_file_fds() {
- let file1: File = SharedMemory::new(None).unwrap().into();
- let file2: File = SharedMemory::new(None).unwrap().into();
- let file3: File = SharedMemory::new(None).unwrap().into();
+ let file1 = tempfile().unwrap();
+ let file2 = tempfile().unwrap();
+ let file3 = tempfile().unwrap();
let mut in_fds = vec![
file1.as_raw_descriptor(),
file2.as_raw_descriptor(),
@@ -437,9 +438,9 @@ mod tests {
#[test]
fn triple_file_passthrough() {
- let file1: File = SharedMemory::new(None).unwrap().into();
- let file2: File = SharedMemory::new(None).unwrap().into();
- let file3: File = SharedMemory::new(None).unwrap().into();
+ let file1 = tempfile().unwrap();
+ let file2 = tempfile().unwrap();
+ let file3 = tempfile().unwrap();
let disk_part1 = ComponentDiskPart {
file: Box::new(file1),
offset: 0,
@@ -467,14 +468,14 @@ mod tests {
composite
.read_exact_at_volatile(output_volatile_memory.get_slice(0, 200).unwrap(), 50)
.unwrap();
- assert!(input_memory.into_iter().eq(output_memory.into_iter()));
+ assert!(input_memory.iter().eq(output_memory.iter()));
}
#[test]
fn triple_file_punch_hole() {
- let file1: File = SharedMemory::new(None).unwrap().into();
- let file2: File = SharedMemory::new(None).unwrap().into();
- let file3: File = SharedMemory::new(None).unwrap().into();
+ let file1 = tempfile().unwrap();
+ let file2 = tempfile().unwrap();
+ let file3 = tempfile().unwrap();
let disk_part1 = ComponentDiskPart {
file: Box::new(file1),
offset: 0,
@@ -507,14 +508,14 @@ mod tests {
for i in 50..250 {
input_memory[i] = 0;
}
- assert!(input_memory.into_iter().eq(output_memory.into_iter()));
+ assert!(input_memory.iter().eq(output_memory.iter()));
}
#[test]
fn triple_file_write_zeroes() {
- let file1: File = SharedMemory::new(None).unwrap().into();
- let file2: File = SharedMemory::new(None).unwrap().into();
- let file3: File = SharedMemory::new(None).unwrap().into();
+ let file1 = tempfile().unwrap();
+ let file2 = tempfile().unwrap();
+ let file3 = tempfile().unwrap();
let disk_part1 = ComponentDiskPart {
file: Box::new(file1),
offset: 0,
@@ -558,6 +559,6 @@ mod tests {
i, input_memory[i], output_memory[i]
);
}
- assert!(input_memory.into_iter().eq(output_memory.into_iter()));
+ assert!(input_memory.iter().eq(output_memory.iter()));
}
}
diff --git a/disk/src/disk.rs b/disk/src/disk.rs
index 9f934135a..bd1945d30 100644
--- a/disk/src/disk.rs
+++ b/disk/src/disk.rs
@@ -264,36 +264,45 @@ pub fn convert(src_file: File, dst_file: File, dst_type: ImageType) -> Result<()
}
}
-/// Detect the type of an image file by checking for a valid qcow2 header.
+/// Detect the type of an image file by checking for a valid header of the supported formats.
pub fn detect_image_type(file: &File) -> Result<ImageType> {
let mut f = file;
+ let disk_size = f.get_len().map_err(Error::SeekingFile)?;
let orig_seek = f.seek(SeekFrom::Current(0)).map_err(Error::SeekingFile)?;
f.seek(SeekFrom::Start(0)).map_err(Error::SeekingFile)?;
- let mut magic = [0u8; 4];
- f.read_exact(&mut magic).map_err(Error::ReadingHeader)?;
- let magic = u32::from_be_bytes(magic);
+
+ // Try to read the disk in a nicely-aligned block size unless the whole file is smaller.
+ const MAGIC_BLOCK_SIZE: usize = 4096;
+ let mut magic = [0u8; MAGIC_BLOCK_SIZE];
+ let magic_read_len = if disk_size > MAGIC_BLOCK_SIZE as u64 {
+ MAGIC_BLOCK_SIZE
+ } else {
+ // This cast is safe since we know disk_size is less than MAGIC_BLOCK_SIZE (4096) and
+ // therefore is representable in usize.
+ disk_size as usize
+ };
+
+ f.read_exact(&mut magic[0..magic_read_len])
+ .map_err(Error::ReadingHeader)?;
+ f.seek(SeekFrom::Start(orig_seek))
+ .map_err(Error::SeekingFile)?;
+
#[cfg(feature = "composite-disk")]
- {
- f.seek(SeekFrom::Start(0)).map_err(Error::SeekingFile)?;
- let mut cdisk_magic = [0u8; CDISK_MAGIC_LEN];
- f.read_exact(&mut cdisk_magic[..])
- .map_err(Error::ReadingHeader)?;
+ if let Some(cdisk_magic) = magic.get(0..CDISK_MAGIC_LEN) {
if cdisk_magic == CDISK_MAGIC.as_bytes() {
- f.seek(SeekFrom::Start(orig_seek))
- .map_err(Error::SeekingFile)?;
return Ok(ImageType::CompositeDisk);
}
}
- let image_type = if magic == QCOW_MAGIC {
- ImageType::Qcow2
- } else if magic == SPARSE_HEADER_MAGIC.to_be() {
- ImageType::AndroidSparse
- } else {
- ImageType::Raw
- };
- f.seek(SeekFrom::Start(orig_seek))
- .map_err(Error::SeekingFile)?;
- Ok(image_type)
+
+ if let Some(magic4) = magic.get(0..4) {
+ if magic4 == QCOW_MAGIC.to_be_bytes() {
+ return Ok(ImageType::Qcow2);
+ } else if magic4 == SPARSE_HEADER_MAGIC.to_le_bytes() {
+ return Ok(ImageType::AndroidSparse);
+ }
+ }
+
+ Ok(ImageType::Raw)
}
/// Check if the image file type can be used for async disk access.
@@ -530,4 +539,62 @@ mod tests {
let ex = Executor::new().unwrap();
ex.run_until(write_zeros_async(&ex)).unwrap();
}
+
+ #[test]
+ fn detect_image_type_raw() {
+ let mut t = tempfile::tempfile().unwrap();
+ // Fill the first block of the file with "random" data.
+ let buf = "ABCD".as_bytes().repeat(1024);
+ t.write_all(&buf).unwrap();
+ let image_type = detect_image_type(&t).expect("failed to detect image type");
+ assert_eq!(image_type, ImageType::Raw);
+ }
+
+ #[test]
+ fn detect_image_type_qcow2() {
+ let mut t = tempfile::tempfile().unwrap();
+ // Write the qcow2 magic signature. The rest of the header is not filled in, so if
+ // detect_image_type is ever updated to validate more of the header, this test would need
+ // to be updated.
+ let buf: &[u8] = &[0x51, 0x46, 0x49, 0xfb];
+ t.write_all(&buf).unwrap();
+ let image_type = detect_image_type(&t).expect("failed to detect image type");
+ assert_eq!(image_type, ImageType::Qcow2);
+ }
+
+ #[test]
+ fn detect_image_type_android_sparse() {
+ let mut t = tempfile::tempfile().unwrap();
+ // Write the Android sparse magic signature. The rest of the header is not filled in, so if
+ // detect_image_type is ever updated to validate more of the header, this test would need
+ // to be updated.
+ let buf: &[u8] = &[0x3a, 0xff, 0x26, 0xed];
+ t.write_all(&buf).unwrap();
+ let image_type = detect_image_type(&t).expect("failed to detect image type");
+ assert_eq!(image_type, ImageType::AndroidSparse);
+ }
+
+ #[test]
+ #[cfg(feature = "composite-disk")]
+ fn detect_image_type_composite() {
+ let mut t = tempfile::tempfile().unwrap();
+ // Write the composite disk magic signature. The rest of the header is not filled in, so if
+ // detect_image_type is ever updated to validate more of the header, this test would need
+ // to be updated.
+ let buf = "composite_disk\x1d".as_bytes();
+ t.write_all(&buf).unwrap();
+ let image_type = detect_image_type(&t).expect("failed to detect image type");
+ assert_eq!(image_type, ImageType::CompositeDisk);
+ }
+
+ #[test]
+ fn detect_image_type_small_file() {
+ let mut t = tempfile::tempfile().unwrap();
+ // Write a file smaller than the four-byte qcow2/sparse magic to ensure the small file logic
+ // works correctly and handles it as a raw file.
+ let buf: &[u8] = &[0xAA, 0xBB];
+ t.write_all(&buf).unwrap();
+ let image_type = detect_image_type(&t).expect("failed to detect image type");
+ assert_eq!(image_type, ImageType::Raw);
+ }
}
diff --git a/disk/src/qcow/mod.rs b/disk/src/qcow/mod.rs
index a4a52349c..3a3ef9764 100644
--- a/disk/src/qcow/mod.rs
+++ b/disk/src/qcow/mod.rs
@@ -433,7 +433,7 @@ impl QcowFile {
}
let cluster_bits: u32 = header.cluster_bits;
- if cluster_bits < MIN_CLUSTER_BITS || cluster_bits > MAX_CLUSTER_BITS {
+ if !(MIN_CLUSTER_BITS..=MAX_CLUSTER_BITS).contains(&cluster_bits) {
return Err(Error::InvalidClusterSize);
}
let cluster_size = 0x01u64 << cluster_bits;
diff --git a/docs/architecture.md b/docs/architecture.md
index 5ff0567b2..064826c68 100644
--- a/docs/architecture.md
+++ b/docs/architecture.md
@@ -75,8 +75,8 @@ Most threads in crosvm will have a wait loop using a `PollContext`, which is a w
Note that the limitations of `PollContext` are the same as the limitations of `epoll`. The same FD can not be inserted more than once, and the FD will be automatically removed if the process runs out of references to that FD. A `dup`/`fork` call will increment that reference count, so closing the original FD will not actually remove it from the `PollContext`. It is possible to receive tokens from `PollContext` for an FD that was closed because of a race condition in which an event was registered in the background before the `close` happened. Best practice is to remove an FD before closing it so that events associated with it can be reliably eliminated.
-### MsgSocket
+### `serde` with Descriptors.
-Using raw sockets and pipes to communicate is very inconvenient for rich data types. To help make this easier and less error prone, crosvm has the `msg_socket` crate. Included is a trait for messages encodable on a Unix socket (`MsgOnSocket`), a set of traits for sending and receiving (`MsgSender`/`MsgReceiver`), and implementations of those traits over `UnixSeqpacket` (`MsgSocket`/`Sender`/`Receiver`). To make implementing `MsgOnSocket` very easy, a custom derive for that trait can be utilized with `#[derive(MsgOnSocket)]`. The custom derive will work for enums and structs with nested data, primitive types, and anything that implements `AsRawFd`. However, structures with no fixed upper limit in size, such as `Vec` or `BTreeMap`, are not supported.
+Using raw sockets and pipes to communicate is very inconvenient for rich data types. To help make this easier and less error prone, crosvm uses the `serde` crate. To allow transmitting types with embedded descriptors (FDs on Linux or HANDLEs on Windows), a module is provided for sending and receiving descriptors alongside the plain old bytes that serde consumes.
[minijail]: https://android.googlesource.com/platform/external/minijail
diff --git a/fuse/src/server.rs b/fuse/src/server.rs
index 0c62896b6..f28e9488a 100644
--- a/fuse/src/server.rs
+++ b/fuse/src/server.rs
@@ -1298,6 +1298,7 @@ impl<F: FileSystem + Sync> Server<F> {
}
}
+ #[allow(clippy::unnecessary_wraps)]
fn interrupt(&self, _in_header: InHeader) -> Result<usize> {
Ok(0)
}
@@ -1310,6 +1311,7 @@ impl<F: FileSystem + Sync> Server<F> {
}
}
+ #[allow(clippy::unnecessary_wraps)]
fn destroy(&self) -> Result<usize> {
// No reply to this function.
self.fs.destroy();
diff --git a/fuzz/Cargo.toml b/fuzz/Cargo.toml
index ab2e06558..d4c7a1e26 100644
--- a/fuzz/Cargo.toml
+++ b/fuzz/Cargo.toml
@@ -18,10 +18,6 @@ tempfile = { path = "../tempfile" }
usb_util = { path = "../usb_util" }
vm_memory = { path = "../vm_memory" }
-# Prevent this from interfering with workspaces
-[workspace]
-members = ["."]
-
[[bin]]
name = "crosvm_block_fuzzer"
path = "block_fuzzer.rs"
@@ -46,5 +42,3 @@ path = "virtqueue_fuzzer.rs"
name = "crosvm_zimage_fuzzer"
path = "zimage_fuzzer.rs"
-[patch.crates-io]
-base = { path = "../base" }
diff --git a/gpu_display/build.rs b/gpu_display/build.rs
index 8f134f942..5cdba2f7e 100644
--- a/gpu_display/build.rs
+++ b/gpu_display/build.rs
@@ -62,13 +62,13 @@ fn compile_protocol<P: AsRef<Path>>(name: &str, out: P) -> PathBuf {
.arg(&in_protocol)
.arg(&out_code)
.output()
- .unwrap();
+ .expect("wayland-scanner code failed");
Command::new("wayland-scanner")
.arg("client-header")
.arg(&in_protocol)
.arg(&out_header)
.output()
- .unwrap();
+ .expect("wayland-scanner client-header failed");
out_code
}
diff --git a/gpu_display/src/gpu_display_stub.rs b/gpu_display/src/gpu_display_stub.rs
index 63cc81104..11b96c59c 100644
--- a/gpu_display/src/gpu_display_stub.rs
+++ b/gpu_display/src/gpu_display_stub.rs
@@ -45,12 +45,12 @@ struct Surface {
}
impl Surface {
- fn create(width: u32, height: u32) -> Result<Surface, GpuDisplayError> {
- Ok(Surface {
+ fn create(width: u32, height: u32) -> Surface {
+ Surface {
width,
height,
buffer: None,
- })
+ }
}
/// Gets the buffer at buffer_index, allocating it if necessary.
@@ -103,14 +103,14 @@ impl SurfacesHelper {
}
}
- fn create_surface(&mut self, width: u32, height: u32) -> Result<u32, GpuDisplayError> {
- let new_surface = Surface::create(width, height)?;
+ fn create_surface(&mut self, width: u32, height: u32) -> u32 {
+ let new_surface = Surface::create(width, height);
let new_surface_id = self.next_surface_id;
self.surfaces.insert(new_surface_id, new_surface);
self.next_surface_id = SurfaceId::new(self.next_surface_id.get() + 1).unwrap();
- Ok(new_surface_id.get())
+ new_surface_id.get()
}
fn get_surface(&mut self, surface_id: u32) -> Option<&mut Surface> {
@@ -157,7 +157,7 @@ impl DisplayT for DisplayStub {
if parent_surface_id.is_some() {
return Err(GpuDisplayError::Unsupported);
}
- self.surfaces.create_surface(width, height)
+ Ok(self.surfaces.create_surface(width, height))
}
fn release_surface(&mut self, surface_id: u32) {
diff --git a/gpu_display/src/gpu_display_wl.rs b/gpu_display/src/gpu_display_wl.rs
index d6365ac99..d242bbbd1 100644
--- a/gpu_display/src/gpu_display_wl.rs
+++ b/gpu_display/src/gpu_display_wl.rs
@@ -209,7 +209,7 @@ impl DisplayT for DisplayWl {
let buffer_shm = SharedMemory::named("GpuDisplaySurface", buffer_size as u64)
.map_err(GpuDisplayError::CreateShm)?;
let buffer_mem = MemoryMappingBuilder::new(buffer_size)
- .from_descriptor(&buffer_shm)
+ .from_shared_memory(&buffer_shm)
.build()
.unwrap();
diff --git a/gpu_display/src/gpu_display_x.rs b/gpu_display/src/gpu_display_x.rs
index 5d6a9b609..16bdd1096 100644
--- a/gpu_display/src/gpu_display_x.rs
+++ b/gpu_display/src/gpu_display_x.rs
@@ -243,7 +243,7 @@ impl Surface {
visual: *mut xlib::Visual,
width: u32,
height: u32,
- ) -> Result<Surface, GpuDisplayError> {
+ ) -> Surface {
let keycode_translator = KeycodeTranslator::new(KeycodeTypes::XkbScancode);
unsafe {
let depth = xlib::XDefaultDepthOfScreen(screen.as_ptr()) as u32;
@@ -305,7 +305,7 @@ impl Surface {
// Flush everything so that the window is visible immediately.
display.flush();
- Ok(Surface {
+ Surface {
display,
visual,
depth,
@@ -320,7 +320,7 @@ impl Surface {
buffer_completion_type,
delete_window_atom,
close_requested: false,
- })
+ }
}
}
@@ -367,8 +367,8 @@ impl Surface {
// The touch event *must* be first per the Linux input subsystem's guidance.
let events = &[
virtio_input_event::touch(pressed),
- virtio_input_event::absolute_x(max(0, button_event.x) as u32),
- virtio_input_event::absolute_y(max(0, button_event.y) as u32),
+ virtio_input_event::absolute_x(max(0, button_event.x)),
+ virtio_input_event::absolute_y(max(0, button_event.y)),
];
self.dispatch_to_event_devices(events, EventDeviceKind::Touchscreen);
}
@@ -377,8 +377,8 @@ impl Surface {
if motion.state & xlib::Button1Mask != 0 {
let events = &[
virtio_input_event::touch(true),
- virtio_input_event::absolute_x(max(0, motion.x) as u32),
- virtio_input_event::absolute_y(max(0, motion.y) as u32),
+ virtio_input_event::absolute_x(max(0, motion.x)),
+ virtio_input_event::absolute_y(max(0, motion.y)),
];
self.dispatch_to_event_devices(events, EventDeviceKind::Touchscreen);
}
@@ -730,7 +730,7 @@ impl DisplayT for DisplayX {
self.visual,
width,
height,
- )?;
+ );
let new_surface_id = self.next_id;
self.surfaces.insert(new_surface_id, new_surface);
self.next_id = ObjectId::new(self.next_id.get() + 1).unwrap();
diff --git a/gpu_display/src/lib.rs b/gpu_display/src/lib.rs
index 0ffae8792..1f8ea0f13 100644
--- a/gpu_display/src/lib.rs
+++ b/gpu_display/src/lib.rs
@@ -15,6 +15,7 @@ mod gpu_display_stub;
mod gpu_display_wl;
#[cfg(feature = "x")]
mod gpu_display_x;
+#[cfg(feature = "x")]
mod keycode_converter;
pub use event_device::{EventDevice, EventDeviceKind};
diff --git a/hypervisor/Cargo.toml b/hypervisor/Cargo.toml
index aef0ce4a8..70c0ec41c 100644
--- a/hypervisor/Cargo.toml
+++ b/hypervisor/Cargo.toml
@@ -12,7 +12,7 @@ enumn = { path = "../enumn" }
kvm = { path = "../kvm" }
kvm_sys = { path = "../kvm_sys" }
libc = "*"
-msg_socket = { path = "../msg_socket" }
+serde = { version = "1", features = [ "derive" ] }
sync = { path = "../sync" }
base = { path = "../base" }
vm_memory = { path = "../vm_memory" }
diff --git a/hypervisor/src/kvm/mod.rs b/hypervisor/src/kvm/mod.rs
index 24955c287..78708beca 100644
--- a/hypervisor/src/kvm/mod.rs
+++ b/hypervisor/src/kvm/mod.rs
@@ -16,9 +16,11 @@ use std::cell::RefCell;
use std::cmp::{min, Reverse};
use std::collections::{BTreeMap, BinaryHeap};
use std::convert::TryFrom;
+use std::ffi::CString;
use std::mem::{size_of, ManuallyDrop};
-use std::os::raw::{c_char, c_int, c_ulong, c_void};
-use std::os::unix::io::AsRawFd;
+use std::os::raw::{c_int, c_ulong, c_void};
+use std::os::unix::{io::AsRawFd, prelude::OsStrExt};
+use std::path::{Path, PathBuf};
use std::ptr::copy_nonoverlapping;
use std::sync::atomic::AtomicU64;
use std::sync::Arc;
@@ -30,8 +32,8 @@ use libc::{
use base::{
block_signal, errno_result, error, ioctl, ioctl_with_mut_ref, ioctl_with_ref, ioctl_with_val,
pagesize, signal, unblock_signal, AsRawDescriptor, Error, Event, FromRawDescriptor,
- MappedRegion, MemoryMapping, MemoryMappingBuilder, MmapError, Protection, RawDescriptor,
- Result, SafeDescriptor,
+ MappedRegion, MemoryMapping, MemoryMappingBuilder, MemoryMappingBuilderUnix, MmapError,
+ Protection, RawDescriptor, Result, SafeDescriptor,
};
use data_model::vec_with_array_field;
use kvm_sys::*;
@@ -94,11 +96,10 @@ pub struct Kvm {
type KvmCap = kvm::Cap;
impl Kvm {
- /// Opens `/dev/kvm/` and returns a Kvm object on success.
- pub fn new() -> Result<Kvm> {
- // Open calls are safe because we give a constant nul-terminated string and verify the
- // result.
- let ret = unsafe { open("/dev/kvm\0".as_ptr() as *const c_char, O_RDWR | O_CLOEXEC) };
+ pub fn new_with_path(device_path: &Path) -> Result<Kvm> {
+ // Open calls are safe because we give a nul-terminated string and verify the result.
+ let c_path = CString::new(device_path.as_os_str().as_bytes()).unwrap();
+ let ret = unsafe { open(c_path.as_ptr(), O_RDWR | O_CLOEXEC) };
if ret < 0 {
return errno_result();
}
@@ -108,6 +109,11 @@ impl Kvm {
})
}
+ /// Opens `/dev/kvm/` and returns a Kvm object on success.
+ pub fn new() -> Result<Kvm> {
+ Kvm::new_with_path(&PathBuf::from("/dev/kvm"))
+ }
+
/// Gets the size of the mmap required to use vcpu's `kvm_run` structure.
pub fn get_vcpu_mmap_size(&self) -> Result<usize> {
// Safe because we know that our file is a KVM fd and we verify the return result.
@@ -166,7 +172,7 @@ impl KvmVm {
}
// Safe because we verify that ret is valid and we own the fd.
let vm_descriptor = unsafe { SafeDescriptor::from_raw_descriptor(ret) };
- guest_mem.with_regions(|index, guest_addr, size, host_addr, _| {
+ guest_mem.with_regions(|index, guest_addr, size, host_addr, _, _| {
unsafe {
// Safe because the guest regions are guaranteed not to overlap.
set_user_memory_region(
@@ -448,7 +454,11 @@ impl Vm for KvmVm {
read_only: bool,
log_dirty_pages: bool,
) -> Result<MemSlot> {
- let size = mem.size() as u64;
+ let pgsz = pagesize() as u64;
+ // KVM require to set the user memory region with page size aligned size. Safe to extend
+ // the mem.size() to be page size aligned because the mmap will round up the size to be
+ // page size aligned if it is not.
+ let size = (mem.size() as u64 + pgsz - 1) / pgsz * pgsz;
let end_addr = guest_addr
.checked_add(size)
.ok_or_else(|| Error::new(EOVERFLOW))?;
@@ -691,11 +701,15 @@ impl Vcpu for KvmVcpu {
// AcqRel ordering is sufficient to ensure only one thread gets to set its fingerprint to
// this Vcpu and subsequent `run` calls will see the fingerprint.
- if self.vcpu_run_handle_fingerprint.compare_and_swap(
- 0,
- vcpu_run_handle.fingerprint().as_u64(),
- std::sync::atomic::Ordering::AcqRel,
- ) != 0
+ if self
+ .vcpu_run_handle_fingerprint
+ .compare_exchange(
+ 0,
+ vcpu_run_handle.fingerprint().as_u64(),
+ std::sync::atomic::Ordering::AcqRel,
+ std::sync::atomic::Ordering::Acquire,
+ )
+ .is_err()
{
return Err(Error::new(EBUSY));
}
@@ -867,7 +881,7 @@ impl Vcpu for KvmVcpu {
// The pointer is page aligned so casting to a different type is well defined, hence the clippy
// allow attribute.
fn run(&self, run_handle: &VcpuRunHandle) -> Result<VcpuExit> {
- // Acquire is used to ensure this check is ordered after the `compare_and_swap` in `run`.
+ // Acquire is used to ensure this check is ordered after the `compare_exchange` in `run`.
if self
.vcpu_run_handle_fingerprint
.load(std::sync::atomic::Ordering::Acquire)
diff --git a/hypervisor/src/kvm/x86_64.rs b/hypervisor/src/kvm/x86_64.rs
index b43a02d78..6cfd394a6 100644
--- a/hypervisor/src/kvm/x86_64.rs
+++ b/hypervisor/src/kvm/x86_64.rs
@@ -20,7 +20,6 @@ use crate::{
ClockState, CpuId, CpuIdEntry, DebugRegs, DescriptorTable, DeviceKind, Fpu, HypervisorX86_64,
IoapicRedirectionTableEntry, IoapicState, IrqSourceChip, LapicState, PicSelect, PicState,
PitChannelState, PitState, Register, Regs, Segment, Sregs, VcpuX86_64, VmCap, VmX86_64,
- NUM_IOAPIC_PINS,
};
type KvmCpuId = kvm::CpuId;
@@ -280,12 +279,12 @@ impl KvmVm {
}
/// Enable support for split-irqchip.
- pub fn enable_split_irqchip(&self) -> Result<()> {
+ pub fn enable_split_irqchip(&self, ioapic_pins: usize) -> Result<()> {
let mut cap = kvm_enable_cap {
cap: KVM_CAP_SPLIT_IRQCHIP,
..Default::default()
};
- cap.args[0] = NUM_IOAPIC_PINS as u64;
+ cap.args[0] = ioapic_pins as u64;
// safe becuase we allocated the struct and we know the kernel will read
// exactly the size of the struct
let ret = unsafe { ioctl_with_ref(self, KVM_ENABLE_CAP(), &cap) };
diff --git a/hypervisor/src/lib.rs b/hypervisor/src/lib.rs
index cb534114d..ac3642b79 100644
--- a/hypervisor/src/lib.rs
+++ b/hypervisor/src/lib.rs
@@ -13,8 +13,9 @@ pub mod x86_64;
use std::os::raw::c_int;
use std::os::unix::io::AsRawFd;
-use base::{Event, MappedRegion, Protection, RawDescriptor, Result, SafeDescriptor};
-use msg_socket::MsgOnSocket;
+use serde::{Deserialize, Serialize};
+
+use base::{Event, MappedRegion, Protection, Result, SafeDescriptor};
use vm_memory::{GuestAddress, GuestMemory};
#[cfg(any(target_arch = "arm", target_arch = "aarch64"))]
@@ -281,7 +282,7 @@ pub trait Vcpu: downcast_rs::DowncastSync {
downcast_rs::impl_downcast!(sync Vcpu);
/// An address either in programmable I/O space or in memory mapped I/O space.
-#[derive(Copy, Clone, Debug, MsgOnSocket, PartialEq, Eq, std::hash::Hash)]
+#[derive(Copy, Clone, Debug, Serialize, Deserialize, PartialEq, Eq, std::hash::Hash)]
pub enum IoEventAddress {
Pio(u64),
Mmio(u64),
diff --git a/hypervisor/src/x86_64.rs b/hypervisor/src/x86_64.rs
index ae4cb4433..b880689b2 100644
--- a/hypervisor/src/x86_64.rs
+++ b/hypervisor/src/x86_64.rs
@@ -2,10 +2,12 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
-use base::{error, RawDescriptor, Result};
+use serde::{Deserialize, Serialize};
+
+use base::{error, Result};
use bit_field::*;
use downcast_rs::impl_downcast;
-use msg_socket::MsgOnSocket;
+
use vm_memory::GuestAddress;
use crate::{Hypervisor, IrqRoute, IrqSource, IrqSourceChip, Vcpu, Vm};
@@ -566,7 +568,7 @@ pub struct DebugRegs {
}
/// State of one VCPU register. Currently used for MSRs and XCRs.
-#[derive(Debug, Default, Copy, Clone, MsgOnSocket)]
+#[derive(Debug, Default, Copy, Clone, Serialize, Deserialize)]
pub struct Register {
pub id: u32,
pub value: u64,
diff --git a/integration_tests/guest_under_test/Dockerfile b/integration_tests/guest_under_test/Dockerfile
index aa0523df5..d0813b257 100644
--- a/integration_tests/guest_under_test/Dockerfile
+++ b/integration_tests/guest_under_test/Dockerfile
@@ -1,8 +1,8 @@
# Copyright 2020 The Chromium OS Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
-
-FROM amd64/alpine:3.12
+ARG ARCH
+FROM ${ARCH}/alpine:3.12
RUN apk add --no-cache pciutils
diff --git a/integration_tests/guest_under_test/Makefile b/integration_tests/guest_under_test/Makefile
index fa34edb4f..50ca767bd 100644
--- a/integration_tests/guest_under_test/Makefile
+++ b/integration_tests/guest_under_test/Makefile
@@ -8,24 +8,40 @@
# target/guest_under_test/bzImage
# target/guest_under_test/rootfs
+ARCH ?= $(shell arch)
+ifeq ($(ARCH), x86_64)
+ KERNEL_ARCH=x86
+ KERNEL_CONFIG=arch/x86/configs/chromiumos-container-vm-x86_64_defconfig
+ KERNEL_BINARY=bzImage
+ DOCKER_ARCH=amd64
+ CROSS_COMPILE=
+ RUSTFLAGS=
+else ifeq ($(ARCH), aarch64)
+ KERNEL_ARCH=arm64
+ KERNEL_CONFIG=arch/arm64/configs/chromiumos-container-vm-arm64_defconfig
+ KERNEL_BINARY=Image
+ DOCKER_ARCH=arm64v8
+ CROSS_COMPILE=aarch64-linux-gnu-
+ RUSTFLAGS="-Clinker=aarch64-linux-gnu-ld"
+else
+ $(error Only x86_64 or aarch64 are supported)
+endif
+
+# Build against the musl toolchain, which will produce a statically linked,
+# portable binary that we can run on the alpine linux guest without needing
+# libc at runtime
+RUST_TARGET ?= $(ARCH)-unknown-linux-musl
+
# We are building everything in target/guest_under_test
CARGO_TARGET ?= $(shell cargo metadata --no-deps --format-version 1 | \
jq -r ".target_directory")
-TARGET ?= $(CARGO_TARGET)/guest_under_test
+TARGET ?= $(CARGO_TARGET)/guest_under_test/$(ARCH)
$(shell mkdir -p $(TARGET))
-# Currently only x86_64 is tested and supported.
-ARCH = $(shell arch)
-
# Parameteters for building the kernel locally
KERNEL_REPO ?= https://chromium.googlesource.com/chromiumos/third_party/kernel
KERNEL_BRANCH ?= chromeos-4.19
-# Build against the musl toolchain, which will produce a statically linked,
-# portable binary that we can run on the alpine linux guest without needing
-# libc at runtime
-RUST_TARGET ?= $(ARCH)-unknown-linux-musl
-
################################################################################
# Main targets
@@ -42,8 +58,8 @@ dockerfile := $(shell pwd)/Dockerfile
# Build rootfs from Dockerfile and export into squashfs
$(TARGET)/rootfs: $(TARGET)/rootfs-build/delegate
# Build image from Dockerfile
- cd $(TARGET)/rootfs-build && docker build -t crosvm_integration_test . \
- -f $(dockerfile)
+ docker build -t crosvm_integration_test $(TARGET)/rootfs-build \
+ -f $(dockerfile) --build-arg ARCH=$(DOCKER_ARCH)
# Create container and export into squashfs, and don't forget to clean up
# the container afterwards.
@@ -55,21 +71,19 @@ $(TARGET)/rootfs: $(TARGET)/rootfs-build/delegate
# Build and copy delegate binary into rootfs build directory
$(TARGET)/rootfs-build/delegate: delegate.rs
rustup target add $(RUST_TARGET)
- rustc --edition=2018 delegate.rs --out-dir $(@D) --target $(RUST_TARGET)
+ rustc --edition=2018 delegate.rs --out-dir $(@D) \
+ $(RUSTFLAGS) --target $(RUST_TARGET)
################################################################################
# Build kernel
-ifeq ($(ARCH), x86_64)
- KERNEL_CONFIG ?= arch/x86/configs/chromiumos-container-vm-x86_64_defconfig
-else
- $(error Only x86_64 is supported)
-endif
-
$(TARGET)/bzImage: $(TARGET)/kernel-source $(TARGET)/kernel-build
cd $(TARGET)/kernel-source && \
- yes "" | make O=$(TARGET)/kernel-build -j$(shell nproc) bzImage
- cp $(TARGET)/kernel-build/arch/x86/boot/bzImage $@
+ yes "" | make \
+ O=$(TARGET)/kernel-build \
+ ARCH=$(KERNEL_ARCH) CROSS_COMPILE=$(CROSS_COMPILE) \
+ -j$(shell nproc) $(KERNEL_BINARY)
+ cp $(TARGET)/kernel-build/arch/${KERNEL_ARCH}/boot/$(KERNEL_BINARY) $@
$(TARGET)/kernel-build: $(TARGET)/kernel-source
mkdir -p $@
diff --git a/integration_tests/guest_under_test/PREBUILT_VERSION b/integration_tests/guest_under_test/PREBUILT_VERSION
index 26dc4f9ed..75d30fb53 100644
--- a/integration_tests/guest_under_test/PREBUILT_VERSION
+++ b/integration_tests/guest_under_test/PREBUILT_VERSION
@@ -1 +1 @@
-r0000
+r0001
diff --git a/integration_tests/guest_under_test/upload_prebuilts.sh b/integration_tests/guest_under_test/upload_prebuilts.sh
index 4b362e2d1..1383badc5 100755
--- a/integration_tests/guest_under_test/upload_prebuilts.sh
+++ b/integration_tests/guest_under_test/upload_prebuilts.sh
@@ -14,31 +14,36 @@ set -e
cd "${0%/*}"
readonly PREBUILT_VERSION="$(cat ./PREBUILT_VERSION)"
-
-# Cloud storage files
readonly GS_BUCKET="gs://chromeos-localmirror/distfiles"
readonly GS_PREFIX="${GS_BUCKET}/crosvm-testing"
-readonly REMOTE_BZIMAGE="${GS_PREFIX}-bzimage-$(arch)-${PREBUILT_VERSION}"
-readonly REMOTE_ROOTFS="${GS_PREFIX}-rootfs-$(arch)-${PREBUILT_VERSION}"
-
-# Local files
-CARGO_TARGET=$(cargo metadata --no-deps --format-version 1 |
- jq -r ".target_directory")
-LOCAL_BZIMAGE=${CARGO_TARGET}/guest_under_test/bzImage
-LOCAL_ROOTFS=${CARGO_TARGET}/guest_under_test/rootfs
function prebuilts_exist_error() {
echo "Prebuilts of version ${PREBUILT_VERSION} already exist. See README.md"
exit 1
}
-echo "Checking if prebuilts already exist."
-gsutil stat "${REMOTE_BZIMAGE}" && prebuilts_exist_error
-gsutil stat "${REMOTE_ROOTFS}" && prebuilts_exist_error
+function upload() {
+ local arch=$1
+ local remote_bzimage="${GS_PREFIX}-bzimage-${arch}-${PREBUILT_VERSION}"
+ local remote_rootfs="${GS_PREFIX}-rootfs-${arch}-${PREBUILT_VERSION}"
+
+ # Local files
+ local cargo_target=$(cargo metadata --no-deps --format-version 1 |
+ jq -r ".target_directory")
+ local local_bzimage=${cargo_target}/guest_under_test/${arch}/bzImage
+ local local_rootfs=${cargo_target}/guest_under_test/${arch}/rootfs
-echo "Building rootfs and kernel."
-make "${LOCAL_BZIMAGE}" "${LOCAL_ROOTFS}"
+ echo "Checking if prebuilts already exist."
+ gsutil stat "${remote_bzimage}" && prebuilts_exist_error
+ gsutil stat "${remote_rootfs}" && prebuilts_exist_error
+
+ echo "Building rootfs and kernel."
+ make ARCH=${arch} "${local_bzimage}" "${local_rootfs}"
+
+ echo "Uploading files."
+ gsutil cp -n -a public-read "${local_bzimage}" "${remote_bzimage}"
+ gsutil cp -n -a public-read "${local_rootfs}" "${remote_rootfs}"
+}
-echo "Uploading files."
-gsutil cp -n -a public-read "${LOCAL_BZIMAGE}" "${REMOTE_BZIMAGE}"
-gsutil cp -n -a public-read "${LOCAL_ROOTFS}" "${REMOTE_ROOTFS}"
+upload x86_64
+upload aarch64
diff --git a/integration_tests/run b/integration_tests/run
new file mode 100755
index 000000000..1d2cd15ce
--- /dev/null
+++ b/integration_tests/run
@@ -0,0 +1,11 @@
+#!/bin/bash
+# Copyright 2021 The Chromium OS Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style license that can be
+# found in the LICENSE file
+
+# We require the crosvm binary to build before running the integration tests.
+# There is an RFC for cargo to allow for this kind of dependency:
+# https://github.com/rust-lang/cargo/issues/9096
+cd $(dirname $0)
+(cd .. && cargo build $@)
+cargo test $@
diff --git a/integration_tests/tests/boot.rs b/integration_tests/tests/boot.rs
index bebd11935..877610c2c 100644
--- a/integration_tests/tests/boot.rs
+++ b/integration_tests/tests/boot.rs
@@ -2,11 +2,20 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
mod fixture;
-use crosvm::Config;
-use fixture::{TestVm, TestVmOptions};
+use fixture::TestVm;
#[test]
fn boot_test_vm() {
- let mut vm = TestVm::new(Config::default(), TestVmOptions::default()).unwrap();
+ let mut vm = TestVm::new(&[], false).unwrap();
+ assert_eq!(vm.exec_in_guest("echo 42").unwrap().trim(), "42");
+}
+
+#[test]
+fn boot_test_suspend_resume() {
+ // There is no easy way for us to check if the VM is actually suspended. But at
+ // least exercise the code-path.
+ let mut vm = TestVm::new(&[], false).unwrap();
+ vm.suspend().unwrap();
+ vm.resume().unwrap();
assert_eq!(vm.exec_in_guest("echo 42").unwrap().trim(), "42");
}
diff --git a/integration_tests/tests/fixture.rs b/integration_tests/tests/fixture.rs
index 57867eb5a..42a1604f6 100644
--- a/integration_tests/tests/fixture.rs
+++ b/integration_tests/tests/fixture.rs
@@ -2,9 +2,7 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
-use std::env;
use std::ffi::CString;
-use std::fs::File;
use std::io::{self, BufRead, BufReader, Write};
use std::path::{Path, PathBuf};
use std::process::Command;
@@ -12,11 +10,11 @@ use std::sync::mpsc::sync_channel;
use std::sync::Once;
use std::thread;
use std::time::Duration;
+use std::{env, process::Child};
+use std::{fs::File, process::Stdio};
use anyhow::{anyhow, Result};
-use arch::{set_default_serial_parameters, SerialHardware, SerialParameters, SerialType};
use base::syslog;
-use crosvm::{platform, Config, DiskOption, Executable};
use tempfile::TempDir;
const PREBUILT_URL: &str = "https://storage.googleapis.com/chromeos-localmirror/distfiles";
@@ -30,7 +28,7 @@ const ARCH: &str = "aarch64";
/// Timeout for communicating with the VM. If we do not hear back, panic so we
/// do not block the tests.
-const VM_COMMUNICATION_TIMEOUT: Duration = Duration::from_millis(1000);
+const VM_COMMUNICATION_TIMEOUT: Duration = Duration::from_secs(10);
fn prebuilt_version() -> &'static str {
include_str!("../guest_under_test/PREBUILT_VERSION").trim()
@@ -76,6 +74,22 @@ fn rootfs_path() -> PathBuf {
}
}
+/// The crosvm binary is expected to be alongside to the integration tests
+/// binary. Alternatively in the parent directory (cargo will put the
+/// test binary in target/debug/deps/ but the crosvm binary in target/debug).
+fn find_crosvm_binary() -> PathBuf {
+ let exe_dir = env::current_exe().unwrap().parent().unwrap().to_path_buf();
+ let first = exe_dir.join("crosvm");
+ if first.exists() {
+ return first;
+ }
+ let second = exe_dir.parent().unwrap().join("crosvm");
+ if second.exists() {
+ return second;
+ }
+ panic!("Cannot find ./crosvm or ../crosvm alongside test binary.");
+}
+
/// Safe wrapper for libc::mkfifo
fn mkfifo(path: &Path) -> io::Result<()> {
let cpath = CString::new(path.to_str().unwrap()).unwrap();
@@ -124,9 +138,18 @@ fn download_file(url: &str, destination: &Path) -> Result<()> {
}
}
-#[derive(Default)]
-pub struct TestVmOptions {
- pub debug: bool,
+fn crosvm_command(command: &str, args: &[&str]) -> Result<()> {
+ println!("$ crosvm {} {:?}", command, &args.join(" "));
+ let status = Command::new(find_crosvm_binary())
+ .arg(command)
+ .args(args)
+ .status()?;
+
+ if !status.success() {
+ Err(anyhow!("Command failed with exit code {}", status))
+ } else {
+ Ok(())
+ }
}
/// Test fixture to spin up a VM running a guest that can be communicated with.
@@ -139,8 +162,9 @@ pub struct TestVm {
test_dir: TempDir,
from_guest_reader: BufReader<File>,
to_guest: File,
- vm_thread: Option<thread::JoinHandle<()>>,
- options: TestVmOptions,
+ control_socket_path: PathBuf,
+ process: Child,
+ debug: bool,
}
impl TestVm {
@@ -188,85 +212,35 @@ impl TestVm {
// delegate binary.
// - ttyS1: Serial device attached to the named pipes.
fn configure_serial_devices(
- config: &mut Config,
+ command: &mut Command,
from_guest_pipe: &Path,
to_guest_pipe: &Path,
- debug: bool,
- ) -> Result<()> {
- for ((_, index), _) in &config.serial_parameters {
- if *index == 1 || *index == 2 {
- return Err(anyhow!("Do not specify serial device 1 or 2."));
- }
- }
-
- config.serial_parameters.insert(
- (SerialHardware::Serial, 1),
- SerialParameters {
- type_: if debug {
- SerialType::Stdout
- } else {
- SerialType::Sink
- },
- hardware: SerialHardware::Serial,
- path: None,
- input: None,
- num: 1,
- console: true,
- earlycon: false,
- stdin: false,
- },
+ ) {
+ command.args(&["--serial", "type=syslog"]);
+
+ // Setup channel for communication with the delegate.
+ let serial_params = format!(
+ "type=file,path={},input={},num=2",
+ from_guest_pipe.display(),
+ to_guest_pipe.display()
);
- config.serial_parameters.insert(
- (SerialHardware::Serial, 2),
- SerialParameters {
- type_: SerialType::File,
- hardware: SerialHardware::Serial,
- path: Some(PathBuf::from(from_guest_pipe)),
- input: Some(PathBuf::from(to_guest_pipe.clone())),
- num: 2,
- console: false,
- earlycon: false,
- stdin: false,
- },
- );
- set_default_serial_parameters(&mut config.serial_parameters);
- return Ok(());
+ command.args(&["--serial", &serial_params]);
}
/// Configures the VM kernel and rootfs to load from the guest_under_test assets.
- fn configure_kernel(config: &mut Config) -> Result<()> {
- for param in &config.params {
- if param.starts_with("root") || param.starts_with("init") {
- return Err(anyhow!("Do not set the root or init parameters."));
- }
- }
- config.executable_path = Some(Executable::Kernel(kernel_path()));
- config.params.push("root=/dev/vda ro".to_string());
- config.params.push("init=/bin/delegate".to_string());
- config.disks.insert(
- 0,
- DiskOption {
- id: None,
- path: rootfs_path(),
- read_only: true,
- sparse: true,
- block_size: 512,
- },
- );
-
- return Ok(());
+ fn configure_kernel(command: &mut Command) {
+ command
+ .args(&["--root", rootfs_path().to_str().unwrap()])
+ .args(&["--params", "init=/bin/delegate"])
+ .arg(kernel_path());
}
/// Instanciate a new crosvm instance. The first call will trigger the download of prebuilt
/// files if necessary.
- pub fn new(mut config: Config, options: TestVmOptions) -> Result<TestVm> {
+ pub fn new(additional_arguments: &[&str], debug: bool) -> Result<TestVm> {
static PREP_ONCE: Once = Once::new();
PREP_ONCE.call_once(|| TestVm::initialize_once());
- // TODO(b/173233134): Running sandboxed tests is going to require a lot of configuration
- // on the host.
- config.sandbox = false;
-
// Create two named pipes to communicate with the guest.
let test_dir = TempDir::new()?;
let from_guest_pipe = test_dir.path().join("from_guest");
@@ -274,18 +248,22 @@ impl TestVm {
mkfifo(&from_guest_pipe)?;
mkfifo(&to_guest_pipe)?;
- TestVm::configure_serial_devices(
- &mut config,
- &from_guest_pipe,
- &to_guest_pipe,
- options.debug,
- )?;
- TestVm::configure_kernel(&mut config)?;
+ let control_socket_path = test_dir.path().join("control");
+
+ let mut command = Command::new(find_crosvm_binary());
+ command.args(&["run", "--disable-sandbox"]);
+ TestVm::configure_serial_devices(&mut command, &from_guest_pipe, &to_guest_pipe);
+ command.args(&["--socket", &control_socket_path.to_str().unwrap()]);
+ command.args(additional_arguments);
- // Run VM in a separate thread.
- let vm_thread = thread::spawn(move || {
- platform::run_config(config).expect("Cannot run VM.");
- });
+ TestVm::configure_kernel(&mut command);
+
+ println!("$ {:?}", command);
+ if !debug {
+ command.stdout(Stdio::null());
+ command.stderr(Stdio::null());
+ }
+ let process = command.spawn()?;
// Open pipes. Panic if we cannot connect after a timeout.
let (to_guest, from_guest) = panic_on_timeout(
@@ -303,8 +281,9 @@ impl TestVm {
test_dir,
from_guest_reader,
to_guest: to_guest?,
- vm_thread: Some(vm_thread),
- options,
+ control_socket_path,
+ process,
+ debug,
})
}
@@ -329,25 +308,28 @@ impl TestVm {
output.push_str(&line);
}
let trimmed = output.trim();
- if self.options.debug {
+ if self.debug {
println!("<- {:?}", trimmed);
}
Ok(trimmed.to_string())
}
+
+ pub fn stop(&self) -> Result<()> {
+ crosvm_command("stop", &[self.control_socket_path.to_str().unwrap()])
+ }
+
+ pub fn suspend(&self) -> Result<()> {
+ crosvm_command("suspend", &[self.control_socket_path.to_str().unwrap()])
+ }
+
+ pub fn resume(&self) -> Result<()> {
+ crosvm_command("resume", &[self.control_socket_path.to_str().unwrap()])
+ }
}
impl Drop for TestVm {
fn drop(&mut self) {
- if let Some(handle) = self.vm_thread.take() {
- // Run exit command to shut down the VM.
- writeln!(&mut self.to_guest, "exit").expect("Cannot send exit command.");
- // Wait for the VM to exit, but don't wait forever.
- panic_on_timeout(
- move || {
- handle.join().expect("Cannot join VM thread.");
- },
- VM_COMMUNICATION_TIMEOUT,
- );
- }
+ self.stop().unwrap();
+ self.process.wait().unwrap();
}
}
diff --git a/io_uring/Cargo.toml b/io_uring/Cargo.toml
index cbe5aa1e6..3b7cc7c36 100644
--- a/io_uring/Cargo.toml
+++ b/io_uring/Cargo.toml
@@ -5,13 +5,12 @@ authors = ["The Chromium OS Authors"]
edition = "2018"
[dependencies]
-data_model = { path = "../data_model" }
+data_model = { path = "../data_model" } # provided by ebuild
libc = "*"
-syscall_defines = { path = "../syscall_defines" }
-sync = { path = "../sync" }
-sys_util = { path = "../sys_util" }
+sync = { path = "../sync" } # provided by ebuild
+sys_util = { path = "../sys_util" } # provided by ebuild
[dev-dependencies]
-tempfile = { path = "../tempfile" }
+tempfile = { path = "../tempfile" } # provided by ebuild
[workspace]
diff --git a/io_uring/src/syscalls.rs b/io_uring/src/syscalls.rs
index e3c9f8630..f3a3b41a0 100644
--- a/io_uring/src/syscalls.rs
+++ b/io_uring/src/syscalls.rs
@@ -6,8 +6,7 @@ use std::io::Error;
use std::os::unix::io::RawFd;
use std::ptr::null_mut;
-use libc::{c_int, c_long, c_void};
-use syscall_defines::linux::LinuxSyscall::*;
+use libc::{c_int, c_long, c_void, syscall, SYS_io_uring_enter, SYS_io_uring_setup};
use crate::bindings::*;
@@ -15,7 +14,7 @@ use crate::bindings::*;
pub type Result<T> = std::result::Result<T, c_int>;
pub unsafe fn io_uring_setup(num_entries: usize, params: &io_uring_params) -> Result<RawFd> {
- let ret = libc::syscall(
+ let ret = syscall(
SYS_io_uring_setup as c_long,
num_entries as c_int,
params as *const _,
@@ -27,7 +26,7 @@ pub unsafe fn io_uring_setup(num_entries: usize, params: &io_uring_params) -> Re
}
pub unsafe fn io_uring_enter(fd: RawFd, to_submit: u64, to_wait: u64, flags: u32) -> Result<()> {
- let ret = libc::syscall(
+ let ret = syscall(
SYS_io_uring_enter as c_long,
fd,
to_submit as c_int,
diff --git a/io_uring/src/uring.rs b/io_uring/src/uring.rs
index 6586a71b4..d67447379 100644
--- a/io_uring/src/uring.rs
+++ b/io_uring/src/uring.rs
@@ -1338,15 +1338,7 @@ mod tests {
}
mem::drop(c);
- // Now add NOPs to wake up any threads blocked on the syscall.
- for i in 0..NUM_THREADS {
- uring.add_nop((num_entries * 3 + i) as UserData).unwrap();
- }
- uring.submit().unwrap();
-
- for t in threads {
- t.join().unwrap();
- }
+ // Let the OS clean up the still-waiting threads after the test run.
}
#[test]
@@ -1431,7 +1423,9 @@ mod tests {
);
}
+ // TODO(b/183722981): Fix and re-enable test
#[test]
+ #[ignore]
fn multi_thread_submit_and_complete() {
const NUM_SUBMITTERS: usize = 7;
const NUM_COMPLETERS: usize = 3;
diff --git a/kvm/Cargo.toml b/kvm/Cargo.toml
index d2e94cad7..82cb31b6e 100644
--- a/kvm/Cargo.toml
+++ b/kvm/Cargo.toml
@@ -8,7 +8,6 @@ edition = "2018"
data_model = { path = "../data_model" }
kvm_sys = { path = "../kvm_sys" }
libc = "*"
-msg_socket = { path = "../msg_socket" }
base = { path = "../base" }
sync = { path = "../sync" }
vm_memory = { path = "../vm_memory" }
diff --git a/kvm/src/lib.rs b/kvm/src/lib.rs
index 357020af2..8945fb275 100644
--- a/kvm/src/lib.rs
+++ b/kvm/src/lib.rs
@@ -9,10 +9,13 @@ mod cap;
use std::cell::RefCell;
use std::cmp::{min, Ordering};
use std::collections::{BTreeMap, BinaryHeap};
+use std::ffi::CString;
use std::fs::File;
use std::mem::size_of;
use std::ops::{Deref, DerefMut};
use std::os::raw::*;
+use std::os::unix::prelude::OsStrExt;
+use std::path::{Path, PathBuf};
use std::ptr::copy_nonoverlapping;
use std::sync::Arc;
use sync::Mutex;
@@ -34,7 +37,6 @@ use base::{
ioctl_with_val, pagesize, signal, unblock_signal, warn, Error, Event, IoctlNr, MappedRegion,
MemoryMapping, MemoryMappingBuilder, MmapError, Result, SIGRTMIN,
};
-use msg_socket::MsgOnSocket;
use vm_memory::{GuestAddress, GuestMemory};
pub use crate::cap::*;
@@ -94,9 +96,14 @@ pub struct Kvm {
impl Kvm {
/// Opens `/dev/kvm/` and returns a Kvm object on success.
pub fn new() -> Result<Kvm> {
- // Open calls are safe because we give a constant nul-terminated string and verify the
- // result.
- let ret = unsafe { open("/dev/kvm\0".as_ptr() as *const c_char, O_RDWR | O_CLOEXEC) };
+ Kvm::new_with_path(&PathBuf::from("/dev/kvm"))
+ }
+
+ /// Opens a KVM device at `device_path` and returns a Kvm object on success.
+ pub fn new_with_path(device_path: &Path) -> Result<Kvm> {
+ // Open calls are safe because we give a nul-terminated string and verify the result.
+ let c_path = CString::new(device_path.as_os_str().as_bytes()).unwrap();
+ let ret = unsafe { open(c_path.as_ptr(), O_RDWR | O_CLOEXEC) };
if ret < 0 {
return errno_result();
}
@@ -200,7 +207,7 @@ impl AsRawDescriptor for Kvm {
}
/// An address either in programmable I/O space or in memory mapped I/O space.
-#[derive(Copy, Clone, Debug, MsgOnSocket)]
+#[derive(Copy, Clone, Debug)]
pub enum IoeventAddress {
Pio(u64),
Mmio(u64),
@@ -271,7 +278,7 @@ impl Vm {
if ret >= 0 {
// Safe because we verify the value of ret and we are the owners of the fd.
let vm_file = unsafe { File::from_raw_descriptor(ret) };
- guest_mem.with_regions(|index, guest_addr, size, host_addr, _| {
+ guest_mem.with_regions(|index, guest_addr, size, host_addr, _, _| {
unsafe {
// Safe because the guest regions are guaranteed not to overlap.
set_user_memory_region(
@@ -962,7 +969,7 @@ impl Vcpu {
let vcpu = unsafe { File::from_raw_descriptor(vcpu_fd) };
let run_mmap = MemoryMappingBuilder::new(run_mmap_size)
- .from_descriptor(&vcpu)
+ .from_file(&vcpu)
.build()
.map_err(|_| Error::new(ENOSPC))?;
diff --git a/kvm/tests/dirty_log.rs b/kvm/tests/dirty_log.rs
index 9fb5e5919..fb848df31 100644
--- a/kvm/tests/dirty_log.rs
+++ b/kvm/tests/dirty_log.rs
@@ -21,7 +21,7 @@ fn test_run() {
let guest_mem = GuestMemory::new(&[]).unwrap();
let mem = SharedMemory::anon(mem_size).expect("failed to create shared memory");
let mmap = MemoryMappingBuilder::new(mem_size as usize)
- .from_descriptor(&mem)
+ .from_shared_memory(&mem)
.build()
.expect("failed to create memory mapping");
@@ -48,7 +48,7 @@ fn test_run() {
GuestAddress(0),
Box::new(
MemoryMappingBuilder::new(mem_size as usize)
- .from_descriptor(&mem)
+ .from_shared_memory(&mem)
.build()
.expect("failed to create memory mapping"),
),
diff --git a/kvm/tests/read_only_memory.rs b/kvm/tests/read_only_memory.rs
index a2e5105df..2a2111063 100644
--- a/kvm/tests/read_only_memory.rs
+++ b/kvm/tests/read_only_memory.rs
@@ -23,7 +23,7 @@ fn test_run() {
let guest_mem = GuestMemory::new(&[]).unwrap();
let mem = SharedMemory::anon(mem_size).expect("failed to create shared memory");
let mmap = MemoryMappingBuilder::new(mem_size as usize)
- .from_descriptor(&mem)
+ .from_shared_memory(&mem)
.build()
.expect("failed to create memory mapping");
@@ -50,7 +50,7 @@ fn test_run() {
GuestAddress(0),
Box::new(
MemoryMappingBuilder::new(mem_size as usize)
- .from_descriptor(&mem)
+ .from_shared_memory(&mem)
.build()
.expect("failed to create memory mapping"),
),
@@ -63,7 +63,7 @@ fn test_run() {
// from it.
let mem_ro = SharedMemory::anon(0x1000).expect("failed to create shared memory");
let mmap_ro = MemoryMappingBuilder::new(0x1000)
- .from_descriptor(&mem_ro)
+ .from_shared_memory(&mem_ro)
.build()
.expect("failed to create memory mapping");
mmap_ro
@@ -73,7 +73,7 @@ fn test_run() {
GuestAddress(vcpu_sregs.es.base),
Box::new(
MemoryMappingBuilder::new(0x1000)
- .from_descriptor(&mem_ro)
+ .from_shared_memory(&mem_ro)
.build()
.expect("failed to create memory mapping"),
),
diff --git a/libcrosvm_control/Cargo.toml b/libcrosvm_control/Cargo.toml
new file mode 100644
index 000000000..2ed68b8a9
--- /dev/null
+++ b/libcrosvm_control/Cargo.toml
@@ -0,0 +1,13 @@
+[package]
+name = "libcrosvm_control"
+version = "0.1.0"
+authors = ["The Chromium OS Authors"]
+edition = "2018"
+
+[lib]
+crate-type = ["cdylib"]
+
+[dependencies]
+base = { path = "../base" }
+vm_control = { path = "../vm_control" }
+libc = "0.2.65"
diff --git a/libcrosvm_control/src/lib.rs b/libcrosvm_control/src/lib.rs
new file mode 100644
index 000000000..973d64e0a
--- /dev/null
+++ b/libcrosvm_control/src/lib.rs
@@ -0,0 +1,358 @@
+// Copyright 2021 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+//
+// Provides parts of crosvm as a library to communicate with running crosvm instances.
+// Usually you would need to invoke crosvm with subcommands and you'd get the result on
+// stdout.
+use std::convert::{TryFrom, TryInto};
+use std::ffi::CStr;
+use std::panic::catch_unwind;
+use std::path::{Path, PathBuf};
+
+use libc::{c_char, ssize_t};
+
+use vm_control::{
+ client::*, BalloonControlCommand, BalloonStats, DiskControlCommand, UsbControlAttachedDevice,
+ UsbControlResult, VmRequest, VmResponse,
+};
+
+fn validate_socket_path(socket_path: *const c_char) -> Option<PathBuf> {
+ if !socket_path.is_null() {
+ let socket_path = unsafe { CStr::from_ptr(socket_path) };
+ Some(PathBuf::from(socket_path.to_str().ok()?))
+ } else {
+ None
+ }
+}
+
+/// Stops the crosvm instance whose control socket is listening on `socket_path`.
+///
+/// The function returns true on success or false if an error occured.
+#[no_mangle]
+pub extern "C" fn crosvm_client_stop_vm(socket_path: *const c_char) -> bool {
+ catch_unwind(|| {
+ if let Some(socket_path) = validate_socket_path(socket_path) {
+ vms_request(&VmRequest::Exit, &socket_path).is_ok()
+ } else {
+ false
+ }
+ })
+ .unwrap_or(false)
+}
+
+/// Suspends the crosvm instance whose control socket is listening on `socket_path`.
+///
+/// The function returns true on success or false if an error occured.
+#[no_mangle]
+pub extern "C" fn crosvm_client_suspend_vm(socket_path: *const c_char) -> bool {
+ catch_unwind(|| {
+ if let Some(socket_path) = validate_socket_path(socket_path) {
+ vms_request(&VmRequest::Suspend, &socket_path).is_ok()
+ } else {
+ false
+ }
+ })
+ .unwrap_or(false)
+}
+
+/// Resumes the crosvm instance whose control socket is listening on `socket_path`.
+///
+/// The function returns true on success or false if an error occured.
+#[no_mangle]
+pub extern "C" fn crosvm_client_resume_vm(socket_path: *const c_char) -> bool {
+ catch_unwind(|| {
+ if let Some(socket_path) = validate_socket_path(socket_path) {
+ vms_request(&VmRequest::Resume, &socket_path).is_ok()
+ } else {
+ false
+ }
+ })
+ .unwrap_or(false)
+}
+
+/// Adjusts the balloon size of the crosvm instance whose control socket is
+/// listening on `socket_path`.
+///
+/// The function returns true on success or false if an error occured.
+#[no_mangle]
+pub extern "C" fn crosvm_client_balloon_vms(socket_path: *const c_char, num_bytes: u64) -> bool {
+ catch_unwind(|| {
+ if let Some(socket_path) = validate_socket_path(socket_path) {
+ let command = BalloonControlCommand::Adjust { num_bytes };
+ vms_request(&VmRequest::BalloonCommand(command), &socket_path).is_ok()
+ } else {
+ false
+ }
+ })
+ .unwrap_or(false)
+}
+
+/// Represents an individual attached USB device.
+#[repr(C)]
+pub struct UsbDeviceEntry {
+ /// Internal port index used for identifying this individual device.
+ port: u8,
+ /// USB vendor ID
+ vendor_id: u16,
+ /// USB product ID
+ product_id: u16,
+}
+
+impl From<&UsbControlAttachedDevice> for UsbDeviceEntry {
+ fn from(other: &UsbControlAttachedDevice) -> Self {
+ Self {
+ port: other.port,
+ vendor_id: other.vendor_id,
+ product_id: other.product_id,
+ }
+ }
+}
+
+/// Returns all USB devices passed through the crosvm instance whose control socket is listening on `socket_path`.
+///
+/// The function returns the amount of entries written.
+/// # Arguments
+///
+/// * `socket_path` - Path to the crosvm control socket
+/// * `entries` - Pointer to an array of `UsbDeviceEntry` where the details about the attached
+/// devices will be written to
+/// * `entries_length` - Amount of entries in the array specified by `entries`
+///
+/// Crosvm supports passing through up to 255 devices, so pasing an array with 255 entries will
+/// guarantee to return all entries.
+#[no_mangle]
+pub extern "C" fn crosvm_client_usb_list(
+ socket_path: *const c_char,
+ entries: *mut UsbDeviceEntry,
+ entries_length: ssize_t,
+) -> ssize_t {
+ catch_unwind(|| {
+ if let Some(socket_path) = validate_socket_path(socket_path) {
+ if let Ok(UsbControlResult::Devices(res)) = do_usb_list(&socket_path) {
+ let mut i = 0;
+ for entry in res.iter().filter(|x| x.valid()) {
+ if i >= entries_length {
+ break;
+ }
+ unsafe {
+ *entries.offset(i) = entry.into();
+ i += 1;
+ }
+ }
+ i
+ } else {
+ -1
+ }
+ } else {
+ -1
+ }
+ })
+ .unwrap_or(-1)
+}
+
+/// Attaches an USB device to crosvm instance whose control socket is listening on `socket_path`.
+///
+/// The function returns the amount of entries written.
+/// # Arguments
+///
+/// * `socket_path` - Path to the crosvm control socket
+/// * `bus` - USB device bus ID
+/// * `addr` - USB device address
+/// * `vid` - USB device vendor ID
+/// * `pid` - USB device product ID
+/// * `dev_path` - Path to the USB device (Most likely `/dev/bus/usb/<bus>/<addr>`).
+/// * `out_port` - (optional) internal port will be written here if provided.
+///
+/// The function returns true on success or false if an error occured.
+#[no_mangle]
+pub extern "C" fn crosvm_client_usb_attach(
+ socket_path: *const c_char,
+ bus: u8,
+ addr: u8,
+ vid: u16,
+ pid: u16,
+ dev_path: *const c_char,
+ out_port: *mut u8,
+) -> bool {
+ catch_unwind(|| {
+ if let Some(socket_path) = validate_socket_path(socket_path) {
+ if dev_path.is_null() {
+ return false;
+ }
+ let dev_path = Path::new(unsafe { CStr::from_ptr(dev_path) }.to_str().unwrap_or(""));
+
+ if let Ok(UsbControlResult::Ok { port }) =
+ do_usb_attach(&socket_path, bus, addr, vid, pid, dev_path)
+ {
+ if !out_port.is_null() {
+ unsafe { *out_port = port };
+ }
+ true
+ } else {
+ false
+ }
+ } else {
+ false
+ }
+ })
+ .unwrap_or(false)
+}
+
+/// Detaches an USB device from crosvm instance whose control socket is listening on `socket_path`.
+/// `port` determines device to be detached.
+///
+/// The function returns true on success or false if an error occured.
+#[no_mangle]
+pub extern "C" fn crosvm_client_usb_detach(socket_path: *const c_char, port: u8) -> bool {
+ catch_unwind(|| {
+ if let Some(socket_path) = validate_socket_path(socket_path) {
+ do_usb_detach(&socket_path, port).is_ok()
+ } else {
+ false
+ }
+ })
+ .unwrap_or(false)
+}
+
+/// Modifies the battery status of crosvm instance whose control socket is listening on
+/// `socket_path`.
+///
+/// The function returns true on success or false if an error occured.
+#[no_mangle]
+pub extern "C" fn crosvm_client_modify_battery(
+ socket_path: *const c_char,
+ battery_type: *const c_char,
+ property: *const c_char,
+ target: *const c_char,
+) -> bool {
+ catch_unwind(|| {
+ if let Some(socket_path) = validate_socket_path(socket_path) {
+ if battery_type.is_null() || property.is_null() || target.is_null() {
+ return false;
+ }
+ let battery_type = unsafe { CStr::from_ptr(battery_type) };
+ let property = unsafe { CStr::from_ptr(property) };
+ let target = unsafe { CStr::from_ptr(target) };
+
+ do_modify_battery(
+ &socket_path,
+ &battery_type.to_str().unwrap(),
+ &property.to_str().unwrap(),
+ &target.to_str().unwrap(),
+ )
+ .is_ok()
+ } else {
+ false
+ }
+ })
+ .unwrap_or(false)
+}
+
+/// Resizes the disk of the crosvm instance whose control socket is listening on `socket_path`.
+///
+/// The function returns true on success or false if an error occured.
+#[no_mangle]
+pub extern "C" fn crosvm_client_resize_disk(
+ socket_path: *const c_char,
+ disk_index: u64,
+ new_size: u64,
+) -> bool {
+ catch_unwind(|| {
+ if let Some(socket_path) = validate_socket_path(socket_path) {
+ if let Ok(disk_index) = usize::try_from(disk_index) {
+ let request = VmRequest::DiskCommand {
+ disk_index,
+ command: DiskControlCommand::Resize { new_size },
+ };
+ vms_request(&request, &socket_path).is_ok()
+ } else {
+ false
+ }
+ } else {
+ false
+ }
+ })
+ .unwrap_or(false)
+}
+
+/// Similar to internally used `BalloonStats` but using i64 instead of
+/// Option<u64>. `None` (or values bigger than i64::max) will be encoded as -1.
+#[repr(C)]
+pub struct BalloonStatsFfi {
+ swap_in: i64,
+ swap_out: i64,
+ major_faults: i64,
+ minor_faults: i64,
+ free_memory: i64,
+ total_memory: i64,
+ available_memory: i64,
+ disk_caches: i64,
+ hugetlb_allocations: i64,
+ hugetlb_failures: i64,
+}
+
+impl From<&BalloonStats> for BalloonStatsFfi {
+ fn from(other: &BalloonStats) -> Self {
+ let convert =
+ |x: Option<u64>| -> i64 { x.map(|y| y.try_into().ok()).flatten().unwrap_or(-1) };
+ Self {
+ swap_in: convert(other.swap_in),
+ swap_out: convert(other.swap_out),
+ major_faults: convert(other.major_faults),
+ minor_faults: convert(other.minor_faults),
+ free_memory: convert(other.free_memory),
+ total_memory: convert(other.total_memory),
+ available_memory: convert(other.available_memory),
+ disk_caches: convert(other.disk_caches),
+ hugetlb_allocations: convert(other.hugetlb_allocations),
+ hugetlb_failures: convert(other.hugetlb_failures),
+ }
+ }
+}
+
+/// Returns balloon stats of the crosvm instance whose control socket is listening on `socket_path`.
+///
+/// The parameters `stats` and `actual` are optional and will only be written to if they are
+/// non-null.
+///
+/// The function returns true on success or false if an error occured.
+///
+/// # Note
+///
+/// Entries in `BalloonStatsFfi` that are not available will be set to `-1`.
+#[no_mangle]
+pub extern "C" fn crosvm_client_balloon_stats(
+ socket_path: *const c_char,
+ stats: *mut BalloonStatsFfi,
+ actual: *mut u64,
+) -> bool {
+ catch_unwind(|| {
+ if let Some(socket_path) = validate_socket_path(socket_path) {
+ let request = &VmRequest::BalloonCommand(BalloonControlCommand::Stats {});
+ if let Ok(VmResponse::BalloonStats {
+ stats: ref balloon_stats,
+ balloon_actual,
+ }) = handle_request(request, &socket_path)
+ {
+ if !stats.is_null() {
+ unsafe {
+ *stats = balloon_stats.into();
+ }
+ }
+
+ if !actual.is_null() {
+ unsafe {
+ *actual = balloon_actual;
+ }
+ }
+ true
+ } else {
+ false
+ }
+ } else {
+ false
+ }
+ })
+ .unwrap_or(false)
+}
diff --git a/linux_input_sys/src/lib.rs b/linux_input_sys/src/lib.rs
index 1c9e801da..b48e9370e 100644
--- a/linux_input_sys/src/lib.rs
+++ b/linux_input_sys/src/lib.rs
@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
-use data_model::{DataInit, Le16, Le32};
+use data_model::{DataInit, Le16, SLe32};
use std::mem::size_of;
const EV_SYN: u16 = 0x00;
@@ -37,7 +37,7 @@ pub struct input_event {
pub timestamp_fields: [u64; 2],
pub type_: u16,
pub code: u16,
- pub value: u32,
+ pub value: i32,
}
// Safe because it only has data and has no implicit padding.
unsafe impl DataInit for input_event {}
@@ -64,7 +64,7 @@ impl InputEventDecoder for input_event {
virtio_input_event {
type_: Le16::from(e.type_),
code: Le16::from(e.code),
- value: Le32::from(e.value),
+ value: SLe32::from(e.value),
}
}
}
@@ -74,7 +74,7 @@ impl InputEventDecoder for input_event {
pub struct virtio_input_event {
pub type_: Le16,
pub code: Le16,
- pub value: Le32,
+ pub value: SLe32,
}
// Safe because it only has data and has no implicit padding.
@@ -97,46 +97,46 @@ impl virtio_input_event {
virtio_input_event {
type_: Le16::from(EV_SYN),
code: Le16::from(SYN_REPORT),
- value: Le32::from(0),
+ value: SLe32::from(0),
}
}
#[inline]
- pub fn absolute(code: u16, value: u32) -> virtio_input_event {
+ pub fn absolute(code: u16, value: i32) -> virtio_input_event {
virtio_input_event {
type_: Le16::from(EV_ABS),
code: Le16::from(code),
- value: Le32::from(value),
+ value: SLe32::from(value),
}
}
#[inline]
- pub fn multitouch_tracking_id(id: u32) -> virtio_input_event {
+ pub fn multitouch_tracking_id(id: i32) -> virtio_input_event {
Self::absolute(ABS_MT_TRACKING_ID, id)
}
#[inline]
- pub fn multitouch_slot(slot: u32) -> virtio_input_event {
+ pub fn multitouch_slot(slot: i32) -> virtio_input_event {
Self::absolute(ABS_MT_SLOT, slot)
}
#[inline]
- pub fn multitouch_absolute_x(x: u32) -> virtio_input_event {
+ pub fn multitouch_absolute_x(x: i32) -> virtio_input_event {
Self::absolute(ABS_MT_POSITION_X, x)
}
#[inline]
- pub fn multitouch_absolute_y(y: u32) -> virtio_input_event {
+ pub fn multitouch_absolute_y(y: i32) -> virtio_input_event {
Self::absolute(ABS_MT_POSITION_Y, y)
}
#[inline]
- pub fn absolute_x(x: u32) -> virtio_input_event {
+ pub fn absolute_x(x: i32) -> virtio_input_event {
Self::absolute(ABS_X, x)
}
#[inline]
- pub fn absolute_y(y: u32) -> virtio_input_event {
+ pub fn absolute_y(y: i32) -> virtio_input_event {
Self::absolute(ABS_Y, y)
}
@@ -155,7 +155,7 @@ impl virtio_input_event {
virtio_input_event {
type_: Le16::from(EV_KEY),
code: Le16::from(code),
- value: Le32::from(if pressed { 1 } else { 0 }),
+ value: SLe32::from(if pressed { 1 } else { 0 }),
}
}
}
diff --git a/msg_socket/Cargo.toml b/msg_socket/Cargo.toml
deleted file mode 100644
index 418b31132..000000000
--- a/msg_socket/Cargo.toml
+++ /dev/null
@@ -1,14 +0,0 @@
-[package]
-name = "msg_socket"
-version = "0.1.0"
-authors = ["The Chromium OS Authors"]
-edition = "2018"
-
-[dependencies]
-cros_async = { path = "../cros_async" }
-data_model = { path = "../data_model" }
-futures = "*"
-libc = "*"
-msg_on_socket_derive = { path = "msg_on_socket_derive" }
-base = { path = "../base" }
-sync = { path = "../sync" }
diff --git a/msg_socket/msg_on_socket_derive/Cargo.toml b/msg_socket/msg_on_socket_derive/Cargo.toml
deleted file mode 100644
index 206c03a78..000000000
--- a/msg_socket/msg_on_socket_derive/Cargo.toml
+++ /dev/null
@@ -1,15 +0,0 @@
-[package]
-name = "msg_on_socket_derive"
-version = "0.1.0"
-authors = ["The Chromium OS Authors"]
-edition = "2018"
-
-[dependencies]
-base = "*"
-proc-macro2 = "^1"
-quote = "^1"
-syn = "^1"
-
-[lib]
-proc-macro = true
-path = "msg_on_socket_derive.rs"
diff --git a/msg_socket/msg_on_socket_derive/msg_on_socket_derive.rs b/msg_socket/msg_on_socket_derive/msg_on_socket_derive.rs
deleted file mode 100644
index 98b354900..000000000
--- a/msg_socket/msg_on_socket_derive/msg_on_socket_derive.rs
+++ /dev/null
@@ -1,911 +0,0 @@
-// Copyright 2018 The Chromium OS Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style license that can be
-// found in the LICENSE file.
-
-#![recursion_limit = "256"]
-extern crate proc_macro;
-
-use std::vec::Vec;
-
-use proc_macro2::{Span, TokenStream};
-use quote::{format_ident, quote};
-use syn::{
- parse_macro_input, Data, DataEnum, DataStruct, DeriveInput, Fields, Ident, Index, Member, Meta,
- NestedMeta, Type,
-};
-
-/// The function that derives the recursive implementation for struct or enum.
-#[proc_macro_derive(MsgOnSocket, attributes(msg_on_socket))]
-pub fn msg_on_socket_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
- let input = parse_macro_input!(input as DeriveInput);
- let impl_for_input = socket_msg_impl(input);
- impl_for_input.into()
-}
-
-fn socket_msg_impl(input: DeriveInput) -> TokenStream {
- if !input.generics.params.is_empty() {
- return quote! {
- compile_error!("derive(SocketMsg) does not support generic parameters");
- };
- }
- match input.data {
- Data::Struct(ds) => {
- if is_named_struct(&ds) {
- impl_for_named_struct(input.ident, ds)
- } else {
- impl_for_tuple_struct(input.ident, ds)
- }
- }
- Data::Enum(de) => impl_for_enum(input.ident, de),
- _ => quote! {
- compile_error!("derive(SocketMsg) only support struct and enum");
- },
- }
-}
-
-fn is_named_struct(ds: &DataStruct) -> bool {
- matches!(&ds.fields, Fields::Named(_f))
-}
-
-/************************** Named Struct Impls ********************************************/
-
-struct StructField {
- member: Member,
- ty: Type,
- skipped: bool,
-}
-
-fn impl_for_named_struct(name: Ident, ds: DataStruct) -> TokenStream {
- let fields = get_struct_fields(ds);
- let uses_fd_impl = define_uses_fd_for_struct(&fields);
- let buffer_sizes_impls = define_buffer_size_for_struct(&fields);
-
- let read_buffer = define_read_buffer_for_struct(&name, &fields);
- let write_buffer = define_write_buffer_for_struct(&name, &fields);
- quote! {
- impl msg_socket::MsgOnSocket for #name {
- #uses_fd_impl
- #buffer_sizes_impls
- #read_buffer
- #write_buffer
- }
- }
-}
-
-// Flatten struct fields.
-fn get_struct_fields(ds: DataStruct) -> Vec<StructField> {
- let fields = match ds.fields {
- Fields::Named(fields_named) => fields_named.named,
- _ => {
- panic!("Struct must have named fields");
- }
- };
- let mut vec = Vec::new();
- for field in fields {
- let member = match field.ident {
- Some(ident) => Member::Named(ident),
- None => panic!("Unknown Error."),
- };
- let ty = field.ty;
- let mut skipped = false;
- for attr in field
- .attrs
- .iter()
- .filter(|attr| attr.path.is_ident("msg_on_socket"))
- {
- match attr.parse_meta().unwrap() {
- Meta::List(meta) => {
- for nested in meta.nested {
- match nested {
- NestedMeta::Meta(Meta::Path(ref meta_path))
- if meta_path.is_ident("skip") =>
- {
- skipped = true;
- }
- _ => panic!("unrecognized attribute meta `{}`", quote! { #nested }),
- }
- }
- }
- _ => panic!("unrecognized attribute `{}`", quote! { #attr }),
- }
- }
- vec.push(StructField {
- member,
- ty,
- skipped,
- });
- }
- vec
-}
-
-fn define_uses_fd_for_struct(fields: &[StructField]) -> TokenStream {
- let field_types: Vec<_> = fields
- .iter()
- .filter(|f| !f.skipped)
- .map(|f| &f.ty)
- .collect();
-
- if field_types.is_empty() {
- return quote!();
- }
-
- quote! {
- fn uses_descriptor() -> bool {
- #(<#field_types>::uses_descriptor())||*
- }
- }
-}
-
-fn define_buffer_size_for_struct(fields: &[StructField]) -> TokenStream {
- let (msg_size, fd_count) = get_fields_buffer_size_sum(fields);
- quote! {
- fn msg_size(&self) -> usize {
- #msg_size
- }
- fn descriptor_count(&self) -> usize {
- #fd_count
- }
- }
-}
-
-fn define_read_buffer_for_struct(_name: &Ident, fields: &[StructField]) -> TokenStream {
- let mut read_fields = Vec::new();
- let mut init_fields = Vec::new();
- for field in fields {
- let ident = match &field.member {
- Member::Named(ident) => ident,
- Member::Unnamed(_) => unreachable!(),
- };
- let name = ident.clone();
- if field.skipped {
- let ty = &field.ty;
- init_fields.push(quote! {
- #name: <#ty>::default()
- });
- continue;
- }
- let read_field = read_from_buffer_and_move_offset(&ident, &field.ty);
- read_fields.push(read_field);
- init_fields.push(quote!(#name));
- }
- quote! {
- unsafe fn read_from_buffer(
- buffer: &[u8],
- fds: &[RawDescriptor],
- ) -> msg_socket::MsgResult<(Self, usize)> {
- let mut __offset = 0usize;
- let mut __fd_offset = 0usize;
- #(#read_fields)*
- Ok((
- Self {
- #(#init_fields),*
- },
- __fd_offset
- ))
- }
- }
-}
-
-fn define_write_buffer_for_struct(_name: &Ident, fields: &[StructField]) -> TokenStream {
- let mut write_fields = Vec::new();
- for field in fields {
- if field.skipped {
- continue;
- }
- let ident = match &field.member {
- Member::Named(ident) => ident,
- Member::Unnamed(_) => unreachable!(),
- };
- let write_field = write_to_buffer_and_move_offset(&ident);
- write_fields.push(write_field);
- }
- quote! {
- fn write_to_buffer(
- &self,
- buffer: &mut [u8],
- fds: &mut [RawDescriptor],
- ) -> msg_socket::MsgResult<usize> {
- let mut __offset = 0usize;
- let mut __fd_offset = 0usize;
- #(#write_fields)*
- Ok(__fd_offset)
- }
- }
-}
-
-/************************** Enum Impls ********************************************/
-fn impl_for_enum(name: Ident, de: DataEnum) -> TokenStream {
- let uses_fd_impl = define_uses_fd_for_enum(&de);
- let buffer_sizes_impls = define_buffer_size_for_enum(&name, &de);
- let read_buffer = define_read_buffer_for_enum(&name, &de);
- let write_buffer = define_write_buffer_for_enum(&name, &de);
- quote! {
- impl msg_socket::MsgOnSocket for #name {
- #uses_fd_impl
- #buffer_sizes_impls
- #read_buffer
- #write_buffer
- }
- }
-}
-
-fn define_uses_fd_for_enum(de: &DataEnum) -> TokenStream {
- let mut variant_field_types = Vec::new();
- for variant in &de.variants {
- for variant_field_ty in variant.fields.iter().map(|f| &f.ty) {
- variant_field_types.push(variant_field_ty);
- }
- }
-
- if variant_field_types.is_empty() {
- return quote!();
- }
-
- quote! {
- fn uses_descriptor() -> bool {
- #(<#variant_field_types>::uses_descriptor())||*
- }
- }
-}
-
-fn define_buffer_size_for_enum(name: &Ident, de: &DataEnum) -> TokenStream {
- let mut msg_size_match_variants = Vec::new();
- let mut fd_count_match_variants = Vec::new();
-
- for variant in &de.variants {
- let variant_name = &variant.ident;
- match &variant.fields {
- Fields::Named(fields) => {
- let mut tmp_names = Vec::new();
- for field in &fields.named {
- tmp_names.push(field.ident.clone().unwrap());
- }
-
- let v = quote! {
- #name::#variant_name { #(#tmp_names),* } => #(#tmp_names.msg_size())+*,
- };
- msg_size_match_variants.push(v);
-
- let v = quote! {
- #name::#variant_name { #(#tmp_names),* } => #(#tmp_names.descriptor_count())+*,
- };
- fd_count_match_variants.push(v);
- }
- Fields::Unnamed(fields) => {
- let mut tmp_names = Vec::new();
- for idx in 0..fields.unnamed.len() {
- let tmp_name = format!("enum_field{}", idx);
- let tmp_name = Ident::new(&tmp_name, Span::call_site());
- tmp_names.push(tmp_name.clone());
- }
-
- let v = quote! {
- #name::#variant_name(#(#tmp_names),*) => #(#tmp_names.msg_size())+*,
- };
- msg_size_match_variants.push(v);
-
- let v = quote! {
- #name::#variant_name(#(#tmp_names),*) => #(#tmp_names.descriptor_count())+*,
- };
- fd_count_match_variants.push(v);
- }
- Fields::Unit => {
- let v = quote! {
- #name::#variant_name => 0,
- };
- msg_size_match_variants.push(v.clone());
- fd_count_match_variants.push(v);
- }
- }
- }
-
- quote! {
- fn msg_size(&self) -> usize {
- 1 + match self {
- #(#msg_size_match_variants)*
- }
- }
- fn descriptor_count(&self) -> usize {
- match self {
- #(#fd_count_match_variants)*
- }
- }
- }
-}
-
-fn define_read_buffer_for_enum(name: &Ident, de: &DataEnum) -> TokenStream {
- let mut match_variants = Vec::new();
- let de = de.clone();
- for (idx, variant) in de.variants.iter().enumerate() {
- let idx = idx as u8;
- let variant_name = &variant.ident;
- match &variant.fields {
- Fields::Named(fields) => {
- let mut tmp_names = Vec::new();
- let mut read_tmps = Vec::new();
- for f in &fields.named {
- tmp_names.push(f.ident.clone());
- let read_tmp =
- read_from_buffer_and_move_offset(f.ident.as_ref().unwrap(), &f.ty);
- read_tmps.push(read_tmp);
- }
- let v = quote! {
- #idx => {
- let mut __offset = 1usize;
- let mut __fd_offset = 0usize;
- #(#read_tmps)*
- Ok((#name::#variant_name { #(#tmp_names),* }, __fd_offset))
- }
- };
- match_variants.push(v);
- }
- Fields::Unnamed(fields) => {
- let mut tmp_names = Vec::new();
- let mut read_tmps = Vec::new();
- for (idx, field) in fields.unnamed.iter().enumerate() {
- let tmp_name = format_ident!("enum_field{}", idx);
- tmp_names.push(tmp_name.clone());
- let read_tmp = read_from_buffer_and_move_offset(&tmp_name, &field.ty);
- read_tmps.push(read_tmp);
- }
-
- let v = quote! {
- #idx => {
- let mut __offset = 1usize;
- let mut __fd_offset = 0usize;
- #(#read_tmps)*
- Ok((#name::#variant_name( #(#tmp_names),*), __fd_offset))
- }
- };
- match_variants.push(v);
- }
- Fields::Unit => {
- let v = quote! {
- #idx => Ok((#name::#variant_name, 0)),
- };
- match_variants.push(v);
- }
- }
- }
- quote! {
- unsafe fn read_from_buffer(
- buffer: &[u8],
- fds: &[RawDescriptor],
- ) -> msg_socket::MsgResult<(Self, usize)> {
- let v = buffer.get(0).ok_or(msg_socket::MsgError::WrongMsgBufferSize)?;
- match v {
- #(#match_variants)*
- _ => Err(msg_socket::MsgError::InvalidType),
- }
- }
- }
-}
-
-fn define_write_buffer_for_enum(name: &Ident, de: &DataEnum) -> TokenStream {
- let mut match_variants = Vec::new();
- let de = de.clone();
- for (idx, variant) in de.variants.iter().enumerate() {
- let idx = idx as u8;
- let variant_name = &variant.ident;
- match &variant.fields {
- Fields::Named(fields) => {
- let mut tmp_names = Vec::new();
- let mut write_tmps = Vec::new();
- for f in &fields.named {
- tmp_names.push(f.ident.clone().unwrap());
- let write_tmp =
- enum_write_to_buffer_and_move_offset(&f.ident.as_ref().unwrap());
- write_tmps.push(write_tmp);
- }
-
- let v = quote! {
- #name::#variant_name { #(#tmp_names),* } => {
- buffer[0] = #idx;
- let mut __offset = 1usize;
- let mut __fd_offset = 0usize;
- #(#write_tmps)*
- Ok(__fd_offset)
- }
- };
- match_variants.push(v);
- }
- Fields::Unnamed(fields) => {
- let mut tmp_names = Vec::new();
- let mut write_tmps = Vec::new();
- for idx in 0..fields.unnamed.len() {
- let tmp_name = format_ident!("enum_field{}", idx);
- tmp_names.push(tmp_name.clone());
- let write_tmp = enum_write_to_buffer_and_move_offset(&tmp_name);
- write_tmps.push(write_tmp);
- }
-
- let v = quote! {
- #name::#variant_name(#(#tmp_names),*) => {
- buffer[0] = #idx;
- let mut __offset = 1usize;
- let mut __fd_offset = 0usize;
- #(#write_tmps)*
- Ok(__fd_offset)
- }
- };
- match_variants.push(v);
- }
- Fields::Unit => {
- let v = quote! {
- #name::#variant_name => {
- buffer[0] = #idx;
- Ok(0)
- }
- };
- match_variants.push(v);
- }
- }
- }
-
- quote! {
- fn write_to_buffer(
- &self,
- buffer: &mut [u8],
- fds: &mut [RawDescriptor],
- ) -> msg_socket::MsgResult<usize> {
- if buffer.is_empty() {
- return Err(msg_socket::MsgError::WrongMsgBufferSize)
- }
- match self {
- #(#match_variants)*
- }
- }
- }
-}
-
-fn enum_write_to_buffer_and_move_offset(name: &Ident) -> TokenStream {
- quote! {
- let o = #name.write_to_buffer(&mut buffer[__offset..], &mut fds[__fd_offset..])?;
- __offset += #name.msg_size();
- __fd_offset += o;
- }
-}
-
-/************************** Tuple Impls ********************************************/
-fn impl_for_tuple_struct(name: Ident, ds: DataStruct) -> TokenStream {
- let fields = get_tuple_fields(ds);
-
- let uses_fd_impl = define_uses_fd_for_tuples(&fields);
- let buffer_sizes_impls = define_buffer_size_for_struct(&fields);
- let read_buffer = define_read_buffer_for_tuples(&name, &fields);
- let write_buffer = define_write_buffer_for_tuples(&name, &fields);
- quote! {
- impl msg_socket::MsgOnSocket for #name {
- #uses_fd_impl
- #buffer_sizes_impls
- #read_buffer
- #write_buffer
- }
- }
-}
-
-fn get_tuple_fields(ds: DataStruct) -> Vec<StructField> {
- let mut field_idents = Vec::new();
- let fields = match ds.fields {
- Fields::Unnamed(fields_unnamed) => fields_unnamed.unnamed,
- _ => {
- panic!("Tuple struct must have unnamed fields.");
- }
- };
- for (idx, field) in fields.iter().enumerate() {
- let member = Member::Unnamed(Index::from(idx));
- let ty = field.ty.clone();
- field_idents.push(StructField {
- member,
- ty,
- skipped: false,
- });
- }
- field_idents
-}
-
-fn define_uses_fd_for_tuples(fields: &[StructField]) -> TokenStream {
- if fields.is_empty() {
- return quote!();
- }
-
- let field_types = fields.iter().map(|f| &f.ty);
- quote! {
- fn uses_descriptor() -> bool {
- #(<#field_types>::uses_descriptor())||*
- }
- }
-}
-
-fn define_read_buffer_for_tuples(name: &Ident, fields: &[StructField]) -> TokenStream {
- let mut read_fields = Vec::new();
- let mut init_fields = Vec::new();
- for (idx, field) in fields.iter().enumerate() {
- let tmp_name = format!("tuple_tmp{}", idx);
- let tmp_name = Ident::new(&tmp_name, Span::call_site());
- let read_field = read_from_buffer_and_move_offset(&tmp_name, &field.ty);
- read_fields.push(read_field);
- init_fields.push(quote!(#tmp_name));
- }
-
- quote! {
- unsafe fn read_from_buffer(
- buffer: &[u8],
- fds: &[RawDescriptor],
- ) -> msg_socket::MsgResult<(Self, usize)> {
- let mut __offset = 0usize;
- let mut __fd_offset = 0usize;
- #(#read_fields)*
- Ok((
- #name (
- #(#init_fields),*
- ),
- __fd_offset
- ))
- }
- }
-}
-
-fn define_write_buffer_for_tuples(name: &Ident, fields: &[StructField]) -> TokenStream {
- let mut write_fields = Vec::new();
- let mut tmp_names = Vec::new();
- for idx in 0..fields.len() {
- let tmp_name = format_ident!("tuple_tmp{}", idx);
- let write_field = enum_write_to_buffer_and_move_offset(&tmp_name);
- write_fields.push(write_field);
- tmp_names.push(tmp_name);
- }
- quote! {
- fn write_to_buffer(
- &self,
- buffer: &mut [u8],
- fds: &mut [RawDescriptor],
- ) -> msg_socket::MsgResult<usize> {
- let mut __offset = 0usize;
- let mut __fd_offset = 0usize;
- let #name( #(#tmp_names),* ) = self;
- #(#write_fields)*
- Ok(__fd_offset)
- }
- }
-}
-/************************** Helpers ********************************************/
-fn get_fields_buffer_size_sum(fields: &[StructField]) -> (TokenStream, TokenStream) {
- let fields: Vec<_> = fields
- .iter()
- .filter(|f| !f.skipped)
- .map(|f| &f.member)
- .collect();
- if !fields.is_empty() {
- (
- quote! {
- #( self.#fields.msg_size() as usize )+*
- },
- quote! {
- #( self.#fields.descriptor_count() as usize )+*
- },
- )
- } else {
- (quote!(0), quote!(0))
- }
-}
-
-fn read_from_buffer_and_move_offset(name: &Ident, ty: &Type) -> TokenStream {
- quote! {
- let t = <#ty>::read_from_buffer(&buffer[__offset..], &fds[__fd_offset..])?;
- __offset += t.0.msg_size();
- __fd_offset += t.1;
- let #name = t.0;
- }
-}
-
-fn write_to_buffer_and_move_offset(name: &Ident) -> TokenStream {
- quote! {
- let o = self.#name.write_to_buffer(&mut buffer[__offset..], &mut fds[__fd_offset..])?;
- __offset += self.#name.msg_size();
- __fd_offset += o;
- }
-}
-
-#[cfg(test)]
-mod tests {
- use crate::socket_msg_impl;
- use quote::quote;
- use syn::{parse_quote, DeriveInput};
-
- #[test]
- fn end_to_end_struct_test() {
- let input: DeriveInput = parse_quote! {
- struct MyMsg {
- a: u8,
- b: RawDescriptor,
- c: u32,
- }
- };
-
- let expected = quote! {
- impl msg_socket::MsgOnSocket for MyMsg {
- fn uses_descriptor() -> bool {
- <u8>::uses_descriptor()
- || <RawDescriptor>::uses_descriptor()
- || <u32>::uses_descriptor()
- }
- fn msg_size(&self) -> usize {
- self.a.msg_size() as usize
- + self.b.msg_size() as usize
- + self.c.msg_size() as usize
- }
- fn descriptor_count(&self) -> usize {
- self.a.descriptor_count() as usize
- + self.b.descriptor_count() as usize
- + self.c.descriptor_count() as usize
- }
- unsafe fn read_from_buffer(
- buffer: &[u8],
- fds: &[RawDescriptor],
- ) -> msg_socket::MsgResult<(Self, usize)> {
- let mut __offset = 0usize;
- let mut __fd_offset = 0usize;
- let t = <u8>::read_from_buffer(&buffer[__offset..], &fds[__fd_offset..])?;
- __offset += t.0.msg_size();
- __fd_offset += t.1;
- let a = t.0;
- let t = <RawDescriptor>::read_from_buffer(
- &buffer[__offset..], &fds[__fd_offset..])?;
- __offset += t.0.msg_size();
- __fd_offset += t.1;
- let b = t.0;
- let t = <u32>::read_from_buffer(&buffer[__offset..], &fds[__fd_offset..])?;
- __offset += t.0.msg_size();
- __fd_offset += t.1;
- let c = t.0;
- Ok((Self { a, b, c }, __fd_offset))
- }
- fn write_to_buffer(
- &self,
- buffer: &mut [u8],
- fds: &mut [RawDescriptor],
- ) -> msg_socket::MsgResult<usize> {
- let mut __offset = 0usize;
- let mut __fd_offset = 0usize;
- let o = self
- .a
- .write_to_buffer(&mut buffer[__offset..], &mut fds[__fd_offset..])?;
- __offset += self.a.msg_size();
- __fd_offset += o;
- let o = self
- .b
- .write_to_buffer(&mut buffer[__offset..], &mut fds[__fd_offset..])?;
- __offset += self.b.msg_size();
- __fd_offset += o;
- let o = self
- .c
- .write_to_buffer(&mut buffer[__offset..], &mut fds[__fd_offset..])?;
- __offset += self.c.msg_size();
- __fd_offset += o;
- Ok(__fd_offset)
- }
- }
-
- };
-
- assert_eq!(socket_msg_impl(input).to_string(), expected.to_string());
- }
-
- #[test]
- fn end_to_end_tuple_struct_test() {
- let input: DeriveInput = parse_quote! {
- struct MyMsg(u8, u32, File);
- };
-
- let expected = quote! {
- impl msg_socket::MsgOnSocket for MyMsg {
- fn uses_descriptor() -> bool {
- <u8>::uses_descriptor() || <u32>::uses_descriptor() || <File>::uses_descriptor()
- }
- fn msg_size(&self) -> usize {
- self.0.msg_size() as usize
- + self.1.msg_size() as usize + self.2.msg_size() as usize
- }
- fn descriptor_count(&self) -> usize {
- self.0.descriptor_count() as usize
- + self.1.descriptor_count() as usize
- + self.2.descriptor_count() as usize
- }
- unsafe fn read_from_buffer(
- buffer: &[u8],
- fds: &[RawDescriptor],
- ) -> msg_socket::MsgResult<(Self, usize)> {
- let mut __offset = 0usize;
- let mut __fd_offset = 0usize;
- let t = <u8>::read_from_buffer(&buffer[__offset..], &fds[__fd_offset..])?;
- __offset += t.0.msg_size();
- __fd_offset += t.1;
- let tuple_tmp0 = t.0;
- let t = <u32>::read_from_buffer(&buffer[__offset..], &fds[__fd_offset..])?;
- __offset += t.0.msg_size();
- __fd_offset += t.1;
- let tuple_tmp1 = t.0;
- let t = <File>::read_from_buffer(&buffer[__offset..], &fds[__fd_offset..])?;
- __offset += t.0.msg_size();
- __fd_offset += t.1;
- let tuple_tmp2 = t.0;
- Ok((MyMsg(tuple_tmp0, tuple_tmp1, tuple_tmp2), __fd_offset))
- }
- fn write_to_buffer(
- &self,
- buffer: &mut [u8],
- fds: &mut [RawDescriptor],
- ) -> msg_socket::MsgResult<usize> {
- let mut __offset = 0usize;
- let mut __fd_offset = 0usize;
- let MyMsg(tuple_tmp0, tuple_tmp1, tuple_tmp2) = self;
- let o = tuple_tmp0.write_to_buffer(&mut buffer[__offset..], &mut fds[__fd_offset..])?;
- __offset += tuple_tmp0.msg_size();
- __fd_offset += o;
- let o = tuple_tmp1.write_to_buffer(&mut buffer[__offset..], &mut fds[__fd_offset..])?;
- __offset += tuple_tmp1.msg_size();
- __fd_offset += o;
- let o = tuple_tmp2.write_to_buffer(&mut buffer[__offset..], &mut fds[__fd_offset..])?;
- __offset += tuple_tmp2.msg_size();
- __fd_offset += o;
- Ok(__fd_offset)
- }
- }
- };
-
- assert_eq!(socket_msg_impl(input).to_string(), expected.to_string());
- }
-
- #[test]
- fn end_to_end_enum_test() {
- let input: DeriveInput = parse_quote! {
- enum MyMsg {
- A(u8),
- B,
- C {
- f0: u8,
- f1: RawDescriptor,
- },
- }
- };
-
- let expected = quote! {
- impl msg_socket::MsgOnSocket for MyMsg {
- fn uses_descriptor() -> bool {
- <u8>::uses_descriptor()
- || <u8>::uses_descriptor()
- || <RawDescriptor>::uses_descriptor()
- }
- fn msg_size(&self) -> usize {
- 1 + match self {
- MyMsg::A(enum_field0) => enum_field0.msg_size(),
- MyMsg::B => 0,
- MyMsg::C { f0, f1 } => f0.msg_size() + f1.msg_size(),
- }
- }
- fn descriptor_count(&self) -> usize {
- match self {
- MyMsg::A(enum_field0) => enum_field0.descriptor_count(),
- MyMsg::B => 0,
- MyMsg::C { f0, f1 } => f0.descriptor_count() + f1.descriptor_count(),
- }
- }
- unsafe fn read_from_buffer(
- buffer: &[u8],
- fds: &[RawDescriptor],
- ) -> msg_socket::MsgResult<(Self, usize)> {
- let v = buffer
- .get(0)
- .ok_or(msg_socket::MsgError::WrongMsgBufferSize)?;
- match v {
- 0u8 => {
- let mut __offset = 1usize;
- let mut __fd_offset = 0usize;
- let t = <u8>::read_from_buffer(&buffer[__offset..], &fds[__fd_offset..])?;
- __offset += t.0.msg_size();
- __fd_offset += t.1;
- let enum_field0 = t.0;
- Ok((MyMsg::A(enum_field0), __fd_offset))
- }
- 1u8 => Ok((MyMsg::B, 0)),
- 2u8 => {
- let mut __offset = 1usize;
- let mut __fd_offset = 0usize;
- let t = <u8>::read_from_buffer(&buffer[__offset..], &fds[__fd_offset..])?;
- __offset += t.0.msg_size();
- __fd_offset += t.1;
- let f0 = t.0;
- let t = <RawDescriptor>::read_from_buffer(&buffer[__offset..], &fds[__fd_offset..])?;
- __offset += t.0.msg_size();
- __fd_offset += t.1;
- let f1 = t.0;
- Ok((MyMsg::C { f0, f1 }, __fd_offset))
- }
- _ => Err(msg_socket::MsgError::InvalidType),
- }
- }
- fn write_to_buffer(
- &self,
- buffer: &mut [u8],
- fds: &mut [RawDescriptor],
- ) -> msg_socket::MsgResult<usize> {
- if buffer.is_empty() {
- return Err(msg_socket::MsgError::WrongMsgBufferSize)
- }
- match self {
- MyMsg::A(enum_field0) => {
- buffer[0] = 0u8;
- let mut __offset = 1usize;
- let mut __fd_offset = 0usize;
- let o = enum_field0
- .write_to_buffer(&mut buffer[__offset..], &mut fds[__fd_offset..])?;
- __offset += enum_field0.msg_size();
- __fd_offset += o;
- Ok(__fd_offset)
- }
- MyMsg::B => {
- buffer[0] = 1u8;
- Ok(0)
- }
- MyMsg::C { f0, f1 } => {
- buffer[0] = 2u8;
- let mut __offset = 1usize;
- let mut __fd_offset = 0usize;
- let o = f0.write_to_buffer(&mut buffer[__offset..], &mut fds[__fd_offset..])?;
- __offset += f0.msg_size();
- __fd_offset += o;
- let o = f1.write_to_buffer(&mut buffer[__offset..], &mut fds[__fd_offset..])?;
- __offset += f1.msg_size();
- __fd_offset += o;
- Ok(__fd_offset)
- }
- }
- }
- }
- };
-
- assert_eq!(socket_msg_impl(input).to_string(), expected.to_string());
- }
-
- #[test]
- fn end_to_end_struct_skip_test() {
- let input: DeriveInput = parse_quote! {
- struct MyMsg {
- #[msg_on_socket(skip)]
- a: u8,
- }
- };
-
- let expected = quote! {
- impl msg_socket::MsgOnSocket for MyMsg {
- fn msg_size(&self) -> usize {
- 0
- }
- fn descriptor_count(&self) -> usize {
- 0
- }
- unsafe fn read_from_buffer(
- buffer: &[u8],
- fds: &[RawDescriptor],
- ) -> msg_socket::MsgResult<(Self, usize)> {
- let mut __offset = 0usize;
- let mut __fd_offset = 0usize;
- Ok((Self { a: <u8>::default() }, __fd_offset))
- }
- fn write_to_buffer(
- &self,
- buffer: &mut [u8],
- fds: &mut [RawDescriptor],
- ) -> msg_socket::MsgResult<usize> {
- let mut __offset = 0usize;
- let mut __fd_offset = 0usize;
- Ok(__fd_offset)
- }
- }
-
- };
-
- assert_eq!(socket_msg_impl(input).to_string(), expected.to_string());
- }
-}
diff --git a/msg_socket/src/lib.rs b/msg_socket/src/lib.rs
deleted file mode 100644
index d6f90773b..000000000
--- a/msg_socket/src/lib.rs
+++ /dev/null
@@ -1,232 +0,0 @@
-// Copyright 2018 The Chromium OS Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style license that can be
-// found in the LICENSE file.
-
-mod msg_on_socket;
-mod serializable_descriptors;
-
-use std::io::{IoSlice, Result};
-use std::marker::PhantomData;
-
-use base::{
- handle_eintr, net::UnixSeqpacket, AsRawDescriptor, Error as SysError, RawDescriptor, ScmSocket,
- UnsyncMarker,
-};
-use cros_async::{Executor, IoSourceExt};
-
-pub use crate::msg_on_socket::*;
-pub use msg_on_socket_derive::*;
-
-/// Create a pair of socket. Request is send in one direction while response is in the other
-/// direction.
-pub fn pair<Request: MsgOnSocket, Response: MsgOnSocket>(
-) -> Result<(MsgSocket<Request, Response>, MsgSocket<Response, Request>)> {
- let (sock1, sock2) = UnixSeqpacket::pair()?;
- let requester = MsgSocket::new(sock1);
- let responder = MsgSocket::new(sock2);
- Ok((requester, responder))
-}
-
-/// Bidirection sock that support both send and recv.
-pub struct MsgSocket<I: MsgOnSocket, O: MsgOnSocket> {
- sock: UnixSeqpacket,
- _i: PhantomData<I>,
- _o: PhantomData<O>,
- _unsync_marker: UnsyncMarker,
-}
-
-impl<I: MsgOnSocket, O: MsgOnSocket> MsgSocket<I, O> {
- // Create a new MsgSocket.
- pub fn new(s: UnixSeqpacket) -> MsgSocket<I, O> {
- MsgSocket {
- sock: s,
- _i: PhantomData,
- _o: PhantomData,
- _unsync_marker: PhantomData,
- }
- }
-
- // Creates an async receiver that implements `futures::Stream`.
- pub fn async_receiver(&self, ex: &Executor) -> MsgResult<AsyncReceiver<I, O>> {
- AsyncReceiver::new(self, ex)
- }
-}
-
-/// One direction socket that only supports sending.
-pub struct Sender<M: MsgOnSocket> {
- sock: UnixSeqpacket,
- _m: PhantomData<M>,
-}
-
-impl<M: MsgOnSocket> Sender<M> {
- /// Create a new sender sock.
- pub fn new(s: UnixSeqpacket) -> Sender<M> {
- Sender {
- sock: s,
- _m: PhantomData,
- }
- }
-}
-
-/// One direction socket that only supports receiving.
-pub struct Receiver<M: MsgOnSocket> {
- sock: UnixSeqpacket,
- _m: PhantomData<M>,
-}
-
-impl<M: MsgOnSocket> Receiver<M> {
- /// Create a new receiver sock.
- pub fn new(s: UnixSeqpacket) -> Receiver<M> {
- Receiver {
- sock: s,
- _m: PhantomData,
- }
- }
-}
-
-impl<I: MsgOnSocket, O: MsgOnSocket> AsRef<UnixSeqpacket> for MsgSocket<I, O> {
- fn as_ref(&self) -> &UnixSeqpacket {
- &self.sock
- }
-}
-
-impl<I: MsgOnSocket, O: MsgOnSocket> AsRawDescriptor for MsgSocket<I, O> {
- fn as_raw_descriptor(&self) -> RawDescriptor {
- self.sock.as_raw_descriptor()
- }
-}
-
-impl<I: MsgOnSocket, O: MsgOnSocket> AsRawDescriptor for &MsgSocket<I, O> {
- fn as_raw_descriptor(&self) -> RawDescriptor {
- self.sock.as_raw_descriptor()
- }
-}
-
-impl<M: MsgOnSocket> AsRef<UnixSeqpacket> for Sender<M> {
- fn as_ref(&self) -> &UnixSeqpacket {
- &self.sock
- }
-}
-
-impl<M: MsgOnSocket> AsRawDescriptor for Sender<M> {
- fn as_raw_descriptor(&self) -> RawDescriptor {
- self.sock.as_raw_descriptor()
- }
-}
-
-impl<M: MsgOnSocket> AsRef<UnixSeqpacket> for Receiver<M> {
- fn as_ref(&self) -> &UnixSeqpacket {
- &self.sock
- }
-}
-
-impl<M: MsgOnSocket> AsRawDescriptor for Receiver<M> {
- fn as_raw_descriptor(&self) -> RawDescriptor {
- self.sock.as_raw_descriptor()
- }
-}
-
-/// Types that could send a message.
-pub trait MsgSender: AsRef<UnixSeqpacket> {
- type M: MsgOnSocket;
- fn send(&self, msg: &Self::M) -> MsgResult<()> {
- let msg_size = msg.msg_size();
- let descriptor_size = msg.descriptor_count();
- let mut msg_buffer: Vec<u8> = vec![0; msg_size];
- let mut descriptor_buffer: Vec<RawDescriptor> = vec![0; descriptor_size];
-
- let descriptor_size = msg.write_to_buffer(&mut msg_buffer, &mut descriptor_buffer)?;
- let sock: &UnixSeqpacket = self.as_ref();
- if descriptor_size == 0 {
- handle_eintr!(sock.send(&msg_buffer))
- .map_err(|e| MsgError::Send(SysError::new(e.raw_os_error().unwrap_or(0))))?;
- } else {
- let ioslice = IoSlice::new(&msg_buffer[..]);
- sock.send_with_fds(&[ioslice], &descriptor_buffer[0..descriptor_size])
- .map_err(MsgError::Send)?;
- }
- Ok(())
- }
-}
-
-/// Types that could receive a message.
-pub trait MsgReceiver: AsRef<UnixSeqpacket> {
- type M: MsgOnSocket;
- fn recv(&self) -> MsgResult<Self::M> {
- let sock: &UnixSeqpacket = self.as_ref();
-
- let (msg_buffer, descriptor_buffer) = {
- if Self::M::uses_descriptor() {
- sock.recv_as_vec_with_fds()
- .map_err(|e| MsgError::Recv(SysError::new(e.raw_os_error().unwrap_or(0))))?
- } else {
- (
- sock.recv_as_vec().map_err(|e| {
- MsgError::Recv(SysError::new(e.raw_os_error().unwrap_or(0)))
- })?,
- vec![],
- )
- }
- };
-
- if msg_buffer.is_empty() && Self::M::fixed_size() != Some(0) {
- return Err(MsgError::RecvZero);
- }
-
- if let Some(fixed_size) = Self::M::fixed_size() {
- if fixed_size != msg_buffer.len() {
- return Err(MsgError::BadRecvSize {
- expected: fixed_size,
- actual: msg_buffer.len(),
- });
- }
- }
-
- // Safe because fd buffer is read from socket.
- let (v, read_descriptor_size) =
- unsafe { Self::M::read_from_buffer(&msg_buffer, &descriptor_buffer)? };
- if descriptor_buffer.len() != read_descriptor_size {
- return Err(MsgError::NotExpectDescriptor);
- }
- Ok(v)
- }
-}
-
-impl<I: MsgOnSocket, O: MsgOnSocket> MsgSender for MsgSocket<I, O> {
- type M = I;
-}
-impl<I: MsgOnSocket, O: MsgOnSocket> MsgReceiver for MsgSocket<I, O> {
- type M = O;
-}
-
-impl<I: MsgOnSocket> MsgSender for Sender<I> {
- type M = I;
-}
-impl<O: MsgOnSocket> MsgReceiver for Receiver<O> {
- type M = O;
-}
-
-/// Asynchronous adaptor for `MsgSocket`.
-pub struct AsyncReceiver<'m, I: MsgOnSocket, O: MsgOnSocket> {
- // This weirdness is because we can't directly implement IntoAsync for &MsgSocket because there
- // is no AsRawFd impl for references.
- inner: &'m MsgSocket<I, O>,
- sock: Box<dyn IoSourceExt<&'m UnixSeqpacket> + 'm>,
-}
-
-impl<'m, I: MsgOnSocket, O: MsgOnSocket> AsyncReceiver<'m, I, O> {
- fn new(msg_socket: &'m MsgSocket<I, O>, ex: &Executor) -> MsgResult<Self> {
- let sock = ex
- .async_from(&msg_socket.sock)
- .map_err(MsgError::CreateAsync)?;
- Ok(AsyncReceiver {
- inner: msg_socket,
- sock,
- })
- }
-
- pub async fn next(&mut self) -> MsgResult<O> {
- self.sock.wait_readable().await.unwrap();
- self.inner.recv()
- }
-}
diff --git a/msg_socket/src/msg_on_socket.rs b/msg_socket/src/msg_on_socket.rs
deleted file mode 100644
index 5db521772..000000000
--- a/msg_socket/src/msg_on_socket.rs
+++ /dev/null
@@ -1,438 +0,0 @@
-// Copyright 2018 The Chromium OS Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style license that can be
-// found in the LICENSE file.
-
-mod slice;
-mod tuple;
-
-use std::fmt::{self, Display};
-use std::mem::{size_of, transmute_copy, MaybeUninit};
-use std::result;
-use std::sync::Arc;
-
-use base::{Error as SysError, RawDescriptor};
-use data_model::*;
-use slice::{slice_read_helper, slice_write_helper};
-
-#[derive(Debug)]
-/// An error during transaction or serialization/deserialization.
-pub enum MsgError {
- /// Error while creating an async socket.
- CreateAsync(cros_async::AsyncError),
- /// Error while sending a request or response.
- Send(SysError),
- /// Error while receiving a request or response.
- Recv(SysError),
- /// The type of a received request or response is unknown.
- InvalidType,
- /// There was not the expected amount of data when receiving a message. The inner
- /// value is how much data is expected and how much data was actually received.
- BadRecvSize { expected: usize, actual: usize },
- /// There was no data received when the socket `recv`-ed.
- RecvZero,
- /// There was no associated file descriptor received for a request that expected it.
- ExpectDescriptor,
- /// There was some associated file descriptor received but not used when deserialize.
- NotExpectDescriptor,
- /// Failed to set flags on the file descriptor.
- SettingDescriptorFlags(SysError),
- /// Trying to serialize/deserialize, but fd buffer size is too small. This typically happens
- /// when max_fd_count() returns a value that is too small.
- WrongDescriptorBufferSize,
- /// Trying to serialize/deserialize, but msg buffer size is too small. This typically happens
- /// when msg_size() returns a value that is too small.
- WrongMsgBufferSize,
-}
-
-pub type MsgResult<T> = result::Result<T, MsgError>;
-
-impl Display for MsgError {
- fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
- use self::MsgError::*;
-
- match self {
- CreateAsync(e) => write!(f, "failed to create an async socket: {}", e),
- Send(e) => write!(f, "failed to send request or response: {}", e),
- Recv(e) => write!(f, "failed to receive request or response: {}", e),
- InvalidType => write!(f, "invalid type"),
- BadRecvSize { expected, actual } => write!(
- f,
- "wrong amount of data received; expected {} bytes; got {} bytes",
- expected, actual
- ),
- RecvZero => write!(f, "received zero data"),
- ExpectDescriptor => write!(f, "missing associated file descriptor for request"),
- NotExpectDescriptor => write!(f, "unexpected file descriptor is unused"),
- SettingDescriptorFlags(e) => {
- write!(f, "failed setting flags on the message descriptor: {}", e)
- }
- WrongDescriptorBufferSize => write!(f, "descriptor buffer size too small"),
- WrongMsgBufferSize => write!(f, "msg buffer size too small"),
- }
- }
-}
-
-/// A msg that could be serialized to and deserialize from array in little endian.
-///
-/// For structs, we always have fixed size of bytes and fixed count of fds.
-/// For enums, the size needed might be different for each variant.
-///
-/// e.g.
-/// ```
-/// use base::RawDescriptor;
-/// enum Message {
-/// VariantA(u8),
-/// VariantB(u32, RawDescriptor),
-/// VariantC,
-/// }
-/// ```
-///
-/// For variant A, we need 1 byte to store its inner value.
-/// For variant B, we need 4 bytes and 1 RawDescriptor to store its inner value.
-/// For variant C, we need 0 bytes to store its inner value.
-/// When we serialize Message to (buffer, fd_buffer), we always use fixed number of bytes in
-/// the buffer. Unused buffer bytes will be padded with zero.
-/// However, for fd_buffer, we could not do the same thing. Otherwise, we are essentially sending
-/// fd 0 through the socket.
-/// Thus, read/write functions always the return correct count of fds in this variant. There will be
-/// no padding in fd_buffer.
-pub trait MsgOnSocket: Sized {
- // `true` if this structure can potentially serialize descriptors.
- fn uses_descriptor() -> bool {
- false
- }
-
- // Returns `Some(size)` if this structure always has a fixed size.
- fn fixed_size() -> Option<usize> {
- None
- }
-
- /// Size of message in bytes.
- fn msg_size(&self) -> usize {
- Self::fixed_size().unwrap()
- }
-
- /// Number of FDs in this message. This must be overridden if `uses_descriptor()` returns true.
- fn descriptor_count(&self) -> usize {
- assert!(!Self::uses_descriptor());
- 0
- }
- /// Returns (self, fd read count).
- /// This function is safe only when:
- /// 0. fds contains valid fds, received from socket, serialized by Self::write_to_buffer.
- /// 1. For enum, fds contains correct fd layout of the particular variant.
- /// 2. write_to_buffer is implemented correctly(put valid fds into the buffer, has no padding,
- /// return correct count).
- unsafe fn read_from_buffer(buffer: &[u8], fds: &[RawDescriptor]) -> MsgResult<(Self, usize)>;
-
- /// Serialize self to buffers.
- fn write_to_buffer(&self, buffer: &mut [u8], fds: &mut [RawDescriptor]) -> MsgResult<usize>;
-}
-
-impl MsgOnSocket for SysError {
- fn fixed_size() -> Option<usize> {
- Some(size_of::<u32>())
- }
- unsafe fn read_from_buffer(buffer: &[u8], fds: &[RawDescriptor]) -> MsgResult<(Self, usize)> {
- let (v, size) = u32::read_from_buffer(buffer, fds)?;
- Ok((SysError::new(v as i32), size))
- }
- fn write_to_buffer(&self, buffer: &mut [u8], fds: &mut [RawDescriptor]) -> MsgResult<usize> {
- let v = self.errno() as u32;
- v.write_to_buffer(buffer, fds)
- }
-}
-
-impl<T: MsgOnSocket> MsgOnSocket for Option<T> {
- fn uses_descriptor() -> bool {
- T::uses_descriptor()
- }
-
- fn msg_size(&self) -> usize {
- match self {
- Some(v) => v.msg_size() + 1,
- None => 1,
- }
- }
-
- fn descriptor_count(&self) -> usize {
- match self {
- Some(v) => v.descriptor_count(),
- None => 0,
- }
- }
-
- unsafe fn read_from_buffer(buffer: &[u8], fds: &[RawDescriptor]) -> MsgResult<(Self, usize)> {
- match buffer[0] {
- 0 => Ok((None, 0)),
- 1 => {
- let (inner, len) = T::read_from_buffer(&buffer[1..], fds)?;
- Ok((Some(inner), len))
- }
- _ => Err(MsgError::InvalidType),
- }
- }
-
- fn write_to_buffer(&self, buffer: &mut [u8], fds: &mut [RawDescriptor]) -> MsgResult<usize> {
- match self {
- None => {
- buffer[0] = 0;
- Ok(0)
- }
- Some(inner) => {
- buffer[0] = 1;
- inner.write_to_buffer(&mut buffer[1..], fds)
- }
- }
- }
-}
-
-impl<T: MsgOnSocket> MsgOnSocket for Arc<T> {
- fn uses_descriptor() -> bool {
- T::uses_descriptor()
- }
-
- fn msg_size(&self) -> usize {
- (**self).msg_size()
- }
-
- fn descriptor_count(&self) -> usize {
- (**self).descriptor_count()
- }
-
- unsafe fn read_from_buffer(buffer: &[u8], fds: &[RawDescriptor]) -> MsgResult<(Self, usize)> {
- T::read_from_buffer(buffer, fds).map(|(v, count)| (Arc::new(v), count))
- }
-
- fn write_to_buffer(&self, buffer: &mut [u8], fds: &mut [RawDescriptor]) -> MsgResult<usize> {
- (**self).write_to_buffer(buffer, fds)
- }
-}
-
-impl MsgOnSocket for () {
- fn fixed_size() -> Option<usize> {
- Some(0)
- }
-
- unsafe fn read_from_buffer(_buffer: &[u8], _fds: &[RawDescriptor]) -> MsgResult<(Self, usize)> {
- Ok(((), 0))
- }
-
- fn write_to_buffer(&self, _buffer: &mut [u8], _fds: &mut [RawDescriptor]) -> MsgResult<usize> {
- Ok(0)
- }
-}
-
-// usize could be different sizes on different targets. We always use u64.
-impl MsgOnSocket for usize {
- fn msg_size(&self) -> usize {
- size_of::<u64>()
- }
- unsafe fn read_from_buffer(buffer: &[u8], _fds: &[RawDescriptor]) -> MsgResult<(Self, usize)> {
- if buffer.len() < size_of::<u64>() {
- return Err(MsgError::WrongMsgBufferSize);
- }
- let t = u64::from_le_bytes(slice_to_array(buffer));
- Ok((t as usize, 0))
- }
-
- fn write_to_buffer(&self, buffer: &mut [u8], _fds: &mut [RawDescriptor]) -> MsgResult<usize> {
- if buffer.len() < size_of::<u64>() {
- return Err(MsgError::WrongMsgBufferSize);
- }
- let t: Le64 = (*self as u64).into();
- buffer[0..self.msg_size()].copy_from_slice(t.as_slice());
- Ok(0)
- }
-}
-
-// Encode bool as a u8 of value 0 or 1
-impl MsgOnSocket for bool {
- fn msg_size(&self) -> usize {
- size_of::<u8>()
- }
- unsafe fn read_from_buffer(buffer: &[u8], _fds: &[RawDescriptor]) -> MsgResult<(Self, usize)> {
- if buffer.len() < size_of::<u8>() {
- return Err(MsgError::WrongMsgBufferSize);
- }
- let t: u8 = buffer[0];
- match t {
- 0 => Ok((false, 0)),
- 1 => Ok((true, 0)),
- _ => Err(MsgError::InvalidType),
- }
- }
- fn write_to_buffer(&self, buffer: &mut [u8], _fds: &mut [RawDescriptor]) -> MsgResult<usize> {
- if buffer.len() < size_of::<u8>() {
- return Err(MsgError::WrongMsgBufferSize);
- }
- buffer[0] = *self as u8;
- Ok(0)
- }
-}
-
-macro_rules! le_impl {
- ($type:ident, $native_type:ident) => {
- impl MsgOnSocket for $type {
- fn fixed_size() -> Option<usize> {
- Some(size_of::<$native_type>())
- }
-
- unsafe fn read_from_buffer(
- buffer: &[u8],
- _fds: &[RawDescriptor],
- ) -> MsgResult<(Self, usize)> {
- if buffer.len() < size_of::<$native_type>() {
- return Err(MsgError::WrongMsgBufferSize);
- }
- let t = $native_type::from_le_bytes(slice_to_array(buffer));
- Ok((t.into(), 0))
- }
-
- fn write_to_buffer(
- &self,
- buffer: &mut [u8],
- _fds: &mut [RawDescriptor],
- ) -> MsgResult<usize> {
- if buffer.len() < size_of::<$native_type>() {
- return Err(MsgError::WrongMsgBufferSize);
- }
- let t: $native_type = self.clone().into();
- buffer[0..self.msg_size()].copy_from_slice(&t.to_le_bytes());
- Ok(0)
- }
- }
- };
-}
-
-le_impl!(u8, u8);
-le_impl!(u16, u16);
-le_impl!(u32, u32);
-le_impl!(u64, u64);
-
-le_impl!(Le16, u16);
-le_impl!(Le32, u32);
-le_impl!(Le64, u64);
-
-fn simple_read<T: MsgOnSocket>(buffer: &[u8], offset: &mut usize) -> MsgResult<T> {
- assert!(!T::uses_descriptor());
- // Safety for T::read_from_buffer depends on the given FDs being valid, but we pass no FDs.
- let (v, _) = unsafe { T::read_from_buffer(&buffer[*offset..], &[])? };
- *offset += v.msg_size();
- Ok(v)
-}
-
-fn simple_write<T: MsgOnSocket>(v: T, buffer: &mut [u8], offset: &mut usize) -> MsgResult<()> {
- assert!(!T::uses_descriptor());
- v.write_to_buffer(&mut buffer[*offset..], &mut [])?;
- *offset += v.msg_size();
- Ok(())
-}
-
-// Converts a slice into an array of fixed size inferred from by the return value. Panics if the
-// slice is too small, but will tolerate slices that are too large.
-fn slice_to_array<T, O>(s: &[T]) -> O
-where
- T: Copy,
- O: Default + AsMut<[T]>,
-{
- let mut o = O::default();
- let o_slice = o.as_mut();
- let len = o_slice.len();
- o_slice.copy_from_slice(&s[..len]);
- o
-}
-
-macro_rules! array_impls {
- ($N:expr, $t: ident $($ts:ident)*)
- => {
- impl<T: MsgOnSocket + Clone> MsgOnSocket for [T; $N] {
- fn uses_descriptor() -> bool {
- T::uses_descriptor()
- }
-
- fn fixed_size() -> Option<usize> {
- Some(T::fixed_size()? * $N)
- }
-
- fn msg_size(&self) -> usize {
- match T::fixed_size() {
- Some(s) => s * $N,
- None => self.iter().map(|i| i.msg_size()).sum::<usize>() + size_of::<u64>() * $N
- }
- }
-
- fn descriptor_count(&self) -> usize {
- if T::uses_descriptor() {
- self.iter().map(|i| i.descriptor_count()).sum()
- } else {
- 0
- }
- }
-
- unsafe fn read_from_buffer(buffer: &[u8], fds: &[RawDescriptor])
- -> MsgResult<(Self, usize)> {
- // Taken from the canonical example of initializing an array, the `assume_init` can
- // be assumed safe because the array elements (`MaybeUninit<T>` in this case)
- // themselves don't require initializing.
- let mut msgs: [MaybeUninit<T>; $N] = MaybeUninit::uninit().assume_init();
-
- let fd_count = slice_read_helper(buffer, fds, &mut msgs)?;
-
- // Also taken from the canonical example, we initialized every member of the array
- // in the first loop of this function, so it is safe to `transmute_copy` the array
- // of `MaybeUninit` data to plain data. Although `transmute`, which checks the
- // types' sizes, would have been preferred in this code, the compiler complains with
- // "cannot transmute between types of different sizes, or dependently-sized types."
- // Because this function operates on generic data, the type is "dependently-sized"
- // and so the compiler will not check that the size of the input and output match.
- // See this issue for details: https://github.com/rust-lang/rust/issues/61956
- Ok((transmute_copy::<_, [T; $N]>(&msgs), fd_count))
- }
-
- fn write_to_buffer(
- &self,
- buffer: &mut [u8],
- fds: &mut [RawDescriptor],
- ) -> MsgResult<usize> {
- slice_write_helper(self, buffer, fds)
- }
- }
- #[cfg(test)]
- mod $t {
- use super::MsgOnSocket;
-
- #[test]
- fn read_write_option_array() {
- type ArrayType = [Option<u32>; $N];
- let array = [Some($N); $N];
- let mut buffer = vec![0; array.msg_size()];
- array.write_to_buffer(&mut buffer, &mut []).unwrap();
- let read_array = unsafe { ArrayType::read_from_buffer(&buffer, &[]) }.unwrap().0;
-
- assert!(array.iter().eq(read_array.iter()));
- }
-
- #[test]
- fn read_write_fixed() {
- type ArrayType = [u32; $N];
- let mut buffer = vec![0; <ArrayType>::fixed_size().unwrap()];
- let array = [$N as u32; $N];
- array.write_to_buffer(&mut buffer, &mut []).unwrap();
- let read_array = unsafe { ArrayType::read_from_buffer(&buffer, &[]) }.unwrap().0;
-
- assert!(array.iter().eq(read_array.iter()));
- }
- }
- array_impls!(($N - 1), $($ts)*);
- };
- {$N:expr, } => {};
-}
-
-array_impls! {
- 64, tmp1 tmp2 tmp3 tmp4 tmp5 tmp6 tmp7 tmp8 tmp9 tmp10 tmp11 tmp12 tmp13 tmp14 tmp15 tmp16
- tmp17 tmp18 tmp19 tmp20 tmp21 tmp22 tmp23 tmp24 tmp25 tmp26 tmp27 tmp28 tmp29 tmp30 tmp31
- tmp32 tmp33 tmp34 tmp35 tmp36 tmp37 tmp38 tmp39 tmp40 tmp41 tmp42 tmp43 tmp44 tmp45 tmp46
- tmp47 tmp48 tmp49 tmp50 tmp51 tmp52 tmp53 tmp54 tmp55 tmp56 tmp57 tmp58 tmp59 tmp60 tmp61
- tmp62 tmp63 tmp64
-}
diff --git a/msg_socket/src/msg_on_socket/slice.rs b/msg_socket/src/msg_on_socket/slice.rs
deleted file mode 100644
index 05df5e142..000000000
--- a/msg_socket/src/msg_on_socket/slice.rs
+++ /dev/null
@@ -1,184 +0,0 @@
-// Copyright 2020 The Chromium OS Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style license that can be
-// found in the LICENSE file.
-
-use base::RawDescriptor;
-use std::mem::{size_of, ManuallyDrop, MaybeUninit};
-use std::ptr::drop_in_place;
-
-use crate::{MsgOnSocket, MsgResult};
-
-use super::{simple_read, simple_write};
-
-/// Helper used by the types that read a slice of homegenously typed data.
-///
-/// # Safety
-/// This function has the same safety requirements as `T::read_from_buffer`, with the additional
-/// requirements that the `msgs` are only used on success of this function
-pub unsafe fn slice_read_helper<T: MsgOnSocket>(
- buffer: &[u8],
- fds: &[RawDescriptor],
- msgs: &mut [MaybeUninit<T>],
-) -> MsgResult<usize> {
- let mut offset = 0usize;
- let mut fd_offset = 0usize;
-
- // In case of an error, we need to keep track of how many elements got initialized.
- // In order to perform the necessary drops, the below loop is executed in a closure
- // to capture errors without returning.
- let mut last_index = 0;
- let res = (|| {
- for msg in &mut msgs[..] {
- let element_size = match T::fixed_size() {
- Some(s) => s,
- None => simple_read::<u64>(buffer, &mut offset)? as usize,
- };
- // Assuming the unsafe caller gave valid FDs, this call should be safe.
- let (m, fd_size) = T::read_from_buffer(&buffer[offset..], &fds[fd_offset..])?;
- *msg = MaybeUninit::new(m);
- offset += element_size;
- fd_offset += fd_size;
- last_index += 1;
- }
- Ok(())
- })();
-
- // Because `MaybeUninit` will not automatically call drops, we have to drop the
- // partially initialized array manually in the case of an error.
- if let Err(e) = res {
- for msg in &mut msgs[..last_index] {
- // The call to `as_mut_ptr()` turns the `MaybeUninit` element of the array
- // into a pointer, which can be used with `drop_in_place` to call the
- // destructor without moving the element, which is impossible. This is safe
- // because `last_index` prevents this loop from traversing into the
- // uninitialized parts of the array.
- drop_in_place(msg.as_mut_ptr());
- }
- return Err(e);
- }
-
- Ok(fd_offset)
-}
-
-/// Helper used by the types that write a slice of homegenously typed data.
-pub fn slice_write_helper<T: MsgOnSocket>(
- msgs: &[T],
- buffer: &mut [u8],
- fds: &mut [RawDescriptor],
-) -> MsgResult<usize> {
- let mut offset = 0usize;
- let mut fd_offset = 0usize;
- for msg in msgs {
- let element_size = match T::fixed_size() {
- Some(s) => s,
- None => {
- let element_size = msg.msg_size();
- simple_write(element_size as u64, buffer, &mut offset)?;
- element_size as usize
- }
- };
- let fd_size = msg.write_to_buffer(&mut buffer[offset..], &mut fds[fd_offset..])?;
- offset += element_size;
- fd_offset += fd_size;
- }
-
- Ok(fd_offset)
-}
-
-impl<T: MsgOnSocket> MsgOnSocket for Vec<T> {
- fn uses_descriptor() -> bool {
- T::uses_descriptor()
- }
-
- fn fixed_size() -> Option<usize> {
- None
- }
-
- fn msg_size(&self) -> usize {
- let vec_size = match T::fixed_size() {
- Some(s) => s * self.len(),
- None => self.iter().map(|i| i.msg_size() + size_of::<u64>()).sum(),
- };
- size_of::<u64>() + vec_size
- }
-
- fn descriptor_count(&self) -> usize {
- if T::uses_descriptor() {
- self.iter().map(|i| i.descriptor_count()).sum()
- } else {
- 0
- }
- }
-
- unsafe fn read_from_buffer(buffer: &[u8], fds: &[RawDescriptor]) -> MsgResult<(Self, usize)> {
- let mut offset = 0;
- let len = simple_read::<u64>(buffer, &mut offset)? as usize;
- let mut msgs: Vec<MaybeUninit<T>> = Vec::with_capacity(len);
- msgs.set_len(len);
- let fd_count = slice_read_helper(&buffer[offset..], fds, &mut msgs)?;
- let mut msgs = ManuallyDrop::new(msgs);
- Ok((
- Vec::from_raw_parts(msgs.as_mut_ptr() as *mut T, msgs.len(), msgs.capacity()),
- fd_count,
- ))
- }
-
- fn write_to_buffer(&self, buffer: &mut [u8], fds: &mut [RawDescriptor]) -> MsgResult<usize> {
- let mut offset = 0;
- simple_write(self.len() as u64, buffer, &mut offset)?;
- slice_write_helper(self, &mut buffer[offset..], fds)
- }
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
-
- #[test]
- fn read_write_1_fixed() {
- let vec = vec![1u32];
- let mut buffer = vec![0; vec.msg_size()];
- vec.write_to_buffer(&mut buffer, &mut []).unwrap();
- let read_vec = unsafe { <Vec<u32>>::read_from_buffer(&buffer, &[]) }
- .unwrap()
- .0;
-
- assert_eq!(vec, read_vec);
- }
-
- #[test]
- fn read_write_8_fixed() {
- let vec = vec![1u16, 1, 3, 5, 8, 13, 21, 34];
- let mut buffer = vec![0; vec.msg_size()];
- vec.write_to_buffer(&mut buffer, &mut []).unwrap();
- let read_vec = unsafe { <Vec<u16>>::read_from_buffer(&buffer, &[]) }
- .unwrap()
- .0;
- assert_eq!(vec, read_vec);
- }
-
- #[test]
- fn read_write_1() {
- let vec = vec![Some(1u64)];
- let mut buffer = vec![0; vec.msg_size()];
- println!("{:?}", vec.msg_size());
- vec.write_to_buffer(&mut buffer, &mut []).unwrap();
- let read_vec = unsafe { <Vec<_>>::read_from_buffer(&buffer, &[]) }
- .unwrap()
- .0;
-
- assert_eq!(vec, read_vec);
- }
-
- #[test]
- fn read_write_4() {
- let vec = vec![Some(12u16), Some(0), None, None];
- let mut buffer = vec![0; vec.msg_size()];
- vec.write_to_buffer(&mut buffer, &mut []).unwrap();
- let read_vec = unsafe { <Vec<_>>::read_from_buffer(&buffer, &[]) }
- .unwrap()
- .0;
-
- assert_eq!(vec, read_vec);
- }
-}
diff --git a/msg_socket/src/msg_on_socket/tuple.rs b/msg_socket/src/msg_on_socket/tuple.rs
deleted file mode 100644
index 90784bfef..000000000
--- a/msg_socket/src/msg_on_socket/tuple.rs
+++ /dev/null
@@ -1,205 +0,0 @@
-// Copyright 2020 The Chromium OS Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style license that can be
-// found in the LICENSE file.
-
-use base::RawDescriptor;
-use std::mem::size_of;
-
-use crate::{MsgOnSocket, MsgResult};
-
-use super::{simple_read, simple_write};
-
-// Returns the size of one part of a tuple.
-fn tuple_size_helper<T: MsgOnSocket>(v: &T) -> usize {
- T::fixed_size().unwrap_or_else(|| v.msg_size() + size_of::<u64>())
-}
-
-unsafe fn tuple_read_helper<T: MsgOnSocket>(
- buffer: &[u8],
- fds: &[RawDescriptor],
- buffer_index: &mut usize,
- fd_index: &mut usize,
-) -> MsgResult<T> {
- let end = match T::fixed_size() {
- Some(_) => buffer.len(),
- None => {
- let len = simple_read::<u64>(buffer, buffer_index)? as usize;
- *buffer_index + len
- }
- };
- let (v, fd_read) = T::read_from_buffer(&buffer[*buffer_index..end], &fds[*fd_index..])?;
- *buffer_index += v.msg_size();
- *fd_index += fd_read;
- Ok(v)
-}
-
-fn tuple_write_helper<T: MsgOnSocket>(
- v: &T,
- buffer: &mut [u8],
- fds: &mut [RawDescriptor],
- buffer_index: &mut usize,
- fd_index: &mut usize,
-) -> MsgResult<()> {
- let end = match T::fixed_size() {
- Some(_) => buffer.len(),
- None => {
- let len = v.msg_size();
- simple_write(len as u64, buffer, buffer_index)?;
- *buffer_index + len
- }
- };
- let fd_written = v.write_to_buffer(&mut buffer[*buffer_index..end], &mut fds[*fd_index..])?;
- *buffer_index += v.msg_size();
- *fd_index += fd_written;
- Ok(())
-}
-
-macro_rules! tuple_impls {
- () => {};
- ($t: ident) => {
- #[allow(unused_variables, non_snake_case)]
- impl<$t: MsgOnSocket> MsgOnSocket for ($t,) {
- fn uses_descriptor() -> bool {
- $t::uses_descriptor()
- }
-
- fn descriptor_count(&self) -> usize {
- self.0.descriptor_count()
- }
-
- fn fixed_size() -> Option<usize> {
- $t::fixed_size()
- }
-
- fn msg_size(&self) -> usize {
- self.0.msg_size()
- }
-
- unsafe fn read_from_buffer(buffer: &[u8], fds: &[RawDescriptor]) -> MsgResult<(Self, usize)> {
- let (t, s) = $t::read_from_buffer(buffer, fds)?;
- Ok(((t,), s))
- }
-
- fn write_to_buffer(
- &self,
- buffer: &mut [u8],
- fds: &mut [RawDescriptor],
- ) -> MsgResult<usize> {
- self.0.write_to_buffer(buffer, fds)
- }
- }
- };
- ($t: ident, $($ts:ident),*) => {
- #[allow(unused_variables, non_snake_case)]
- impl<$t: MsgOnSocket $(, $ts: MsgOnSocket)*> MsgOnSocket for ($t$(, $ts)*) {
- fn uses_descriptor() -> bool {
- $t::uses_descriptor() $(|| $ts::uses_descriptor())*
- }
-
- fn descriptor_count(&self) -> usize {
- if Self::uses_descriptor() {
- return 0;
- }
- let ($t $(,$ts)*) = self;
- $t.descriptor_count() $(+ $ts.descriptor_count())*
- }
-
- fn fixed_size() -> Option<usize> {
- // Returns None if any element is not fixed size.
- Some($t::fixed_size()? $(+ $ts::fixed_size()?)*)
- }
-
- fn msg_size(&self) -> usize {
- if let Some(size) = Self::fixed_size() {
- return size
- }
-
- let ($t $(,$ts)*) = self;
- tuple_size_helper($t) $(+ tuple_size_helper($ts))*
- }
-
- unsafe fn read_from_buffer(buffer: &[u8], fds: &[RawDescriptor]) -> MsgResult<(Self, usize)> {
- let mut buffer_index = 0;
- let mut fd_index = 0;
- Ok((
- (
- tuple_read_helper(buffer, fds, &mut buffer_index, &mut fd_index)?,
- $({
- // Dummy let used to trigger the correct number of iterations.
- let $ts = ();
- tuple_read_helper(buffer, fds, &mut buffer_index, &mut fd_index)?
- },)*
- ),
- fd_index
- ))
- }
-
- fn write_to_buffer(
- &self,
- buffer: &mut [u8],
- fds: &mut [RawDescriptor],
- ) -> MsgResult<usize> {
- let mut buffer_index = 0;
- let mut fd_index = 0;
- let ($t $(,$ts)*) = self;
- tuple_write_helper($t, buffer, fds, &mut buffer_index, &mut fd_index)?;
- $(
- tuple_write_helper($ts, buffer, fds, &mut buffer_index, &mut fd_index)?;
- )*
- Ok(fd_index)
- }
- }
- tuple_impls!{ $($ts),* }
- }
-}
-
-// Imlpement tuple for up to 8 elements.
-tuple_impls! { A, B, C, D, E, F, G, H }
-
-#[cfg(test)]
-mod tests {
- use super::*;
-
- #[test]
- fn read_write_1_fixed() {
- let tuple = (1,);
- let mut buffer = vec![0; tuple.msg_size()];
- tuple.write_to_buffer(&mut buffer, &mut []).unwrap();
- let read_tuple = unsafe { <(u32,)>::read_from_buffer(&buffer, &[]) }
- .unwrap()
- .0;
-
- assert_eq!(tuple, read_tuple);
- }
-
- #[test]
- fn read_write_8_fixed() {
- let tuple = (1u32, 2u8, 3u16, 4u64, 5u32, 6u16, 7u8, 8u8);
- let mut buffer = vec![0; tuple.msg_size()];
- tuple.write_to_buffer(&mut buffer, &mut []).unwrap();
- let read_tuple = unsafe { <_>::read_from_buffer(&buffer, &[]) }.unwrap().0;
-
- assert_eq!(tuple, read_tuple);
- }
-
- #[test]
- fn read_write_1() {
- let tuple = (Some(1u64),);
- let mut buffer = vec![0; tuple.msg_size()];
- tuple.write_to_buffer(&mut buffer, &mut []).unwrap();
- let read_tuple = unsafe { <_>::read_from_buffer(&buffer, &[]) }.unwrap().0;
-
- assert_eq!(tuple, read_tuple);
- }
-
- #[test]
- fn read_write_4() {
- let tuple = (Some(12u16), Some(false), None::<u8>, None::<u64>);
- let mut buffer = vec![0; tuple.msg_size()];
- println!("{:?}", tuple.msg_size());
- tuple.write_to_buffer(&mut buffer, &mut []).unwrap();
- let read_tuple = unsafe { <_>::read_from_buffer(&buffer, &[]) }.unwrap().0;
-
- assert_eq!(tuple, read_tuple);
- }
-}
diff --git a/msg_socket/src/serializable_descriptors.rs b/msg_socket/src/serializable_descriptors.rs
deleted file mode 100644
index 21ce84b7e..000000000
--- a/msg_socket/src/serializable_descriptors.rs
+++ /dev/null
@@ -1,85 +0,0 @@
-// Copyright 2020 The Chromium OS Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style license that can be
-// found in the LICENSE file.
-
-use crate::msg_on_socket::{MsgError, MsgOnSocket, MsgResult};
-use base::{AsRawDescriptor, Event, FromRawDescriptor, RawDescriptor};
-use std::fs::File;
-use std::net::{TcpListener, TcpStream, UdpSocket};
-use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
-use std::os::unix::net::{UnixDatagram, UnixListener, UnixStream};
-
-macro_rules! rawdescriptor_impl {
- ($type:ident) => {
- impl MsgOnSocket for $type {
- fn uses_descriptor() -> bool {
- true
- }
- fn msg_size(&self) -> usize {
- 0
- }
- fn descriptor_count(&self) -> usize {
- 1
- }
- unsafe fn read_from_buffer(
- _buffer: &[u8],
- descriptors: &[RawDescriptor],
- ) -> MsgResult<(Self, usize)> {
- if descriptors.len() < 1 {
- return Err(MsgError::ExpectDescriptor);
- }
- Ok(($type::from_raw_descriptor(descriptors[0]), 1))
- }
- fn write_to_buffer(
- &self,
- _buffer: &mut [u8],
- descriptors: &mut [RawDescriptor],
- ) -> MsgResult<usize> {
- if descriptors.is_empty() {
- return Err(MsgError::WrongDescriptorBufferSize);
- }
- descriptors[0] = self.as_raw_descriptor();
- Ok(1)
- }
- }
- };
-}
-
-rawdescriptor_impl!(Event);
-rawdescriptor_impl!(File);
-
-macro_rules! rawfd_impl {
- ($type:ident) => {
- impl MsgOnSocket for $type {
- fn uses_descriptor() -> bool {
- true
- }
- fn msg_size(&self) -> usize {
- 0
- }
- fn descriptor_count(&self) -> usize {
- 1
- }
- unsafe fn read_from_buffer(_buffer: &[u8], fds: &[RawFd]) -> MsgResult<(Self, usize)> {
- if fds.len() < 1 {
- return Err(MsgError::ExpectDescriptor);
- }
- Ok(($type::from_raw_fd(fds[0]), 1))
- }
- fn write_to_buffer(&self, _buffer: &mut [u8], fds: &mut [RawFd]) -> MsgResult<usize> {
- if fds.is_empty() {
- return Err(MsgError::WrongDescriptorBufferSize);
- }
- fds[0] = self.as_raw_fd();
- Ok(1)
- }
- }
- };
-}
-
-rawfd_impl!(UnixStream);
-rawfd_impl!(TcpStream);
-rawfd_impl!(TcpListener);
-rawfd_impl!(UdpSocket);
-rawfd_impl!(UnixListener);
-rawfd_impl!(UnixDatagram);
diff --git a/msg_socket/tests/enum.rs b/msg_socket/tests/enum.rs
deleted file mode 100644
index 7f5998bdc..000000000
--- a/msg_socket/tests/enum.rs
+++ /dev/null
@@ -1,67 +0,0 @@
-// Copyright 2019 The Chromium OS Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style license that can be
-// found in the LICENSE file.
-
-use base::{Event, RawDescriptor};
-
-use msg_socket::*;
-
-#[derive(MsgOnSocket)]
-struct DummyRequest {}
-
-#[derive(MsgOnSocket)]
-enum Response {
- A(u8),
- B,
- C(u32, Event),
- D([u8; 4]),
- E { f0: u8, f1: u32 },
-}
-
-#[test]
-fn sock_send_recv_enum() {
- let (req, res) = pair::<DummyRequest, Response>().unwrap();
- let e0 = Event::new().unwrap();
- let e1 = e0.try_clone().unwrap();
- res.send(&Response::C(0xf0f0, e0)).unwrap();
- let r = req.recv().unwrap();
- match r {
- Response::C(v, efd) => {
- assert_eq!(v, 0xf0f0);
- efd.write(0x0f0f).unwrap();
- }
- _ => panic!("wrong type"),
- };
- assert_eq!(e1.read().unwrap(), 0x0f0f);
-
- res.send(&Response::B).unwrap();
- match req.recv().unwrap() {
- Response::B => {}
- _ => panic!("Wrong enum type"),
- };
-
- res.send(&Response::A(0x3)).unwrap();
- match req.recv().unwrap() {
- Response::A(v) => assert_eq!(v, 0x3),
- _ => panic!("Wrong enum type"),
- };
-
- res.send(&Response::D([0, 1, 2, 3])).unwrap();
- match req.recv().unwrap() {
- Response::D(v) => assert_eq!(v, [0, 1, 2, 3]),
- _ => panic!("Wrong enum type"),
- };
-
- res.send(&Response::E {
- f0: 0x12,
- f1: 0x0f0f,
- })
- .unwrap();
- match req.recv().unwrap() {
- Response::E { f0, f1 } => {
- assert_eq!(f0, 0x12);
- assert_eq!(f1, 0x0f0f);
- }
- _ => panic!("Wrong enum type"),
- };
-}
diff --git a/msg_socket/tests/struct.rs b/msg_socket/tests/struct.rs
deleted file mode 100644
index 8e3c93f4a..000000000
--- a/msg_socket/tests/struct.rs
+++ /dev/null
@@ -1,38 +0,0 @@
-// Copyright 2019 The Chromium OS Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style license that can be
-// found in the LICENSE file.
-
-use base::{Event, RawDescriptor};
-
-use msg_socket::*;
-
-#[derive(MsgOnSocket)]
-struct Request {
- field0: u8,
- field1: Event,
- field2: u32,
- field3: bool,
-}
-
-#[derive(MsgOnSocket)]
-struct DummyResponse {}
-
-#[test]
-fn sock_send_recv_struct() {
- let (req, res) = pair::<Request, DummyResponse>().unwrap();
- let e0 = Event::new().unwrap();
- let e1 = e0.try_clone().unwrap();
- req.send(&Request {
- field0: 2,
- field1: e0,
- field2: 0xf0f0,
- field3: true,
- })
- .unwrap();
- let r = res.recv().unwrap();
- assert_eq!(r.field0, 2);
- assert_eq!(r.field2, 0xf0f0);
- assert_eq!(r.field3, true);
- r.field1.write(0x0f0f).unwrap();
- assert_eq!(e1.read().unwrap(), 0x0f0f);
-}
diff --git a/msg_socket/tests/tuple.rs b/msg_socket/tests/tuple.rs
deleted file mode 100644
index ae008135f..000000000
--- a/msg_socket/tests/tuple.rs
+++ /dev/null
@@ -1,22 +0,0 @@
-// Copyright 2019 The Chromium OS Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style license that can be
-// found in the LICENSE file.
-
-use base::{Event, RawDescriptor};
-use msg_socket::*;
-
-#[derive(MsgOnSocket)]
-struct Message(u8, u16, Event);
-
-#[test]
-fn sock_send_recv_tuple() {
- let (req, res) = pair::<Message, Message>().unwrap();
- let e0 = Event::new().unwrap();
- let e1 = e0.try_clone().unwrap();
- req.send(&Message(1, 0x12, e0)).unwrap();
- let r = res.recv().unwrap();
- assert_eq!(r.0, 1);
- assert_eq!(r.1, 0x12);
- r.2.write(0x0f0f).unwrap();
- assert_eq!(e1.read().unwrap(), 0x0f0f);
-}
diff --git a/msg_socket/tests/unit.rs b/msg_socket/tests/unit.rs
deleted file mode 100644
index 9855752f8..000000000
--- a/msg_socket/tests/unit.rs
+++ /dev/null
@@ -1,12 +0,0 @@
-// Copyright 2018 The Chromium OS Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style license that can be
-// found in the LICENSE file.
-
-use msg_socket::*;
-
-#[test]
-fn sock_send_recv_unit() {
- let (req, res) = pair::<(), ()>().unwrap();
- req.send(&()).unwrap();
- let _ = res.recv().unwrap();
-}
diff --git a/resources/Cargo.toml b/resources/Cargo.toml
index 041da9f3c..ca1d4b882 100644
--- a/resources/Cargo.toml
+++ b/resources/Cargo.toml
@@ -6,5 +6,5 @@ edition = "2018"
[dependencies]
libc = "*"
-msg_socket = { path = "../msg_socket" }
base = { path = "../base" }
+serde = { version = "1", features = ["derive"] }
diff --git a/resources/src/lib.rs b/resources/src/lib.rs
index 0a2518199..0983ffaac 100644
--- a/resources/src/lib.rs
+++ b/resources/src/lib.rs
@@ -4,14 +4,10 @@
//! Manages system resources that can be allocated to VMs and their devices.
-extern crate base;
-extern crate libc;
-extern crate msg_socket;
-
-use base::RawDescriptor;
-use msg_socket::MsgOnSocket;
use std::fmt::Display;
+use serde::{Deserialize, Serialize};
+
pub use crate::address_allocator::AddressAllocator;
pub use crate::system_allocator::{MmioType, SystemAllocator};
@@ -19,7 +15,7 @@ mod address_allocator;
mod system_allocator;
/// Used to tag SystemAllocator allocations.
-#[derive(Debug, Eq, PartialEq, Hash, MsgOnSocket, Copy, Clone)]
+#[derive(Debug, Eq, PartialEq, Hash, Copy, Clone, Serialize, Deserialize)]
pub enum Alloc {
/// An anonymous resource allocation.
/// Should only be instantiated through `SystemAllocator::get_anon_alloc()`.
diff --git a/run_tests b/run_tests
index 066b35ae5..95c92b8d0 100755
--- a/run_tests
+++ b/run_tests
@@ -11,10 +11,9 @@ from typing import List, Dict
from ci.test_runner import Requirements, main
# A list of all crates and their requirements
+# See ci/test_runner.py for documentation on the requirements
CRATE_REQUIREMENTS: Dict[str, List[Requirements]] = {
"aarch64": [Requirements.AARCH64],
- "crosvm": [Requirements.DISABLED],
- "aarch64": [Requirements.AARCH64],
"acpi_tables": [],
"arch": [],
"assertions": [],
@@ -22,6 +21,7 @@ CRATE_REQUIREMENTS: Dict[str, List[Requirements]] = {
"bit_field": [],
"bit_field_derive": [],
"cros_async": [Requirements.DISABLED],
+ "crosvm": [Requirements.DO_NOT_RUN],
"crosvm_plugin": [Requirements.X86_64],
"data_model": [],
"devices": [
@@ -29,20 +29,24 @@ CRATE_REQUIREMENTS: Dict[str, List[Requirements]] = {
Requirements.PRIVILEGED,
Requirements.X86_64,
],
- "disk": [Requirements.DISABLED],
+ "disk": [Requirements.PRIVILEGED],
"enumn": [],
"fuse": [],
"fuzz": [Requirements.DISABLED],
"gpu_display": [],
"hypervisor": [Requirements.PRIVILEGED, Requirements.X86_64],
- "io_uring": [Requirements.DISABLED],
+ "integration_tests": [Requirements.PRIVILEGED, Requirements.X86_64],
+ "io_uring": [
+ Requirements.SEPARATE_WORKSPACE,
+ Requirements.PRIVILEGED,
+ Requirements.SINGLE_THREADED,
+ ],
"kernel_cmdline": [],
"kernel_loader": [Requirements.PRIVILEGED],
"kvm_sys": [Requirements.PRIVILEGED],
"kvm": [Requirements.PRIVILEGED, Requirements.X86_64],
+ "libcrosvm_control": [],
"linux_input_sys": [],
- "msg_socket": [Requirements.PRIVILEGED],
- "msg_on_socket_derive": [],
"net_sys": [],
"net_util": [Requirements.PRIVILEGED],
"power_monitor": [],
@@ -50,11 +54,10 @@ CRATE_REQUIREMENTS: Dict[str, List[Requirements]] = {
"qcow_utils": [],
"rand_ish": [],
"resources": [],
- "rutabaga_gfx": [Requirements.CROS_BUILD, Requirements.X86_64],
+ "rutabaga_gfx": [Requirements.CROS_BUILD, Requirements.PRIVILEGED],
"sync": [],
"sys_util": [Requirements.SINGLE_THREADED, Requirements.PRIVILEGED],
"poll_token_derive": [],
- "syscall_defines": [],
"tempfile": [],
"tpm2-sys": [],
"tpm2": [],
@@ -64,7 +67,7 @@ CRATE_REQUIREMENTS: Dict[str, List[Requirements]] = {
"vhost": [Requirements.PRIVILEGED],
"virtio_sys": [],
"vm_control": [],
- "vm_memory": [Requirements.DISABLED],
+ "vm_memory": [Requirements.PRIVILEGED],
"x86_64": [Requirements.X86_64, Requirements.PRIVILEGED],
}
diff --git a/rutabaga_gfx/src/cross_domain/cross_domain.rs b/rutabaga_gfx/src/cross_domain/cross_domain.rs
index 5ab738715..a7b3ac999 100644
--- a/rutabaga_gfx/src/cross_domain/cross_domain.rs
+++ b/rutabaga_gfx/src/cross_domain/cross_domain.rs
@@ -138,6 +138,7 @@ impl RutabagaContext for CrossDomainContext {
&mut self,
resource_id: u32,
resource_create_blob: ResourceCreateBlob,
+ handle: Option<RutabagaHandle>,
) -> RutabagaResult<RutabagaResource> {
let reqs = self
.requirements_blobs
@@ -152,7 +153,11 @@ impl RutabagaContext for CrossDomainContext {
// create blob function, which says "the actual allocation is done via
// VIRTIO_GPU_CMD_SUBMIT_3D." However, atomic resource creation is easiest for the
// cross-domain use case, so whatever.
- let handle = self.gralloc.lock().allocate_memory(*reqs)?;
+ let hnd = match handle {
+ Some(handle) => handle,
+ None => self.gralloc.lock().allocate_memory(*reqs)?,
+ };
+
let info_3d = Resource3DInfo {
width: reqs.info.width,
height: reqs.info.height,
@@ -164,7 +169,7 @@ impl RutabagaContext for CrossDomainContext {
Ok(RutabagaResource {
resource_id,
- handle: Some(Arc::new(handle)),
+ handle: Some(Arc::new(hnd)),
blob: true,
blob_mem: resource_create_blob.blob_mem,
blob_flags: resource_create_blob.blob_flags,
@@ -251,7 +256,7 @@ impl RutabagaContext for CrossDomainContext {
impl RutabagaComponent for CrossDomain {
fn get_capset_info(&self, _capset_id: u32) -> (u32, u32) {
- return (0 as u32, size_of::<CrossDomainCapabilities>() as u32);
+ return (0u32, size_of::<CrossDomainCapabilities>() as u32);
}
fn get_capset(&self, _capset_id: u32, _version: u32) -> Vec<u8> {
diff --git a/rutabaga_gfx/src/generated/virgl_renderer_bindings.rs b/rutabaga_gfx/src/generated/virgl_renderer_bindings.rs
index 3f03ad75e..08b8f7a3e 100644
--- a/rutabaga_gfx/src/generated/virgl_renderer_bindings.rs
+++ b/rutabaga_gfx/src/generated/virgl_renderer_bindings.rs
@@ -11,6 +11,8 @@ pub const VIRGL_RENDERER_USE_GLX: u32 = 4;
pub const VIRGL_RENDERER_USE_SURFACELESS: u32 = 8;
pub const VIRGL_RENDERER_USE_GLES: u32 = 16;
pub const VIRGL_RENDERER_USE_EXTERNAL_BLOB: u32 = 32;
+pub const VIRGL_RENDERER_VENUS: u32 = 64;
+pub const VIRGL_RENDERER_NO_VIRGL: u32 = 128;
pub const VIRGL_RES_BIND_DEPTH_STENCIL: u32 = 1;
pub const VIRGL_RES_BIND_RENDER_TARGET: u32 = 2;
pub const VIRGL_RES_BIND_SAMPLER_VIEW: u32 = 8;
diff --git a/rutabaga_gfx/src/gfxstream.rs b/rutabaga_gfx/src/gfxstream.rs
index 09a6a17a8..81d3b8a20 100644
--- a/rutabaga_gfx/src/gfxstream.rs
+++ b/rutabaga_gfx/src/gfxstream.rs
@@ -221,6 +221,7 @@ impl Gfxstream {
Ok(Box::new(Gfxstream { fence_state }))
}
+ #[allow(clippy::unnecessary_wraps)]
fn map_info(&self, _resource_id: u32) -> RutabagaResult<u32> {
Ok(RUTABAGA_MAP_CACHE_WC)
}
@@ -410,7 +411,7 @@ impl RutabagaComponent for Gfxstream {
_ctx_id: u32,
resource_id: u32,
resource_create_blob: ResourceCreateBlob,
- _iovecs: Vec<RutabagaIovec>,
+ _iovec_opt: Option<Vec<RutabagaIovec>>,
) -> RutabagaResult<RutabagaResource> {
unsafe {
stream_renderer_resource_create_v2(resource_id, resource_create_blob.blob_id);
diff --git a/rutabaga_gfx/src/lib.rs b/rutabaga_gfx/src/lib.rs
index 40c131ba3..dc99fa159 100644
--- a/rutabaga_gfx/src/lib.rs
+++ b/rutabaga_gfx/src/lib.rs
@@ -10,6 +10,7 @@ mod generated;
mod gfxstream;
#[macro_use]
mod macros;
+#[cfg(any(feature = "gfxstream", feature = "virgl_renderer"))]
mod renderer_utils;
mod rutabaga_2d;
mod rutabaga_core;
diff --git a/rutabaga_gfx/src/renderer_utils.rs b/rutabaga_gfx/src/renderer_utils.rs
index 9b3726437..94a271a72 100644
--- a/rutabaga_gfx/src/renderer_utils.rs
+++ b/rutabaga_gfx/src/renderer_utils.rs
@@ -30,7 +30,7 @@ pub struct VirglBox {
* -o vsnprintf.rs
*/
-#[allow(dead_code, non_snake_case, non_camel_case_types)]
+#[allow(non_snake_case, non_camel_case_types)]
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
extern "C" {
pub fn vsnprintf(
diff --git a/rutabaga_gfx/src/rutabaga_core.rs b/rutabaga_gfx/src/rutabaga_core.rs
index 9fef434c3..23b30aca9 100644
--- a/rutabaga_gfx/src/rutabaga_core.rs
+++ b/rutabaga_gfx/src/rutabaga_core.rs
@@ -132,7 +132,7 @@ pub trait RutabagaComponent {
_ctx_id: u32,
_resource_id: u32,
_resource_create_blob: ResourceCreateBlob,
- _backing_iovecs: Vec<RutabagaIovec>,
+ _iovec_opt: Option<Vec<RutabagaIovec>>,
) -> RutabagaResult<RutabagaResource> {
Err(RutabagaError::Unsupported)
}
@@ -166,6 +166,7 @@ pub trait RutabagaContext {
&mut self,
_resource_id: u32,
_resource_create_blob: ResourceCreateBlob,
+ _handle: Option<RutabagaHandle>,
) -> RutabagaResult<RutabagaResource> {
Err(RutabagaError::Unsupported)
}
@@ -442,13 +443,15 @@ impl Rutabaga {
}
/// Creates a blob resource with the `ctx_id` and `resource_create_blob` metadata.
- /// Associates `iovecs` with the resource, if there are any.
+ /// Associates `iovecs` with the resource, if there are any. Associates externally
+ /// created `handle` with the resource, if there is any.
pub fn resource_create_blob(
&mut self,
ctx_id: u32,
resource_id: u32,
resource_create_blob: ResourceCreateBlob,
- iovecs: Vec<RutabagaIovec>,
+ iovecs: Option<Vec<RutabagaIovec>>,
+ handle: Option<RutabagaHandle>,
) -> RutabagaResult<()> {
if self.resources.contains_key(&resource_id) {
return Err(RutabagaError::InvalidResourceId);
@@ -463,7 +466,8 @@ impl Rutabaga {
.get_mut(&ctx_id)
.ok_or(RutabagaError::InvalidContextId)?;
- if let Ok(resource) = ctx.context_create_blob(resource_id, resource_create_blob) {
+ if let Ok(resource) = ctx.context_create_blob(resource_id, resource_create_blob, handle)
+ {
self.resources.insert(resource_id, resource);
return Ok(());
}
@@ -501,8 +505,18 @@ impl Rutabaga {
.get(&resource_id)
.ok_or(RutabagaError::InvalidResourceId)?;
- let map_info = resource.map_info.ok_or(RutabagaError::SpecViolation)?;
- Ok(map_info)
+ resource.map_info.ok_or(RutabagaError::SpecViolation)
+ }
+
+ /// Returns the `vulkan_info` of the blob resource, which consists of the physical device
+ /// index and memory index associated with the resource.
+ pub fn vulkan_info(&self, resource_id: u32) -> RutabagaResult<VulkanInfo> {
+ let resource = self
+ .resources
+ .get(&resource_id)
+ .ok_or(RutabagaError::InvalidResourceId)?;
+
+ resource.vulkan_info.ok_or(RutabagaError::Unsupported)
}
/// Returns the 3D info associated with the resource, if any.
@@ -512,8 +526,7 @@ impl Rutabaga {
.get(&resource_id)
.ok_or(RutabagaError::InvalidResourceId)?;
- let info_3d = resource.info_3d.ok_or(RutabagaError::Unsupported)?;
- Ok(info_3d)
+ resource.info_3d.ok_or(RutabagaError::Unsupported)
}
/// Exports a blob resource. See virtio-gpu spec for blob flag use flags.
diff --git a/rutabaga_gfx/src/rutabaga_gralloc/gralloc.rs b/rutabaga_gfx/src/rutabaga_gralloc/gralloc.rs
index 4222ba945..9b3f25f03 100644
--- a/rutabaga_gfx/src/rutabaga_gralloc/gralloc.rs
+++ b/rutabaga_gfx/src/rutabaga_gralloc/gralloc.rs
@@ -7,7 +7,7 @@
use std::collections::BTreeMap as Map;
-use base::round_up_to_page_size;
+use base::{round_up_to_page_size, MappedRegion};
use crate::rutabaga_gralloc::formats::*;
use crate::rutabaga_gralloc::system_gralloc::SystemGralloc;
@@ -89,6 +89,26 @@ impl RutabagaGrallocFlags {
}
}
+ /// Sets the SW write flag's presence.
+ #[inline(always)]
+ pub fn use_sw_write(self, e: bool) -> RutabagaGrallocFlags {
+ if e {
+ RutabagaGrallocFlags(self.0 | RUTABAGA_GRALLOC_USE_SW_WRITE_OFTEN)
+ } else {
+ RutabagaGrallocFlags(self.0 & !RUTABAGA_GRALLOC_USE_SW_WRITE_OFTEN)
+ }
+ }
+
+ /// Sets the SW read flag's presence.
+ #[inline(always)]
+ pub fn use_sw_read(self, e: bool) -> RutabagaGrallocFlags {
+ if e {
+ RutabagaGrallocFlags(self.0 | RUTABAGA_GRALLOC_USE_SW_READ_OFTEN)
+ } else {
+ RutabagaGrallocFlags(self.0 & !RUTABAGA_GRALLOC_USE_SW_READ_OFTEN)
+ }
+ }
+
/// Returns true if the texturing flag is set.
#[inline(always)]
pub fn uses_texturing(self) -> bool {
@@ -167,6 +187,17 @@ pub trait Gralloc {
/// Implementations must allocate memory given the requirements and return a RutabagaHandle
/// upon success.
fn allocate_memory(&mut self, reqs: ImageMemoryRequirements) -> RutabagaResult<RutabagaHandle>;
+
+ /// Implementations must import the given `handle` and return a mapping, suitable for use with
+ /// KVM and other hypervisors. This is optional and only works with the Vulkano backend.
+ fn import_and_map(
+ &mut self,
+ _handle: RutabagaHandle,
+ _vulkan_info: VulkanInfo,
+ _size: u64,
+ ) -> RutabagaResult<Box<dyn MappedRegion>> {
+ Err(RutabagaError::Unsupported)
+ }
}
/// Enumeration of possible allocation backends.
@@ -293,6 +324,22 @@ impl RutabagaGralloc {
gralloc.allocate_memory(reqs)
}
+
+ /// Imports the `handle` using the given `vulkan_info`. Returns a mapping using Vulkano upon
+ /// success. Should not be used with minigbm or system gralloc backends.
+ pub fn import_and_map(
+ &mut self,
+ handle: RutabagaHandle,
+ vulkan_info: VulkanInfo,
+ size: u64,
+ ) -> RutabagaResult<Box<dyn MappedRegion>> {
+ let gralloc = self
+ .grallocs
+ .get_mut(&GrallocBackend::Vulkano)
+ .ok_or(RutabagaError::Unsupported)?;
+
+ gralloc.import_and_map(handle, vulkan_info, size)
+ }
}
#[cfg(test)]
@@ -363,4 +410,44 @@ mod tests {
// Reallocate with same requirements
let _handle2 = gralloc.allocate_memory(reqs).unwrap();
}
+
+ #[test]
+ fn export_and_map() {
+ let gralloc_result = RutabagaGralloc::new();
+ if gralloc_result.is_err() {
+ return;
+ }
+
+ let mut gralloc = gralloc_result.unwrap();
+
+ let info = ImageAllocationInfo {
+ width: 512,
+ height: 1024,
+ drm_format: DrmFormat::new(b'X', b'R', b'2', b'4'),
+ flags: RutabagaGrallocFlags::empty()
+ .use_linear(true)
+ .use_sw_write(true)
+ .use_sw_read(true),
+ };
+
+ let mut reqs = gralloc.get_image_memory_requirements(info).unwrap();
+
+ // Anything else can use the mmap(..) system call.
+ if reqs.vulkan_info.is_none() {
+ return;
+ }
+
+ let handle = gralloc.allocate_memory(reqs).unwrap();
+ let vulkan_info = reqs.vulkan_info.take().unwrap();
+
+ let mapping = gralloc
+ .import_and_map(handle, vulkan_info, reqs.size)
+ .unwrap();
+
+ let addr = mapping.as_ptr();
+ let size = mapping.size();
+
+ assert_eq!(size as u64, reqs.size);
+ assert_ne!(addr as *const u8, std::ptr::null());
+ }
}
diff --git a/rutabaga_gfx/src/rutabaga_gralloc/minigbm.rs b/rutabaga_gfx/src/rutabaga_gralloc/minigbm.rs
index bdf62e753..c20438de3 100644
--- a/rutabaga_gfx/src/rutabaga_gralloc/minigbm.rs
+++ b/rutabaga_gfx/src/rutabaga_gralloc/minigbm.rs
@@ -143,7 +143,7 @@ impl Gralloc for MinigbmDevice {
return Err(RutabagaError::SpecViolation);
}
- let dmabuf = gbm_buffer.export()?;
+ let dmabuf = gbm_buffer.export()?.into();
return Ok(RutabagaHandle {
os_handle: dmabuf,
handle_type: RUTABAGA_MEM_HANDLE_TYPE_DMABUF,
@@ -165,7 +165,7 @@ impl Gralloc for MinigbmDevice {
}
let gbm_buffer = MinigbmBuffer(bo, self.clone());
- let dmabuf = gbm_buffer.export()?;
+ let dmabuf = gbm_buffer.export()?.into();
Ok(RutabagaHandle {
os_handle: dmabuf,
handle_type: RUTABAGA_MEM_HANDLE_TYPE_DMABUF,
diff --git a/rutabaga_gfx/src/rutabaga_gralloc/vulkano_gralloc.rs b/rutabaga_gfx/src/rutabaga_gralloc/vulkano_gralloc.rs
index b74908a08..7c66d210b 100644
--- a/rutabaga_gfx/src/rutabaga_gralloc/vulkano_gralloc.rs
+++ b/rutabaga_gfx/src/rutabaga_gralloc/vulkano_gralloc.rs
@@ -9,22 +9,27 @@
#![cfg(feature = "vulkano")]
+use std::collections::BTreeMap as Map;
+use std::convert::TryInto;
use std::iter::Empty;
use std::sync::Arc;
+use base::MappedRegion;
+
use crate::rutabaga_gralloc::gralloc::{Gralloc, ImageAllocationInfo, ImageMemoryRequirements};
use crate::rutabaga_utils::*;
use vulkano::device::{Device, DeviceCreationError, DeviceExtensions};
-use vulkano::image::{sys, ImageCreationError, ImageDimensions, ImageUsage};
+use vulkano::image::{sys, ImageCreateFlags, ImageCreationError, ImageDimensions, ImageUsage};
use vulkano::instance::{
Instance, InstanceCreationError, InstanceExtensions, MemoryType, PhysicalDevice,
+ PhysicalDeviceType,
};
use vulkano::memory::{
- DedicatedAlloc, DeviceMemoryAllocError, DeviceMemoryBuilder, ExternalMemoryHandleType,
- MemoryRequirements,
+ DedicatedAlloc, DeviceMemoryAllocError, DeviceMemoryBuilder, DeviceMemoryMapping,
+ ExternalMemoryHandleType, MemoryRequirements,
};
use vulkano::memory::pool::AllocFromRequirementsFilter;
@@ -32,7 +37,32 @@ use vulkano::sync::Sharing;
/// A gralloc implementation capable of allocation `VkDeviceMemory`.
pub struct VulkanoGralloc {
- device: Arc<Device>,
+ devices: Map<PhysicalDeviceType, Arc<Device>>,
+ has_integrated_gpu: bool,
+}
+
+struct VulkanoMapping {
+ mapping: DeviceMemoryMapping,
+ size: usize,
+}
+
+impl VulkanoMapping {
+ pub fn new(mapping: DeviceMemoryMapping, size: usize) -> VulkanoMapping {
+ VulkanoMapping { mapping, size }
+ }
+}
+
+unsafe impl MappedRegion for VulkanoMapping {
+ /// Used for passing this region for hypervisor memory mappings. We trust crosvm to use this
+ /// safely.
+ fn as_ptr(&self) -> *mut u8 {
+ unsafe { self.mapping.as_ptr() }
+ }
+
+ /// Returns the size of the memory region in bytes.
+ fn size(&self) -> usize {
+ self.size
+ }
}
impl VulkanoGralloc {
@@ -42,39 +72,52 @@ impl VulkanoGralloc {
// explanation of VK initialization.
let instance = Instance::new(None, &InstanceExtensions::none(), None)?;
- // We should really check for integrated GPU versus dGPU.
- let physical = PhysicalDevice::enumerate(&instance)
- .next()
- .ok_or(RutabagaError::Unsupported)?;
+ let mut devices: Map<PhysicalDeviceType, Arc<Device>> = Default::default();
+ let mut has_integrated_gpu = false;
+
+ for physical in PhysicalDevice::enumerate(&instance) {
+ let queue_family = physical
+ .queue_families()
+ .find(|&q| {
+ // We take the first queue family that supports graphics.
+ q.supports_graphics()
+ })
+ .ok_or(RutabagaError::Unsupported)?;
+
+ let supported_extensions = DeviceExtensions::supported_by_device(physical);
+
+ let desired_extensions = DeviceExtensions {
+ khr_dedicated_allocation: true,
+ khr_get_memory_requirements2: true,
+ khr_external_memory: true,
+ khr_external_memory_fd: true,
+ ext_external_memory_dmabuf: true,
+ ..DeviceExtensions::none()
+ };
- let queue_family = physical
- .queue_families()
- .find(|&q| {
- // We take the first queue family that supports graphics.
- q.supports_graphics()
- })
- .ok_or(RutabagaError::Unsupported)?;
+ let intersection = supported_extensions.intersection(&desired_extensions);
- let supported_extensions = DeviceExtensions::supported_by_device(physical);
- let desired_extensions = DeviceExtensions {
- khr_dedicated_allocation: true,
- khr_get_memory_requirements2: true,
- khr_external_memory: true,
- khr_external_memory_fd: true,
- ext_external_memory_dmabuf: true,
- ..DeviceExtensions::none()
- };
+ let (device, mut _queues) = Device::new(
+ physical,
+ physical.supported_features(),
+ &intersection,
+ [(queue_family, 0.5)].iter().cloned(),
+ )?;
- let intersection = supported_extensions.intersection(&desired_extensions);
+ if device.physical_device().ty() == PhysicalDeviceType::IntegratedGpu {
+ has_integrated_gpu = true
+ }
- let (device, mut _queues) = Device::new(
- physical,
- physical.supported_features(),
- &intersection,
- [(queue_family, 0.5)].iter().cloned(),
- )?;
+ // If we have two devices of the same type (two integrated GPUs), the old value is
+ // dropped. Vulkano is verbose enough such that a keener selection algorithm may
+ // be used, but the need for such complexity does not seem to exist now.
+ devices.insert(device.physical_device().ty(), device);
+ }
- Ok(Box::new(VulkanoGralloc { device }))
+ Ok(Box::new(VulkanoGralloc {
+ devices,
+ has_integrated_gpu,
+ }))
}
// This function is used safely in this module because gralloc does not:
@@ -88,6 +131,16 @@ impl VulkanoGralloc {
&mut self,
info: ImageAllocationInfo,
) -> RutabagaResult<(sys::UnsafeImage, MemoryRequirements)> {
+ let device = if self.has_integrated_gpu {
+ self.devices
+ .get(&PhysicalDeviceType::IntegratedGpu)
+ .ok_or(RutabagaError::Unsupported)?
+ } else {
+ self.devices
+ .get(&PhysicalDeviceType::DiscreteGpu)
+ .ok_or(RutabagaError::Unsupported)?
+ };
+
let usage = match info.flags.uses_rendering() {
true => ImageUsage {
color_attachment: true,
@@ -111,14 +164,14 @@ impl VulkanoGralloc {
let vulkan_format = info.drm_format.vulkan_format()?;
let (unsafe_image, memory_requirements) = sys::UnsafeImage::new(
- self.device.clone(),
+ device.clone(),
usage,
vulkan_format,
+ ImageCreateFlags::none(),
ImageDimensions::Dim2d {
width: info.width,
height: info.height,
array_layers: 1,
- cubemap_compatible: false,
},
1, /* number of samples */
1, /* mipmap count */
@@ -133,11 +186,23 @@ impl VulkanoGralloc {
impl Gralloc for VulkanoGralloc {
fn supports_external_gpu_memory(&self) -> bool {
- self.device.loaded_extensions().khr_external_memory
+ for device in self.devices.values() {
+ if !device.loaded_extensions().khr_external_memory {
+ return false;
+ }
+ }
+
+ true
}
fn supports_dmabuf(&self) -> bool {
- self.device.loaded_extensions().ext_external_memory_dmabuf
+ for device in self.devices.values() {
+ if !device.loaded_extensions().ext_external_memory_dmabuf {
+ return false;
+ }
+ }
+
+ true
}
fn get_image_memory_requirements(
@@ -148,6 +213,16 @@ impl Gralloc for VulkanoGralloc {
let (unsafe_image, memory_requirements) = unsafe { self.create_image(info)? };
+ let device = if self.has_integrated_gpu {
+ self.devices
+ .get(&PhysicalDeviceType::IntegratedGpu)
+ .ok_or(RutabagaError::Unsupported)?
+ } else {
+ self.devices
+ .get(&PhysicalDeviceType::DiscreteGpu)
+ .ok_or(RutabagaError::Unsupported)?
+ };
+
let planar_layout = info.drm_format.planar_layout()?;
// Safe because we created the image with the linear bit set and verified the format is
@@ -188,13 +263,11 @@ impl Gralloc for VulkanoGralloc {
AllocFromRequirementsFilter::Allowed
};
- let first_loop = self
- .device
+ let first_loop = device
.physical_device()
.memory_types()
.map(|t| (t, AllocFromRequirementsFilter::Preferred));
- let second_loop = self
- .device
+ let second_loop = device
.physical_device()
.memory_types()
.map(|t| (t, AllocFromRequirementsFilter::Allowed));
@@ -219,7 +292,7 @@ impl Gralloc for VulkanoGralloc {
reqs.vulkan_info = Some(VulkanInfo {
memory_idx: memory_type.id() as u32,
- physical_device_idx: self.device.physical_device().index() as u32,
+ physical_device_idx: device.physical_device().index() as u32,
});
Ok(reqs)
@@ -227,15 +300,26 @@ impl Gralloc for VulkanoGralloc {
fn allocate_memory(&mut self, reqs: ImageMemoryRequirements) -> RutabagaResult<RutabagaHandle> {
let (unsafe_image, memory_requirements) = unsafe { self.create_image(reqs.info)? };
+
let vulkan_info = reqs.vulkan_info.ok_or(RutabagaError::SpecViolation)?;
- let memory_type = self
- .device
+
+ let device = if self.has_integrated_gpu {
+ self.devices
+ .get(&PhysicalDeviceType::IntegratedGpu)
+ .ok_or(RutabagaError::Unsupported)?
+ } else {
+ self.devices
+ .get(&PhysicalDeviceType::DiscreteGpu)
+ .ok_or(RutabagaError::Unsupported)?
+ };
+
+ let memory_type = device
.physical_device()
.memory_type_by_id(vulkan_info.memory_idx)
.ok_or(RutabagaError::SpecViolation)?;
let (handle_type, rutabaga_type) =
- match self.device.loaded_extensions().ext_external_memory_dmabuf {
+ match device.loaded_extensions().ext_external_memory_dmabuf {
true => (
ExternalMemoryHandleType {
dma_buf: true,
@@ -252,7 +336,7 @@ impl Gralloc for VulkanoGralloc {
),
};
- let dedicated = match self.device.loaded_extensions().khr_dedicated_allocation {
+ let dedicated = match device.loaded_extensions().khr_dedicated_allocation {
true => {
if memory_requirements.prefer_dedicated {
DedicatedAlloc::Image(&unsafe_image)
@@ -264,18 +348,55 @@ impl Gralloc for VulkanoGralloc {
};
let device_memory =
- DeviceMemoryBuilder::new(self.device.clone(), memory_type, reqs.size as usize)
+ DeviceMemoryBuilder::new(device.clone(), memory_type.id(), reqs.size as usize)
.dedicated_info(dedicated)
.export_info(handle_type)
.build()?;
- let file = device_memory.export_fd(handle_type)?;
+ let descriptor = device_memory.export_fd(handle_type)?.into();
Ok(RutabagaHandle {
- os_handle: file,
+ os_handle: descriptor,
handle_type: rutabaga_type,
})
}
+
+ /// Implementations must map the memory associated with the `resource_id` upon success.
+ fn import_and_map(
+ &mut self,
+ handle: RutabagaHandle,
+ vulkan_info: VulkanInfo,
+ size: u64,
+ ) -> RutabagaResult<Box<dyn MappedRegion>> {
+ let device = self
+ .devices
+ .values()
+ .find(|device| {
+ device.physical_device().index() as u32 == vulkan_info.physical_device_idx
+ })
+ .ok_or(RutabagaError::Unsupported)?;
+
+ let handle_type = match handle.handle_type {
+ RUTABAGA_MEM_HANDLE_TYPE_DMABUF => ExternalMemoryHandleType {
+ dma_buf: true,
+ ..ExternalMemoryHandleType::none()
+ },
+ RUTABAGA_MEM_HANDLE_TYPE_OPAQUE_FD => ExternalMemoryHandleType {
+ opaque_fd: true,
+ ..ExternalMemoryHandleType::none()
+ },
+ _ => return Err(RutabagaError::Unsupported),
+ };
+
+ let valid_size: usize = size.try_into()?;
+ let device_memory =
+ DeviceMemoryBuilder::new(device.clone(), vulkan_info.memory_idx, valid_size)
+ .import_info(handle.os_handle.into(), handle_type)
+ .build()?;
+ let mapping = DeviceMemoryMapping::new(device.clone(), device_memory.clone(), 0, size, 0)?;
+
+ Ok(Box::new(VulkanoMapping::new(mapping, valid_size)))
+ }
}
// Vulkano should really define an universal type that wraps all these errors, say
diff --git a/rutabaga_gfx/src/rutabaga_utils.rs b/rutabaga_gfx/src/rutabaga_utils.rs
index 943f36ccd..06f6fe7d6 100644
--- a/rutabaga_gfx/src/rutabaga_utils.rs
+++ b/rutabaga_gfx/src/rutabaga_utils.rs
@@ -5,13 +5,13 @@
//! rutabaga_utils: Utility enums, structs, and implementations needed by the rest of the crate.
use std::fmt::{self, Display};
-use std::fs::File;
use std::io::Error as IoError;
+use std::num::TryFromIntError;
use std::os::raw::c_void;
use std::path::PathBuf;
use std::str::Utf8Error;
-use base::{Error as SysError, ExternalMappingError};
+use base::{Error as SysError, ExternalMappingError, SafeDescriptor};
use data_model::VolatileMemoryError;
#[cfg(feature = "vulkano")]
@@ -155,6 +155,8 @@ pub enum RutabagaError {
SpecViolation,
/// System error returned as a result of rutabaga library operation.
SysError(SysError),
+ /// An attempted integer conversion failed.
+ TryFromIntError(TryFromIntError),
/// The command is unsupported.
Unsupported,
/// Utf8 error.
@@ -210,6 +212,7 @@ impl Display for RutabagaError {
ComponentError(ret) => write!(f, "rutabaga component failed with error {}", ret),
SpecViolation => write!(f, "violation of the rutabaga spec"),
SysError(e) => write!(f, "rutabaga received a system error: {}", e),
+ TryFromIntError(e) => write!(f, "int conversion failed: {}", e),
Unsupported => write!(f, "feature or function unsupported"),
Utf8Error(e) => write!(f, "an utf8 error occured: {}", e),
VolatileMemoryError(e) => write!(f, "noticed a volatile memory error {}", e),
@@ -239,6 +242,12 @@ impl From<SysError> for RutabagaError {
}
}
+impl From<TryFromIntError> for RutabagaError {
+ fn from(e: TryFromIntError) -> RutabagaError {
+ RutabagaError::TryFromIntError(e)
+ }
+}
+
impl From<Utf8Error> for RutabagaError {
fn from(e: Utf8Error) -> RutabagaError {
RutabagaError::Utf8Error(e)
@@ -262,6 +271,8 @@ const VIRGLRENDERER_USE_GLX: u32 = 1 << 2;
const VIRGLRENDERER_USE_SURFACELESS: u32 = 1 << 3;
const VIRGLRENDERER_USE_GLES: u32 = 1 << 4;
const VIRGLRENDERER_USE_EXTERNAL_BLOB: u32 = 1 << 5;
+const VIRGLRENDERER_VENUS: u32 = 1 << 6;
+const VIRGLRENDERER_NO_VIRGL: u32 = 1 << 7;
/// virglrenderer flag struct.
#[derive(Copy, Clone)]
@@ -270,6 +281,8 @@ pub struct VirglRendererFlags(u32);
impl Default for VirglRendererFlags {
fn default() -> VirglRendererFlags {
VirglRendererFlags::new()
+ .use_virgl(true)
+ .use_venus(false)
.use_egl(true)
.use_surfaceless(true)
.use_gles(true)
@@ -296,6 +309,16 @@ impl VirglRendererFlags {
}
}
+ /// Enable virgl support
+ pub fn use_virgl(self, v: bool) -> VirglRendererFlags {
+ self.set_flag(VIRGLRENDERER_NO_VIRGL, !v)
+ }
+
+ /// Enable venus support
+ pub fn use_venus(self, v: bool) -> VirglRendererFlags {
+ self.set_flag(VIRGLRENDERER_VENUS, v)
+ }
+
/// Use EGL for context creation.
pub fn use_egl(self, v: bool) -> VirglRendererFlags {
self.set_flag(VIRGLRENDERER_USE_EGL, v)
@@ -377,7 +400,7 @@ impl GfxstreamFlags {
}
/// Support using Vulkan.
- pub fn support_vulkan(self, v: bool) -> GfxstreamFlags {
+ pub fn use_vulkan(self, v: bool) -> GfxstreamFlags {
self.set_flag(GFXSTREAM_RENDERER_FLAGS_NO_VK_BIT, !v)
}
@@ -463,7 +486,7 @@ pub const RUTABAGE_FENCE_HANDLE_TYPE_OPAQUE_WIN32: u32 = 0x0006;
/// Handle to OS-specific memory or synchronization objects.
pub struct RutabagaHandle {
- pub os_handle: File,
+ pub os_handle: SafeDescriptor,
pub handle_type: u32,
}
diff --git a/rutabaga_gfx/src/virgl_renderer.rs b/rutabaga_gfx/src/virgl_renderer.rs
index ffc35a230..cdea147a3 100644
--- a/rutabaga_gfx/src/virgl_renderer.rs
+++ b/rutabaga_gfx/src/virgl_renderer.rs
@@ -9,7 +9,6 @@
use std::cell::RefCell;
use std::ffi::CString;
-use std::fs::File;
use std::mem::{size_of, transmute};
use std::os::raw::{c_char, c_void};
use std::ptr::null_mut;
@@ -19,7 +18,7 @@ use std::sync::Arc;
use base::{
warn, Error as SysError, ExternalMapping, ExternalMappingError, ExternalMappingResult,
- FromRawDescriptor,
+ FromRawDescriptor, SafeDescriptor,
};
use crate::generated::virgl_renderer_bindings::*;
@@ -181,7 +180,10 @@ impl VirglRenderer {
// Initialize it only once and use the non-send/non-sync Renderer struct to keep things tied
// to whichever thread called this function first.
static INIT_ONCE: AtomicBool = AtomicBool::new(false);
- if INIT_ONCE.compare_and_swap(false, true, Ordering::Acquire) {
+ if INIT_ONCE
+ .compare_exchange(false, true, Ordering::Acquire, Ordering::Acquire)
+ .is_err()
+ {
return Err(RutabagaError::AlreadyInUse);
}
@@ -266,7 +268,7 @@ impl VirglRenderer {
return Err(RutabagaError::Unsupported);
}
- let dmabuf = unsafe { File::from_raw_descriptor(fd) };
+ let dmabuf = unsafe { SafeDescriptor::from_raw_descriptor(fd) };
Ok(Arc::new(RutabagaHandle {
os_handle: dmabuf,
handle_type: RUTABAGA_MEM_HANDLE_TYPE_DMABUF,
@@ -478,10 +480,17 @@ impl RutabagaComponent for VirglRenderer {
ctx_id: u32,
resource_id: u32,
resource_create_blob: ResourceCreateBlob,
- mut iovecs: Vec<RutabagaIovec>,
+ mut iovec_opt: Option<Vec<RutabagaIovec>>,
) -> RutabagaResult<RutabagaResource> {
#[cfg(feature = "virgl_renderer_next")]
{
+ let mut iovec_ptr = null_mut();
+ let mut num_iovecs = 0;
+ if let Some(ref mut iovecs) = iovec_opt {
+ iovec_ptr = iovecs.as_mut_ptr();
+ num_iovecs = iovecs.len();
+ }
+
let resource_create_args = virgl_renderer_resource_create_blob_args {
res_handle: resource_id,
ctx_id,
@@ -489,17 +498,13 @@ impl RutabagaComponent for VirglRenderer {
blob_flags: resource_create_blob.blob_flags,
blob_id: resource_create_blob.blob_id,
size: resource_create_blob.size,
- iovecs: iovecs.as_mut_ptr() as *const iovec,
- num_iovs: iovecs.len() as u32,
+ iovecs: iovec_ptr as *const iovec,
+ num_iovs: num_iovecs as u32,
};
+
let ret = unsafe { virgl_renderer_resource_create_blob(&resource_create_args) };
ret_to_res(ret)?;
- let iovec_opt = match resource_create_blob.blob_mem {
- RUTABAGA_BLOB_MEM_GUEST => Some(iovecs),
- _ => None,
- };
-
Ok(RutabagaResource {
resource_id,
handle: self.export_blob(resource_id).ok(),
@@ -536,7 +541,7 @@ impl RutabagaComponent for VirglRenderer {
// Safe because the FD was just returned by a successful virglrenderer call so it must
// be valid and owned by us.
- let fence = unsafe { File::from_raw_descriptor(fd) };
+ let fence = unsafe { SafeDescriptor::from_raw_descriptor(fd) };
Ok(RutabagaHandle {
os_handle: fence,
handle_type: RUTABAGA_FENCE_HANDLE_TYPE_SYNC_FD,
diff --git a/seccomp/aarch64/9p_device.policy b/seccomp/aarch64/9p_device.policy
index 27c79083f..85344028c 100644
--- a/seccomp/aarch64/9p_device.policy
+++ b/seccomp/aarch64/9p_device.policy
@@ -6,7 +6,6 @@ openat: 1
@include /usr/share/policy/crosvm/common_device.policy
-fcntl: 1
pread64: 1
pwrite64: 1
statx: 1
@@ -23,6 +22,8 @@ unlinkat: 1
socket: arg0 == AF_UNIX
utimensat: 1
ftruncate: 1
-fchown: arg1 == 0xffffffff && arg2 == 0xffffffff
+fchmod: 1
+fchown: 1
fstatfs: 1
newfstatat: 1
+prctl: arg0 == PR_SET_NAME
diff --git a/seccomp/aarch64/balloon_device.policy b/seccomp/aarch64/balloon_device.policy
index e1ca95339..57e21cce4 100644
--- a/seccomp/aarch64/balloon_device.policy
+++ b/seccomp/aarch64/balloon_device.policy
@@ -4,5 +4,5 @@
@include /usr/share/policy/crosvm/common_device.policy
-fcntl: 1
openat: return ENOENT
+prctl: arg0 == PR_SET_NAME
diff --git a/seccomp/aarch64/battery.policy b/seccomp/aarch64/battery.policy
index a4fb9fcb6..f26af9caa 100644
--- a/seccomp/aarch64/battery.policy
+++ b/seccomp/aarch64/battery.policy
@@ -3,3 +3,4 @@
# found in the LICENSE file.
@include /usr/share/policy/crosvm/common_device.policy
+prctl: arg0 == PR_SET_NAME
diff --git a/seccomp/aarch64/block_device.policy b/seccomp/aarch64/block_device.policy
index 7697a7e6e..64d5ca5ff 100644
--- a/seccomp/aarch64/block_device.policy
+++ b/seccomp/aarch64/block_device.policy
@@ -5,7 +5,6 @@
@include /usr/share/policy/crosvm/common_device.policy
fallocate: 1
-fcntl: 1
fdatasync: 1
fstat: 1
fsync: 1
@@ -18,3 +17,4 @@ statx: 1
timerfd_create: 1
timerfd_gettime: 1
timerfd_settime: 1
+prctl: arg0 == PR_SET_NAME
diff --git a/seccomp/aarch64/common_device.policy b/seccomp/aarch64/common_device.policy
index 841e52d09..349afe9aa 100644
--- a/seccomp/aarch64/common_device.policy
+++ b/seccomp/aarch64/common_device.policy
@@ -25,9 +25,9 @@ mprotect: arg2 in ~PROT_EXEC
mremap: 1
munmap: 1
nanosleep: 1
+clock_nanosleep: 1
pipe2: 1
ppoll: 1
-prctl: arg0 == PR_SET_NAME
read: 1
readv: 1
recvfrom: 1
@@ -43,3 +43,4 @@ set_robust_list: 1
sigaltstack: 1
write: 1
writev: 1
+fcntl: 1
diff --git a/seccomp/aarch64/cras_audio_device.policy b/seccomp/aarch64/cras_audio_device.policy
index 19419fd4f..60797f978 100644
--- a/seccomp/aarch64/cras_audio_device.policy
+++ b/seccomp/aarch64/cras_audio_device.policy
@@ -11,3 +11,4 @@ sched_setscheduler: 1
socketpair: arg0 == AF_UNIX
clock_gettime: 1
openat: return ENOENT
+prctl: arg0 == PR_SET_NAME
diff --git a/seccomp/aarch64/fs_device.policy b/seccomp/aarch64/fs_device.policy
index 1d8bbd9f7..828003e0c 100644
--- a/seccomp/aarch64/fs_device.policy
+++ b/seccomp/aarch64/fs_device.policy
@@ -48,3 +48,6 @@ symlinkat: 1
umask: 1
unlinkat: 1
utimensat: 1
+prctl: arg0 == PR_SET_NAME || arg0 == PR_SET_SECUREBITS || arg0 == PR_GET_SECUREBITS
+capget: 1
+capset: 1
diff --git a/seccomp/aarch64/gpu_device.policy b/seccomp/aarch64/gpu_device.policy
index bd1f6481d..4ceac5c46 100644
--- a/seccomp/aarch64/gpu_device.policy
+++ b/seccomp/aarch64/gpu_device.policy
@@ -23,6 +23,7 @@ madvise: arg2 == MADV_DONTNEED || arg2 == MADV_DONTDUMP || arg2 == MADV_REMOVE
mremap: 1
munmap: 1
nanosleep: 1
+clock_nanosleep: 1
pipe2: 1
ppoll: 1
prctl: arg0 == PR_SET_NAME || arg0 == PR_GET_NAME
@@ -58,8 +59,8 @@ newfstatat: 1
getdents64: 1
sysinfo: 1
-# 0x6400 == DRM_IOCTL_BASE, 0x8000 = KBASE_IOCTL_TYPE (mali)
-ioctl: arg1 & 0x6400 || arg1 & 0x8000
+# 0x6400 == DRM_IOCTL_BASE, 0x8000 = KBASE_IOCTL_TYPE (mali), 0x40086200 = DMA_BUF_IOCTL_SYNC, 0x40087543 == UDMABUF_CREATE_LIST
+ioctl: arg1 & 0x6400 || arg1 & 0x8000 || arg1 == 0x40086200 || arg1 == 0x40087543
## mmap/mprotect differ from the common_device.policy
mmap: arg2 == PROT_READ|PROT_WRITE || arg2 == PROT_NONE || arg2 == PROT_READ|PROT_EXEC || arg2 == PROT_WRITE || arg2 == PROT_READ
diff --git a/seccomp/aarch64/input_device.policy b/seccomp/aarch64/input_device.policy
index 07d3b5f74..728580dc6 100644
--- a/seccomp/aarch64/input_device.policy
+++ b/seccomp/aarch64/input_device.policy
@@ -5,6 +5,6 @@
@include /usr/share/policy/crosvm/common_device.policy
ioctl: 1
-fcntl: 1
getsockname: 1
openat: return ENOENT
+prctl: arg0 == PR_SET_NAME
diff --git a/seccomp/aarch64/net_device.policy b/seccomp/aarch64/net_device.policy
index a1c2eeff4..b77bdfb82 100644
--- a/seccomp/aarch64/net_device.policy
+++ b/seccomp/aarch64/net_device.policy
@@ -7,3 +7,5 @@
# TUNSETOFFLOAD
ioctl: arg1 == 0x400454d0
openat: return ENOENT
+
+prctl: arg0 == PR_SET_NAME
diff --git a/seccomp/aarch64/null_audio_device.policy b/seccomp/aarch64/null_audio_device.policy
index 7a88fe297..b55aa1e94 100644
--- a/seccomp/aarch64/null_audio_device.policy
+++ b/seccomp/aarch64/null_audio_device.policy
@@ -9,3 +9,4 @@ prlimit64: 1
setrlimit: 1
clock_gettime: 1
openat: return ENOENT
+prctl: arg0 == PR_SET_NAME
diff --git a/seccomp/aarch64/pmem_device.policy b/seccomp/aarch64/pmem_device.policy
index 77719a997..cbdf83a23 100644
--- a/seccomp/aarch64/pmem_device.policy
+++ b/seccomp/aarch64/pmem_device.policy
@@ -7,3 +7,4 @@
fdatasync: 1
fsync: 1
openat: return ENOENT
+prctl: arg0 == PR_SET_NAME
diff --git a/seccomp/aarch64/rng_device.policy b/seccomp/aarch64/rng_device.policy
index fa86280a4..57e21cce4 100644
--- a/seccomp/aarch64/rng_device.policy
+++ b/seccomp/aarch64/rng_device.policy
@@ -5,3 +5,4 @@
@include /usr/share/policy/crosvm/common_device.policy
openat: return ENOENT
+prctl: arg0 == PR_SET_NAME
diff --git a/seccomp/aarch64/serial.policy b/seccomp/aarch64/serial.policy
index 8d23c0f90..3a76beee4 100644
--- a/seccomp/aarch64/serial.policy
+++ b/seccomp/aarch64/serial.policy
@@ -7,3 +7,4 @@
connect: 1
bind: 1
openat: return ENOENT
+prctl: arg0 == PR_SET_NAME
diff --git a/seccomp/aarch64/tpm_device.policy b/seccomp/aarch64/tpm_device.policy
index a39d61c6a..98d32b6b1 100644
--- a/seccomp/aarch64/tpm_device.policy
+++ b/seccomp/aarch64/tpm_device.policy
@@ -25,6 +25,7 @@ mprotect: arg2 in ~PROT_EXEC
mremap: 1
munmap: 1
nanosleep: 1
+clock_nanosleep: 1
pipe2: 1
ppoll: 1
prctl: arg0 == PR_SET_NAME
diff --git a/seccomp/aarch64/vhost_net_device.policy b/seccomp/aarch64/vhost_net_device.policy
index 4de1967b3..44247e097 100644
--- a/seccomp/aarch64/vhost_net_device.policy
+++ b/seccomp/aarch64/vhost_net_device.policy
@@ -22,3 +22,5 @@
# arg1 == VHOST_NET_SET_BACKEND
ioctl: arg1 == 0x8008af00 || arg1 == 0x4008af00 || arg1 == 0x0000af01 || arg1 == 0x0000af02 || arg1 == 0x4008af03 || arg1 == 0x4008af04 || arg1 == 0x4004af07 || arg1 == 0x4008af10 || arg1 == 0x4028af11 || arg1 == 0x4008af12 || arg1 == 0xc008af12 || arg1 == 0x4008af20 || arg1 == 0x4008af21 || arg1 == 0x4008af22 || arg1 == 0x4008af30
openat: return ENOENT
+
+prctl: arg0 == PR_SET_NAME
diff --git a/seccomp/aarch64/vhost_vsock_device.policy b/seccomp/aarch64/vhost_vsock_device.policy
index 82b66502d..a2774e16b 100644
--- a/seccomp/aarch64/vhost_vsock_device.policy
+++ b/seccomp/aarch64/vhost_vsock_device.policy
@@ -23,3 +23,5 @@
# arg1 == VHOST_VSOCK_SET_RUNNING
ioctl: arg1 == 0x8008af00 || arg1 == 0x4008af00 || arg1 == 0x0000af01 || arg1 == 0x0000af02 || arg1 == 0x4008af03 || arg1 == 0x4008af04 || arg1 == 0x4004af07 || arg1 == 0x4008af10 || arg1 == 0x4028af11 || arg1 == 0x4008af12 || arg1 == 0xc008af12 || arg1 == 0x4008af20 || arg1 == 0x4008af21 || arg1 == 0x4008af22 || arg1 == 0x4008af60 || arg1 == 0x4004af61
openat: return ENOENT
+
+prctl: arg0 == PR_SET_NAME
diff --git a/seccomp/aarch64/vios_audio_device.policy b/seccomp/aarch64/vios_audio_device.policy
index df54139f5..d425ab279 100644
--- a/seccomp/aarch64/vios_audio_device.policy
+++ b/seccomp/aarch64/vios_audio_device.policy
@@ -5,9 +5,9 @@
@include /usr/share/policy/crosvm/common_device.policy
clock_gettime: 1
-clock_nanosleep: 1
lseek: 1
openat: return ENOENT
prlimit64: 1
sched_setscheduler: 1
setrlimit: 1
+prctl: arg0 == PR_SET_NAME
diff --git a/seccomp/aarch64/wl_device.policy b/seccomp/aarch64/wl_device.policy
index 864aefb88..cd6804637 100644
--- a/seccomp/aarch64/wl_device.policy
+++ b/seccomp/aarch64/wl_device.policy
@@ -16,5 +16,5 @@ memfd_create: arg1 == 3
ftruncate: 1
# Used to determine shm size after recvmsg with fd
lseek: 1
-# Allow F_GETFL only
-fcntl: arg1 == 3
+
+prctl: arg0 == PR_SET_NAME
diff --git a/seccomp/aarch64/xhci.policy b/seccomp/aarch64/xhci.policy
index d69514090..684ae0d2f 100644
--- a/seccomp/aarch64/xhci.policy
+++ b/seccomp/aarch64/xhci.policy
@@ -16,7 +16,6 @@ getsockname: 1
openat: 1
setsockopt: 1
bind: 1
-fcntl: 1
socket: arg0 == AF_NETLINK
uname: 1
# The following ioctls are:
@@ -37,3 +36,4 @@ ioctl: arg1 == 0xc0105500 || arg1 == 0x802c550a || arg1 == 0x8004551a || arg1 ==
fstat: 1
getrandom: 1
lseek: 1
+prctl: arg0 == PR_SET_NAME
diff --git a/seccomp/arm/9p_device.policy b/seccomp/arm/9p_device.policy
index 95d0b320d..a7b877b27 100644
--- a/seccomp/arm/9p_device.policy
+++ b/seccomp/arm/9p_device.policy
@@ -4,7 +4,6 @@
@include /usr/share/policy/crosvm/common_device.policy
-fcntl64: 1
pread64: 1
pwrite64: 1
stat64: 1
@@ -24,7 +23,10 @@ linkat: 1
unlinkat: 1
socket: arg0 == AF_UNIX
utimensat: 1
+utimensat_time64: 1
ftruncate64: 1
-fchown: arg1 == 0xffffffff && arg2 == 0xffffffff
+fchmod: 1
+fchown: 1
fstatfs64: 1
fstatat64: 1
+prctl: arg0 == PR_SET_NAME
diff --git a/seccomp/arm/balloon_device.policy b/seccomp/arm/balloon_device.policy
index 868ae3110..e0e444270 100644
--- a/seccomp/arm/balloon_device.policy
+++ b/seccomp/arm/balloon_device.policy
@@ -4,6 +4,6 @@
@include /usr/share/policy/crosvm/common_device.policy
-fcntl64: 1
open: return ENOENT
openat: return ENOENT
+prctl: arg0 == PR_SET_NAME
diff --git a/seccomp/arm/battery.policy b/seccomp/arm/battery.policy
index a4fb9fcb6..f26af9caa 100644
--- a/seccomp/arm/battery.policy
+++ b/seccomp/arm/battery.policy
@@ -3,3 +3,4 @@
# found in the LICENSE file.
@include /usr/share/policy/crosvm/common_device.policy
+prctl: arg0 == PR_SET_NAME
diff --git a/seccomp/arm/block_device.policy b/seccomp/arm/block_device.policy
index 785af4582..75b769f43 100644
--- a/seccomp/arm/block_device.policy
+++ b/seccomp/arm/block_device.policy
@@ -5,7 +5,6 @@
@include /usr/share/policy/crosvm/common_device.policy
fallocate: 1
-fcntl64: 1
fdatasync: 1
fstat64: 1
fsync: 1
@@ -20,4 +19,7 @@ pwritev: 1
statx: 1
timerfd_create: 1
timerfd_gettime: 1
+timerfd_gettime64: 1
timerfd_settime: 1
+timerfd_settime64: 1
+prctl: arg0 == PR_SET_NAME
diff --git a/seccomp/arm/common_device.policy b/seccomp/arm/common_device.policy
index cbbfd7d43..165bfda6e 100644
--- a/seccomp/arm/common_device.policy
+++ b/seccomp/arm/common_device.policy
@@ -3,6 +3,8 @@
# found in the LICENSE file.
brk: 1
+clock_gettime: 1
+clock_gettime64: 1
clone: arg0 & CLONE_THREAD
close: 1
dup2: 1
@@ -14,6 +16,7 @@ eventfd2: 1
exit: 1
exit_group: 1
futex: 1
+futex_time64: 1
getpid: 1
gettid: 1
gettimeofday: 1
@@ -26,15 +29,18 @@ mprotect: arg2 in ~PROT_EXEC
mremap: 1
munmap: 1
nanosleep: 1
+clock_nanosleep: 1
+clock_nanosleep_time64: 1
pipe2: 1
poll: 1
ppoll: 1
-prctl: arg0 == PR_SET_NAME
+ppoll_time64: 1
read: 1
readv: 1
recv: 1
recvfrom: 1
recvmsg: 1
+recvmmsg_time64: 1
restart_syscall: 1
rt_sigaction: 1
rt_sigprocmask: 1
@@ -46,3 +52,4 @@ set_robust_list: 1
sigaltstack: 1
write: 1
writev: 1
+fcntl64: 1
diff --git a/seccomp/arm/cras_audio_device.policy b/seccomp/arm/cras_audio_device.policy
index 505208b15..20bf60e1f 100644
--- a/seccomp/arm/cras_audio_device.policy
+++ b/seccomp/arm/cras_audio_device.policy
@@ -11,4 +11,4 @@ prlimit64: 1
setrlimit: 1
sched_setscheduler: 1
socketpair: arg0 == AF_UNIX
-clock_gettime: 1
+prctl: arg0 == PR_SET_NAME
diff --git a/seccomp/arm/fs_device.policy b/seccomp/arm/fs_device.policy
index 02dff2916..e84fd08fc 100644
--- a/seccomp/arm/fs_device.policy
+++ b/seccomp/arm/fs_device.policy
@@ -50,4 +50,8 @@ statx: 1
symlinkat: 1
umask: 1
unlinkat: 1
-utimensat: 1 \ No newline at end of file
+utimensat: 1
+utimensat_time64: 1
+prctl: arg0 == PR_SET_NAME || arg0 == PR_SET_SECUREBITS || arg0 == PR_GET_SECUREBITS
+capget: 1
+capset: 1
diff --git a/seccomp/arm/gpu_device.policy b/seccomp/arm/gpu_device.policy
index 1bdea6d0d..ec5a5b481 100644
--- a/seccomp/arm/gpu_device.policy
+++ b/seccomp/arm/gpu_device.policy
@@ -16,6 +16,7 @@ eventfd2: 1
exit: 1
exit_group: 1
futex: 1
+futex_time64: 1
getpid: 1
gettimeofday: 1
kill: 1
@@ -23,15 +24,19 @@ madvise: arg2 == MADV_DONTNEED || arg2 == MADV_DONTDUMP || arg2 == MADV_REMOVE
mremap: 1
munmap: 1
nanosleep: 1
+clock_nanosleep: 1
+clock_nanosleep_time64: 1
pipe2: 1
poll: 1
ppoll: 1
+ppoll_time64: 1
prctl: arg0 == PR_SET_NAME || arg0 == PR_GET_NAME
read: 1
readv: 1
recv: 1
recvfrom: 1
recvmsg: 1
+recvmmsg_time64: 1
restart_syscall: 1
rt_sigaction: 1
rt_sigprocmask: 1
@@ -60,8 +65,8 @@ getdents: 1
getdents64: 1
sysinfo: 1
-# 0x6400 == DRM_IOCTL_BASE, 0x8000 = KBASE_IOCTL_TYPE (mali)
-ioctl: arg1 & 0x6400 || arg1 & 0x8000
+# 0x6400 == DRM_IOCTL_BASE, 0x8000 = KBASE_IOCTL_TYPE (mali), 0x40086200 = DMA_BUF_IOCTL_SYNC, 0x40087543 == UDMABUF_CREATE_LIST
+ioctl: arg1 & 0x6400 || arg1 & 0x8000 || arg1 == 0x40086200 || arg1 == 0x40087543
# Used for sharing memory with wayland. arg1 == MFD_CLOEXEC|MFD_ALLOW_SEALING
memfd_create: arg1 == 3
@@ -81,6 +86,7 @@ gettid: 1
fcntl64: 1
tgkill: 1
clock_gettime: 1
+clock_gettime64: 1
# Rules specific to Mesa.
uname: 1
diff --git a/seccomp/arm/input_device.policy b/seccomp/arm/input_device.policy
index d32c31222..bb6985315 100644
--- a/seccomp/arm/input_device.policy
+++ b/seccomp/arm/input_device.policy
@@ -5,7 +5,7 @@
@include /usr/share/policy/crosvm/common_device.policy
ioctl: 1
-fcntl: 1
getsockname: 1
open: return ENOENT
openat: return ENOENT
+prctl: arg0 == PR_SET_NAME
diff --git a/seccomp/arm/net_device.policy b/seccomp/arm/net_device.policy
index cf0584ce6..4cd4815ca 100644
--- a/seccomp/arm/net_device.policy
+++ b/seccomp/arm/net_device.policy
@@ -8,3 +8,4 @@
ioctl: arg1 == 0x400454d0
open: return ENOENT
openat: return ENOENT
+prctl: arg0 == PR_SET_NAME
diff --git a/seccomp/arm/null_audio_device.policy b/seccomp/arm/null_audio_device.policy
index f89397b9e..c87441c38 100644
--- a/seccomp/arm/null_audio_device.policy
+++ b/seccomp/arm/null_audio_device.policy
@@ -9,4 +9,4 @@ open: return ENOENT
openat: return ENOENT
prlimit64: 1
setrlimit: 1
-clock_gettime: 1
+prctl: arg0 == PR_SET_NAME
diff --git a/seccomp/arm/pmem_device.policy b/seccomp/arm/pmem_device.policy
index 12a3b04f8..e7321f747 100644
--- a/seccomp/arm/pmem_device.policy
+++ b/seccomp/arm/pmem_device.policy
@@ -8,3 +8,4 @@ fdatasync: 1
fsync: 1
open: return ENOENT
openat: return ENOENT
+prctl: arg0 == PR_SET_NAME
diff --git a/seccomp/arm/rng_device.policy b/seccomp/arm/rng_device.policy
index 0c7d2583a..e0e444270 100644
--- a/seccomp/arm/rng_device.policy
+++ b/seccomp/arm/rng_device.policy
@@ -6,3 +6,4 @@
open: return ENOENT
openat: return ENOENT
+prctl: arg0 == PR_SET_NAME
diff --git a/seccomp/arm/serial.policy b/seccomp/arm/serial.policy
index f0456e931..1d8140d44 100644
--- a/seccomp/arm/serial.policy
+++ b/seccomp/arm/serial.policy
@@ -8,3 +8,4 @@ connect: 1
bind: 1
open: return ENOENT
openat: return ENOENT
+prctl: arg0 == PR_SET_NAME
diff --git a/seccomp/arm/tpm_device.policy b/seccomp/arm/tpm_device.policy
index d17f67cd1..5653f7794 100644
--- a/seccomp/arm/tpm_device.policy
+++ b/seccomp/arm/tpm_device.policy
@@ -4,6 +4,8 @@
# common policy
brk: 1
+clock_gettime: 1
+clock_gettime64: 1
clone: arg0 & CLONE_THREAD
close: 1
dup2: 1
@@ -15,6 +17,7 @@ eventfd2: 1
exit: 1
exit_group: 1
futex: 1
+futex_time64: 1
getpid: 1
getrandom: 1
gettimeofday: 1
@@ -25,14 +28,18 @@ mprotect: arg2 in ~PROT_EXEC
mremap: 1
munmap: 1
nanosleep: 1
+clock_nanosleep: 1
+clock_nanosleep_time64: 1
pipe2: 1
poll: 1
ppoll: 1
+ppoll_time64: 1
prctl: arg0 == PR_SET_NAME
read: 1
recv: 1
recvfrom: 1
recvmsg: 1
+recvmmsg_time64: 1
restart_syscall: 1
rt_sigaction: 1
rt_sigprocmask: 1
diff --git a/seccomp/arm/vhost_net_device.policy b/seccomp/arm/vhost_net_device.policy
index 4571a938e..fbee418b2 100644
--- a/seccomp/arm/vhost_net_device.policy
+++ b/seccomp/arm/vhost_net_device.policy
@@ -23,3 +23,4 @@
ioctl: arg1 == 0x8008af00 || arg1 == 0x4008af00 || arg1 == 0x0000af01 || arg1 == 0x0000af02 || arg1 == 0x4008af03 || arg1 == 0x4008af04 || arg1 == 0x4004af07 || arg1 == 0x4008af10 || arg1 == 0x4028af11 || arg1 == 0x4008af12 || arg1 == 0xc008af12 || arg1 == 0x4008af20 || arg1 == 0x4008af21 || arg1 == 0x4008af22 || arg1 == 0x4008af30
open: return ENOENT
openat: return ENOENT
+prctl: arg0 == PR_SET_NAME
diff --git a/seccomp/arm/vhost_vsock_device.policy b/seccomp/arm/vhost_vsock_device.policy
index c6a984c51..a793afc9e 100644
--- a/seccomp/arm/vhost_vsock_device.policy
+++ b/seccomp/arm/vhost_vsock_device.policy
@@ -24,3 +24,4 @@
ioctl: arg1 == 0x8008af00 || arg1 == 0x4008af00 || arg1 == 0x0000af01 || arg1 == 0x0000af02 || arg1 == 0x4008af03 || arg1 == 0x4008af04 || arg1 == 0x4004af07 || arg1 == 0x4008af10 || arg1 == 0x4028af11 || arg1 == 0x4008af12 || arg1 == 0xc008af12 || arg1 == 0x4008af20 || arg1 == 0x4008af21 || arg1 == 0x4008af22 || arg1 == 0x4008af60 || arg1 == 0x4004af61
open: return ENOENT
openat: return ENOENT
+prctl: arg0 == PR_SET_NAME
diff --git a/seccomp/arm/video_device.policy b/seccomp/arm/video_device.policy
index 784cc7cd4..5c5a4a5a1 100644
--- a/seccomp/arm/video_device.policy
+++ b/seccomp/arm/video_device.policy
@@ -6,12 +6,12 @@
# Syscalls specific to video devices.
clock_getres: 1
-clock_gettime: 1
+clock_getres_time64: 1
connect: 1
-fcntl64: arg1 == F_GETFL || arg1 == F_SETFL || arg1 == F_DUPFD_CLOEXEC || arg1 == F_GETFD || arg1 == F_SETFD
getegid32: 1
geteuid32: 1
getgid32: 1
+getrandom: 1
getresgid32: 1
getresuid32: 1
getsockname: 1
@@ -24,3 +24,4 @@ send: 1
setpriority: 1
socket: arg0 == AF_UNIX
stat64: 1
+prctl: arg0 == PR_SET_NAME
diff --git a/seccomp/arm/vios_audio_device.policy b/seccomp/arm/vios_audio_device.policy
index ad27b0e36..3a1fb0811 100644
--- a/seccomp/arm/vios_audio_device.policy
+++ b/seccomp/arm/vios_audio_device.policy
@@ -4,11 +4,10 @@
@include /usr/share/policy/crosvm/common_device.policy
-clock_gettime: 1
-clock_nanosleep: 1
lseek: 1
open: return ENOENT
openat: return ENOENT
prlimit64: 1
sched_setscheduler: 1
setrlimit: 1
+prctl: arg0 == PR_SET_NAME
diff --git a/seccomp/arm/wl_device.policy b/seccomp/arm/wl_device.policy
index 0b84c4b42..0a3de4f7f 100644
--- a/seccomp/arm/wl_device.policy
+++ b/seccomp/arm/wl_device.policy
@@ -17,5 +17,5 @@ memfd_create: arg1 == 3
ftruncate64: 1
# Used to determine shm size after recvmsg with fd
_llseek: 1
-# Allow F_GETFL only
-fcntl64: arg1 == 3
+
+prctl: arg0 == PR_SET_NAME
diff --git a/seccomp/arm/xhci.policy b/seccomp/arm/xhci.policy
index 6c51ddf8d..ca1a73dfc 100644
--- a/seccomp/arm/xhci.policy
+++ b/seccomp/arm/xhci.policy
@@ -5,20 +5,17 @@
@include /usr/share/policy/crosvm/common_device.policy
stat64: 1
-fcntl64: 1
lstat64: 1
readlink: 1
readlinkat: 1
getdents64: 1
name_to_handle_at: 1
access: 1
-clock_gettime: 1
timerfd_create: 1
getsockname: 1
pipe: 1
setsockopt: 1
bind: 1
-fcntl: 1
socket: arg0 == AF_NETLINK
stat: 1
statx: 1
@@ -44,3 +41,4 @@ getdents: 1
_llseek: 1
open: return ENOENT
openat: 1
+prctl: arg0 == PR_SET_NAME
diff --git a/seccomp/x86_64/9p_device.policy b/seccomp/x86_64/9p_device.policy
index 6f14c0af6..48eb0415d 100644
--- a/seccomp/x86_64/9p_device.policy
+++ b/seccomp/x86_64/9p_device.policy
@@ -7,7 +7,6 @@ openat: 1
@include /usr/share/policy/crosvm/common_device.policy
-fcntl: 1
pwrite64: 1
stat: 1
statx: 1
@@ -25,6 +24,8 @@ fsync: 1
fdatasync: 1
utimensat: 1
ftruncate: 1
-fchown: arg1 == 0xffffffff && arg2 == 0xffffffff
+fchmod: 1
+fchown: 1
fstatfs: 1
newfstatat: 1
+prctl: arg0 == PR_SET_NAME
diff --git a/seccomp/x86_64/balloon_device.policy b/seccomp/x86_64/balloon_device.policy
index 49cf785ab..f717ad476 100644
--- a/seccomp/x86_64/balloon_device.policy
+++ b/seccomp/x86_64/balloon_device.policy
@@ -4,6 +4,6 @@
@include /usr/share/policy/crosvm/common_device.policy
-fcntl: 1
open: return ENOENT
openat: return ENOENT
+prctl: arg0 == PR_SET_NAME
diff --git a/seccomp/x86_64/battery.policy b/seccomp/x86_64/battery.policy
index f6bbe2889..ce6d41271 100644
--- a/seccomp/x86_64/battery.policy
+++ b/seccomp/x86_64/battery.policy
@@ -7,7 +7,6 @@
# Syscalls used by power_monitor's powerd implementation.
clock_getres: 1
connect: 1
-fcntl: 1
getcwd: 1
getegid: 1
geteuid: 1
@@ -19,3 +18,4 @@ openat: 1
readlink: 1
socket: arg0 == AF_UNIX
tgkill: 1
+prctl: arg0 == PR_SET_NAME
diff --git a/seccomp/x86_64/block_device.policy b/seccomp/x86_64/block_device.policy
index f1130d90d..8f68c9be5 100644
--- a/seccomp/x86_64/block_device.policy
+++ b/seccomp/x86_64/block_device.policy
@@ -5,7 +5,6 @@
@include /usr/share/policy/crosvm/common_device.policy
fallocate: 1
-fcntl: 1
fdatasync: 1
fstat: 1
fsync: 1
@@ -21,3 +20,4 @@ statx: 1
timerfd_create: 1
timerfd_gettime: 1
timerfd_settime: 1
+prctl: arg0 == PR_SET_NAME
diff --git a/seccomp/x86_64/common_device.policy b/seccomp/x86_64/common_device.policy
index bf8dd1581..49d452051 100644
--- a/seccomp/x86_64/common_device.policy
+++ b/seccomp/x86_64/common_device.policy
@@ -27,10 +27,10 @@ mprotect: arg2 in ~PROT_EXEC
mremap: 1
munmap: 1
nanosleep: 1
+clock_nanosleep: 1
pipe2: 1
poll: 1
ppoll: 1
-prctl: arg0 == PR_SET_NAME
read: 1
readv: 1
recvfrom: 1
@@ -46,3 +46,4 @@ set_robust_list: 1
sigaltstack: 1
write: 1
writev: 1
+fcntl: 1
diff --git a/seccomp/x86_64/cras_audio_device.policy b/seccomp/x86_64/cras_audio_device.policy
index 505208b15..bbaffb080 100644
--- a/seccomp/x86_64/cras_audio_device.policy
+++ b/seccomp/x86_64/cras_audio_device.policy
@@ -12,3 +12,4 @@ setrlimit: 1
sched_setscheduler: 1
socketpair: arg0 == AF_UNIX
clock_gettime: 1
+prctl: arg0 == PR_SET_NAME
diff --git a/seccomp/x86_64/fs_device.policy b/seccomp/x86_64/fs_device.policy
index dea28aef1..bd03307a2 100644
--- a/seccomp/x86_64/fs_device.policy
+++ b/seccomp/x86_64/fs_device.policy
@@ -50,4 +50,7 @@ symlinkat: 1
statx: 1
umask: 1
unlinkat: 1
-utimensat: 1 \ No newline at end of file
+utimensat: 1
+prctl: arg0 == PR_SET_NAME || arg0 == PR_SET_SECUREBITS || arg0 == PR_GET_SECUREBITS
+capget: 1
+capset: 1
diff --git a/seccomp/x86_64/gpu_device.policy b/seccomp/x86_64/gpu_device.policy
index 7f167d728..2b4f4b2bb 100644
--- a/seccomp/x86_64/gpu_device.policy
+++ b/seccomp/x86_64/gpu_device.policy
@@ -25,6 +25,7 @@ madvise: arg2 == MADV_DONTNEED || arg2 == MADV_DONTDUMP || arg2 == MADV_REMOVE
mremap: 1
munmap: 1
nanosleep: 1
+clock_nanosleep: 1
pipe2: 1
poll: 1
ppoll: 1
@@ -53,10 +54,12 @@ fstat: 1
# Used to set of size new memfd.
ftruncate: 1
getdents: 1
+getdents64: 1
geteuid: 1
getrandom: 1
getuid: 1
-ioctl: arg1 == FIONBIO || arg1 == FIOCLEX || arg1 == 0x40086200 || arg1 & 0x6400
+# 0x40086200 = DMA_BUF_IOCTL_SYNC, 0x6400 == DRM_IOCTL_BASE, 0x40087543 == UDMABUF_CREATE_LIST
+ioctl: arg1 == FIONBIO || arg1 == FIOCLEX || arg1 == 0x40086200 || arg1 & 0x6400 || arg1 == 0x40087543
lseek: 1
lstat: 1
# Used for sharing memory with wayland. Also internally by Intel anv.
diff --git a/seccomp/x86_64/input_device.policy b/seccomp/x86_64/input_device.policy
index d32c31222..bb6985315 100644
--- a/seccomp/x86_64/input_device.policy
+++ b/seccomp/x86_64/input_device.policy
@@ -5,7 +5,7 @@
@include /usr/share/policy/crosvm/common_device.policy
ioctl: 1
-fcntl: 1
getsockname: 1
open: return ENOENT
openat: return ENOENT
+prctl: arg0 == PR_SET_NAME
diff --git a/seccomp/x86_64/net_device.policy b/seccomp/x86_64/net_device.policy
index 5d6535a94..b8c9d41ca 100644
--- a/seccomp/x86_64/net_device.policy
+++ b/seccomp/x86_64/net_device.policy
@@ -8,3 +8,4 @@
ioctl: arg1 == 0x400454d0
open: return ENOENT
openat: return ENOENT
+prctl: arg0 == PR_SET_NAME
diff --git a/seccomp/x86_64/null_audio_device.policy b/seccomp/x86_64/null_audio_device.policy
index f118d88de..5c360c945 100644
--- a/seccomp/x86_64/null_audio_device.policy
+++ b/seccomp/x86_64/null_audio_device.policy
@@ -9,3 +9,5 @@ open: return ENOENT
openat: return ENOENT
prlimit64: 1
setrlimit: 1
+sched_setscheduler: 1
+prctl: arg0 == PR_SET_NAME
diff --git a/seccomp/x86_64/pmem_device.policy b/seccomp/x86_64/pmem_device.policy
index 12a3b04f8..e7321f747 100644
--- a/seccomp/x86_64/pmem_device.policy
+++ b/seccomp/x86_64/pmem_device.policy
@@ -8,3 +8,4 @@ fdatasync: 1
fsync: 1
open: return ENOENT
openat: return ENOENT
+prctl: arg0 == PR_SET_NAME
diff --git a/seccomp/x86_64/rng_device.policy b/seccomp/x86_64/rng_device.policy
index c6681634f..f717ad476 100644
--- a/seccomp/x86_64/rng_device.policy
+++ b/seccomp/x86_64/rng_device.policy
@@ -6,3 +6,4 @@
open: return ENOENT
openat: return ENOENT
+prctl: arg0 == PR_SET_NAME
diff --git a/seccomp/x86_64/serial.policy b/seccomp/x86_64/serial.policy
index f0456e931..1d8140d44 100644
--- a/seccomp/x86_64/serial.policy
+++ b/seccomp/x86_64/serial.policy
@@ -8,3 +8,4 @@ connect: 1
bind: 1
open: return ENOENT
openat: return ENOENT
+prctl: arg0 == PR_SET_NAME
diff --git a/seccomp/x86_64/tpm_device.policy b/seccomp/x86_64/tpm_device.policy
index 50536f8aa..bfd64a838 100644
--- a/seccomp/x86_64/tpm_device.policy
+++ b/seccomp/x86_64/tpm_device.policy
@@ -25,6 +25,7 @@ mprotect: arg2 in ~PROT_EXEC
mremap: 1
munmap: 1
nanosleep: 1
+clock_nanosleep: 1
pipe2: 1
poll: 1
ppoll: 1
diff --git a/seccomp/x86_64/vfio_device.policy b/seccomp/x86_64/vfio_device.policy
index aa28d1ad4..bf7a00d12 100644
--- a/seccomp/x86_64/vfio_device.policy
+++ b/seccomp/x86_64/vfio_device.policy
@@ -10,3 +10,4 @@ openat: return ENOENT
readlink: 1
pread64: 1
pwrite64: 1
+prctl: arg0 == PR_SET_NAME
diff --git a/seccomp/x86_64/vhost_net_device.policy b/seccomp/x86_64/vhost_net_device.policy
index c9182e6ec..55a3d188d 100644
--- a/seccomp/x86_64/vhost_net_device.policy
+++ b/seccomp/x86_64/vhost_net_device.policy
@@ -23,3 +23,4 @@
ioctl: arg1 == 0x8008af00 || arg1 == 0x4008af00 || arg1 == 0x0000af01 || arg1 == 0x0000af02 || arg1 == 0x4008af03 || arg1 == 0x4008af04 || arg1 == 0x4004af07 || arg1 == 0x4008af10 || arg1 == 0x4028af11 || arg1 == 0x4008af12 || arg1 == 0xc008af12 || arg1 == 0x4008af20 || arg1 == 0x4008af21 || arg1 == 0x4008af22 || arg1 == 0x4008af30
open: return ENOENT
openat: return ENOENT
+prctl: arg0 == PR_SET_NAME
diff --git a/seccomp/x86_64/vhost_vsock_device.policy b/seccomp/x86_64/vhost_vsock_device.policy
index 69fca47fb..e558c4e79 100644
--- a/seccomp/x86_64/vhost_vsock_device.policy
+++ b/seccomp/x86_64/vhost_vsock_device.policy
@@ -25,3 +25,4 @@ ioctl: arg1 == 0x8008af00 || arg1 == 0x4008af00 || arg1 == 0x0000af01 || arg1 ==
connect: 1
open: return ENOENT
openat: return ENOENT
+prctl: arg0 == PR_SET_NAME
diff --git a/seccomp/x86_64/video_device.policy b/seccomp/x86_64/video_device.policy
index e43900ae1..4c54d9d17 100644
--- a/seccomp/x86_64/video_device.policy
+++ b/seccomp/x86_64/video_device.policy
@@ -7,8 +7,8 @@
# Syscalls specific to video devices.
clock_getres: 1
connect: 1
-fcntl: arg1 == F_GETFL || arg1 == F_SETFL || arg1 == F_DUPFD_CLOEXEC || arg1 == F_GETFD || arg1 == F_SETFD
getdents: 1
+getdents64: 1
getegid: 1
geteuid: 1
getgid: 1
@@ -38,3 +38,5 @@ uname: 1
# Required by mesa on AMD GPU
sysinfo: 1
+
+prctl: arg0 == PR_SET_NAME
diff --git a/seccomp/x86_64/vios_audio_device.policy b/seccomp/x86_64/vios_audio_device.policy
index ad27b0e36..a3b7f1961 100644
--- a/seccomp/x86_64/vios_audio_device.policy
+++ b/seccomp/x86_64/vios_audio_device.policy
@@ -5,10 +5,10 @@
@include /usr/share/policy/crosvm/common_device.policy
clock_gettime: 1
-clock_nanosleep: 1
lseek: 1
open: return ENOENT
openat: return ENOENT
prlimit64: 1
sched_setscheduler: 1
setrlimit: 1
+prctl: arg0 == PR_SET_NAME
diff --git a/seccomp/x86_64/wl_device.policy b/seccomp/x86_64/wl_device.policy
index f79b08aab..f2cda7f24 100644
--- a/seccomp/x86_64/wl_device.policy
+++ b/seccomp/x86_64/wl_device.policy
@@ -15,7 +15,6 @@ memfd_create: arg1 == 3
ftruncate: 1
# Used to determine shm size after recvmsg with fd
lseek: 1
-# Allow F_GETFL only
-fcntl: arg1 == 3
open: return ENOENT
openat: return ENOENT
+prctl: arg0 == PR_SET_NAME
diff --git a/seccomp/x86_64/xhci.policy b/seccomp/x86_64/xhci.policy
index a548d9ea0..9ef376698 100644
--- a/seccomp/x86_64/xhci.policy
+++ b/seccomp/x86_64/xhci.policy
@@ -14,7 +14,6 @@ getsockname: 1
pipe: 1
setsockopt: 1
bind: 1
-fcntl: 1
open: return ENOENT
openat: 1
socket: arg0 == AF_NETLINK
@@ -39,4 +38,6 @@ ioctl: arg1 == 0xc0185500 || arg1 == 0x41045508 || arg1 == 0x8004550f || arg1 ==
fstat: 1
getrandom: 1
getdents: 1
+getdents64: 1
lseek: 1
+prctl: arg0 == PR_SET_NAME
diff --git a/src/argument.rs b/src/argument.rs
index 525506df7..64be18227 100644
--- a/src/argument.rs
+++ b/src/argument.rs
@@ -364,7 +364,7 @@ where
/// Prints command line usage information to stdout.
///
/// Usage information is printed according to the help fields in `args` with a leading usage line.
-/// The usage line is of the format "`program_name` [ARGUMENTS] `required_arg`".
+/// The usage line is of the format "`program_name` \[ARGUMENTS\] `required_arg`".
pub fn print_help(program_name: &str, required_arg: &str, args: &[Argument]) {
println!(
"Usage: {} {}{}\n",
diff --git a/src/crosvm.rs b/src/crosvm.rs
index 1a44031bf..7a4318ab3 100644
--- a/src/crosvm.rs
+++ b/src/crosvm.rs
@@ -29,6 +29,9 @@ use devices::ProtectionType;
use libc::{getegid, geteuid};
use vm_control::BatteryType;
+static KVM_PATH: &str = "/dev/kvm";
+static VHOST_VSOCK_PATH: &str = "/dev/vhost-vsock";
+static VHOST_NET_PATH: &str = "/dev/vhost-net";
static SECCOMP_POLICY_DIR: &str = "/usr/share/policy/crosvm";
/// Indicates the location and kind of executable kernel for a VM.
@@ -55,6 +58,15 @@ pub struct DiskOption {
pub id: Option<[u8; DISK_ID_LEN]>,
}
+pub struct VhostUserOption {
+ pub socket: PathBuf,
+}
+
+pub struct VhostUserFsOption {
+ pub socket: PathBuf,
+ pub tag: String,
+}
+
/// A bind mount for directories in the plugin process.
pub struct BindMount {
pub src: PathBuf,
@@ -69,6 +81,13 @@ pub struct GidMap {
pub count: u32,
}
+/// Direct IO forwarding options
+#[cfg(feature = "direct")]
+pub struct DirectIoOption {
+ pub path: PathBuf,
+ pub ranges: Vec<(u64, u64)>,
+}
+
pub const DEFAULT_TOUCH_DEVICE_HEIGHT: u32 = 1024;
pub const DEFAULT_TOUCH_DEVICE_WIDTH: u32 = 1280;
@@ -174,11 +193,16 @@ impl Default for SharedDir {
/// Aggregate of all configurable options for a running VM.
pub struct Config {
+ pub kvm_device_path: PathBuf,
+ pub vhost_vsock_device_path: PathBuf,
+ pub vhost_net_device_path: PathBuf,
pub vcpu_count: Option<usize>,
pub rt_cpus: Vec<usize>,
pub vcpu_affinity: Option<VcpuAffinity>,
pub no_smt: bool,
pub memory: Option<u64>,
+ pub hugepages: bool,
+ pub memory_file: Option<PathBuf>,
pub executable_path: Option<Executable>,
pub android_fstab: Option<PathBuf>,
pub initrd_path: Option<PathBuf>,
@@ -230,16 +254,31 @@ pub struct Config {
#[cfg(all(target_arch = "x86_64", feature = "gdb"))]
pub gdb: Option<u32>,
pub balloon_bias: i64,
+ pub vhost_user_blk: Vec<VhostUserOption>,
+ pub vhost_user_fs: Vec<VhostUserFsOption>,
+ pub vhost_user_net: Vec<VhostUserOption>,
+ #[cfg(feature = "direct")]
+ pub direct_pmio: Option<DirectIoOption>,
+ #[cfg(feature = "direct")]
+ pub direct_level_irq: Vec<u32>,
+ #[cfg(feature = "direct")]
+ pub direct_edge_irq: Vec<u32>,
+ pub dmi_path: Option<PathBuf>,
}
impl Default for Config {
fn default() -> Config {
Config {
+ kvm_device_path: PathBuf::from(KVM_PATH),
+ vhost_vsock_device_path: PathBuf::from(VHOST_VSOCK_PATH),
+ vhost_net_device_path: PathBuf::from(VHOST_NET_PATH),
vcpu_count: None,
rt_cpus: Vec::new(),
vcpu_affinity: None,
no_smt: false,
memory: None,
+ hugepages: false,
+ memory_file: None,
executable_path: None,
android_fstab: None,
initrd_path: None,
@@ -291,6 +330,16 @@ impl Default for Config {
#[cfg(all(target_arch = "x86_64", feature = "gdb"))]
gdb: None,
balloon_bias: 0,
+ vhost_user_blk: Vec::new(),
+ vhost_user_fs: Vec::new(),
+ vhost_user_net: Vec::new(),
+ #[cfg(feature = "direct")]
+ direct_pmio: None,
+ #[cfg(feature = "direct")]
+ direct_level_irq: Vec::new(),
+ #[cfg(feature = "direct")]
+ direct_edge_irq: Vec::new(),
+ dmi_path: None,
}
}
}
diff --git a/src/gdb.rs b/src/gdb.rs
index 65626f594..6a18f5564 100644
--- a/src/gdb.rs
+++ b/src/gdb.rs
@@ -6,12 +6,11 @@ use std::net::TcpListener;
use std::sync::mpsc;
use std::time::Duration;
-use base::{error, info};
-use msg_socket::{MsgReceiver, MsgSender};
+use base::{error, info, Tube, TubeError};
+
use sync::Mutex;
use vm_control::{
- VcpuControl, VcpuDebug, VcpuDebugStatus, VcpuDebugStatusMessage, VmControlRequestSocket,
- VmRequest, VmResponse,
+ VcpuControl, VcpuDebug, VcpuDebugStatus, VcpuDebugStatusMessage, VmRequest, VmResponse,
};
use vm_memory::GuestAddress;
@@ -82,15 +81,15 @@ enum Error {
VcpuResponse(mpsc::RecvTimeoutError),
/// Failed to send a VM request.
#[error("failed to send a VM request: {0}")]
- VmRequest(msg_socket::MsgError),
+ VmRequest(TubeError),
/// Failed to receive a VM request.
#[error("failed to receive a VM response: {0}")]
- VmResponse(msg_socket::MsgError),
+ VmResponse(TubeError),
}
type GdbResult<T> = std::result::Result<T, Error>;
pub struct GdbStub {
- vm_socket: Mutex<VmControlRequestSocket>,
+ vm_tube: Mutex<Tube>,
vcpu_com: Vec<mpsc::Sender<VcpuControl>>,
from_vcpu: mpsc::Receiver<VcpuDebugStatusMessage>,
@@ -99,12 +98,12 @@ pub struct GdbStub {
impl GdbStub {
pub fn new(
- vm_socket: VmControlRequestSocket,
+ vm_tube: Tube,
vcpu_com: Vec<mpsc::Sender<VcpuControl>>,
from_vcpu: mpsc::Receiver<VcpuDebugStatusMessage>,
) -> Self {
GdbStub {
- vm_socket: Mutex::new(vm_socket),
+ vm_tube: Mutex::new(vm_tube),
vcpu_com,
from_vcpu,
hw_breakpoints: Default::default(),
@@ -122,9 +121,9 @@ impl GdbStub {
}
fn vm_request(&self, request: VmRequest) -> GdbResult<()> {
- let vm_socket = self.vm_socket.lock();
- vm_socket.send(&request).map_err(Error::VmRequest)?;
- match vm_socket.recv() {
+ let vm_tube = self.vm_tube.lock();
+ vm_tube.send(&request).map_err(Error::VmRequest)?;
+ match vm_tube.recv() {
Ok(VmResponse::Ok) => Ok(()),
Ok(r) => Err(Error::UnexpectedVmResponse(r)),
Err(e) => Err(Error::VmResponse(e)),
diff --git a/src/linux.rs b/src/linux.rs
index b89ce6df5..5efa2bb1f 100644
--- a/src/linux.rs
+++ b/src/linux.rs
@@ -32,7 +32,11 @@ use libc::{self, c_int, gid_t, uid_t};
use acpi_tables::sdt::SDT;
-use base::net::{UnixSeqpacket, UnixSeqpacketListener, UnlinkUnixSeqpacketListener};
+use base::net::{UnixSeqpacketListener, UnlinkUnixSeqpacketListener};
+use base::*;
+use devices::virtio::vhost::user::{
+ Block as VhostUserBlock, Error as VhostUserError, Fs as VhostUserFs, Net as VhostUserNet,
+};
#[cfg(feature = "gpu")]
use devices::virtio::EventDevice;
use devices::virtio::{self, Console, VirtioDevice};
@@ -45,38 +49,20 @@ use devices::{
use hypervisor::kvm::{Kvm, KvmVcpu, KvmVm};
use hypervisor::{HypervisorCap, Vcpu, VcpuExit, VcpuRunHandle, Vm, VmCap};
use minijail::{self, Minijail};
-use msg_socket::{MsgError, MsgReceiver, MsgSender, MsgSocket};
use net_util::{Error as NetError, MacAddress, Tap};
use remain::sorted;
use resources::{Alloc, MmioType, SystemAllocator};
use rutabaga_gfx::RutabagaGralloc;
use sync::Mutex;
-
-use base::{
- self, block_signal, clear_signal, drop_capabilities, error, flock, get_blocked_signals,
- get_group_id, get_user_id, getegid, geteuid, info, register_rt_signal_handler,
- set_cpu_affinity, set_rt_prio_limit, set_rt_round_robin, signal, validate_raw_descriptor, warn,
- AsRawDescriptor, Event, EventType, ExternalMapping, FlockOperation, FromRawDescriptor,
- Killable, MemoryMappingArena, PollToken, Protection, RawDescriptor, ScopedEvent, SignalFd,
- Terminal, Timer, WaitContext, SIGRTMIN,
-};
-use vm_control::{
- BalloonControlCommand, BalloonControlRequestSocket, BalloonControlResponseSocket,
- BalloonControlResult, BalloonStats, DiskControlCommand, DiskControlRequestSocket,
- DiskControlResponseSocket, DiskControlResult, FsMappingRequest, FsMappingRequestSocket,
- FsMappingResponseSocket, IrqSetup, UsbControlSocket, VcpuControl, VmControlResponseSocket,
- VmIrqRequest, VmIrqRequestSocket, VmIrqResponse, VmIrqResponseSocket,
- VmMemoryControlRequestSocket, VmMemoryControlResponseSocket, VmMemoryRequest, VmMemoryResponse,
- VmMsyncRequest, VmMsyncRequestSocket, VmMsyncResponse, VmMsyncResponseSocket, VmResponse,
- VmRunMode,
-};
-#[cfg(all(target_arch = "x86_64", feature = "gdb"))]
-use vm_control::{VcpuDebug, VcpuDebugStatus, VcpuDebugStatusMessage, VmRequest};
-use vm_memory::{GuestAddress, GuestMemory};
+use vm_control::*;
+use vm_memory::{GuestAddress, GuestMemory, MemoryPolicy};
#[cfg(all(target_arch = "x86_64", feature = "gdb"))]
use crate::gdb::{gdb_thread, GdbStub};
-use crate::{Config, DiskOption, Executable, SharedDir, SharedDirKind, TouchDeviceOption};
+use crate::{
+ Config, DiskOption, Executable, SharedDir, SharedDirKind, TouchDeviceOption, VhostUserFsOption,
+ VhostUserOption,
+};
use arch::{
self, LinuxArch, RunnableLinuxVm, SerialHardware, SerialParameters, VcpuAffinity,
VirtioDeviceStub, VmComponents, VmImage,
@@ -115,20 +101,28 @@ pub enum Error {
#[cfg(feature = "audio")]
CreateAc97(devices::PciDeviceError),
CreateConsole(arch::serial::Error),
+ CreateControlServer(io::Error),
CreateDiskError(disk::Error),
CreateEvent(base::Error),
CreateGrallocError(rutabaga_gfx::RutabagaError),
+ CreateKvm(base::Error),
CreateSignalFd(base::SignalFdError),
CreateSocket(io::Error),
CreateTapDevice(NetError),
CreateTimer(base::Error),
CreateTpmStorage(PathBuf, io::Error),
+ CreateTube(TubeError),
CreateUsbProvider(devices::usb::host_backend::error::Error),
CreateVcpu(base::Error),
CreateVfioDevice(devices::vfio::VfioError),
+ CreateVm(base::Error),
CreateWaitContext(base::Error),
DeviceJail(minijail::Error),
DevicePivotRoot(minijail::Error),
+ #[cfg(feature = "direct")]
+ DirectIo(io::Error),
+ #[cfg(feature = "direct")]
+ DirectIrq(devices::DirectIrqError),
Disk(PathBuf, io::Error),
DiskImageLock(base::Error),
DropCapabilities(base::Error),
@@ -139,6 +133,7 @@ pub enum Error {
GuestCachedTooLarge(std::num::TryFromIntError),
GuestFreeMissing(),
GuestFreeTooLarge(std::num::TryFromIntError),
+ GuestMemoryLayout(<Arch as LinuxArch>::Error),
#[cfg(all(target_arch = "x86_64", feature = "gdb"))]
HandleDebugCommand(<Arch as LinuxArch>::Error),
InputDeviceNew(virtio::InputError),
@@ -189,6 +184,10 @@ pub enum Error {
Timer(base::Error),
ValidateRawDescriptor(base::Error),
VhostNetDeviceNew(virtio::vhost::Error),
+ VhostUserBlockDeviceNew(VhostUserError),
+ VhostUserFsDeviceNew(VhostUserError),
+ VhostUserNetDeviceNew(VhostUserError),
+ VhostUserNetWithNetArgs,
VhostVsockDeviceNew(virtio::vhost::Error),
VirtioPciDev(base::Error),
WaitContextAdd(base::Error),
@@ -222,9 +221,11 @@ impl Display for Error {
#[cfg(feature = "audio")]
CreateAc97(e) => write!(f, "failed to create ac97 device: {}", e),
CreateConsole(e) => write!(f, "failed to create console device: {}", e),
+ CreateControlServer(e) => write!(f, "failed to create control server: {}", e),
CreateDiskError(e) => write!(f, "failed to create virtual disk: {}", e),
CreateEvent(e) => write!(f, "failed to create event: {}", e),
CreateGrallocError(e) => write!(f, "failed to create gralloc: {}", e),
+ CreateKvm(e) => write!(f, "failed to create kvm: {}", e),
CreateSignalFd(e) => write!(f, "failed to create signalfd: {}", e),
CreateSocket(e) => write!(f, "failed to create socket: {}", e),
CreateTapDevice(e) => write!(f, "failed to create tap device: {}", e),
@@ -232,12 +233,18 @@ impl Display for Error {
CreateTpmStorage(p, e) => {
write!(f, "failed to create tpm storage dir {}: {}", p.display(), e)
}
+ CreateTube(e) => write!(f, "failed to create tube: {}", e),
CreateUsbProvider(e) => write!(f, "failed to create usb provider: {}", e),
CreateVcpu(e) => write!(f, "failed to create vcpu: {}", e),
CreateVfioDevice(e) => write!(f, "Failed to create vfio device {}", e),
+ CreateVm(e) => write!(f, "failed to create vm: {}", e),
CreateWaitContext(e) => write!(f, "failed to create wait context: {}", e),
DeviceJail(e) => write!(f, "failed to jail device: {}", e),
DevicePivotRoot(e) => write!(f, "failed to pivot root device: {}", e),
+ #[cfg(feature = "direct")]
+ DirectIo(e) => write!(f, "failed to open direct io device: {}", e),
+ #[cfg(feature = "direct")]
+ DirectIrq(e) => write!(f, "failed to enable interrupt forwarding: {}", e),
Disk(p, e) => write!(f, "failed to load disk image {}: {}", p.display(), e),
DiskImageLock(e) => write!(f, "failed to lock disk image: {}", e),
DropCapabilities(e) => write!(f, "failed to drop process capabilities: {}", e),
@@ -248,6 +255,7 @@ impl Display for Error {
GuestCachedTooLarge(e) => write!(f, "guest cached is too large: {}", e),
GuestFreeMissing() => write!(f, "guest free is missing from balloon stats"),
GuestFreeTooLarge(e) => write!(f, "guest free is too large: {}", e),
+ GuestMemoryLayout(e) => write!(f, "failed to create guest memory layout: {}", e),
#[cfg(all(target_arch = "x86_64", feature = "gdb"))]
HandleDebugCommand(e) => write!(f, "failed to handle a gdb command: {}", e),
InputDeviceNew(e) => write!(f, "failed to set up input device: {}", e),
@@ -309,6 +317,15 @@ impl Display for Error {
Timer(e) => write!(f, "failed to read timer fd: {}", e),
ValidateRawDescriptor(e) => write!(f, "failed to validate raw descriptor: {}", e),
VhostNetDeviceNew(e) => write!(f, "failed to set up vhost networking: {}", e),
+ VhostUserBlockDeviceNew(e) => {
+ write!(f, "failed to set up vhost-user block device: {}", e)
+ }
+ VhostUserFsDeviceNew(e) => write!(f, "failed to set up vhost-user fs device: {}", e),
+ VhostUserNetDeviceNew(e) => write!(f, "failed to set up vhost-user net device: {}", e),
+ VhostUserNetWithNetArgs => write!(
+ f,
+ "vhost-user-net cannot be used with any of --host_ip, --netmask or --mac"
+ ),
VhostVsockDeviceNew(e) => write!(f, "failed to set up virtual socket device: {}", e),
VirtioPciDev(e) => write!(f, "failed to create virtio pci dev: {}", e),
WaitContextAdd(e) => write!(f, "failed to add descriptor to wait context: {}", e),
@@ -330,28 +347,24 @@ impl std::error::Error for Error {}
type Result<T> = std::result::Result<T, Error>;
-enum TaggedControlSocket {
- Fs(FsMappingResponseSocket),
- Vm(VmControlResponseSocket),
- VmMemory(VmMemoryControlResponseSocket),
- VmIrq(VmIrqResponseSocket),
- VmMsync(VmMsyncResponseSocket),
+enum TaggedControlTube {
+ Fs(Tube),
+ Vm(Tube),
+ VmMemory(Tube),
+ VmIrq(Tube),
+ VmMsync(Tube),
}
-impl AsRef<UnixSeqpacket> for TaggedControlSocket {
- fn as_ref(&self) -> &UnixSeqpacket {
- use self::TaggedControlSocket::*;
+impl AsRef<Tube> for TaggedControlTube {
+ fn as_ref(&self) -> &Tube {
+ use self::TaggedControlTube::*;
match &self {
- Fs(ref socket) => socket.as_ref(),
- Vm(ref socket) => socket.as_ref(),
- VmMemory(ref socket) => socket.as_ref(),
- VmIrq(ref socket) => socket.as_ref(),
- VmMsync(ref socket) => socket.as_ref(),
+ Fs(tube) | Vm(tube) | VmMemory(tube) | VmIrq(tube) | VmMsync(tube) => tube,
}
}
}
-impl AsRawDescriptor for TaggedControlSocket {
+impl AsRawDescriptor for TaggedControlTube {
fn as_raw_descriptor(&self) -> RawDescriptor {
self.as_ref().as_raw_descriptor()
}
@@ -477,11 +490,7 @@ fn simple_jail(cfg: &Config, policy: &str) -> Result<Option<Minijail>> {
type DeviceResult<T = VirtioDeviceStub> = std::result::Result<T, Error>;
-fn create_block_device(
- cfg: &Config,
- disk: &DiskOption,
- disk_device_socket: DiskControlResponseSocket,
-) -> DeviceResult {
+fn create_block_device(cfg: &Config, disk: &DiskOption, disk_device_tube: Tube) -> DeviceResult {
// Special case '/proc/self/fd/*' paths. The FD is already open, just use it.
let raw_image: File = if disk.path.parent() == Some(Path::new("/proc/self/fd")) {
// Safe because we will validate |raw_fd|.
@@ -510,7 +519,8 @@ fn create_block_device(
disk.read_only,
disk.sparse,
disk.block_size,
- Some(disk_device_socket),
+ disk.id,
+ Some(disk_device_tube),
)
.map_err(Error::BlockDeviceNew)?,
) as Box<dyn VirtioDevice>
@@ -524,7 +534,7 @@ fn create_block_device(
disk.sparse,
disk.block_size,
disk.id,
- Some(disk_device_socket),
+ Some(disk_device_tube),
)
.map_err(Error::BlockDeviceNew)?,
) as Box<dyn VirtioDevice>
@@ -536,6 +546,32 @@ fn create_block_device(
})
}
+fn create_vhost_user_block_device(cfg: &Config, opt: &VhostUserOption) -> DeviceResult {
+ let dev = VhostUserBlock::new(virtio::base_features(cfg.protected_vm), &opt.socket)
+ .map_err(Error::VhostUserBlockDeviceNew)?;
+
+ Ok(VirtioDeviceStub {
+ dev: Box::new(dev),
+ // no sandbox here because virtqueue handling is exported to a different process.
+ jail: None,
+ })
+}
+
+fn create_vhost_user_fs_device(cfg: &Config, option: &VhostUserFsOption) -> DeviceResult {
+ let dev = VhostUserFs::new(
+ virtio::base_features(cfg.protected_vm),
+ &option.socket,
+ &option.tag,
+ )
+ .map_err(Error::VhostUserFsDeviceNew)?;
+
+ Ok(VirtioDeviceStub {
+ dev: Box::new(dev),
+ // no sandbox here because virtqueue handling is exported to a different process.
+ jail: None,
+ })
+}
+
fn create_rng_device(cfg: &Config) -> DeviceResult {
let dev =
virtio::Rng::new(virtio::base_features(cfg.protected_vm)).map_err(Error::RngDeviceNew)?;
@@ -548,7 +584,6 @@ fn create_rng_device(cfg: &Config) -> DeviceResult {
#[cfg(feature = "tpm")]
fn create_tpm_device(cfg: &Config) -> DeviceResult {
- use base::chown;
use std::ffi::CString;
use std::fs;
use std::process;
@@ -724,8 +759,8 @@ fn create_vinput_device(cfg: &Config, dev_path: &Path) -> DeviceResult {
})
}
-fn create_balloon_device(cfg: &Config, socket: BalloonControlResponseSocket) -> DeviceResult {
- let dev = virtio::Balloon::new(virtio::base_features(cfg.protected_vm), socket)
+fn create_balloon_device(cfg: &Config, tube: Tube) -> DeviceResult {
+ let dev = virtio::Balloon::new(virtio::base_features(cfg.protected_vm), tube)
.map_err(Error::BalloonDeviceNew)?;
Ok(VirtioDeviceStub {
@@ -775,6 +810,7 @@ fn create_net_device(
let features = virtio::base_features(cfg.protected_vm);
let dev = if cfg.vhost_net {
let dev = virtio::vhost::Net::<Tap, vhost::Net<Tap>>::new(
+ &cfg.vhost_net_device_path,
features,
host_ip,
netmask,
@@ -801,16 +837,28 @@ fn create_net_device(
})
}
+fn create_vhost_user_net_device(cfg: &Config, opt: &VhostUserOption) -> DeviceResult {
+ let dev = VhostUserNet::new(virtio::base_features(cfg.protected_vm), &opt.socket)
+ .map_err(Error::VhostUserNetDeviceNew)?;
+
+ Ok(VirtioDeviceStub {
+ dev: Box::new(dev),
+ // no sandbox here because virtqueue handling is exported to a different process.
+ jail: None,
+ })
+}
+
#[cfg(feature = "gpu")]
fn create_gpu_device(
cfg: &Config,
exit_evt: &Event,
- gpu_device_socket: VmMemoryControlRequestSocket,
- gpu_sockets: Vec<virtio::resource_bridge::ResourceResponseSocket>,
+ gpu_device_tube: Tube,
+ resource_bridges: Vec<Tube>,
wayland_socket_path: Option<&PathBuf>,
x_display: Option<String>,
event_devices: Vec<EventDevice>,
map_request: Arc<Mutex<Option<ExternalMapping>>>,
+ mem: &GuestMemory,
) -> DeviceResult {
let jailed_wayland_path = Path::new("/wayland-0");
@@ -832,9 +880,9 @@ fn create_gpu_device(
let dev = virtio::Gpu::new(
exit_evt.try_clone().map_err(Error::CloneEvent)?,
- Some(gpu_device_socket),
+ Some(gpu_device_tube),
NonZeroU8::new(1).unwrap(), // number of scanouts
- gpu_sockets,
+ resource_bridges,
display_backends,
cfg.gpu_parameters.as_ref().unwrap(),
event_devices,
@@ -842,6 +890,7 @@ fn create_gpu_device(
cfg.sandbox,
virtio::base_features(cfg.protected_vm),
cfg.wayland_socket_paths.clone(),
+ mem.clone(),
);
let jail = match simple_jail(&cfg, "gpu_device")? {
@@ -902,6 +951,12 @@ fn create_gpu_device(
jail.mount_bind(pvr_sync_path, pvr_sync_path, true)?;
}
+ // If the udmabuf driver exists on the host, bind mount it in.
+ let udmabuf_path = Path::new("/dev/udmabuf");
+ if udmabuf_path.exists() {
+ jail.mount_bind(udmabuf_path, udmabuf_path, true)?;
+ }
+
// Libraries that are required when mesa drivers are dynamically loaded.
let lib_dirs = &[
"/usr/lib",
@@ -957,8 +1012,8 @@ fn create_gpu_device(
fn create_wayland_device(
cfg: &Config,
- socket: VmMemoryControlRequestSocket,
- resource_bridge: Option<virtio::resource_bridge::ResourceRequestSocket>,
+ control_tube: Tube,
+ resource_bridge: Option<Tube>,
) -> DeviceResult {
let wayland_socket_dirs = cfg
.wayland_socket_paths
@@ -971,7 +1026,7 @@ fn create_wayland_device(
let dev = virtio::Wl::new(
features,
cfg.wayland_socket_paths.clone(),
- socket,
+ control_tube,
resource_bridge,
)
.map_err(Error::WaylandDeviceNew)?;
@@ -1012,7 +1067,7 @@ fn create_wayland_device(
fn create_video_device(
cfg: &Config,
typ: devices::virtio::VideoDeviceType,
- resource_bridge: virtio::resource_bridge::ResourceRequestSocket,
+ resource_bridge: Tube,
) -> DeviceResult {
let jail = match simple_jail(&cfg, "video_device")? {
Some(mut jail) => {
@@ -1075,20 +1130,18 @@ fn create_video_device(
#[cfg(any(feature = "video-decoder", feature = "video-encoder"))]
fn register_video_device(
devs: &mut Vec<VirtioDeviceStub>,
- resource_bridges: &mut Vec<virtio::resource_bridge::ResourceResponseSocket>,
+ video_tube: Tube,
cfg: &Config,
typ: devices::virtio::VideoDeviceType,
) -> std::result::Result<(), Error> {
- let (video_socket, gpu_socket) =
- virtio::resource_bridge::pair().map_err(Error::CreateSocket)?;
- resource_bridges.push(gpu_socket);
- devs.push(create_video_device(cfg, typ, video_socket)?);
+ devs.push(create_video_device(cfg, typ, video_tube)?);
Ok(())
}
fn create_vhost_vsock_device(cfg: &Config, cid: u64, mem: &GuestMemory) -> DeviceResult {
let features = virtio::base_features(cfg.protected_vm);
- let dev = virtio::vhost::Vsock::new(features, cid, mem).map_err(Error::VhostVsockDeviceNew)?;
+ let dev = virtio::vhost::Vsock::new(&cfg.vhost_vsock_device_path, features, cid, mem)
+ .map_err(Error::VhostVsockDeviceNew)?;
Ok(VirtioDeviceStub {
dev: Box::new(dev),
@@ -1103,7 +1156,7 @@ fn create_fs_device(
src: &Path,
tag: &str,
fs_cfg: virtio::fs::passthrough::Config,
- device_socket: FsMappingRequestSocket,
+ device_tube: Tube,
) -> DeviceResult {
let max_open_files = get_max_open_files()?;
let j = if cfg.sandbox {
@@ -1129,7 +1182,7 @@ fn create_fs_device(
// TODO(chirantan): Use more than one worker once the kernel driver has been fixed to not panic
// when num_queues > 1.
let dev =
- virtio::fs::Fs::new(features, tag, 1, fs_cfg, device_socket).map_err(Error::FsDeviceNew)?;
+ virtio::fs::Fs::new(features, tag, 1, fs_cfg, device_tube).map_err(Error::FsDeviceNew)?;
Ok(VirtioDeviceStub {
dev: Box::new(dev),
@@ -1186,13 +1239,19 @@ fn create_pmem_device(
resources: &mut SystemAllocator,
disk: &DiskOption,
index: usize,
- pmem_device_socket: VmMsyncRequestSocket,
+ pmem_device_tube: Tube,
) -> DeviceResult {
- let fd = OpenOptions::new()
- .read(true)
- .write(!disk.read_only)
- .open(&disk.path)
- .map_err(|e| Error::Disk(disk.path.to_path_buf(), e))?;
+ // Special case '/proc/self/fd/*' paths. The FD is already open, just use it.
+ let fd: File = if disk.path.parent() == Some(Path::new("/proc/self/fd")) {
+ // Safe because we will validate |raw_fd|.
+ unsafe { File::from_raw_descriptor(raw_descriptor_from_path(&disk.path)?) }
+ } else {
+ OpenOptions::new()
+ .read(true)
+ .write(!disk.read_only)
+ .open(&disk.path)
+ .map_err(|e| Error::Disk(disk.path.to_path_buf(), e))?
+ };
let arena_size = {
let metadata =
@@ -1259,7 +1318,7 @@ fn create_pmem_device(
GuestAddress(mapping_address),
slot,
arena_size,
- Some(pmem_device_socket),
+ Some(pmem_device_tube),
)
.map_err(Error::PmemDeviceNew)?;
@@ -1304,7 +1363,7 @@ fn create_console_device(cfg: &Config, param: &SerialParameters) -> DeviceResult
})
}
-// gpu_device_socket is not used when GPU support is disabled.
+// gpu_device_tube is not used when GPU support is disabled.
#[cfg_attr(not(feature = "gpu"), allow(unused_variables))]
fn create_virtio_devices(
cfg: &Config,
@@ -1312,13 +1371,13 @@ fn create_virtio_devices(
vm: &mut impl Vm,
resources: &mut SystemAllocator,
_exit_evt: &Event,
- wayland_device_socket: VmMemoryControlRequestSocket,
- gpu_device_socket: VmMemoryControlRequestSocket,
- balloon_device_socket: BalloonControlResponseSocket,
- disk_device_sockets: &mut Vec<DiskControlResponseSocket>,
- pmem_device_sockets: &mut Vec<VmMsyncRequestSocket>,
+ wayland_device_tube: Tube,
+ gpu_device_tube: Tube,
+ balloon_device_tube: Tube,
+ disk_device_tubes: &mut Vec<Tube>,
+ pmem_device_tubes: &mut Vec<Tube>,
map_request: Arc<Mutex<Option<ExternalMapping>>>,
- fs_device_sockets: &mut Vec<FsMappingRequestSocket>,
+ fs_device_tubes: &mut Vec<Tube>,
) -> DeviceResult<Vec<VirtioDeviceStub>> {
let mut devs = Vec::new();
@@ -1332,19 +1391,23 @@ fn create_virtio_devices(
}
for disk in &cfg.disks {
- let disk_device_socket = disk_device_sockets.remove(0);
- devs.push(create_block_device(cfg, disk, disk_device_socket)?);
+ let disk_device_tube = disk_device_tubes.remove(0);
+ devs.push(create_block_device(cfg, disk, disk_device_tube)?);
+ }
+
+ for blk in &cfg.vhost_user_blk {
+ devs.push(create_vhost_user_block_device(cfg, blk)?);
}
for (index, pmem_disk) in cfg.pmem_devices.iter().enumerate() {
- let pmem_device_socket = pmem_device_sockets.remove(0);
+ let pmem_device_tube = pmem_device_tubes.remove(0);
devs.push(create_pmem_device(
cfg,
vm,
resources,
pmem_disk,
index,
- pmem_device_socket,
+ pmem_device_tube,
)?);
}
@@ -1385,7 +1448,7 @@ fn create_virtio_devices(
devs.push(create_vinput_device(cfg, dev_path)?);
}
- devs.push(create_balloon_device(cfg, balloon_device_socket)?);
+ devs.push(create_balloon_device(cfg, balloon_device_tube)?);
// We checked above that if the IP is defined, then the netmask is, too.
for tap_fd in &cfg.tap_fd {
@@ -1395,21 +1458,27 @@ fn create_virtio_devices(
if let (Some(host_ip), Some(netmask), Some(mac_address)) =
(cfg.host_ip, cfg.netmask, cfg.mac_address)
{
+ if !cfg.vhost_user_net.is_empty() {
+ return Err(Error::VhostUserNetWithNetArgs);
+ }
devs.push(create_net_device(cfg, host_ip, netmask, mac_address, mem)?);
}
+ for net in &cfg.vhost_user_net {
+ devs.push(create_vhost_user_net_device(cfg, net)?);
+ }
+
#[cfg_attr(not(feature = "gpu"), allow(unused_mut))]
- let mut resource_bridges = Vec::<virtio::resource_bridge::ResourceResponseSocket>::new();
+ let mut resource_bridges = Vec::<Tube>::new();
if !cfg.wayland_socket_paths.is_empty() {
#[cfg_attr(not(feature = "gpu"), allow(unused_mut))]
- let mut wl_resource_bridge = None::<virtio::resource_bridge::ResourceRequestSocket>;
+ let mut wl_resource_bridge = None::<Tube>;
#[cfg(feature = "gpu")]
{
if cfg.gpu_parameters.is_some() {
- let (wl_socket, gpu_socket) =
- virtio::resource_bridge::pair().map_err(Error::CreateSocket)?;
+ let (wl_socket, gpu_socket) = Tube::pair().map_err(Error::CreateTube)?;
resource_bridges.push(gpu_socket);
wl_resource_bridge = Some(wl_socket);
}
@@ -1417,34 +1486,28 @@ fn create_virtio_devices(
devs.push(create_wayland_device(
cfg,
- wayland_device_socket,
+ wayland_device_tube,
wl_resource_bridge,
)?);
}
#[cfg(feature = "video-decoder")]
- {
- if cfg.video_dec {
- register_video_device(
- &mut devs,
- &mut resource_bridges,
- cfg,
- devices::virtio::VideoDeviceType::Decoder,
- )?;
- }
- }
+ let video_dec_tube = if cfg.video_dec {
+ let (video_tube, gpu_tube) = Tube::pair().map_err(Error::CreateTube)?;
+ resource_bridges.push(gpu_tube);
+ Some(video_tube)
+ } else {
+ None
+ };
#[cfg(feature = "video-encoder")]
- {
- if cfg.video_enc {
- register_video_device(
- &mut devs,
- &mut resource_bridges,
- cfg,
- devices::virtio::VideoDeviceType::Encoder,
- )?;
- }
- }
+ let video_enc_tube = if cfg.video_enc {
+ let (video_tube, gpu_tube) = Tube::pair().map_err(Error::CreateTube)?;
+ resource_bridges.push(gpu_tube);
+ Some(video_tube)
+ } else {
+ None
+ };
#[cfg(feature = "gpu")]
{
@@ -1488,21 +1551,50 @@ fn create_virtio_devices(
devs.push(create_gpu_device(
cfg,
_exit_evt,
- gpu_device_socket,
+ gpu_device_tube,
resource_bridges,
// Use the unnamed socket for GPU display screens.
cfg.wayland_socket_paths.get(""),
cfg.x_display.clone(),
event_devices,
map_request,
+ mem,
)?);
}
}
+ #[cfg(feature = "video-decoder")]
+ {
+ if let Some(video_dec_tube) = video_dec_tube {
+ register_video_device(
+ &mut devs,
+ video_dec_tube,
+ cfg,
+ devices::virtio::VideoDeviceType::Decoder,
+ )?;
+ }
+ }
+
+ #[cfg(feature = "video-encoder")]
+ {
+ if let Some(video_enc_tube) = video_enc_tube {
+ register_video_device(
+ &mut devs,
+ video_enc_tube,
+ cfg,
+ devices::virtio::VideoDeviceType::Encoder,
+ )?;
+ }
+ }
+
if let Some(cid) = cfg.cid {
devs.push(create_vhost_vsock_device(cfg, cid, mem)?);
}
+ for vhost_user_fs in &cfg.vhost_user_fs {
+ devs.push(create_vhost_user_fs_device(cfg, &vhost_user_fs)?);
+ }
+
for shared_dir in &cfg.shared_dirs {
let SharedDir {
src,
@@ -1516,16 +1608,8 @@ fn create_virtio_devices(
let dev = match kind {
SharedDirKind::FS => {
- let device_socket = fs_device_sockets.remove(0);
- create_fs_device(
- cfg,
- uid_map,
- gid_map,
- src,
- tag,
- fs_cfg.clone(),
- device_socket,
- )?
+ let device_tube = fs_device_tubes.remove(0);
+ create_fs_device(cfg, uid_map, gid_map, src, tag, fs_cfg.clone(), device_tube)?
}
SharedDirKind::P9 => create_9p_device(cfg, uid_map, gid_map, src, tag, p9_cfg.clone())?,
};
@@ -1541,13 +1625,13 @@ fn create_devices(
vm: &mut impl Vm,
resources: &mut SystemAllocator,
exit_evt: &Event,
- control_sockets: &mut Vec<TaggedControlSocket>,
- wayland_device_socket: VmMemoryControlRequestSocket,
- gpu_device_socket: VmMemoryControlRequestSocket,
- balloon_device_socket: BalloonControlResponseSocket,
- disk_device_sockets: &mut Vec<DiskControlResponseSocket>,
- pmem_device_sockets: &mut Vec<VmMsyncRequestSocket>,
- fs_device_sockets: &mut Vec<FsMappingRequestSocket>,
+ control_tubes: &mut Vec<TaggedControlTube>,
+ wayland_device_tube: Tube,
+ gpu_device_tube: Tube,
+ balloon_device_tube: Tube,
+ disk_device_tubes: &mut Vec<Tube>,
+ pmem_device_tubes: &mut Vec<Tube>,
+ fs_device_tubes: &mut Vec<Tube>,
usb_provider: HostBackendDeviceProvider,
map_request: Arc<Mutex<Option<ExternalMapping>>>,
) -> DeviceResult<Vec<(Box<dyn PciDevice>, Option<Minijail>)>> {
@@ -1557,22 +1641,21 @@ fn create_devices(
vm,
resources,
exit_evt,
- wayland_device_socket,
- gpu_device_socket,
- balloon_device_socket,
- disk_device_sockets,
- pmem_device_sockets,
+ wayland_device_tube,
+ gpu_device_tube,
+ balloon_device_tube,
+ disk_device_tubes,
+ pmem_device_tubes,
map_request,
- fs_device_sockets,
+ fs_device_tubes,
)?;
let mut pci_devices = Vec::new();
for stub in stubs {
- let (msi_host_socket, msi_device_socket) =
- msg_socket::pair::<VmIrqResponse, VmIrqRequest>().map_err(Error::CreateSocket)?;
- control_sockets.push(TaggedControlSocket::VmIrq(msi_host_socket));
- let dev = VirtioPciDevice::new(mem.clone(), stub.dev, msi_device_socket)
+ let (msi_host_tube, msi_device_tube) = Tube::pair().map_err(Error::CreateTube)?;
+ control_tubes.push(TaggedControlTube::VmIrq(msi_host_tube));
+ let dev = VirtioPciDevice::new(mem.clone(), stub.dev, msi_device_tube)
.map_err(Error::VirtioPciDev)?;
let dev = Box::new(dev) as Box<dyn PciDevice>;
pci_devices.push((dev, stub.jail));
@@ -1596,26 +1679,25 @@ fn create_devices(
for vfio_path in &cfg.vfio {
// create MSI, MSI-X, and Mem request sockets for each vfio device
- let (vfio_host_socket_msi, vfio_device_socket_msi) =
- msg_socket::pair::<VmIrqResponse, VmIrqRequest>().map_err(Error::CreateSocket)?;
- control_sockets.push(TaggedControlSocket::VmIrq(vfio_host_socket_msi));
+ let (vfio_host_tube_msi, vfio_device_tube_msi) =
+ Tube::pair().map_err(Error::CreateTube)?;
+ control_tubes.push(TaggedControlTube::VmIrq(vfio_host_tube_msi));
- let (vfio_host_socket_msix, vfio_device_socket_msix) =
- msg_socket::pair::<VmIrqResponse, VmIrqRequest>().map_err(Error::CreateSocket)?;
- control_sockets.push(TaggedControlSocket::VmIrq(vfio_host_socket_msix));
+ let (vfio_host_tube_msix, vfio_device_tube_msix) =
+ Tube::pair().map_err(Error::CreateTube)?;
+ control_tubes.push(TaggedControlTube::VmIrq(vfio_host_tube_msix));
- let (vfio_host_socket_mem, vfio_device_socket_mem) =
- msg_socket::pair::<VmMemoryResponse, VmMemoryRequest>()
- .map_err(Error::CreateSocket)?;
- control_sockets.push(TaggedControlSocket::VmMemory(vfio_host_socket_mem));
+ let (vfio_host_tube_mem, vfio_device_tube_mem) =
+ Tube::pair().map_err(Error::CreateTube)?;
+ control_tubes.push(TaggedControlTube::VmMemory(vfio_host_tube_mem));
let vfiodevice = VfioDevice::new(vfio_path.as_path(), vm, mem, vfio_container.clone())
.map_err(Error::CreateVfioDevice)?;
let mut vfiopcidevice = Box::new(VfioPciDevice::new(
vfiodevice,
- vfio_device_socket_msi,
- vfio_device_socket_msix,
- vfio_device_socket_mem,
+ vfio_device_tube_msi,
+ vfio_device_tube_msix,
+ vfio_device_tube_mem,
));
// early reservation for pass-through PCI devices.
if vfiopcidevice.allocate_address(resources).is_err() {
@@ -1713,7 +1795,7 @@ impl IntoUnixStream for UnixStream {
fn setup_vcpu_signal_handler<T: Vcpu>(use_hypervisor_signals: bool) -> Result<()> {
if use_hypervisor_signals {
unsafe {
- extern "C" fn handle_signal() {}
+ extern "C" fn handle_signal(_: c_int) {}
// Our signal handler does nothing and is trivially async signal safe.
register_rt_signal_handler(SIGRTMIN() + 0, handle_signal)
.map_err(Error::RegisterSignalHandler)?;
@@ -1721,7 +1803,7 @@ fn setup_vcpu_signal_handler<T: Vcpu>(use_hypervisor_signals: bool) -> Result<()
block_signal(SIGRTMIN() + 0).map_err(Error::BlockSignal)?;
} else {
unsafe {
- extern "C" fn handle_signal<T: Vcpu>() {
+ extern "C" fn handle_signal<T: Vcpu>(_: c_int) {
T::set_local_immediate_exit(true);
}
register_rt_signal_handler(SIGRTMIN() + 0, handle_signal::<T>)
@@ -1818,7 +1900,7 @@ fn handle_debug_msg<V>(
vcpu: &V,
guest_mem: &GuestMemory,
d: VcpuDebug,
- reply_channel: &mpsc::Sender<VcpuDebugStatusMessage>,
+ reply_tube: &mpsc::Sender<VcpuDebugStatusMessage>,
) -> Result<()>
where
V: VcpuArch + 'static,
@@ -1831,13 +1913,13 @@ where
Arch::debug_read_registers(vcpu as &V).map_err(Error::HandleDebugCommand)?,
),
};
- reply_channel
+ reply_tube
.send(msg)
.map_err(|e| Error::SendDebugStatus(Box::new(e)))
}
VcpuDebug::WriteRegs(regs) => {
Arch::debug_write_registers(vcpu as &V, &regs).map_err(Error::HandleDebugCommand)?;
- reply_channel
+ reply_tube
.send(VcpuDebugStatusMessage {
cpu: cpu_id as usize,
msg: VcpuDebugStatus::CommandComplete,
@@ -1852,14 +1934,14 @@ where
.unwrap_or(Vec::new()),
),
};
- reply_channel
+ reply_tube
.send(msg)
.map_err(|e| Error::SendDebugStatus(Box::new(e)))
}
VcpuDebug::WriteMem(vaddr, buf) => {
Arch::debug_write_memory(vcpu as &V, guest_mem, vaddr, &buf)
.map_err(Error::HandleDebugCommand)?;
- reply_channel
+ reply_tube
.send(VcpuDebugStatusMessage {
cpu: cpu_id as usize,
msg: VcpuDebugStatus::CommandComplete,
@@ -1868,7 +1950,7 @@ where
}
VcpuDebug::EnableSinglestep => {
Arch::debug_enable_singlestep(vcpu as &V).map_err(Error::HandleDebugCommand)?;
- reply_channel
+ reply_tube
.send(VcpuDebugStatusMessage {
cpu: cpu_id as usize,
msg: VcpuDebugStatus::CommandComplete,
@@ -1878,7 +1960,7 @@ where
VcpuDebug::SetHwBreakPoint(addrs) => {
Arch::debug_set_hw_breakpoints(vcpu as &V, &addrs)
.map_err(Error::HandleDebugCommand)?;
- reply_channel
+ reply_tube
.send(VcpuDebugStatusMessage {
cpu: cpu_id as usize,
msg: VcpuDebugStatus::CommandComplete,
@@ -1903,9 +1985,9 @@ fn run_vcpu<V>(
mmio_bus: devices::Bus,
exit_evt: Event,
requires_pvclock_ctrl: bool,
- from_main_channel: mpsc::Receiver<VcpuControl>,
+ from_main_tube: mpsc::Receiver<VcpuControl>,
use_hypervisor_signals: bool,
- #[cfg(all(target_arch = "x86_64", feature = "gdb"))] to_gdb_channel: Option<
+ #[cfg(all(target_arch = "x86_64", feature = "gdb"))] to_gdb_tube: Option<
mpsc::Sender<VcpuDebugStatusMessage>,
>,
) -> Result<JoinHandle<()>>
@@ -1946,7 +2028,7 @@ where
let mut run_mode = VmRunMode::Running;
#[cfg(all(target_arch = "x86_64", feature = "gdb"))]
- if to_gdb_channel.is_some() {
+ if to_gdb_tube.is_some() {
// Wait until a GDB client attaches
run_mode = VmRunMode::Breakpoint;
}
@@ -1960,7 +2042,7 @@ where
if interrupted_by_signal || run_mode != VmRunMode::Running {
'state_loop: loop {
// Tries to get a pending message without blocking first.
- let msg = match from_main_channel.try_recv() {
+ let msg = match from_main_tube.try_recv() {
Ok(m) => m,
Err(mpsc::TryRecvError::Empty) if run_mode == VmRunMode::Running => {
// If the VM is running and no message is pending, the state won't
@@ -1969,23 +2051,23 @@ where
}
Err(mpsc::TryRecvError::Empty) => {
// If the VM is not running, wait until a message is ready.
- match from_main_channel.recv() {
+ match from_main_tube.recv() {
Ok(m) => m,
Err(mpsc::RecvError) => {
- error!("Failed to read from main channel in vcpu");
+ error!("Failed to read from main tube in vcpu");
break 'vcpu_loop;
}
}
}
Err(mpsc::TryRecvError::Disconnected) => {
- error!("Failed to read from main channel in vcpu");
+ error!("Failed to read from main tube in vcpu");
break 'vcpu_loop;
}
};
// Collect all pending messages.
let mut messages = vec![msg];
- messages.append(&mut from_main_channel.try_iter().collect());
+ messages.append(&mut from_main_tube.try_iter().collect());
for msg in messages {
match msg {
@@ -2015,7 +2097,7 @@ where
}
#[cfg(all(target_arch = "x86_64", feature = "gdb"))]
VcpuControl::Debug(d) => {
- match &to_gdb_channel {
+ match &to_gdb_tube {
Some(ref ch) => {
if let Err(e) = handle_debug_msg(
cpu_id, &vcpu, &guest_mem, d, &ch,
@@ -2112,7 +2194,7 @@ where
cpu: cpu_id as usize,
msg: VcpuDebugStatus::HitBreakPoint,
};
- if let Some(ref ch) = to_gdb_channel {
+ if let Some(ref ch) = to_gdb_tube {
if let Err(e) = ch.send(msg) {
error!("failed to notify breakpoint to GDB thread: {}", e);
break;
@@ -2183,16 +2265,10 @@ fn file_to_i64<P: AsRef<Path>>(path: P, nth: usize) -> io::Result<i64> {
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "empty file"))
}
-fn create_kvm(mem: GuestMemory) -> base::Result<KvmVm> {
- let kvm = Kvm::new()?;
- let vm = KvmVm::new(&kvm, mem)?;
- Ok(vm)
-}
-
fn create_kvm_kernel_irq_chip(
vm: &KvmVm,
vcpu_count: usize,
- _ioapic_device_socket: VmIrqRequestSocket,
+ _ioapic_device_tube: Tube,
) -> base::Result<impl IrqChipArch> {
let irq_chip = KvmKernelIrqChip::new(vm.try_clone()?, vcpu_count)?;
Ok(irq_chip)
@@ -2202,13 +2278,27 @@ fn create_kvm_kernel_irq_chip(
fn create_kvm_split_irq_chip(
vm: &KvmVm,
vcpu_count: usize,
- ioapic_device_socket: VmIrqRequestSocket,
+ ioapic_device_tube: Tube,
) -> base::Result<impl IrqChipArch> {
- let irq_chip = KvmSplitIrqChip::new(vm.try_clone()?, vcpu_count, ioapic_device_socket)?;
+ let irq_chip =
+ KvmSplitIrqChip::new(vm.try_clone()?, vcpu_count, ioapic_device_tube, Some(120))?;
Ok(irq_chip)
}
pub fn run_config(cfg: Config) -> Result<()> {
+ let components = setup_vm_components(&cfg)?;
+
+ let guest_mem_layout =
+ Arch::guest_memory_layout(&components).map_err(Error::GuestMemoryLayout)?;
+ let guest_mem = GuestMemory::new(&guest_mem_layout).unwrap();
+ let mut mem_policy = MemoryPolicy::empty();
+ if components.hugepages {
+ mem_policy |= MemoryPolicy::USE_HUGEPAGES;
+ }
+ guest_mem.set_memory_policy(mem_policy);
+ let kvm = Kvm::new_with_path(&cfg.kvm_device_path).map_err(Error::CreateKvm)?;
+ let vm = KvmVm::new(&kvm, guest_mem).map_err(Error::CreateVm)?;
+
if cfg.split_irqchip {
#[cfg(any(target_arch = "arm", target_arch = "aarch64"))]
{
@@ -2217,39 +2307,14 @@ pub fn run_config(cfg: Config) -> Result<()> {
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
{
- run_vm::<_, KvmVcpu, _, _, _>(cfg, create_kvm, create_kvm_split_irq_chip)
+ run_vm::<KvmVcpu, _, _, _>(cfg, components, vm, create_kvm_split_irq_chip)
}
} else {
- run_vm::<_, KvmVcpu, _, _, _>(cfg, create_kvm, create_kvm_kernel_irq_chip)
+ run_vm::<KvmVcpu, _, _, _>(cfg, components, vm, create_kvm_kernel_irq_chip)
}
}
-fn run_vm<V, Vcpu, I, FV, FI>(cfg: Config, create_vm: FV, create_irq_chip: FI) -> Result<()>
-where
- V: VmArch + 'static,
- Vcpu: VcpuArch + 'static,
- I: IrqChipArch + 'static,
- FV: FnOnce(GuestMemory) -> base::Result<V>,
- FI: FnOnce(
- &V,
- usize, // vcpu_count
- VmIrqRequestSocket, // ioapic_device_socket
- ) -> base::Result<I>,
-{
- if cfg.sandbox {
- // Printing something to the syslog before entering minijail so that libc's syslogger has a
- // chance to open files necessary for its operation, like `/etc/localtime`. After jailing,
- // access to those files will not be possible.
- info!("crosvm entering multiprocess mode");
- }
-
- let (usb_control_socket, usb_provider) =
- HostBackendDeviceProvider::new().map_err(Error::CreateUsbProvider)?;
- // Masking signals is inherently dangerous, since this can persist across clones/execs. Do this
- // before any jailed devices have been spawned, so that we can catch any of them that fail very
- // quickly.
- let sigchld_fd = SignalFd::new(libc::SIGCHLD).map_err(Error::CreateSignalFd)?;
-
+fn setup_vm_components(cfg: &Config) -> Result<VmComponents> {
let initrd_image = if let Some(initrd_path) = &cfg.initrd_path {
Some(File::open(initrd_path).map_err(|e| Error::OpenInitrd(initrd_path.clone(), e))?)
} else {
@@ -2266,19 +2331,7 @@ where
_ => panic!("Did not receive a bios or kernel, should be impossible."),
};
- let mut control_sockets = Vec::new();
- #[cfg(all(target_arch = "x86_64", feature = "gdb"))]
- let gdb_socket = if let Some(port) = cfg.gdb {
- // GDB needs a control socket to interrupt vcpus.
- let (gdb_host_socket, gdb_control_socket) =
- msg_socket::pair::<VmResponse, VmRequest>().map_err(Error::CreateSocket)?;
- control_sockets.push(TaggedControlSocket::Vm(gdb_host_socket));
- Some((port, gdb_control_socket))
- } else {
- None
- };
-
- let components = VmComponents {
+ Ok(VmComponents {
memory_size: cfg
.memory
.unwrap_or(256)
@@ -2287,6 +2340,7 @@ where
vcpu_count: cfg.vcpu_count.unwrap_or(1),
vcpu_affinity: cfg.vcpu_affinity.clone(),
no_smt: cfg.no_smt,
+ hugepages: cfg.hugepages,
vm_image,
android_fstab: cfg
.android_fstab
@@ -2305,52 +2359,86 @@ where
rt_cpus: cfg.rt_cpus.clone(),
protected_vm: cfg.protected_vm,
#[cfg(all(target_arch = "x86_64", feature = "gdb"))]
- gdb: gdb_socket,
- };
+ gdb: None,
+ dmi_path: cfg.dmi_path.clone(),
+ })
+}
+
+fn run_vm<Vcpu, V, I, FI>(
+ cfg: Config,
+ #[allow(unused_mut)] mut components: VmComponents,
+ vm: V,
+ create_irq_chip: FI,
+) -> Result<()>
+where
+ Vcpu: VcpuArch + 'static,
+ V: VmArch + 'static,
+ I: IrqChipArch + 'static,
+ FI: FnOnce(
+ &V,
+ usize, // vcpu_count
+ Tube, // ioapic_device_tube
+ ) -> base::Result<I>,
+{
+ if cfg.sandbox {
+ // Printing something to the syslog before entering minijail so that libc's syslogger has a
+ // chance to open files necessary for its operation, like `/etc/localtime`. After jailing,
+ // access to those files will not be possible.
+ info!("crosvm entering multiprocess mode");
+ }
+
+ let (usb_control_tube, usb_provider) =
+ HostBackendDeviceProvider::new().map_err(Error::CreateUsbProvider)?;
+ // Masking signals is inherently dangerous, since this can persist across clones/execs. Do this
+ // before any jailed devices have been spawned, so that we can catch any of them that fail very
+ // quickly.
+ let sigchld_fd = SignalFd::new(libc::SIGCHLD).map_err(Error::CreateSignalFd)?;
let control_server_socket = match &cfg.socket_path {
Some(path) => Some(UnlinkUnixSeqpacketListener(
- UnixSeqpacketListener::bind(path).map_err(Error::CreateSocket)?,
+ UnixSeqpacketListener::bind(path).map_err(Error::CreateControlServer)?,
)),
None => None,
};
- let (wayland_host_socket, wayland_device_socket) =
- msg_socket::pair::<VmMemoryResponse, VmMemoryRequest>().map_err(Error::CreateSocket)?;
- control_sockets.push(TaggedControlSocket::VmMemory(wayland_host_socket));
+ let mut control_tubes = Vec::new();
+
+ #[cfg(all(target_arch = "x86_64", feature = "gdb"))]
+ if let Some(port) = cfg.gdb {
+ // GDB needs a control socket to interrupt vcpus.
+ let (gdb_host_tube, gdb_control_tube) = Tube::pair().map_err(Error::CreateTube)?;
+ control_tubes.push(TaggedControlTube::Vm(gdb_host_tube));
+ components.gdb = Some((port, gdb_control_tube));
+ }
+
+ let (wayland_host_tube, wayland_device_tube) = Tube::pair().map_err(Error::CreateTube)?;
+ control_tubes.push(TaggedControlTube::VmMemory(wayland_host_tube));
// Balloon gets a special socket so balloon requests can be forwarded from the main process.
- let (balloon_host_socket, balloon_device_socket) =
- msg_socket::pair::<BalloonControlCommand, BalloonControlResult>()
- .map_err(Error::CreateSocket)?;
+ let (balloon_host_tube, balloon_device_tube) = Tube::pair().map_err(Error::CreateTube)?;
// Create one control socket per disk.
- let mut disk_device_sockets = Vec::new();
- let mut disk_host_sockets = Vec::new();
+ let mut disk_device_tubes = Vec::new();
+ let mut disk_host_tubes = Vec::new();
let disk_count = cfg.disks.len();
for _ in 0..disk_count {
- let (disk_host_socket, disk_device_socket) =
- msg_socket::pair::<DiskControlCommand, DiskControlResult>()
- .map_err(Error::CreateSocket)?;
- disk_host_sockets.push(disk_host_socket);
- disk_device_sockets.push(disk_device_socket);
+ let (disk_host_tub, disk_device_tube) = Tube::pair().map_err(Error::CreateTube)?;
+ disk_host_tubes.push(disk_host_tub);
+ disk_device_tubes.push(disk_device_tube);
}
- let mut pmem_device_sockets = Vec::new();
+ let mut pmem_device_tubes = Vec::new();
let pmem_count = cfg.pmem_devices.len();
for _ in 0..pmem_count {
- let (pmem_host_socket, pmem_device_socket) =
- msg_socket::pair::<VmMsyncResponse, VmMsyncRequest>().map_err(Error::CreateSocket)?;
- pmem_device_sockets.push(pmem_device_socket);
- control_sockets.push(TaggedControlSocket::VmMsync(pmem_host_socket));
+ let (pmem_host_tube, pmem_device_tube) = Tube::pair().map_err(Error::CreateTube)?;
+ pmem_device_tubes.push(pmem_device_tube);
+ control_tubes.push(TaggedControlTube::VmMsync(pmem_host_tube));
}
- let (gpu_host_socket, gpu_device_socket) =
- msg_socket::pair::<VmMemoryResponse, VmMemoryRequest>().map_err(Error::CreateSocket)?;
- control_sockets.push(TaggedControlSocket::VmMemory(gpu_host_socket));
+ let (gpu_host_tube, gpu_device_tube) = Tube::pair().map_err(Error::CreateTube)?;
+ control_tubes.push(TaggedControlTube::VmMemory(gpu_host_tube));
- let (ioapic_host_socket, ioapic_device_socket) =
- msg_socket::pair::<VmIrqResponse, VmIrqRequest>().map_err(Error::CreateSocket)?;
- control_sockets.push(TaggedControlSocket::VmIrq(ioapic_host_socket));
+ let (ioapic_host_tube, ioapic_device_tube) = Tube::pair().map_err(Error::CreateTube)?;
+ control_tubes.push(TaggedControlTube::VmIrq(ioapic_host_tube));
let battery = if cfg.battery_type.is_some() {
let jail = match simple_jail(&cfg, "battery")? {
@@ -2390,19 +2478,20 @@ where
.iter()
.filter(|sd| sd.kind == SharedDirKind::FS)
.count();
- let mut fs_device_sockets = Vec::with_capacity(fs_count);
+ let mut fs_device_tubes = Vec::with_capacity(fs_count);
for _ in 0..fs_count {
- let (fs_host_socket, fs_device_socket) =
- msg_socket::pair::<VmResponse, FsMappingRequest>().map_err(Error::CreateSocket)?;
- control_sockets.push(TaggedControlSocket::Fs(fs_host_socket));
- fs_device_sockets.push(fs_device_socket);
+ let (fs_host_tube, fs_device_tube) = Tube::pair().map_err(Error::CreateTube)?;
+ control_tubes.push(TaggedControlTube::Fs(fs_host_tube));
+ fs_device_tubes.push(fs_device_tube);
}
- let linux: RunnableLinuxVm<_, Vcpu, _> = Arch::build_vm(
+ #[cfg_attr(not(feature = "direct"), allow(unused_mut))]
+ let mut linux: RunnableLinuxVm<_, Vcpu, _> = Arch::build_vm(
components,
&cfg.serial_parameters,
simple_jail(&cfg, "serial")?,
battery,
+ vm,
|mem, vm, sys_allocator, exit_evt| {
create_devices(
&cfg,
@@ -2410,29 +2499,75 @@ where
vm,
sys_allocator,
exit_evt,
- &mut control_sockets,
- wayland_device_socket,
- gpu_device_socket,
- balloon_device_socket,
- &mut disk_device_sockets,
- &mut pmem_device_sockets,
- &mut fs_device_sockets,
+ &mut control_tubes,
+ wayland_device_tube,
+ gpu_device_tube,
+ balloon_device_tube,
+ &mut disk_device_tubes,
+ &mut pmem_device_tubes,
+ &mut fs_device_tubes,
usb_provider,
Arc::clone(&map_request),
)
},
- create_vm,
- |vm, vcpu_count| create_irq_chip(vm, vcpu_count, ioapic_device_socket),
+ |vm, vcpu_count| create_irq_chip(vm, vcpu_count, ioapic_device_tube),
)
.map_err(Error::BuildVm)?;
+ #[cfg(feature = "direct")]
+ if let Some(pmio) = &cfg.direct_pmio {
+ let direct_io =
+ Arc::new(devices::DirectIo::new(&pmio.path, false).map_err(Error::DirectIo)?);
+ for range in pmio.ranges.iter() {
+ linux
+ .io_bus
+ .insert_sync(direct_io.clone(), range.0, range.1)
+ .unwrap();
+ }
+ };
+
+ #[cfg(feature = "direct")]
+ let mut irqs = Vec::new();
+
+ #[cfg(feature = "direct")]
+ for irq in &cfg.direct_level_irq {
+ if !linux.resources.reserve_irq(*irq) {
+ warn!("irq {} already reserved.", irq);
+ }
+ let trigger = Event::new().map_err(Error::CreateEvent)?;
+ let resample = Event::new().map_err(Error::CreateEvent)?;
+ linux
+ .irq_chip
+ .register_irq_event(*irq, &trigger, Some(&resample))
+ .unwrap();
+ let direct_irq =
+ devices::DirectIrq::new(trigger, Some(resample)).map_err(Error::DirectIrq)?;
+ direct_irq.irq_enable(*irq).map_err(Error::DirectIrq)?;
+ irqs.push(direct_irq);
+ }
+
+ #[cfg(feature = "direct")]
+ for irq in &cfg.direct_edge_irq {
+ if !linux.resources.reserve_irq(*irq) {
+ warn!("irq {} already reserved.", irq);
+ }
+ let trigger = Event::new().map_err(Error::CreateEvent)?;
+ linux
+ .irq_chip
+ .register_irq_event(*irq, &trigger, None)
+ .unwrap();
+ let direct_irq = devices::DirectIrq::new(trigger, None).map_err(Error::DirectIrq)?;
+ direct_irq.irq_enable(*irq).map_err(Error::DirectIrq)?;
+ irqs.push(direct_irq);
+ }
+
run_control(
linux,
control_server_socket,
- control_sockets,
- balloon_host_socket,
- &disk_host_sockets,
- usb_control_socket,
+ control_tubes,
+ balloon_host_tube,
+ &disk_host_tubes,
+ usb_control_tube,
sigchld_fd,
cfg.sandbox,
Arc::clone(&map_request),
@@ -2441,8 +2576,8 @@ where
)
}
-/// Signals all running VCPUs to vmexit, sends VmRunMode message to each VCPU channel, and tells
-/// `irq_chip` to stop blocking halted VCPUs. The channel message is set first because both the
+/// Signals all running VCPUs to vmexit, sends VmRunMode message to each VCPU tube, and tells
+/// `irq_chip` to stop blocking halted VCPUs. The tube message is set first because both the
/// signal and the irq_chip kick could cause the VCPU thread to continue through the VCPU run
/// loop.
fn kick_all_vcpus(
@@ -2450,8 +2585,8 @@ fn kick_all_vcpus(
irq_chip: &impl IrqChip,
run_mode: &VmRunMode,
) {
- for (handle, channel) in vcpu_handles {
- if let Err(e) = channel.send(VcpuControl::RunState(run_mode.clone())) {
+ for (handle, tube) in vcpu_handles {
+ if let Err(e) = tube.send(VcpuControl::RunState(run_mode.clone())) {
error!("failed to send VmRunMode: {}", e);
}
let _ = handle.kill(SIGRTMIN() + 0);
@@ -2626,10 +2761,10 @@ impl BalloonPolicy {
fn run_control<V: VmArch + 'static, Vcpu: VcpuArch + 'static, I: IrqChipArch + 'static>(
mut linux: RunnableLinuxVm<V, Vcpu, I>,
control_server_socket: Option<UnlinkUnixSeqpacketListener>,
- mut control_sockets: Vec<TaggedControlSocket>,
- balloon_host_socket: BalloonControlRequestSocket,
- disk_host_sockets: &[DiskControlRequestSocket],
- usb_control_socket: UsbControlSocket,
+ mut control_tubes: Vec<TaggedControlTube>,
+ balloon_host_tube: Tube,
+ disk_host_tubes: &[Tube],
+ usb_control_tube: Tube,
sigchld_fd: SignalFd,
sandbox: bool,
map_request: Arc<Mutex<Option<ExternalMapping>>>,
@@ -2664,7 +2799,7 @@ fn run_control<V: VmArch + 'static, Vcpu: VcpuArch + 'static, I: IrqChipArch + '
.add(socket_server, Token::VmControlServer)
.map_err(Error::WaitContextAdd)?;
}
- for (index, socket) in control_sockets.iter().enumerate() {
+ for (index, socket) in control_tubes.iter().enumerate() {
wait_ctx
.add(socket.as_ref(), Token::VmControl { index })
.map_err(Error::WaitContextAdd)?;
@@ -2696,7 +2831,7 @@ fn run_control<V: VmArch + 'static, Vcpu: VcpuArch + 'static, I: IrqChipArch + '
// Listen for balloon statistics from the guest so we can balance.
wait_ctx
- .add(&balloon_host_socket, Token::BalloonResult)
+ .add(&balloon_host_tube, Token::BalloonResult)
.map_err(Error::WaitContextAdd)?;
Some(BalloonPolicy::new(
linux.vm.get_memory().memory_size() as i64,
@@ -2766,13 +2901,13 @@ fn run_control<V: VmArch + 'static, Vcpu: VcpuArch + 'static, I: IrqChipArch + '
#[cfg(all(target_arch = "x86_64", feature = "gdb"))]
// Spawn GDB thread.
- if let Some((gdb_port_num, gdb_control_socket)) = linux.gdb.take() {
+ if let Some((gdb_port_num, gdb_control_tube)) = linux.gdb.take() {
let to_vcpu_channels = vcpu_handles
.iter()
.map(|(_handle, channel)| channel.clone())
.collect();
let target = GdbStub::new(
- gdb_control_socket,
+ gdb_control_tube,
to_vcpu_channels,
from_vcpu_channel.unwrap(), // Must succeed to unwrap()
);
@@ -2834,12 +2969,12 @@ fn run_control<V: VmArch + 'static, Vcpu: VcpuArch + 'static, I: IrqChipArch + '
Token::BalanceMemory => {
balancemem_timer.wait().map_err(Error::Timer)?;
let command = BalloonControlCommand::Stats {};
- if let Err(e) = balloon_host_socket.send(&command) {
+ if let Err(e) = balloon_host_tube.send(&command) {
warn!("failed to send stats request to balloon device: {}", e);
}
}
Token::BalloonResult => {
- match balloon_host_socket.recv() {
+ match balloon_host_tube.recv() {
Ok(BalloonControlResult::Stats {
stats,
balloon_actual: balloon_actual_u,
@@ -2860,7 +2995,7 @@ fn run_control<V: VmArch + 'static, Vcpu: VcpuArch + 'static, I: IrqChipArch + '
let target = max((balloon_actual_u as i64) + delta, 0) as u64;
let command =
BalloonControlCommand::Adjust { num_bytes: target };
- if let Err(e) = balloon_host_socket.send(&command) {
+ if let Err(e) = balloon_host_tube.send(&command) {
warn!(
"failed to send memory value to balloon device: {}",
e
@@ -2883,31 +3018,30 @@ fn run_control<V: VmArch + 'static, Vcpu: VcpuArch + 'static, I: IrqChipArch + '
.add(
&socket,
Token::VmControl {
- index: control_sockets.len(),
+ index: control_tubes.len(),
},
)
.map_err(Error::WaitContextAdd)?;
- control_sockets
- .push(TaggedControlSocket::Vm(MsgSocket::new(socket)));
+ control_tubes.push(TaggedControlTube::Vm(Tube::new(socket)));
}
Err(e) => error!("failed to accept socket: {}", e),
}
}
}
Token::VmControl { index } => {
- if let Some(socket) = control_sockets.get(index) {
+ if let Some(socket) = control_tubes.get(index) {
match socket {
- TaggedControlSocket::Vm(socket) => match socket.recv() {
+ TaggedControlTube::Vm(tube) => match tube.recv::<VmRequest>() {
Ok(request) => {
let mut run_mode_opt = None;
let response = request.execute(
&mut run_mode_opt,
- &balloon_host_socket,
- disk_host_sockets,
- &usb_control_socket,
+ &balloon_host_tube,
+ disk_host_tubes,
+ &usb_control_tube,
&mut linux.bat_control,
);
- if let Err(e) = socket.send(&response) {
+ if let Err(e) = tube.send(&response) {
error!("failed to send VmResponse: {}", e);
}
if let Some(run_mode) = run_mode_opt {
@@ -2930,34 +3064,36 @@ fn run_control<V: VmArch + 'static, Vcpu: VcpuArch + 'static, I: IrqChipArch + '
}
}
Err(e) => {
- if let MsgError::RecvZero = e {
+ if let TubeError::Disconnected = e {
vm_control_indices_to_remove.push(index);
} else {
error!("failed to recv VmRequest: {}", e);
}
}
},
- TaggedControlSocket::VmMemory(socket) => match socket.recv() {
- Ok(request) => {
- let response = request.execute(
- &mut linux.vm,
- &mut linux.resources,
- Arc::clone(&map_request),
- &mut gralloc,
- );
- if let Err(e) = socket.send(&response) {
- error!("failed to send VmMemoryControlResponse: {}", e);
+ TaggedControlTube::VmMemory(tube) => {
+ match tube.recv::<VmMemoryRequest>() {
+ Ok(request) => {
+ let response = request.execute(
+ &mut linux.vm,
+ &mut linux.resources,
+ Arc::clone(&map_request),
+ &mut gralloc,
+ );
+ if let Err(e) = tube.send(&response) {
+ error!("failed to send VmMemoryControlResponse: {}", e);
+ }
}
- }
- Err(e) => {
- if let MsgError::RecvZero = e {
- vm_control_indices_to_remove.push(index);
- } else {
- error!("failed to recv VmMemoryControlRequest: {}", e);
+ Err(e) => {
+ if let TubeError::Disconnected = e {
+ vm_control_indices_to_remove.push(index);
+ } else {
+ error!("failed to recv VmMemoryControlRequest: {}", e);
+ }
}
}
- },
- TaggedControlSocket::VmIrq(socket) => match socket.recv() {
+ }
+ TaggedControlTube::VmIrq(tube) => match tube.recv::<VmIrqRequest>() {
Ok(request) => {
let response = {
let irq_chip = &mut linux.irq_chip;
@@ -2990,43 +3126,45 @@ fn run_control<V: VmArch + 'static, Vcpu: VcpuArch + 'static, I: IrqChipArch + '
&mut linux.resources,
)
};
- if let Err(e) = socket.send(&response) {
+ if let Err(e) = tube.send(&response) {
error!("failed to send VmIrqResponse: {}", e);
}
}
Err(e) => {
- if let MsgError::RecvZero = e {
+ if let TubeError::Disconnected = e {
vm_control_indices_to_remove.push(index);
} else {
error!("failed to recv VmIrqRequest: {}", e);
}
}
},
- TaggedControlSocket::VmMsync(socket) => match socket.recv() {
- Ok(request) => {
- let response = request.execute(&mut linux.vm);
- if let Err(e) = socket.send(&response) {
- error!("failed to send VmMsyncResponse: {}", e);
+ TaggedControlTube::VmMsync(tube) => {
+ match tube.recv::<VmMsyncRequest>() {
+ Ok(request) => {
+ let response = request.execute(&mut linux.vm);
+ if let Err(e) = tube.send(&response) {
+ error!("failed to send VmMsyncResponse: {}", e);
+ }
}
- }
- Err(e) => {
- if let MsgError::BadRecvSize { actual: 0, .. } = e {
- vm_control_indices_to_remove.push(index);
- } else {
- error!("failed to recv VmMsyncRequest: {}", e);
+ Err(e) => {
+ if let TubeError::Disconnected = e {
+ vm_control_indices_to_remove.push(index);
+ } else {
+ error!("failed to recv VmMsyncRequest: {}", e);
+ }
}
}
- },
- TaggedControlSocket::Fs(socket) => match socket.recv() {
+ }
+ TaggedControlTube::Fs(tube) => match tube.recv::<FsMappingRequest>() {
Ok(request) => {
let response =
request.execute(&mut linux.vm, &mut linux.resources);
- if let Err(e) = socket.send(&response) {
+ if let Err(e) = tube.send(&response) {
error!("failed to send VmResponse: {}", e);
}
}
Err(e) => {
- if let MsgError::BadRecvSize { actual: 0, .. } = e {
+ if let TubeError::Disconnected = e {
vm_control_indices_to_remove.push(index);
} else {
error!("failed to recv VmResponse: {}", e);
@@ -3050,15 +3188,14 @@ fn run_control<V: VmArch + 'static, Vcpu: VcpuArch + 'static, I: IrqChipArch + '
Token::VmControlServer => {}
Token::VmControl { index } => {
// It's possible more data is readable and buffered while the socket is hungup,
- // so don't delete the socket from the poll context until we're sure all the
+ // so don't delete the tube from the poll context until we're sure all the
// data is read.
- match control_sockets
+ if control_tubes
.get(index)
- .map(|s| s.as_ref().get_readable_bytes())
+ .map(|s| !s.as_ref().is_packet_ready())
+ .unwrap_or(false)
{
- Some(Ok(0)) | Some(Err(_)) => vm_control_indices_to_remove.push(index),
- Some(Ok(x)) => info!("control index {} has {} bytes readable", index, x),
- _ => {}
+ vm_control_indices_to_remove.push(index);
}
}
}
@@ -3077,7 +3214,7 @@ fn run_control<V: VmArch + 'static, Vcpu: VcpuArch + 'static, I: IrqChipArch + '
// now belongs to a different socket, the control loop will start to interact with
// sockets that might not be ready to use. This can cause incorrect hangup detection or
// blocking on a socket that will never be ready. See also: crbug.com/1019986
- if let Some(socket) = control_sockets.get(index) {
+ if let Some(socket) = control_tubes.get(index) {
wait_ctx.delete(socket).map_err(Error::WaitContextDelete)?;
}
@@ -3085,10 +3222,10 @@ fn run_control<V: VmArch + 'static, Vcpu: VcpuArch + 'static, I: IrqChipArch + '
// `swap_remove`. After this line, the socket at `index` is not the one from
// `vm_control_indices_to_remove`. Because of this socket's change in index, we need to
// use `wait_ctx.modify` to change the associated index in its `Token::VmControl`.
- control_sockets.swap_remove(index);
- if let Some(socket) = control_sockets.get(index) {
+ control_tubes.swap_remove(index);
+ if let Some(tube) = control_tubes.get(index) {
wait_ctx
- .modify(socket, EventType::Read, Token::VmControl { index })
+ .modify(tube, EventType::Read, Token::VmControl { index })
.map_err(Error::WaitContextAdd)?;
}
}
diff --git a/src/main.rs b/src/main.rs
index 4fedde676..e86845a65 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -8,10 +8,8 @@ pub mod panic_hook;
use std::collections::BTreeMap;
use std::default::Default;
-use std::fmt;
use std::fs::{File, OpenOptions};
use std::io::{BufRead, BufReader};
-use std::num::ParseIntError;
use std::path::{Path, PathBuf};
use std::str::FromStr;
use std::string::String;
@@ -22,15 +20,13 @@ use arch::{
set_default_serial_parameters, Pstore, SerialHardware, SerialParameters, SerialType,
VcpuAffinity,
};
-use base::{
- debug, error, getpid, info, kill_process_group, net::UnixSeqpacket, reap_child, syslog,
- validate_raw_descriptor, warn, FromRawDescriptor, IntoRawDescriptor, RawDescriptor,
- SafeDescriptor,
-};
+use base::{debug, error, getpid, info, kill_process_group, reap_child, syslog, warn};
+#[cfg(feature = "direct")]
+use crosvm::DirectIoOption;
use crosvm::{
argument::{self, print_help, set_arguments, Argument},
platform, BindMount, Config, DiskOption, Executable, GidMap, SharedDir, TouchDeviceOption,
- DISK_ID_LEN,
+ VhostUserFsOption, VhostUserOption, DISK_ID_LEN,
};
#[cfg(feature = "gpu")]
use devices::virtio::gpu::{GpuMode, GpuParameters};
@@ -38,11 +34,12 @@ use devices::ProtectionType;
#[cfg(feature = "audio")]
use devices::{Ac97Backend, Ac97Parameters};
use disk::QcowFile;
-use msg_socket::{MsgReceiver, MsgSender, MsgSocket};
use vm_control::{
- BalloonControlCommand, BatControlCommand, BatControlResult, BatteryType, DiskControlCommand,
- MaybeOwnedDescriptor, UsbControlCommand, UsbControlResult, VmControlRequestSocket, VmRequest,
- VmResponse, USB_CONTROL_MAX_PORTS,
+ client::{
+ do_modify_battery, do_usb_attach, do_usb_detach, do_usb_list, handle_request, vms_request,
+ ModifyUsbError, ModifyUsbResult,
+ },
+ BalloonControlCommand, BatteryType, DiskControlCommand, UsbControlResult, VmRequest,
};
fn executable_is_plugin(executable: &Option<Executable>) -> bool {
@@ -175,12 +172,12 @@ fn parse_gpu_options(s: Option<&str>) -> argument::Result<GpuParameters> {
for (k, v) in opts {
match k {
// Deprecated: Specifying --gpu=<mode> Not great as the mode can be set multiple
- // times if the user specifies several modes (--gpu=2d,3d,gfxstream)
+ // times if the user specifies several modes (--gpu=2d,virglrenderer,gfxstream)
"2d" | "2D" => {
gpu_params.mode = GpuMode::Mode2D;
}
- "3d" | "3D" => {
- gpu_params.mode = GpuMode::Mode3D;
+ "3d" | "3D" | "virglrenderer" => {
+ gpu_params.mode = GpuMode::ModeVirglRenderer;
}
#[cfg(feature = "gfxstream")]
"gfxstream" => {
@@ -191,8 +188,8 @@ fn parse_gpu_options(s: Option<&str>) -> argument::Result<GpuParameters> {
"2d" | "2D" => {
gpu_params.mode = GpuMode::Mode2D;
}
- "3d" | "3D" => {
- gpu_params.mode = GpuMode::Mode3D;
+ "3d" | "3D" | "virglrenderer" => {
+ gpu_params.mode = GpuMode::ModeVirglRenderer;
}
#[cfg(feature = "gfxstream")]
"gfxstream" => {
@@ -202,7 +199,7 @@ fn parse_gpu_options(s: Option<&str>) -> argument::Result<GpuParameters> {
return Err(argument::Error::InvalidValue {
value: v.to_string(),
expected: String::from(
- "gpu parameter 'backend' should be one of (2d|3d|gfxstream)",
+ "gpu parameter 'backend' should be one of (2d|virglrenderer|gfxstream)",
),
});
}
@@ -303,15 +300,17 @@ fn parse_gpu_options(s: Option<&str>) -> argument::Result<GpuParameters> {
}
}
}
- #[cfg(feature = "gfxstream")]
"vulkan" => {
- vulkan_specified = true;
+ #[cfg(feature = "gfxstream")]
+ {
+ vulkan_specified = true;
+ }
match v {
"true" | "" => {
- gpu_params.gfxstream_support_vulkan = true;
+ gpu_params.use_vulkan = true;
}
"false" => {
- gpu_params.gfxstream_support_vulkan = false;
+ gpu_params.use_vulkan = false;
}
_ => {
return Err(argument::Error::InvalidValue {
@@ -345,6 +344,20 @@ fn parse_gpu_options(s: Option<&str>) -> argument::Result<GpuParameters> {
}
"cache-path" => gpu_params.cache_path = Some(v.to_string()),
"cache-size" => gpu_params.cache_size = Some(v.to_string()),
+ "udmabuf" => match v {
+ "true" | "" => {
+ gpu_params.udmabuf = true;
+ }
+ "false" => {
+ gpu_params.udmabuf = false;
+ }
+ _ => {
+ return Err(argument::Error::InvalidValue {
+ value: v.to_string(),
+ expected: String::from("gpu parameter 'udmabuf' should be a boolean"),
+ });
+ }
+ },
"" => {}
_ => {
return Err(argument::Error::UnknownArgument(format!(
@@ -358,12 +371,16 @@ fn parse_gpu_options(s: Option<&str>) -> argument::Result<GpuParameters> {
#[cfg(feature = "gfxstream")]
{
- if vulkan_specified || syncfd_specified || angle_specified {
+ if !vulkan_specified && gpu_params.mode == GpuMode::ModeGfxstream {
+ gpu_params.use_vulkan = true;
+ }
+
+ if syncfd_specified || angle_specified {
match gpu_params.mode {
GpuMode::ModeGfxstream => {}
_ => {
return Err(argument::Error::UnknownArgument(
- "gpu parameter vulkan and syncfd are only supported for gfxstream backend"
+ "gpu parameter syncfd and angle are only supported for gfxstream backend"
.to_string(),
));
}
@@ -398,7 +415,15 @@ fn parse_ac97_options(s: &str) -> argument::Result<Ac97Parameters> {
argument::Error::Syntax(format!("invalid capture option: {}", e))
})?;
}
- #[cfg(target_os = "linux")]
+ "client_type" => {
+ ac97_params
+ .set_client_type(v)
+ .map_err(|e| argument::Error::InvalidValue {
+ value: v.to_string(),
+ expected: e.to_string(),
+ })?;
+ }
+ #[cfg(any(target_os = "linux", target_os = "android"))]
"server" => {
ac97_params.vios_server_path =
Some(
@@ -418,7 +443,7 @@ fn parse_ac97_options(s: &str) -> argument::Result<Ac97Parameters> {
}
// server is required for and exclusive to vios backend
- #[cfg(target_os = "linux")]
+ #[cfg(any(target_os = "linux", target_os = "android"))]
match ac97_params.backend {
Ac97Backend::VIOS => {
if ac97_params.vios_server_path.is_none() {
@@ -658,6 +683,64 @@ fn parse_battery_options(s: Option<&str>) -> argument::Result<BatteryType> {
Ok(battery_type)
}
+#[cfg(feature = "direct")]
+fn parse_direct_io_options(s: Option<&str>) -> argument::Result<DirectIoOption> {
+ let s = s.ok_or(argument::Error::ExpectedValue(String::from(
+ "expected path@range[,range] value",
+ )))?;
+ let parts: Vec<&str> = s.splitn(2, '@').collect();
+ if parts.len() != 2 {
+ return Err(argument::Error::InvalidValue {
+ value: s.to_string(),
+ expected: String::from("missing port range, use /path@X-Y,Z,.. syntax"),
+ });
+ }
+ let path = PathBuf::from(parts[0]);
+ if !path.exists() {
+ return Err(argument::Error::InvalidValue {
+ value: parts[0].to_owned(),
+ expected: String::from("the path does not exist"),
+ });
+ };
+ let ranges: argument::Result<Vec<(u64, u64)>> = parts[1]
+ .split(',')
+ .map(|frag| frag.split('-'))
+ .map(|mut range| {
+ let base = range
+ .next()
+ .map(|v| v.parse::<u64>())
+ .map_or(Ok(None), |r| r.map(Some));
+ let last = range
+ .next()
+ .map(|v| v.parse::<u64>())
+ .map_or(Ok(None), |r| r.map(Some));
+ (base, last)
+ })
+ .map(|range| match range {
+ (Ok(Some(base)), Ok(None)) => Ok((base, 1)),
+ (Ok(Some(base)), Ok(Some(last))) => {
+ Ok((base, last.saturating_sub(base).saturating_add(1)))
+ }
+ (Err(e), _) => Err(argument::Error::InvalidValue {
+ value: e.to_string(),
+ expected: String::from("invalid base range value"),
+ }),
+ (_, Err(e)) => Err(argument::Error::InvalidValue {
+ value: e.to_string(),
+ expected: String::from("invalid last range value"),
+ }),
+ _ => Err(argument::Error::InvalidValue {
+ value: s.to_owned(),
+ expected: String::from("invalid range format"),
+ }),
+ })
+ .collect();
+ Ok(DirectIoOption {
+ path,
+ ranges: ranges?,
+ })
+}
+
fn set_argument(cfg: &mut Config, name: &str, value: Option<&str>) -> argument::Result<()> {
match name {
"" => {
@@ -676,6 +759,39 @@ fn set_argument(cfg: &mut Config, name: &str, value: Option<&str>) -> argument::
}
cfg.executable_path = Some(Executable::Kernel(kernel_path));
}
+ "kvm-device" => {
+ let kvm_device_path = PathBuf::from(value.unwrap());
+ if !kvm_device_path.exists() {
+ return Err(argument::Error::InvalidValue {
+ value: value.unwrap().to_owned(),
+ expected: String::from("this kvm device path does not exist"),
+ });
+ }
+
+ cfg.kvm_device_path = kvm_device_path;
+ }
+ "vhost-vsock-device" => {
+ let vhost_vsock_device_path = PathBuf::from(value.unwrap());
+ if !vhost_vsock_device_path.exists() {
+ return Err(argument::Error::InvalidValue {
+ value: value.unwrap().to_owned(),
+ expected: String::from("this vhost-vsock device path does not exist"),
+ });
+ }
+
+ cfg.vhost_vsock_device_path = vhost_vsock_device_path;
+ }
+ "vhost-net-device" => {
+ let vhost_net_device_path = PathBuf::from(value.unwrap());
+ if !vhost_net_device_path.exists() {
+ return Err(argument::Error::InvalidValue {
+ value: value.unwrap().to_owned(),
+ expected: String::from("this vhost-vsock device path does not exist"),
+ });
+ }
+
+ cfg.vhost_net_device_path = vhost_net_device_path;
+ }
"android-fstab" => {
if cfg.android_fstab.is_some()
&& !cfg.android_fstab.as_ref().unwrap().as_os_str().is_empty()
@@ -750,6 +866,9 @@ fn set_argument(cfg: &mut Config, name: &str, value: Option<&str>) -> argument::
})?,
)
}
+ "hugepages" => {
+ cfg.hugepages = true;
+ }
#[cfg(feature = "audio")]
"ac97" => {
let ac97_params = parse_ac97_options(value.unwrap())?;
@@ -1544,6 +1663,94 @@ fn set_argument(cfg: &mut Config, name: &str, value: Option<&str>) -> argument::
* 1024
* 1024; // cfg.balloon_bias is in bytes.
}
+ "vhost-user-blk" => cfg.vhost_user_blk.push(VhostUserOption {
+ socket: PathBuf::from(value.unwrap()),
+ }),
+ "vhost-user-net" => cfg.vhost_user_net.push(VhostUserOption {
+ socket: PathBuf::from(value.unwrap()),
+ }),
+ "vhost-user-fs" => {
+ // (socket:tag)
+ let param = value.unwrap();
+ let mut components = param.split(':');
+ let socket =
+ PathBuf::from(
+ components
+ .next()
+ .ok_or_else(|| argument::Error::InvalidValue {
+ value: param.to_owned(),
+ expected: String::from("missing socket path for `vhost-user-fs`"),
+ })?,
+ );
+ let tag = components
+ .next()
+ .ok_or_else(|| argument::Error::InvalidValue {
+ value: param.to_owned(),
+ expected: String::from("missing tag for `vhost-user-fs`"),
+ })?
+ .to_owned();
+ cfg.vhost_user_fs.push(VhostUserFsOption { socket, tag });
+ }
+ #[cfg(feature = "direct")]
+ "direct-pmio" => {
+ if cfg.direct_pmio.is_some() {
+ return Err(argument::Error::TooManyArguments(
+ "`direct_pmio` already given".to_owned(),
+ ));
+ }
+ cfg.direct_pmio = Some(parse_direct_io_options(value)?);
+ }
+ #[cfg(feature = "direct")]
+ "direct-level-irq" => {
+ cfg.direct_level_irq
+ .push(
+ value
+ .unwrap()
+ .parse()
+ .map_err(|_| argument::Error::InvalidValue {
+ value: value.unwrap().to_owned(),
+ expected: String::from(
+ "this value for `direct-level-irq` must be an unsigned integer",
+ ),
+ })?,
+ );
+ }
+ #[cfg(feature = "direct")]
+ "direct-edge-irq" => {
+ cfg.direct_edge_irq
+ .push(
+ value
+ .unwrap()
+ .parse()
+ .map_err(|_| argument::Error::InvalidValue {
+ value: value.unwrap().to_owned(),
+ expected: String::from(
+ "this value for `direct-edge-irq` must be an unsigned integer",
+ ),
+ })?,
+ );
+ }
+ "dmi" => {
+ if cfg.dmi_path.is_some() {
+ return Err(argument::Error::TooManyArguments(
+ "`dmi` already given".to_owned(),
+ ));
+ }
+ let dmi_path = PathBuf::from(value.unwrap());
+ if !dmi_path.exists() {
+ return Err(argument::Error::InvalidValue {
+ value: value.unwrap().to_owned(),
+ expected: String::from("the dmi path does not exist"),
+ });
+ }
+ if !dmi_path.is_dir() {
+ return Err(argument::Error::InvalidValue {
+ value: value.unwrap().to_owned(),
+ expected: String::from("the dmi path should be directory"),
+ });
+ }
+ cfg.dmi_path = Some(dmi_path);
+ }
"help" => return Err(argument::Error::PrintHelp),
_ => unreachable!(),
}
@@ -1603,6 +1810,9 @@ fn validate_arguments(cfg: &mut Config) -> std::result::Result<(), argument::Err
fn run_vm(args: std::env::Args) -> std::result::Result<(), ()> {
let arguments =
&[Argument::positional("KERNEL", "bzImage of kernel to run"),
+ Argument::value("kvm-device", "PATH", "Path to the KVM device. (default /dev/kvm)"),
+ Argument::value("vhost-vsock-device", "PATH", "Path to the vhost-vsock device. (default /dev/vhost-vsock)"),
+ Argument::value("vhost-net-device", "PATH", "Path to the vhost-net device. (default /dev/vhost-net)"),
Argument::value("android-fstab", "PATH", "Path to Android fstab"),
Argument::short_value('i', "initrd", "PATH", "Initial ramdisk to load."),
Argument::short_value('p',
@@ -1618,6 +1828,7 @@ fn run_vm(args: std::env::Args) -> std::result::Result<(), ()> {
"mem",
"N",
"Amount of guest memory in MiB. (default: 256)"),
+ Argument::flag("hugepages", "Advise the kernel to use Huge Pages for guest memory mappings."),
Argument::short_value('r',
"root",
"PATH[,key=value[,key=value[,...]]",
@@ -1644,13 +1855,14 @@ fn run_vm(args: std::env::Args) -> std::result::Result<(), ()> {
Argument::value("net-vq-pairs", "N", "virtio net virtual queue paris. (default: 1)"),
#[cfg(feature = "audio")]
Argument::value("ac97",
- "[backend=BACKEND,capture=true,capture_effect=EFFECT,shm-fd=FD,client-fd=FD,server-fd=FD]",
+ "[backend=BACKEND,capture=true,capture_effect=EFFECT,client_type=TYPE,shm-fd=FD,client-fd=FD,server-fd=FD]",
"Comma separated key=value pairs for setting up Ac97 devices. Can be given more than once .
Possible key values:
backend=(null, cras, vios) - Where to route the audio device. If not provided, backend will default to null.
`null` for /dev/null, cras for CRAS server and vios for VioS server.
capture - Enable audio capture
capture_effects - | separated effects to be enabled for recording. The only supported effect value now is EchoCancellation or aec.
+ client_type - Set specific client type for cras backend.
server - The to the VIOS server (unix socket)."),
Argument::value("serial",
"type=TYPE,[hardware=HW,num=NUM,path=PATH,input=PATH,console,earlycon,stdin]",
@@ -1712,15 +1924,15 @@ writeback=BOOL - Indicates whether the VM can use writeback caching (default: fa
"[width=INT,height=INT]",
"(EXPERIMENTAL) Comma separated key=value pairs for setting up a virtio-gpu device
Possible key values:
- backend=(2d|3d|gfxstream) - Which backend to use for virtio-gpu (determining rendering protocol)
+ backend=(2d|virglrenderer|gfxstream) - Which backend to use for virtio-gpu (determining rendering protocol)
width=INT - The width of the virtual display connected to the virtio-gpu.
height=INT - The height of the virtual display connected to the virtio-gpu.
- egl[=true|=false] - If the virtio-gpu backend should use a EGL context for rendering.
- glx[=true|=false] - If the virtio-gpu backend should use a GLX context for rendering.
- surfaceless[=true|=false] - If the virtio-gpu backend should use a surfaceless context for rendering.
- angle[=true|=false] - If the guest is using ANGLE (OpenGL on Vulkan) as its native OpenGL driver.
+ egl[=true|=false] - If the backend should use a EGL context for rendering.
+ glx[=true|=false] - If the backend should use a GLX context for rendering.
+ surfaceless[=true|=false] - If the backend should use a surfaceless context for rendering.
+ angle[=true|=false] - If the gfxstream backend should use ANGLE (OpenGL on Vulkan) as its native OpenGL driver.
syncfd[=true|=false] - If the gfxstream backend should support EGL_ANDROID_native_fence_sync
- vulkan[=true|=false] - If the gfxstream backend should support vulkan
+ vulkan[=true|=false] - If the backend should support vulkan
"),
#[cfg(feature = "tpm")]
Argument::flag("software-tpm", "enable a software emulated trusted platform module device"),
@@ -1749,6 +1961,17 @@ writeback=BOOL - Indicates whether the VM can use writeback caching (default: fa
"),
Argument::value("gdb", "PORT", "(EXPERIMENTAL) gdb on the given port"),
Argument::value("balloon_bias_mib", "N", "Amount to bias balance of memory between host and guest as the balloon inflates, in MiB."),
+ Argument::value("vhost-user-blk", "SOCKET_PATH", "Path to a socket for vhost-user block"),
+ Argument::value("vhost-user-net", "SOCKET_PATH", "Path to a socket for vhost-user net"),
+ Argument::value("vhost-user-fs", "SOCKET_PATH:TAG",
+ "Path to a socket path for vhost-user fs, and tag for the shared dir"),
+ #[cfg(feature = "direct")]
+ Argument::value("direct-pmio", "PATH@RANGE[,RANGE[,...]]", "Path and ranges for direct port I/O access"),
+ #[cfg(feature = "direct")]
+ Argument::value("direct-level-irq", "irq", "Enable interrupt passthrough"),
+ #[cfg(feature = "direct")]
+ Argument::value("direct-edge-irq", "irq", "Enable interrupt passthrough"),
+ Argument::value("dmi", "DIR", "Directory with smbios_entry_point/DMI files"),
Argument::short_flag('h', "help", "Print help message.")];
let mut cfg = Config::default();
@@ -1792,76 +2015,37 @@ writeback=BOOL - Indicates whether the VM can use writeback caching (default: fa
}
}
-fn handle_request(
- request: &VmRequest,
- args: std::env::Args,
-) -> std::result::Result<VmResponse, ()> {
- let mut return_result = Err(());
- for socket_path in args {
- match UnixSeqpacket::connect(&socket_path) {
- Ok(s) => {
- let socket: VmControlRequestSocket = MsgSocket::new(s);
- if let Err(e) = socket.send(request) {
- error!(
- "failed to send request to socket at '{}': {}",
- socket_path, e
- );
- return_result = Err(());
- continue;
- }
- match socket.recv() {
- Ok(response) => return_result = Ok(response),
- Err(e) => {
- error!(
- "failed to send request to socket at2 '{}': {}",
- socket_path, e
- );
- return_result = Err(());
- continue;
- }
- }
- }
- Err(e) => {
- error!("failed to connect to socket at '{}': {}", socket_path, e);
- return_result = Err(());
- }
- }
- }
-
- return_result
-}
-
-fn vms_request(request: &VmRequest, args: std::env::Args) -> std::result::Result<(), ()> {
- let response = handle_request(request, args)?;
- info!("request response was {}", response);
- Ok(())
-}
-
-fn stop_vms(args: std::env::Args) -> std::result::Result<(), ()> {
+fn stop_vms(mut args: std::env::Args) -> std::result::Result<(), ()> {
if args.len() == 0 {
print_help("crosvm stop", "VM_SOCKET...", &[]);
println!("Stops the crosvm instance listening on each `VM_SOCKET` given.");
return Err(());
}
- vms_request(&VmRequest::Exit, args)
+ let socket_path = &args.next().unwrap();
+ let socket_path = Path::new(&socket_path);
+ vms_request(&VmRequest::Exit, socket_path)
}
-fn suspend_vms(args: std::env::Args) -> std::result::Result<(), ()> {
+fn suspend_vms(mut args: std::env::Args) -> std::result::Result<(), ()> {
if args.len() == 0 {
print_help("crosvm suspend", "VM_SOCKET...", &[]);
println!("Suspends the crosvm instance listening on each `VM_SOCKET` given.");
return Err(());
}
- vms_request(&VmRequest::Suspend, args)
+ let socket_path = &args.next().unwrap();
+ let socket_path = Path::new(&socket_path);
+ vms_request(&VmRequest::Suspend, socket_path)
}
-fn resume_vms(args: std::env::Args) -> std::result::Result<(), ()> {
+fn resume_vms(mut args: std::env::Args) -> std::result::Result<(), ()> {
if args.len() == 0 {
print_help("crosvm resume", "VM_SOCKET...", &[]);
println!("Resumes the crosvm instance listening on each `VM_SOCKET` given.");
return Err(());
}
- vms_request(&VmRequest::Resume, args)
+ let socket_path = &args.next().unwrap();
+ let socket_path = Path::new(&socket_path);
+ vms_request(&VmRequest::Resume, socket_path)
}
fn balloon_vms(mut args: std::env::Args) -> std::result::Result<(), ()> {
@@ -1879,10 +2063,12 @@ fn balloon_vms(mut args: std::env::Args) -> std::result::Result<(), ()> {
};
let command = BalloonControlCommand::Adjust { num_bytes };
- vms_request(&VmRequest::BalloonCommand(command), args)
+ let socket_path = &args.next().unwrap();
+ let socket_path = Path::new(&socket_path);
+ vms_request(&VmRequest::BalloonCommand(command), socket_path)
}
-fn balloon_stats(args: std::env::Args) -> std::result::Result<(), ()> {
+fn balloon_stats(mut args: std::env::Args) -> std::result::Result<(), ()> {
if args.len() != 1 {
print_help("crosvm balloon_stats", "VM_SOCKET", &[]);
println!("Prints virtio balloon statistics for a `VM_SOCKET`.");
@@ -1890,7 +2076,9 @@ fn balloon_stats(args: std::env::Args) -> std::result::Result<(), ()> {
}
let command = BalloonControlCommand::Stats {};
let request = &VmRequest::BalloonCommand(command);
- let response = handle_request(request, args)?;
+ let socket_path = &args.next().unwrap();
+ let socket_path = Path::new(&socket_path);
+ let response = handle_request(request, socket_path)?;
println!("{}", response);
Ok(())
}
@@ -2013,47 +2201,11 @@ fn disk_cmd(mut args: std::env::Args) -> std::result::Result<(), ()> {
}
};
- vms_request(&request, args)
-}
-
-enum ModifyUsbError {
- ArgMissing(&'static str),
- ArgParse(&'static str, String),
- ArgParseInt(&'static str, String, ParseIntError),
- FailedDescriptorValidate(base::Error),
- PathDoesNotExist(PathBuf),
- SocketFailed,
- UnexpectedResponse(VmResponse),
- UnknownCommand(String),
- UsbControl(UsbControlResult),
-}
-
-impl fmt::Display for ModifyUsbError {
- fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
- use self::ModifyUsbError::*;
-
- match self {
- ArgMissing(a) => write!(f, "argument missing: {}", a),
- ArgParse(name, value) => {
- write!(f, "failed to parse argument {} value `{}`", name, value)
- }
- ArgParseInt(name, value, e) => write!(
- f,
- "failed to parse integer argument {} value `{}`: {}",
- name, value, e
- ),
- FailedDescriptorValidate(e) => write!(f, "failed to validate file descriptor: {}", e),
- PathDoesNotExist(p) => write!(f, "path `{}` does not exist", p.display()),
- SocketFailed => write!(f, "socket failed"),
- UnexpectedResponse(r) => write!(f, "unexpected response: {}", r),
- UnknownCommand(c) => write!(f, "unknown command: `{}`", c),
- UsbControl(e) => write!(f, "{}", e),
- }
- }
+ let socket_path = &args.next().unwrap();
+ let socket_path = Path::new(&socket_path);
+ vms_request(&request, socket_path)
}
-type ModifyUsbResult<T> = std::result::Result<T, ModifyUsbError>;
-
fn parse_bus_id_addr(v: &str) -> ModifyUsbResult<(u8, u8, u16, u16)> {
debug!("parse_bus_id_addr: {}", v);
let mut ids = v.split(':');
@@ -2078,27 +2230,6 @@ fn parse_bus_id_addr(v: &str) -> ModifyUsbResult<(u8, u8, u16, u16)> {
}
}
-fn raw_descriptor_from_path(path: &Path) -> ModifyUsbResult<RawDescriptor> {
- if !path.exists() {
- return Err(ModifyUsbError::PathDoesNotExist(path.to_owned()));
- }
- let raw_descriptor = path
- .file_name()
- .and_then(|fd_osstr| fd_osstr.to_str())
- .map_or(
- Err(ModifyUsbError::ArgParse(
- "USB_DEVICE_PATH",
- path.to_string_lossy().into_owned(),
- )),
- |fd_str| {
- fd_str.parse::<libc::c_int>().map_err(|e| {
- ModifyUsbError::ArgParseInt("USB_DEVICE_PATH", fd_str.to_owned(), e)
- })
- },
- )?;
- validate_raw_descriptor(raw_descriptor).map_err(ModifyUsbError::FailedDescriptorValidate)
-}
-
fn usb_attach(mut args: std::env::Args) -> ModifyUsbResult<UsbControlResult> {
let val = args
.next()
@@ -2108,39 +2239,13 @@ fn usb_attach(mut args: std::env::Args) -> ModifyUsbResult<UsbControlResult> {
args.next()
.ok_or(ModifyUsbError::ArgMissing("usb device path"))?,
);
- let usb_file: Option<File> = if dev_path == Path::new("-") {
- None
- } else if dev_path.parent() == Some(Path::new("/proc/self/fd")) {
- // Special case '/proc/self/fd/*' paths. The FD is already open, just use it.
- // Safe because we will validate |raw_fd|.
- Some(unsafe { File::from_raw_descriptor(raw_descriptor_from_path(&dev_path)?) })
- } else {
- Some(
- OpenOptions::new()
- .read(true)
- .write(true)
- .open(&dev_path)
- .map_err(|_| ModifyUsbError::UsbControl(UsbControlResult::FailedToOpenDevice))?,
- )
- };
- let request = VmRequest::UsbCommand(UsbControlCommand::AttachDevice {
- bus,
- addr,
- vid,
- pid,
- // Safe because we are transferring ownership to the rawdescriptor
- descriptor: usb_file.map(|file| {
- MaybeOwnedDescriptor::Owned(unsafe {
- SafeDescriptor::from_raw_descriptor(file.into_raw_descriptor())
- })
- }),
- });
- let response = handle_request(&request, args).map_err(|_| ModifyUsbError::SocketFailed)?;
- match response {
- VmResponse::UsbResponse(usb_resp) => Ok(usb_resp),
- r => Err(ModifyUsbError::UnexpectedResponse(r)),
- }
+ let socket_path = args
+ .next()
+ .ok_or(ModifyUsbError::ArgMissing("control socket path"))?;
+ let socket_path = Path::new(&socket_path);
+
+ do_usb_attach(&socket_path, bus, addr, vid, pid, &dev_path)
}
fn usb_detach(mut args: std::env::Args) -> ModifyUsbResult<UsbControlResult> {
@@ -2150,25 +2255,19 @@ fn usb_detach(mut args: std::env::Args) -> ModifyUsbResult<UsbControlResult> {
p.parse::<u8>()
.map_err(|e| ModifyUsbError::ArgParseInt("PORT", p.to_owned(), e))
})?;
- let request = VmRequest::UsbCommand(UsbControlCommand::DetachDevice { port });
- let response = handle_request(&request, args).map_err(|_| ModifyUsbError::SocketFailed)?;
- match response {
- VmResponse::UsbResponse(usb_resp) => Ok(usb_resp),
- r => Err(ModifyUsbError::UnexpectedResponse(r)),
- }
+ let socket_path = args
+ .next()
+ .ok_or(ModifyUsbError::ArgMissing("control socket path"))?;
+ let socket_path = Path::new(&socket_path);
+ do_usb_detach(&socket_path, port)
}
-fn usb_list(args: std::env::Args) -> ModifyUsbResult<UsbControlResult> {
- let mut ports: [u8; USB_CONTROL_MAX_PORTS] = Default::default();
- for (index, port) in ports.iter_mut().enumerate() {
- *port = index as u8
- }
- let request = VmRequest::UsbCommand(UsbControlCommand::ListDevice { ports });
- let response = handle_request(&request, args).map_err(|_| ModifyUsbError::SocketFailed)?;
- match response {
- VmResponse::UsbResponse(usb_resp) => Ok(usb_resp),
- r => Err(ModifyUsbError::UnexpectedResponse(r)),
- }
+fn usb_list(mut args: std::env::Args) -> ModifyUsbResult<UsbControlResult> {
+ let socket_path = args
+ .next()
+ .ok_or(ModifyUsbError::ArgMissing("control socket path"))?;
+ let socket_path = Path::new(&socket_path);
+ do_usb_list(&socket_path)
}
fn modify_usb(mut args: std::env::Args) -> std::result::Result<(), ()> {
@@ -2179,7 +2278,7 @@ fn modify_usb(mut args: std::env::Args) -> std::result::Result<(), ()> {
}
// This unwrap will not panic because of the above length check.
- let command = args.next().unwrap();
+ let command = &args.next().unwrap();
let result = match command.as_ref() {
"attach" => usb_attach(args),
"detach" => usb_detach(args),
@@ -2199,16 +2298,22 @@ fn modify_usb(mut args: std::env::Args) -> std::result::Result<(), ()> {
}
fn print_usage() {
- print_help("crosvm", "[stop|run]", &[]);
+ print_help("crosvm", "[command]", &[]);
println!("Commands:");
- println!(" stop - Stops crosvm instances via their control sockets.");
- println!(" run - Start a new crosvm instance.");
+ println!(" balloon - Set balloon size of the crosvm instance.");
+ println!(" balloon_stats - Prints virtio balloon statistics.");
+ println!(" battery - Modify battery.");
println!(" create_qcow2 - Create a new qcow2 disk image file.");
println!(" disk - Manage attached virtual disk devices.");
+ println!(" resume - Resumes the crosvm instance.");
+ println!(" run - Start a new crosvm instance.");
+ println!(" stop - Stops crosvm instances via their control sockets.");
+ println!(" suspend - Suspends the crosvm instance.");
println!(" usb - Manage attached virtual USB devices.");
println!(" version - Show package version.");
}
+#[allow(clippy::unnecessary_wraps)]
fn pkg_version() -> std::result::Result<(), ()> {
const VERSION: Option<&'static str> = option_env!("CARGO_PKG_VERSION");
const PKG_VERSION: Option<&'static str> = option_env!("PKG_VERSION");
@@ -2221,20 +2326,6 @@ fn pkg_version() -> std::result::Result<(), ()> {
Ok(())
}
-enum ModifyBatError {
- BatControlErr(BatControlResult),
-}
-
-impl fmt::Display for ModifyBatError {
- fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
- use self::ModifyBatError::*;
-
- match self {
- BatControlErr(e) => write!(f, "{}", e),
- }
- }
-}
-
fn modify_battery(mut args: std::env::Args) -> std::result::Result<(), ()> {
if args.len() < 4 {
print_help("crosvm battery BATTERY_TYPE ",
@@ -2247,27 +2338,10 @@ fn modify_battery(mut args: std::env::Args) -> std::result::Result<(), ()> {
let property = args.next().unwrap();
let target = args.next().unwrap();
- let response = match battery_type.parse::<BatteryType>() {
- Ok(type_) => match BatControlCommand::new(property, target) {
- Ok(cmd) => {
- let request = VmRequest::BatCommand(type_, cmd);
- Ok(handle_request(&request, args)?)
- }
- Err(e) => Err(ModifyBatError::BatControlErr(e)),
- },
- Err(e) => Err(ModifyBatError::BatControlErr(e)),
- };
+ let socket_path = args.next().unwrap();
+ let socket_path = Path::new(&socket_path);
- match response {
- Ok(response) => {
- println!("{}", response);
- Ok(())
- }
- Err(e) => {
- println!("error {}", e);
- Err(())
- }
- }
+ do_modify_battery(&socket_path, &*battery_type, &*property, &*target)
}
fn crosvm_main() -> std::result::Result<(), ()> {
@@ -2444,6 +2518,17 @@ mod tests {
#[cfg(feature = "audio")]
#[test]
+ fn parse_ac97_client_type() {
+ parse_ac97_options("backend=cras,capture=true,client_type=crosvm")
+ .expect("parse should have succeded");
+ parse_ac97_options("backend=cras,capture=true,client_type=arcvm")
+ .expect("parse should have succeded");
+ parse_ac97_options("backend=cras,capture=true,client_type=none")
+ .expect_err("parse should have failed");
+ }
+
+ #[cfg(feature = "audio")]
+ #[test]
fn parse_ac97_vios_valid() {
parse_ac97_options("backend=vios,server=/path/to/server")
.expect("parse should have succeded");
@@ -2717,41 +2802,54 @@ mod tests {
validate_arguments(&mut config).unwrap();
assert_eq!(
config.virtio_switches.unwrap(),
- PathBuf::from("/dev/switches-test"));
+ PathBuf::from("/dev/switches-test")
+ );
}
- #[cfg(all(feature = "gpu", feature = "gfxstream"))]
+ #[cfg(feature = "gpu")]
+ #[test]
+ fn parse_gpu_options_default_vulkan_support() {
+ assert!(
+ !parse_gpu_options(Some("backend=virglrenderer"))
+ .unwrap()
+ .use_vulkan
+ );
+
+ #[cfg(feature = "gfxstream")]
+ assert!(
+ parse_gpu_options(Some("backend=gfxstream"))
+ .unwrap()
+ .use_vulkan
+ );
+ }
+
+ #[cfg(feature = "gpu")]
#[test]
- fn parse_gpu_options_gfxstream_with_vulkan_specified() {
+ fn parse_gpu_options_with_vulkan_specified() {
+ assert!(parse_gpu_options(Some("vulkan=true")).unwrap().use_vulkan);
assert!(
- parse_gpu_options(Some("backend=gfxstream,vulkan=true"))
+ parse_gpu_options(Some("backend=virglrenderer,vulkan=true"))
.unwrap()
- .gfxstream_support_vulkan
+ .use_vulkan
);
assert!(
- parse_gpu_options(Some("vulkan=true,backend=gfxstream"))
+ parse_gpu_options(Some("vulkan=true,backend=virglrenderer"))
.unwrap()
- .gfxstream_support_vulkan
+ .use_vulkan
);
+ assert!(!parse_gpu_options(Some("vulkan=false")).unwrap().use_vulkan);
assert!(
- !parse_gpu_options(Some("backend=gfxstream,vulkan=false"))
+ !parse_gpu_options(Some("backend=virglrenderer,vulkan=false"))
.unwrap()
- .gfxstream_support_vulkan
+ .use_vulkan
);
assert!(
- !parse_gpu_options(Some("vulkan=false,backend=gfxstream"))
+ !parse_gpu_options(Some("vulkan=false,backend=virglrenderer"))
.unwrap()
- .gfxstream_support_vulkan
+ .use_vulkan
);
- assert!(parse_gpu_options(Some("backend=gfxstream,vulkan=invalid_value")).is_err());
- assert!(parse_gpu_options(Some("vulkan=invalid_value,backend=gfxstream")).is_err());
- }
-
- #[cfg(all(feature = "gpu", feature = "gfxstream"))]
- #[test]
- fn parse_gpu_options_not_gfxstream_with_vulkan_specified() {
- assert!(parse_gpu_options(Some("backend=3d,vulkan=true")).is_err());
- assert!(parse_gpu_options(Some("vulkan=true,backend=3d")).is_err());
+ assert!(parse_gpu_options(Some("backend=virglrenderer,vulkan=invalid_value")).is_err());
+ assert!(parse_gpu_options(Some("vulkan=invalid_value,backend=virglrenderer")).is_err());
}
#[cfg(all(feature = "gpu", feature = "gfxstream"))]
@@ -2784,8 +2882,8 @@ mod tests {
#[cfg(all(feature = "gpu", feature = "gfxstream"))]
#[test]
fn parse_gpu_options_not_gfxstream_with_syncfd_specified() {
- assert!(parse_gpu_options(Some("backend=3d,syncfd=true")).is_err());
- assert!(parse_gpu_options(Some("syncfd=true,backend=3d")).is_err());
+ assert!(parse_gpu_options(Some("backend=virglrenderer,syncfd=true")).is_err());
+ assert!(parse_gpu_options(Some("syncfd=true,backend=virglrenderer")).is_err());
}
#[test]
diff --git a/src/plugin/mod.rs b/src/plugin/mod.rs
index f900bf599..8236425c0 100644
--- a/src/plugin/mod.rs
+++ b/src/plugin/mod.rs
@@ -34,7 +34,7 @@ use base::{
use kvm::{Cap, Datamatch, IoeventAddress, Kvm, Vcpu, VcpuExit, Vm};
use minijail::{self, Minijail};
use net_util::{Error as TapError, Tap, TapT};
-use vm_memory::GuestMemory;
+use vm_memory::{GuestMemory, MemoryPolicy};
use self::process::*;
use self::vcpu::*;
@@ -418,7 +418,7 @@ pub fn run_vcpus(
if use_kvm_signals {
unsafe {
- extern "C" fn handle_signal() {}
+ extern "C" fn handle_signal(_: c_int) {}
// Our signal handler does nothing and is trivially async signal safe.
// We need to install this signal handler even though we do block
// the signal below, to ensure that this signal will interrupt
@@ -430,7 +430,7 @@ pub fn run_vcpus(
block_signal(SIGRTMIN() + 0).expect("failed to block signal");
} else {
unsafe {
- extern "C" fn handle_signal() {
+ extern "C" fn handle_signal(_: c_int) {
Vcpu::set_local_immediate_exit(true);
}
register_rt_signal_handler(SIGRTMIN() + 0, handle_signal)
@@ -691,7 +691,12 @@ pub fn run_config(cfg: Config) -> Result<()> {
};
let vcpu_count = cfg.vcpu_count.unwrap_or(1) as u32;
let mem = GuestMemory::new(&[]).unwrap();
- let kvm = Kvm::new().map_err(Error::CreateKvm)?;
+ let mut mem_policy = MemoryPolicy::empty();
+ if cfg.hugepages {
+ mem_policy |= MemoryPolicy::USE_HUGEPAGES;
+ }
+ mem.set_memory_policy(mem_policy);
+ let kvm = Kvm::new_with_path(&cfg.kvm_device_path).map_err(Error::CreateKvm)?;
let mut vm = Vm::new(&kvm, mem).map_err(Error::CreateVm)?;
vm.create_irq_chip().map_err(Error::CreateIrqChip)?;
vm.create_pit().map_err(Error::CreatePIT)?;
diff --git a/src/plugin/process.rs b/src/plugin/process.rs
index 7c61c0908..c34b708bb 100644
--- a/src/plugin/process.rs
+++ b/src/plugin/process.rs
@@ -365,7 +365,7 @@ impl Process {
_ => {}
}
let mem = MemoryMappingBuilder::new(length as usize)
- .from_descriptor(&shm)
+ .from_shared_memory(&shm)
.offset(offset)
.build()
.map_err(mmap_to_sys_err)?;
diff --git a/sys_util/Cargo.toml b/sys_util/Cargo.toml
index 99b994cdb..7edffd9f5 100644
--- a/sys_util/Cargo.toml
+++ b/sys_util/Cargo.toml
@@ -9,8 +9,9 @@ include = ["src/**/*", "Cargo.toml"]
data_model = { path = "../data_model" } # provided by ebuild
libc = "*"
poll_token_derive = { version = "*", path = "poll_token_derive" }
+serde = { version = "1", features = [ "derive" ] }
+serde_json = "1"
sync = { path = "../sync" } # provided by ebuild
-syscall_defines = { path = "../syscall_defines" } # provided by ebuild
tempfile = { path = "../tempfile" } # provided by ebuild
[target.'cfg(target_os = "android")'.dependencies]
diff --git a/sys_util/src/descriptor.rs b/sys_util/src/descriptor.rs
index 5ee075a42..325b2b450 100644
--- a/sys_util/src/descriptor.rs
+++ b/sys_util/src/descriptor.rs
@@ -2,6 +2,7 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
+use std::convert::TryFrom;
use std::fs::File;
use std::io::{Stderr, Stdin, Stdout};
use std::mem;
@@ -10,6 +11,8 @@ use std::ops::Drop;
use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
use std::os::unix::net::{UnixDatagram, UnixStream};
+use serde::{Deserialize, Serialize};
+
use crate::net::UnlinkUnixSeqpacketListener;
use crate::{errno_result, PollToken, Result};
@@ -33,9 +36,24 @@ pub trait FromRawDescriptor {
unsafe fn from_raw_descriptor(descriptor: RawDescriptor) -> Self;
}
+/// Clones `fd`, returning a new file descriptor that refers to the same open file description as
+/// `fd`. The cloned fd will have the `FD_CLOEXEC` flag set but will not share any other file
+/// descriptor flags with `fd`.
+pub fn clone_fd(fd: &dyn AsRawFd) -> Result<RawFd> {
+ // Safe because this doesn't modify any memory and we check the return value.
+ let ret = unsafe { libc::fcntl(fd.as_raw_fd(), libc::F_DUPFD_CLOEXEC, 0) };
+ if ret < 0 {
+ errno_result()
+ } else {
+ Ok(ret)
+ }
+}
+
/// Wraps a RawDescriptor and safely closes it when self falls out of scope.
-#[derive(Debug, Eq)]
+#[derive(Serialize, Deserialize, Debug, Eq)]
+#[serde(transparent)]
pub struct SafeDescriptor {
+ #[serde(with = "crate::with_raw_descriptor")]
descriptor: RawDescriptor,
}
@@ -98,18 +116,41 @@ impl AsRawFd for SafeDescriptor {
}
}
+impl TryFrom<&dyn AsRawFd> for SafeDescriptor {
+ type Error = std::io::Error;
+
+ fn try_from(fd: &dyn AsRawFd) -> std::result::Result<Self, Self::Error> {
+ Ok(SafeDescriptor {
+ descriptor: clone_fd(fd)?,
+ })
+ }
+}
+
impl SafeDescriptor {
/// Clones this descriptor, internally creating a new descriptor. The new SafeDescriptor will
/// share the same underlying count within the kernel.
pub fn try_clone(&self) -> Result<SafeDescriptor> {
- // Safe because self.as_raw_descriptor() returns a valid value
- let copy_fd = unsafe { libc::dup(self.as_raw_descriptor()) };
- if copy_fd < 0 {
- return errno_result();
+ // Safe because this doesn't modify any memory and we check the return value.
+ let descriptor = unsafe { libc::fcntl(self.descriptor, libc::F_DUPFD_CLOEXEC, 0) };
+ if descriptor < 0 {
+ errno_result()
+ } else {
+ Ok(SafeDescriptor { descriptor })
}
- // Safe becuase we just successfully duplicated and this object will uniquely
- // own the raw descriptor.
- Ok(unsafe { SafeDescriptor::from_raw_descriptor(copy_fd) })
+ }
+}
+
+impl From<SafeDescriptor> for File {
+ fn from(s: SafeDescriptor) -> File {
+ // Safe because we own the SafeDescriptor at this point.
+ unsafe { File::from_raw_fd(s.into_raw_descriptor()) }
+ }
+}
+
+impl From<File> for SafeDescriptor {
+ fn from(f: File) -> SafeDescriptor {
+ // Safe because we own the File at this point.
+ unsafe { SafeDescriptor::from_raw_descriptor(f.into_raw_descriptor()) }
}
}
diff --git a/sys_util/src/descriptor_reflection.rs b/sys_util/src/descriptor_reflection.rs
new file mode 100644
index 000000000..9b18ffc37
--- /dev/null
+++ b/sys_util/src/descriptor_reflection.rs
@@ -0,0 +1,541 @@
+// Copyright 2020 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+//! Provides infrastructure for de/serializing descriptors embedded in Rust data structures.
+//!
+//! # Example
+//!
+//! ```
+//! use serde_json::to_string;
+//! use sys_util::{
+//! FileSerdeWrapper, FromRawDescriptor, SafeDescriptor, SerializeDescriptors,
+//! deserialize_with_descriptors,
+//! };
+//! use tempfile::tempfile;
+//!
+//! let tmp_f = tempfile().unwrap();
+//!
+//! // Uses a simple wrapper to serialize a File because we can't implement Serialize for File.
+//! let data = FileSerdeWrapper(tmp_f);
+//!
+//! // Wraps Serialize types to collect side channel descriptors as Serialize is called.
+//! let data_wrapper = SerializeDescriptors::new(&data);
+//!
+//! // Use the wrapper with any serializer to serialize data is normal, grabbing descriptors
+//! // as the data structures are serialized by the serializer.
+//! let out_json = serde_json::to_string(&data_wrapper).expect("failed to serialize");
+//!
+//! // If data_wrapper contains any side channel descriptor refs
+//! // (it contains tmp_f in this case), we can retrieve the actual descriptors
+//! // from the side channel using into_descriptors().
+//! let out_descriptors = data_wrapper.into_descriptors();
+//!
+//! // When sending out_json over some transport, also send out_descriptors.
+//!
+//! // For this example, we aren't really transporting data across the process, but we do need to
+//! // convert the descriptor type.
+//! let mut safe_descriptors = out_descriptors
+//! .iter()
+//! .map(|&v| Some(unsafe { SafeDescriptor::from_raw_descriptor(v) }))
+//! .collect();
+//! std::mem::forget(data); // Prevent double drop of tmp_f.
+//!
+//! // The deserialize_with_descriptors function is used give the descriptor deserializers access
+//! // to side channel descriptors.
+//! let res: FileSerdeWrapper =
+//! deserialize_with_descriptors(|| serde_json::from_str(&out_json), &mut safe_descriptors)
+//! .expect("failed to deserialize");
+//! ```
+
+use std::cell::{Cell, RefCell};
+use std::convert::TryInto;
+use std::fmt;
+use std::fs::File;
+use std::ops::{Deref, DerefMut};
+use std::panic::{catch_unwind, resume_unwind, AssertUnwindSafe};
+
+use serde::de::{self, Error, Visitor};
+use serde::ser;
+use serde::{Deserialize, Deserializer, Serialize, Serializer};
+
+use crate::{RawDescriptor, SafeDescriptor};
+
+thread_local! {
+ static DESCRIPTOR_DST: RefCell<Option<Vec<RawDescriptor>>> = Default::default();
+}
+
+/// Initializes the thread local storage for descriptor serialization. Fails if it was already
+/// initialized without an intervening `take_descriptor_dst` on this thread.
+fn init_descriptor_dst() -> Result<(), &'static str> {
+ DESCRIPTOR_DST.with(|d| {
+ let mut descriptors = d.borrow_mut();
+ if descriptors.is_some() {
+ return Err(
+ "attempt to initialize descriptor destination that was already initialized",
+ );
+ }
+ *descriptors = Some(Default::default());
+ Ok(())
+ })
+}
+
+/// Takes the thread local storage for descriptor serialization. Fails if there wasn't a prior call
+/// to `init_descriptor_dst` on this thread.
+fn take_descriptor_dst() -> Result<Vec<RawDescriptor>, &'static str> {
+ match DESCRIPTOR_DST.with(|d| d.replace(None)) {
+ Some(d) => Ok(d),
+ None => Err("attempt to take descriptor destination before it was initialized"),
+ }
+}
+
+/// Pushes a descriptor on the thread local destination of descriptors, returning the index in which
+/// the descriptor was pushed.
+//
+/// Returns Err if the thread local destination was not already initialized.
+fn push_descriptor(rd: RawDescriptor) -> Result<usize, &'static str> {
+ DESCRIPTOR_DST.with(|d| {
+ d.borrow_mut()
+ .as_mut()
+ .ok_or("attempt to serialize descriptor without descriptor destination")
+ .map(|descriptors| {
+ let index = descriptors.len();
+ descriptors.push(rd);
+ index
+ })
+ })
+}
+
+/// Serializes a descriptor for later retrieval in a parent `SerializeDescriptors` struct.
+///
+/// If there is no parent `SerializeDescriptors` being serialized, this will return an error.
+///
+/// For convenience, it is recommended to use the `with_raw_descriptor` module in a `#[serde(with =
+/// "...")]` attribute which will make use of this function.
+pub fn serialize_descriptor<S: Serializer>(
+ rd: &RawDescriptor,
+ se: S,
+) -> std::result::Result<S::Ok, S::Error> {
+ let index = push_descriptor(*rd).map_err(ser::Error::custom)?;
+ se.serialize_u32(
+ index
+ .try_into()
+ .map_err(|_| ser::Error::custom("attempt to serialize too many descriptors at once"))?,
+ )
+}
+
+/// Wrapper for a `Serialize` value which will capture any descriptors exported by the value when
+/// given to an ordinary `Serializer`.
+///
+/// This is the corresponding type to use for serialization before using
+/// `deserialize_with_descriptors`.
+///
+/// # Examples
+///
+/// ```
+/// use serde_json::to_string;
+/// use sys_util::{FileSerdeWrapper, SerializeDescriptors};
+/// use tempfile::tempfile;
+///
+/// let tmp_f = tempfile().unwrap();
+/// let data = FileSerdeWrapper(tmp_f);
+/// let data_wrapper = SerializeDescriptors::new(&data);
+///
+/// // Serializes `v` as normal...
+/// let out_json = serde_json::to_string(&data_wrapper).expect("failed to serialize");
+/// // If `serialize_descriptor` was called, we can capture the descriptors from here.
+/// let out_descriptors = data_wrapper.into_descriptors();
+/// ```
+pub struct SerializeDescriptors<'a, T: Serialize>(&'a T, Cell<Vec<RawDescriptor>>);
+
+impl<'a, T: Serialize> SerializeDescriptors<'a, T> {
+ pub fn new(inner: &'a T) -> Self {
+ Self(inner, Default::default())
+ }
+
+ pub fn into_descriptors(self) -> Vec<RawDescriptor> {
+ self.1.into_inner()
+ }
+}
+
+impl<'a, T: Serialize> Serialize for SerializeDescriptors<'a, T> {
+ fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
+ where
+ S: Serializer,
+ {
+ init_descriptor_dst().map_err(ser::Error::custom)?;
+
+ // catch_unwind is used to ensure that init_descriptor_dst is always balanced with a call to
+ // take_descriptor_dst afterwards.
+ let res = catch_unwind(AssertUnwindSafe(|| self.0.serialize(serializer)));
+ self.1.set(take_descriptor_dst().unwrap());
+ match res {
+ Ok(r) => r,
+ Err(e) => resume_unwind(e),
+ }
+ }
+}
+
+thread_local! {
+ static DESCRIPTOR_SRC: RefCell<Option<Vec<Option<SafeDescriptor>>>> = Default::default();
+}
+
+/// Sets the thread local storage of descriptors for deserialization. Fails if this was already
+/// called without a call to `take_descriptor_src` on this thread.
+///
+/// This is given as a collection of `Option` so that unused descriptors can be returned.
+fn set_descriptor_src(descriptors: Vec<Option<SafeDescriptor>>) -> Result<(), &'static str> {
+ DESCRIPTOR_SRC.with(|d| {
+ let mut src = d.borrow_mut();
+ if src.is_some() {
+ return Err("attempt to set descriptor source that was already set");
+ }
+ *src = Some(descriptors);
+ Ok(())
+ })
+}
+
+/// Takes the thread local storage of descriptors for deserialization. Fails if the storage was
+/// already taken or never set with `set_descriptor_src`.
+///
+/// If deserialization was done, the descriptors will mostly come back as `None` unless some of them
+/// were unused.
+fn take_descriptor_src() -> Result<Vec<Option<SafeDescriptor>>, &'static str> {
+ DESCRIPTOR_SRC.with(|d| {
+ d.replace(None)
+ .ok_or("attempt to take descriptor source which was never set")
+ })
+}
+
+/// Takes a descriptor at the given index from the thread local source of descriptors.
+//
+/// Returns None if the thread local source was not already initialized.
+fn take_descriptor(index: usize) -> Result<SafeDescriptor, &'static str> {
+ DESCRIPTOR_SRC.with(|d| {
+ d.borrow_mut()
+ .as_mut()
+ .ok_or("attempt to deserialize descriptor without descriptor source")?
+ .get_mut(index)
+ .ok_or("attempt to deserialize out of bounds descriptor")?
+ .take()
+ .ok_or("attempt to deserialize descriptor that was already taken")
+ })
+}
+
+/// Deserializes a descriptor provided via `deserialize_with_descriptors`.
+///
+/// If `deserialize_with_descriptors` is not in the call chain, this will return an error.
+///
+/// For convenience, it is recommended to use the `with_raw_descriptor` module in a `#[serde(with =
+/// "...")]` attribute which will make use of this function.
+pub fn deserialize_descriptor<'de, D>(de: D) -> std::result::Result<SafeDescriptor, D::Error>
+where
+ D: Deserializer<'de>,
+{
+ struct DescriptorVisitor;
+
+ impl<'de> Visitor<'de> for DescriptorVisitor {
+ type Value = u32;
+
+ fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
+ formatter.write_str("an integer which fits into a u32")
+ }
+
+ fn visit_u8<E: de::Error>(self, value: u8) -> Result<Self::Value, E> {
+ Ok(value as _)
+ }
+
+ fn visit_u16<E: de::Error>(self, value: u16) -> Result<Self::Value, E> {
+ Ok(value as _)
+ }
+
+ fn visit_u32<E: de::Error>(self, value: u32) -> Result<Self::Value, E> {
+ Ok(value)
+ }
+
+ fn visit_u64<E: de::Error>(self, value: u64) -> Result<Self::Value, E> {
+ value.try_into().map_err(E::custom)
+ }
+
+ fn visit_u128<E: de::Error>(self, value: u128) -> Result<Self::Value, E> {
+ value.try_into().map_err(E::custom)
+ }
+
+ fn visit_i8<E: de::Error>(self, value: i8) -> Result<Self::Value, E> {
+ value.try_into().map_err(E::custom)
+ }
+
+ fn visit_i16<E: de::Error>(self, value: i16) -> Result<Self::Value, E> {
+ value.try_into().map_err(E::custom)
+ }
+
+ fn visit_i32<E: de::Error>(self, value: i32) -> Result<Self::Value, E> {
+ value.try_into().map_err(E::custom)
+ }
+
+ fn visit_i64<E: de::Error>(self, value: i64) -> Result<Self::Value, E> {
+ value.try_into().map_err(E::custom)
+ }
+
+ fn visit_i128<E: de::Error>(self, value: i128) -> Result<Self::Value, E> {
+ value.try_into().map_err(E::custom)
+ }
+ }
+
+ let index = de.deserialize_u32(DescriptorVisitor)? as usize;
+ take_descriptor(index).map_err(D::Error::custom)
+}
+
+/// Allows the use of any serde deserializer within a closure while providing access to the a set of
+/// descriptors for use in `deserialize_descriptor`.
+///
+/// This is the corresponding call to use deserialize after using `SerializeDescriptors`.
+///
+/// If `deserialize_with_descriptors` is called anywhere within the given closure, it return an
+/// error.
+pub fn deserialize_with_descriptors<F, T, E>(
+ f: F,
+ descriptors: &mut Vec<Option<SafeDescriptor>>,
+) -> Result<T, E>
+where
+ F: FnOnce() -> Result<T, E>,
+ E: de::Error,
+{
+ let swap_descriptors = std::mem::take(descriptors);
+ set_descriptor_src(swap_descriptors).map_err(E::custom)?;
+
+ // catch_unwind is used to ensure that set_descriptor_src is always balanced with a call to
+ // take_descriptor_src afterwards.
+ let res = catch_unwind(AssertUnwindSafe(f));
+
+ // unwrap is used because set_descriptor_src is always called before this, so it should never
+ // panic.
+ *descriptors = take_descriptor_src().unwrap();
+
+ match res {
+ Ok(r) => r,
+ Err(e) => resume_unwind(e),
+ }
+}
+
+/// Module that exports `serialize`/`deserialize` functions for use with `#[serde(with = "...")]`
+/// attribute. It only works with fields with `RawDescriptor` type.
+///
+/// # Examples
+///
+/// ```
+/// use serde::{Deserialize, Serialize};
+/// use sys_util::RawDescriptor;
+///
+/// #[derive(Serialize, Deserialize)]
+/// struct RawContainer {
+/// #[serde(with = "sys_util::with_raw_descriptor")]
+/// rd: RawDescriptor,
+/// }
+/// ```
+pub mod with_raw_descriptor {
+ use crate::{IntoRawDescriptor, RawDescriptor};
+ use serde::Deserializer;
+
+ pub use super::serialize_descriptor as serialize;
+
+ pub fn deserialize<'de, D>(de: D) -> std::result::Result<RawDescriptor, D::Error>
+ where
+ D: Deserializer<'de>,
+ {
+ super::deserialize_descriptor(de).map(IntoRawDescriptor::into_raw_descriptor)
+ }
+}
+
+/// Module that exports `serialize`/`deserialize` functions for use with `#[serde(with = "...")]`
+/// attribute.
+///
+/// # Examples
+///
+/// ```
+/// use std::fs::File;
+/// use serde::{Deserialize, Serialize};
+/// use sys_util::RawDescriptor;
+///
+/// #[derive(Serialize, Deserialize)]
+/// struct FileContainer {
+/// #[serde(with = "sys_util::with_as_descriptor")]
+/// file: File,
+/// }
+/// ```
+pub mod with_as_descriptor {
+ use crate::{AsRawDescriptor, FromRawDescriptor, IntoRawDescriptor};
+ use serde::{Deserializer, Serializer};
+
+ pub fn serialize<S: Serializer>(
+ rd: &dyn AsRawDescriptor,
+ se: S,
+ ) -> std::result::Result<S::Ok, S::Error> {
+ super::serialize_descriptor(&rd.as_raw_descriptor(), se)
+ }
+
+ pub fn deserialize<'de, D, T>(de: D) -> std::result::Result<T, D::Error>
+ where
+ D: Deserializer<'de>,
+ T: FromRawDescriptor,
+ {
+ super::deserialize_descriptor(de)
+ .map(IntoRawDescriptor::into_raw_descriptor)
+ .map(|rd| unsafe { T::from_raw_descriptor(rd) })
+ }
+}
+
+/// A simple wrapper around `File` that implements `Serialize`/`Deserialize`, which is useful when
+/// the `#[serde(with = "with_as_descriptor")]` trait is infeasible, such as for a field with type
+/// `Option<File>`.
+#[derive(Serialize, Deserialize)]
+#[serde(transparent)]
+pub struct FileSerdeWrapper(#[serde(with = "with_as_descriptor")] pub File);
+
+impl fmt::Debug for FileSerdeWrapper {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ self.0.fmt(f)
+ }
+}
+
+impl From<File> for FileSerdeWrapper {
+ fn from(file: File) -> Self {
+ FileSerdeWrapper(file)
+ }
+}
+
+impl From<FileSerdeWrapper> for File {
+ fn from(f: FileSerdeWrapper) -> File {
+ f.0
+ }
+}
+
+impl Deref for FileSerdeWrapper {
+ type Target = File;
+ fn deref(&self) -> &Self::Target {
+ &self.0
+ }
+}
+
+impl DerefMut for FileSerdeWrapper {
+ fn deref_mut(&mut self) -> &mut Self::Target {
+ &mut self.0
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use crate::{
+ deserialize_with_descriptors, with_as_descriptor, with_raw_descriptor, FileSerdeWrapper,
+ FromRawDescriptor, RawDescriptor, SafeDescriptor, SerializeDescriptors,
+ };
+
+ use std::collections::HashMap;
+ use std::fs::File;
+ use std::mem::ManuallyDrop;
+ use std::os::unix::io::AsRawFd;
+
+ use serde::{de::DeserializeOwned, Deserialize, Serialize};
+ use tempfile::tempfile;
+
+ fn deserialize<T: DeserializeOwned>(json: &str, descriptors: &[RawDescriptor]) -> T {
+ let mut safe_descriptors = descriptors
+ .iter()
+ .map(|&v| Some(unsafe { SafeDescriptor::from_raw_descriptor(v) }))
+ .collect();
+
+ let res =
+ deserialize_with_descriptors(|| serde_json::from_str(json), &mut safe_descriptors)
+ .unwrap();
+
+ assert!(safe_descriptors.iter().all(|v| v.is_none()));
+
+ res
+ }
+
+ #[test]
+ fn raw() {
+ #[derive(Serialize, Deserialize, PartialEq, Debug)]
+ struct RawContainer {
+ #[serde(with = "with_raw_descriptor")]
+ rd: RawDescriptor,
+ }
+ // Specifically chosen to not overlap a real descriptor to avoid having to allocate any
+ // descriptors for this test.
+ let fake_rd = 5_123_457 as _;
+ let v = RawContainer { rd: fake_rd };
+ let v_serialize = SerializeDescriptors::new(&v);
+ let json = serde_json::to_string(&v_serialize).unwrap();
+ let descriptors = v_serialize.into_descriptors();
+ let res = deserialize(&json, &descriptors);
+ assert_eq!(v, res);
+ }
+
+ #[test]
+ fn file() {
+ #[derive(Serialize, Deserialize)]
+ struct FileContainer {
+ #[serde(with = "with_as_descriptor")]
+ file: File,
+ }
+
+ let v = FileContainer {
+ file: tempfile().unwrap(),
+ };
+ let v_serialize = SerializeDescriptors::new(&v);
+ let json = serde_json::to_string(&v_serialize).unwrap();
+ let descriptors = v_serialize.into_descriptors();
+ let v = ManuallyDrop::new(v);
+ let res: FileContainer = deserialize(&json, &descriptors);
+ assert_eq!(v.file.as_raw_fd(), res.file.as_raw_fd());
+ }
+
+ #[test]
+ fn option() {
+ #[derive(Serialize, Deserialize)]
+ struct TestOption {
+ a: Option<FileSerdeWrapper>,
+ b: Option<FileSerdeWrapper>,
+ }
+
+ let v = TestOption {
+ a: None,
+ b: Some(tempfile().unwrap().into()),
+ };
+ let v_serialize = SerializeDescriptors::new(&v);
+ let json = serde_json::to_string(&v_serialize).unwrap();
+ let descriptors = v_serialize.into_descriptors();
+ let v = ManuallyDrop::new(v);
+ let res: TestOption = deserialize(&json, &descriptors);
+ assert!(res.a.is_none());
+ assert!(res.b.is_some());
+ assert_eq!(
+ v.b.as_ref().unwrap().as_raw_fd(),
+ res.b.unwrap().as_raw_fd()
+ );
+ }
+
+ #[test]
+ fn map() {
+ let mut v: HashMap<String, FileSerdeWrapper> = HashMap::new();
+ v.insert("a".into(), tempfile().unwrap().into());
+ v.insert("b".into(), tempfile().unwrap().into());
+ v.insert("c".into(), tempfile().unwrap().into());
+ let v_serialize = SerializeDescriptors::new(&v);
+ let json = serde_json::to_string(&v_serialize).unwrap();
+ let descriptors = v_serialize.into_descriptors();
+ // Prevent the files in `v` from dropping while allowing the HashMap itself to drop. It is
+ // done this way to prevent a double close of the files (which should reside in `res`)
+ // without triggering the leak sanitizer on `v`'s HashMap heap memory.
+ let v: HashMap<_, _> = v
+ .into_iter()
+ .map(|(k, v)| (k, ManuallyDrop::new(v)))
+ .collect();
+ let res: HashMap<String, FileSerdeWrapper> = deserialize(&json, &descriptors);
+
+ assert_eq!(v.len(), res.len());
+ for (k, v) in v.iter() {
+ assert_eq!(res.get(k).unwrap().as_raw_fd(), v.as_raw_fd());
+ }
+ }
+}
diff --git a/sys_util/src/errno.rs b/sys_util/src/errno.rs
index 0475442fb..63af14a04 100644
--- a/sys_util/src/errno.rs
+++ b/sys_util/src/errno.rs
@@ -6,9 +6,12 @@ use std::fmt::{self, Display};
use std::io;
use std::result;
+use serde::{Deserialize, Serialize};
+
/// An error number, retrieved from errno (man 3 errno), set by a libc
/// function that returned an error.
-#[derive(Clone, Copy, Debug, PartialEq)]
+#[derive(Serialize, Deserialize, Clone, Copy, Debug, PartialEq)]
+#[serde(transparent)]
pub struct Error(i32);
pub type Result<T> = result::Result<T, Error>;
diff --git a/sys_util/src/eventfd.rs b/sys_util/src/eventfd.rs
index 3897118e6..29fa34469 100644
--- a/sys_util/src/eventfd.rs
+++ b/sys_util/src/eventfd.rs
@@ -8,18 +8,20 @@ use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
use std::ptr;
use std::time::Duration;
-use libc::{c_void, dup, eventfd, read, write, POLLIN};
+use libc::{c_void, eventfd, read, write, POLLIN};
+use serde::{Deserialize, Serialize};
use crate::{
- errno_result, AsRawDescriptor, FromRawDescriptor, IntoRawDescriptor, RawDescriptor, Result,
- SafeDescriptor,
+ duration_to_timespec, errno_result, AsRawDescriptor, FromRawDescriptor, IntoRawDescriptor,
+ RawDescriptor, Result, SafeDescriptor,
};
/// A safe wrapper around a Linux eventfd (man 2 eventfd).
///
/// An eventfd is useful because it is sendable across processes and can be used for signaling in
/// and out of the KVM API. They can also be polled like any other file descriptor.
-#[derive(Debug, PartialEq, Eq)]
+#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
+#[serde(transparent)]
pub struct EventFd {
event_handle: SafeDescriptor,
}
@@ -94,12 +96,7 @@ impl EventFd {
events: POLLIN,
revents: 0,
};
- // Safe because we are zero-initializing a struct with only primitive member fields.
- let mut timeoutspec: libc::timespec = unsafe { mem::zeroed() };
- timeoutspec.tv_sec = timeout.as_secs() as libc::time_t;
- // nsec always fits in i32 because subsec_nanos is defined to be less than one billion.
- let nsec = timeout.subsec_nanos() as i32;
- timeoutspec.tv_nsec = libc::c_long::from(nsec);
+ let timeoutspec: libc::timespec = duration_to_timespec(timeout);
// Safe because this only modifies |pfd| and we check the return value
let ret = unsafe {
libc::ppoll(
@@ -137,16 +134,9 @@ impl EventFd {
/// Clones this EventFd, internally creating a new file descriptor. The new EventFd will share
/// the same underlying count within the kernel.
pub fn try_clone(&self) -> Result<EventFd> {
- // This is safe because we made this fd and properly check that it returns without error.
- let ret = unsafe { dup(self.as_raw_descriptor()) };
- if ret < 0 {
- return errno_result();
- }
- // This is safe because we checked ret for success and know the kernel gave us an fd that we
- // own.
- Ok(EventFd {
- event_handle: unsafe { SafeDescriptor::from_raw_descriptor(ret) },
- })
+ self.event_handle
+ .try_clone()
+ .map(|event_handle| EventFd { event_handle })
}
}
diff --git a/sys_util/src/fork.rs b/sys_util/src/fork.rs
index a1e9b132d..7f4a25640 100644
--- a/sys_util/src/fork.rs
+++ b/sys_util/src/fork.rs
@@ -8,8 +8,7 @@ use std::path::Path;
use std::process;
use std::result;
-use libc::{c_long, pid_t, syscall, CLONE_NEWPID, CLONE_NEWUSER, SIGCHLD};
-use syscall_defines::linux::LinuxSyscall::SYS_clone;
+use libc::{c_long, pid_t, syscall, SYS_clone, CLONE_NEWPID, CLONE_NEWUSER, SIGCHLD};
use crate::errno_result;
diff --git a/sys_util/src/lib.rs b/sys_util/src/lib.rs
index 87af85292..43bdcc4d1 100644
--- a/sys_util/src/lib.rs
+++ b/sys_util/src/lib.rs
@@ -22,6 +22,7 @@ pub mod syslog;
mod capabilities;
mod clock;
mod descriptor;
+mod descriptor_reflection;
mod errno;
mod eventfd;
mod external_mapping;
@@ -33,8 +34,11 @@ pub mod net;
mod passwd;
mod poll;
mod priority;
+pub mod rand;
mod raw_fd;
pub mod sched;
+pub mod scoped_path;
+pub mod scoped_signal_handler;
mod seek_hole;
mod shm;
pub mod signal;
@@ -43,6 +47,7 @@ mod sock_ctrl_msg;
mod struct_util;
mod terminal;
mod timerfd;
+pub mod vsock;
mod write_zeroes;
pub use crate::alloc::LayoutAllocation;
@@ -61,6 +66,7 @@ pub use crate::poll::*;
pub use crate::priority::*;
pub use crate::raw_fd::*;
pub use crate::sched::*;
+pub use crate::scoped_signal_handler::*;
pub use crate::shm::*;
pub use crate::signal::*;
pub use crate::signalfd::*;
@@ -68,6 +74,10 @@ pub use crate::sock_ctrl_msg::*;
pub use crate::struct_util::*;
pub use crate::terminal::*;
pub use crate::timerfd::*;
+pub use descriptor_reflection::{
+ deserialize_with_descriptors, with_as_descriptor, with_raw_descriptor, FileSerdeWrapper,
+ SerializeDescriptors,
+};
pub use poll_token_derive::*;
pub use crate::external_mapping::Error as ExternalMappingError;
@@ -84,16 +94,21 @@ pub use crate::write_zeroes::{PunchHole, WriteZeroes, WriteZeroesAt};
use std::cell::Cell;
use std::ffi::CStr;
use std::fs::{remove_file, File};
+use std::mem;
use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
use std::os::unix::net::UnixDatagram;
use std::ptr;
+use std::time::Duration;
use libc::{
- c_int, c_long, fcntl, gid_t, kill, pid_t, pipe2, syscall, sysconf, uid_t, waitpid, F_GETFL,
+ c_int, c_long, fcntl, pipe2, syscall, sysconf, waitpid, SYS_getpid, SYS_gettid, F_GETFL,
F_SETFL, O_CLOEXEC, SIGKILL, WNOHANG, _SC_IOV_MAX, _SC_PAGESIZE,
};
-use syscall_defines::linux::LinuxSyscall::SYS_getpid;
+/// Re-export libc types that are part of the API.
+pub type Pid = libc::pid_t;
+pub type Uid = libc::uid_t;
+pub type Gid = libc::gid_t;
/// Used to mark types as !Sync.
pub type UnsyncMarker = std::marker::PhantomData<Cell<usize>>;
@@ -121,28 +136,58 @@ pub fn round_up_to_page_size(v: usize) -> usize {
/// This bypasses `libc`'s caching `getpid(2)` wrapper which can be invalid if a raw clone was used
/// elsewhere.
#[inline(always)]
-pub fn getpid() -> pid_t {
+pub fn getpid() -> Pid {
// Safe because this syscall can never fail and we give it a valid syscall number.
- unsafe { syscall(SYS_getpid as c_long) as pid_t }
+ unsafe { syscall(SYS_getpid as c_long) as Pid }
+}
+
+/// Safe wrapper for the gettid Linux systemcall.
+pub fn gettid() -> Pid {
+ // Calling the gettid() sycall is always safe.
+ unsafe { syscall(SYS_gettid as c_long) as Pid }
+}
+
+/// Safe wrapper for `getsid(2)`.
+pub fn getsid(pid: Option<Pid>) -> Result<Pid> {
+ // Calling the getsid() sycall is always safe.
+ let ret = unsafe { libc::getsid(pid.unwrap_or(0)) } as Pid;
+
+ if ret < 0 {
+ errno_result()
+ } else {
+ Ok(ret)
+ }
+}
+
+/// Wrapper for `setsid(2)`.
+pub fn setsid() -> Result<Pid> {
+ // Safe because the return code is checked.
+ let ret = unsafe { libc::setsid() as Pid };
+
+ if ret < 0 {
+ errno_result()
+ } else {
+ Ok(ret)
+ }
}
/// Safe wrapper for `geteuid(2)`.
#[inline(always)]
-pub fn geteuid() -> uid_t {
+pub fn geteuid() -> Uid {
// trivially safe
unsafe { libc::geteuid() }
}
/// Safe wrapper for `getegid(2)`.
#[inline(always)]
-pub fn getegid() -> gid_t {
+pub fn getegid() -> Gid {
// trivially safe
unsafe { libc::getegid() }
}
/// Safe wrapper for chown(2).
#[inline(always)]
-pub fn chown(path: &CStr, uid: uid_t, gid: gid_t) -> Result<()> {
+pub fn chown(path: &CStr, uid: Uid, gid: Gid) -> Result<()> {
// Safe since we pass in a valid string pointer and check the return value.
let ret = unsafe { libc::chown(path.as_ptr(), uid, gid) };
@@ -256,7 +301,7 @@ pub fn fallocate(
/// }
/// }
/// ```
-pub fn reap_child() -> Result<pid_t> {
+pub fn reap_child() -> Result<Pid> {
// Safe because we pass in no memory, prevent blocking with WNOHANG, and check for error.
let ret = unsafe { waitpid(-1, ptr::null_mut(), WNOHANG) };
if ret == -1 {
@@ -271,13 +316,9 @@ pub fn reap_child() -> Result<pid_t> {
/// On success, this kills all processes in the current process group, including the current
/// process, meaning this will not return. This is equivalent to a call to `kill(0, SIGKILL)`.
pub fn kill_process_group() -> Result<()> {
- let ret = unsafe { kill(0, SIGKILL) };
- if ret == -1 {
- errno_result()
- } else {
- // Kill succeeded, so this process never reaches here.
- unreachable!();
- }
+ unsafe { kill(0, SIGKILL) }?;
+ // Kill succeeded, so this process never reaches here.
+ unreachable!();
}
/// Spawns a pipe pair where the first pipe is the read end and the second pipe is the write end.
@@ -435,6 +476,23 @@ pub fn clear_fd_flags(fd: RawFd, clear_flags: c_int) -> Result<()> {
set_fd_flags(fd, start_flags & !clear_flags)
}
+/// Return a timespec filed with the specified Duration `duration`.
+pub fn duration_to_timespec(duration: Duration) -> libc::timespec {
+ // Safe because we are zero-initializing a struct with only primitive member fields.
+ let mut ts: libc::timespec = unsafe { mem::zeroed() };
+
+ ts.tv_sec = duration.as_secs() as libc::time_t;
+ // nsec always fits in i32 because subsec_nanos is defined to be less than one billion.
+ let nsec = duration.subsec_nanos() as i32;
+ ts.tv_nsec = libc::c_long::from(nsec);
+ ts
+}
+
+/// Return the maximum Duration that can be used with libc::timespec.
+pub fn max_timeout() -> Duration {
+ Duration::new(libc::time_t::max_value() as u64, 999999999)
+}
+
#[cfg(test)]
mod tests {
use std::io::Write;
diff --git a/sys_util/src/linux/syslog.rs b/sys_util/src/linux/syslog.rs
index 179e2bf53..1ec876304 100644
--- a/sys_util/src/linux/syslog.rs
+++ b/sys_util/src/linux/syslog.rs
@@ -148,7 +148,7 @@ fn send_buf(socket: &UnixDatagram, buf: &[u8]) {
const SEND_RETRY: usize = 2;
for _ in 0..SEND_RETRY {
- match socket.send(&buf[..]) {
+ match socket.send(buf) {
Ok(_) => break,
Err(e) => match e.kind() {
ErrorKind::ConnectionRefused
diff --git a/sys_util/src/mmap.rs b/sys_util/src/mmap.rs
index e7f1239d3..28ad07f57 100644
--- a/sys_util/src/mmap.rs
+++ b/sys_util/src/mmap.rs
@@ -59,7 +59,7 @@ impl Display for Error {
"requested memory range spans past the end of the region: offset={} count={} region_size={}",
offset, count, region_size,
),
- SystemCallFailed(e) => write!(f, "mmap system call failed: {}", e),
+ SystemCallFailed(e) => write!(f, "mmap related system call failed: {}", e),
ReadToMemory(e) => write!(f, "failed to read from file to memory: {}", e),
RemoveMappingIsUnsupported => write!(f, "`remove_mapping` is unsupported"),
WriteFromMemory(e) => write!(f, "failed to write from memory to file: {}", e),
@@ -108,9 +108,9 @@ impl From<c_int> for Protection {
}
}
-impl Into<c_int> for Protection {
- fn into(self) -> c_int {
- self.0
+impl From<Protection> for c_int {
+ fn from(p: Protection) -> c_int {
+ p.0
}
}
@@ -411,6 +411,32 @@ impl MemoryMapping {
})
}
+ /// Madvise the kernel to use Huge Pages for this mapping.
+ pub fn use_hugepages(&self) -> Result<()> {
+ const SZ_2M: usize = 2 * 1024 * 1024;
+
+ // THP uses 2M pages, so use THP only on mappings that are at least
+ // 2M in size.
+ if self.size() < SZ_2M {
+ return Ok(());
+ }
+
+ // This is safe because we call madvise with a valid address and size, and we check the
+ // return value.
+ let ret = unsafe {
+ libc::madvise(
+ self.as_ptr() as *mut libc::c_void,
+ self.size(),
+ libc::MADV_HUGEPAGE,
+ )
+ };
+ if ret == -1 {
+ Err(Error::SystemCallFailed(errno::Error::last()))
+ } else {
+ Ok(())
+ }
+ }
+
/// Calls msync with MS_SYNC on the mapping.
pub fn msync(&self) -> Result<()> {
// This is safe since we use the exact address and length of a known
@@ -918,8 +944,9 @@ impl Drop for MemoryMappingArena {
#[cfg(test)]
mod tests {
use super::*;
+ use crate::Descriptor;
use data_model::{VolatileMemory, VolatileMemoryError};
- use std::os::unix::io::FromRawFd;
+ use tempfile::tempfile;
#[test]
fn basic_map() {
@@ -939,7 +966,7 @@ mod tests {
#[test]
fn map_invalid_fd() {
- let fd = unsafe { std::fs::File::from_raw_fd(-1) };
+ let fd = Descriptor(-1);
let res = MemoryMapping::from_fd(&fd, 1024).unwrap_err();
if let Error::SystemCallFailed(e) = res {
assert_eq!(e.errno(), libc::EBADF);
@@ -999,7 +1026,7 @@ mod tests {
#[test]
fn from_fd_offset_invalid() {
- let fd = unsafe { std::fs::File::from_raw_fd(-1) };
+ let fd = tempfile().unwrap();
let res = MemoryMapping::from_fd_offset(&fd, 4096, (libc::off_t::max_value() as u64) + 1)
.unwrap_err();
match res {
diff --git a/sys_util/src/net.rs b/sys_util/src/net.rs
index 8aaf04845..b59fb0f21 100644
--- a/sys_util/src/net.rs
+++ b/sys_util/src/net.rs
@@ -5,22 +5,282 @@
use std::ffi::OsString;
use std::fs::remove_file;
use std::io;
-use std::mem;
+use std::mem::{self, size_of};
+use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6, TcpListener, TcpStream, ToSocketAddrs};
use std::ops::Deref;
use std::os::unix::{
ffi::{OsStrExt, OsStringExt},
- io::{AsRawFd, FromRawFd, RawFd},
+ io::{AsRawFd, FromRawFd, IntoRawFd, RawFd},
};
use std::path::Path;
use std::path::PathBuf;
use std::ptr::null_mut;
use std::time::Duration;
-use libc::{recvfrom, MSG_PEEK, MSG_TRUNC};
+use libc::{
+ c_int, in6_addr, in_addr, recvfrom, sa_family_t, sockaddr, sockaddr_in, sockaddr_in6,
+ socklen_t, AF_INET, AF_INET6, MSG_PEEK, MSG_TRUNC, SOCK_CLOEXEC, SOCK_STREAM,
+};
+use serde::{Deserialize, Serialize};
use crate::sock_ctrl_msg::{ScmSocket, SCM_SOCKET_MAX_FD_COUNT};
use crate::{AsRawDescriptor, RawDescriptor};
+/// Assist in handling both IP version 4 and IP version 6.
+#[derive(Debug, Copy, Clone, PartialEq, Eq)]
+pub enum InetVersion {
+ V4,
+ V6,
+}
+
+impl InetVersion {
+ pub fn from_sockaddr(s: &SocketAddr) -> Self {
+ match s {
+ SocketAddr::V4(_) => InetVersion::V4,
+ SocketAddr::V6(_) => InetVersion::V6,
+ }
+ }
+}
+
+impl From<InetVersion> for sa_family_t {
+ fn from(v: InetVersion) -> sa_family_t {
+ match v {
+ InetVersion::V4 => AF_INET as sa_family_t,
+ InetVersion::V6 => AF_INET6 as sa_family_t,
+ }
+ }
+}
+
+fn sockaddrv4_to_lib_c(s: &SocketAddrV4) -> sockaddr_in {
+ sockaddr_in {
+ sin_family: AF_INET as sa_family_t,
+ sin_port: s.port().to_be(),
+ sin_addr: in_addr {
+ s_addr: u32::from_ne_bytes(s.ip().octets()),
+ },
+ sin_zero: [0; 8],
+ }
+}
+
+fn sockaddrv6_to_lib_c(s: &SocketAddrV6) -> sockaddr_in6 {
+ sockaddr_in6 {
+ sin6_family: AF_INET6 as sa_family_t,
+ sin6_port: s.port().to_be(),
+ sin6_flowinfo: 0,
+ sin6_addr: in6_addr {
+ s6_addr: s.ip().octets(),
+ },
+ sin6_scope_id: 0,
+ }
+}
+
+/// A TCP socket.
+///
+/// Do not use this class unless you need to change socket options or query the
+/// state of the socket prior to calling listen or connect. Instead use either TcpStream or
+/// TcpListener.
+#[derive(Debug)]
+pub struct TcpSocket {
+ inet_version: InetVersion,
+ fd: RawFd,
+}
+
+impl TcpSocket {
+ pub fn new(inet_version: InetVersion) -> io::Result<Self> {
+ let fd = unsafe {
+ libc::socket(
+ Into::<sa_family_t>::into(inet_version) as c_int,
+ SOCK_STREAM | SOCK_CLOEXEC,
+ 0,
+ )
+ };
+ if fd < 0 {
+ Err(io::Error::last_os_error())
+ } else {
+ Ok(TcpSocket { inet_version, fd })
+ }
+ }
+
+ pub fn bind<A: ToSocketAddrs>(&mut self, addr: A) -> io::Result<()> {
+ let sockaddr = addr
+ .to_socket_addrs()
+ .map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?
+ .next()
+ .unwrap();
+
+ let ret = match sockaddr {
+ SocketAddr::V4(a) => {
+ let sin = sockaddrv4_to_lib_c(&a);
+ // Safe because this doesn't modify any memory and we check the return value.
+ unsafe {
+ libc::bind(
+ self.fd,
+ &sin as *const sockaddr_in as *const sockaddr,
+ size_of::<sockaddr_in>() as socklen_t,
+ )
+ }
+ }
+ SocketAddr::V6(a) => {
+ let sin6 = sockaddrv6_to_lib_c(&a);
+ // Safe because this doesn't modify any memory and we check the return value.
+ unsafe {
+ libc::bind(
+ self.fd,
+ &sin6 as *const sockaddr_in6 as *const sockaddr,
+ size_of::<sockaddr_in6>() as socklen_t,
+ )
+ }
+ }
+ };
+ if ret < 0 {
+ let bind_err = io::Error::last_os_error();
+ Err(bind_err)
+ } else {
+ Ok(())
+ }
+ }
+
+ pub fn connect<A: ToSocketAddrs>(self, addr: A) -> io::Result<TcpStream> {
+ let sockaddr = addr
+ .to_socket_addrs()
+ .map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?
+ .next()
+ .unwrap();
+
+ let ret = match sockaddr {
+ SocketAddr::V4(a) => {
+ let sin = sockaddrv4_to_lib_c(&a);
+ // Safe because this doesn't modify any memory and we check the return value.
+ unsafe {
+ libc::connect(
+ self.fd,
+ &sin as *const sockaddr_in as *const sockaddr,
+ size_of::<sockaddr_in>() as socklen_t,
+ )
+ }
+ }
+ SocketAddr::V6(a) => {
+ let sin6 = sockaddrv6_to_lib_c(&a);
+ // Safe because this doesn't modify any memory and we check the return value.
+ unsafe {
+ libc::connect(
+ self.fd,
+ &sin6 as *const sockaddr_in6 as *const sockaddr,
+ size_of::<sockaddr_in>() as socklen_t,
+ )
+ }
+ }
+ };
+
+ if ret < 0 {
+ let connect_err = io::Error::last_os_error();
+ Err(connect_err)
+ } else {
+ // Safe because the ownership of the raw fd is released from self and taken over by the
+ // new TcpStream.
+ Ok(unsafe { TcpStream::from_raw_fd(self.into_raw_fd()) })
+ }
+ }
+
+ pub fn listen(self) -> io::Result<TcpListener> {
+ // Safe because this doesn't modify any memory and we check the return value.
+ let ret = unsafe { libc::listen(self.fd, 1) };
+ if ret < 0 {
+ let listen_err = io::Error::last_os_error();
+ Err(listen_err)
+ } else {
+ // Safe because the ownership of the raw fd is released from self and taken over by the
+ // new TcpListener.
+ Ok(unsafe { TcpListener::from_raw_fd(self.into_raw_fd()) })
+ }
+ }
+
+ /// Returns the port that this socket is bound to. This can only succeed after bind is called.
+ pub fn local_port(&self) -> io::Result<u16> {
+ match self.inet_version {
+ InetVersion::V4 => {
+ let mut sin = sockaddr_in {
+ sin_family: 0,
+ sin_port: 0,
+ sin_addr: in_addr { s_addr: 0 },
+ sin_zero: [0; 8],
+ };
+
+ // Safe because we give a valid pointer for addrlen and check the length.
+ let mut addrlen = size_of::<sockaddr_in>() as socklen_t;
+ let ret = unsafe {
+ // Get the socket address that was actually bound.
+ libc::getsockname(
+ self.fd,
+ &mut sin as *mut sockaddr_in as *mut sockaddr,
+ &mut addrlen as *mut socklen_t,
+ )
+ };
+ if ret < 0 {
+ let getsockname_err = io::Error::last_os_error();
+ Err(getsockname_err)
+ } else {
+ // If this doesn't match, it's not safe to get the port out of the sockaddr.
+ assert_eq!(addrlen as usize, size_of::<sockaddr_in>());
+
+ Ok(u16::from_be(sin.sin_port))
+ }
+ }
+ InetVersion::V6 => {
+ let mut sin6 = sockaddr_in6 {
+ sin6_family: 0,
+ sin6_port: 0,
+ sin6_flowinfo: 0,
+ sin6_addr: in6_addr { s6_addr: [0; 16] },
+ sin6_scope_id: 0,
+ };
+
+ // Safe because we give a valid pointer for addrlen and check the length.
+ let mut addrlen = size_of::<sockaddr_in6>() as socklen_t;
+ let ret = unsafe {
+ // Get the socket address that was actually bound.
+ libc::getsockname(
+ self.fd,
+ &mut sin6 as *mut sockaddr_in6 as *mut sockaddr,
+ &mut addrlen as *mut socklen_t,
+ )
+ };
+ if ret < 0 {
+ let getsockname_err = io::Error::last_os_error();
+ Err(getsockname_err)
+ } else {
+ // If this doesn't match, it's not safe to get the port out of the sockaddr.
+ assert_eq!(addrlen as usize, size_of::<sockaddr_in>());
+
+ Ok(u16::from_be(sin6.sin6_port))
+ }
+ }
+ }
+ }
+}
+
+impl IntoRawFd for TcpSocket {
+ fn into_raw_fd(self) -> RawFd {
+ let fd = self.fd;
+ mem::forget(self);
+ fd
+ }
+}
+
+impl AsRawFd for TcpSocket {
+ fn as_raw_fd(&self) -> RawFd {
+ self.fd
+ }
+}
+
+impl Drop for TcpSocket {
+ fn drop(&mut self) {
+ // Safe because this doesn't modify any memory and we are the only
+ // owner of the file descriptor.
+ unsafe { libc::close(self.fd) };
+ }
+}
+
// Offset of sun_path in structure sockaddr_un.
fn sun_path_offset() -> usize {
// Prefer 0 to null() so that we do not need to subtract from the `sub_path` pointer.
@@ -72,7 +332,9 @@ fn sockaddr_un<P: AsRef<Path>>(path: P) -> io::Result<(libc::sockaddr_un, libc::
}
/// A Unix `SOCK_SEQPACKET` socket point to given `path`
+#[derive(Serialize, Deserialize)]
pub struct UnixSeqpacket {
+ #[serde(with = "crate::with_raw_descriptor")]
fd: RawFd,
}
@@ -134,12 +396,12 @@ impl UnixSeqpacket {
/// Clone the underlying FD.
pub fn try_clone(&self) -> io::Result<Self> {
- // Calling `dup` is safe as the kernel doesn't touch any user memory it the process.
- let new_fd = unsafe { libc::dup(self.fd) };
- if new_fd < 0 {
+ // Safe because this doesn't modify any memory and we check the return value.
+ let fd = unsafe { libc::fcntl(self.fd, libc::F_DUPFD_CLOEXEC, 0) };
+ if fd < 0 {
Err(io::Error::last_os_error())
} else {
- Ok(UnixSeqpacket { fd: new_fd })
+ Ok(Self { fd })
}
}
@@ -157,12 +419,21 @@ impl UnixSeqpacket {
/// Gets the number of bytes in the next packet. This blocks as if `recv` were called,
/// respecting the blocking and timeout settings of the underlying socket.
pub fn next_packet_size(&self) -> io::Result<usize> {
+ #[cfg(not(debug_assertions))]
+ let buf = null_mut();
+ // Work around for qemu's syscall translation which will reject null pointers in recvfrom.
+ // This only matters for running the unit tests for a non-native architecture. See the
+ // upstream thread for the qemu fix:
+ // https://lists.nongnu.org/archive/html/qemu-devel/2021-03/msg09027.html
+ #[cfg(debug_assertions)]
+ let buf = &mut 0 as *mut _ as *mut _;
+
// This form of recvfrom doesn't modify any data because all null pointers are used. We only
// use the return value and check for errors on an FD owned by this structure.
let ret = unsafe {
recvfrom(
self.fd,
- null_mut(),
+ buf,
0,
MSG_TRUNC | MSG_PEEK,
null_mut(),
diff --git a/sys_util/src/rand.rs b/sys_util/src/rand.rs
new file mode 100644
index 000000000..7c7799397
--- /dev/null
+++ b/sys_util/src/rand.rs
@@ -0,0 +1,114 @@
+// Copyright 2021 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+//! Rust implementation of functionality parallel to libchrome's base/rand_util.h.
+
+use std::thread::sleep;
+use std::time::Duration;
+
+use libc::{c_uint, c_void};
+
+use crate::{
+ errno::{errno_result, Result},
+ handle_eintr_errno,
+};
+
+/// How long to wait before calling getrandom again if it does not return
+/// enough bytes.
+const POLL_INTERVAL: Duration = Duration::from_millis(50);
+
+/// Represents whether or not the random bytes are pulled from the source of
+/// /dev/random or /dev/urandom.
+#[derive(Debug, Clone, Eq, PartialEq)]
+pub enum Source {
+ // This is the default and uses the same source as /dev/urandom.
+ Pseudorandom,
+ // This uses the same source as /dev/random and may be.
+ Random,
+}
+
+impl Default for Source {
+ fn default() -> Self {
+ Source::Pseudorandom
+ }
+}
+
+impl Source {
+ fn to_getrandom_flags(&self) -> c_uint {
+ match self {
+ Source::Random => libc::GRND_RANDOM,
+ Source::Pseudorandom => 0,
+ }
+ }
+}
+
+/// Fills `output` completely with random bytes from the specified `source`.
+pub fn rand_bytes(mut output: &mut [u8], source: Source) -> Result<()> {
+ if output.is_empty() {
+ return Ok(());
+ }
+
+ loop {
+ // Safe because output is mutable and the writes are limited by output.len().
+ let bytes = handle_eintr_errno!(unsafe {
+ libc::getrandom(
+ output.as_mut_ptr() as *mut c_void,
+ output.len(),
+ source.to_getrandom_flags(),
+ )
+ });
+
+ if bytes < 0 {
+ return errno_result();
+ }
+ if bytes as usize == output.len() {
+ return Ok(());
+ }
+
+ // Wait for more entropy and try again for the remaining bytes.
+ sleep(POLL_INTERVAL);
+ output = &mut output[bytes as usize..];
+ }
+}
+
+/// Allocates a vector of length `len` filled with random bytes from the
+/// specified `source`.
+pub fn rand_vec(len: usize, source: Source) -> Result<Vec<u8>> {
+ let mut rand = Vec::with_capacity(len);
+ if len == 0 {
+ return Ok(rand);
+ }
+
+ // Safe because rand will either be initialized by getrandom or dropped.
+ unsafe { rand.set_len(len) };
+ rand_bytes(rand.as_mut_slice(), source)?;
+ Ok(rand)
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ const TEST_SIZE: usize = 64;
+
+ #[test]
+ fn randbytes_success() {
+ let mut rand = vec![0u8; TEST_SIZE];
+ rand_bytes(&mut rand, Source::Pseudorandom).unwrap();
+ assert_ne!(&rand, &[0u8; TEST_SIZE]);
+ }
+
+ #[test]
+ fn randvec_success() {
+ let rand = rand_vec(TEST_SIZE, Source::Pseudorandom).unwrap();
+ assert_eq!(rand.len(), TEST_SIZE);
+ assert_ne!(&rand, &[0u8; TEST_SIZE]);
+ }
+
+ #[test]
+ fn sourcerandom_success() {
+ let rand = rand_vec(TEST_SIZE, Source::Random).unwrap();
+ assert_ne!(&rand, &[0u8; TEST_SIZE]);
+ }
+}
diff --git a/sys_util/src/scoped_path.rs b/sys_util/src/scoped_path.rs
new file mode 100644
index 000000000..745137eae
--- /dev/null
+++ b/sys_util/src/scoped_path.rs
@@ -0,0 +1,138 @@
+// Copyright 2021 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+use std::env::{current_exe, temp_dir};
+use std::fs::{create_dir_all, remove_dir_all};
+use std::io::Result;
+use std::ops::Deref;
+use std::path::{Path, PathBuf};
+use std::thread::panicking;
+
+use crate::{getpid, gettid};
+
+/// Returns a stable path based on the label, pid, and tid. If the label isn't provided the
+/// current_exe is used instead.
+pub fn get_temp_path(label: Option<&str>) -> PathBuf {
+ if let Some(label) = label {
+ temp_dir().join(format!("{}-{}-{}", label, getpid(), gettid()))
+ } else {
+ get_temp_path(Some(
+ current_exe()
+ .unwrap()
+ .file_name()
+ .unwrap()
+ .to_str()
+ .unwrap(),
+ ))
+ }
+}
+
+/// Automatically deletes the path it contains when it goes out of scope unless it is a test and
+/// drop is called after a panic!.
+///
+/// This is particularly useful for creating temporary directories for use with tests.
+pub struct ScopedPath<P: AsRef<Path>>(P);
+
+impl<P: AsRef<Path>> ScopedPath<P> {
+ pub fn create(p: P) -> Result<Self> {
+ create_dir_all(p.as_ref())?;
+ Ok(ScopedPath(p))
+ }
+}
+
+impl<P: AsRef<Path>> AsRef<Path> for ScopedPath<P> {
+ fn as_ref(&self) -> &Path {
+ self.0.as_ref()
+ }
+}
+
+impl<P: AsRef<Path>> Deref for ScopedPath<P> {
+ type Target = Path;
+
+ fn deref(&self) -> &Self::Target {
+ self.0.as_ref()
+ }
+}
+
+impl<P: AsRef<Path>> Drop for ScopedPath<P> {
+ fn drop(&mut self) {
+ // Leave the files on a failed test run for debugging.
+ if panicking() && cfg!(test) {
+ eprintln!("NOTE: Not removing {}", self.display());
+ return;
+ }
+ if let Err(e) = remove_dir_all(&**self) {
+ eprintln!("Failed to remove {}: {}", self.display(), e);
+ }
+ }
+}
+
+#[cfg(test)]
+pub(crate) mod tests {
+ use super::*;
+
+ use std::panic::catch_unwind;
+
+ #[test]
+ fn gettemppath() {
+ assert_ne!("", get_temp_path(None).to_string_lossy());
+ assert!(get_temp_path(None).starts_with(temp_dir()));
+ assert_eq!(
+ get_temp_path(None),
+ get_temp_path(Some(
+ current_exe()
+ .unwrap()
+ .file_name()
+ .unwrap()
+ .to_str()
+ .unwrap()
+ ))
+ );
+ assert_ne!(
+ get_temp_path(Some("label")),
+ get_temp_path(Some(
+ current_exe()
+ .unwrap()
+ .file_name()
+ .unwrap()
+ .to_str()
+ .unwrap()
+ ))
+ );
+ }
+
+ #[test]
+ fn scopedpath_exists() {
+ let tmp_path = get_temp_path(None);
+ {
+ let scoped_path = ScopedPath::create(&tmp_path).unwrap();
+ assert!(scoped_path.exists());
+ }
+ assert!(!tmp_path.exists());
+ }
+
+ #[test]
+ fn scopedpath_notexists() {
+ let tmp_path = get_temp_path(None);
+ {
+ let _scoped_path = ScopedPath(&tmp_path);
+ }
+ assert!(!tmp_path.exists());
+ }
+
+ #[test]
+ fn scopedpath_panic() {
+ let tmp_path = get_temp_path(None);
+ assert!(catch_unwind(|| {
+ {
+ let scoped_path = ScopedPath::create(&tmp_path).unwrap();
+ assert!(scoped_path.exists());
+ panic!()
+ }
+ })
+ .is_err());
+ assert!(tmp_path.exists());
+ remove_dir_all(&tmp_path).unwrap();
+ }
+}
diff --git a/sys_util/src/scoped_signal_handler.rs b/sys_util/src/scoped_signal_handler.rs
new file mode 100644
index 000000000..76286161c
--- /dev/null
+++ b/sys_util/src/scoped_signal_handler.rs
@@ -0,0 +1,421 @@
+// Copyright 2021 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+//! Provides a struct for registering signal handlers that get cleared on drop.
+
+use std::convert::TryFrom;
+use std::fmt::{self, Display};
+use std::io::{Cursor, Write};
+use std::panic::catch_unwind;
+use std::result;
+
+use libc::{c_int, c_void, STDERR_FILENO};
+
+use crate::errno;
+use crate::signal::{
+ clear_signal_handler, has_default_signal_handler, register_signal_handler, wait_for_signal,
+ Signal,
+};
+
+#[derive(Debug)]
+pub enum Error {
+ /// Sigaction failed.
+ Sigaction(Signal, errno::Error),
+ /// Failed to check if signal has the default signal handler.
+ HasDefaultSignalHandler(Signal, errno::Error),
+ /// Failed to register a signal handler.
+ RegisterSignalHandler(Signal, errno::Error),
+ /// Signal already has a handler.
+ HandlerAlreadySet(Signal),
+ /// Already waiting for interrupt.
+ AlreadyWaiting,
+ /// Failed to wait for signal.
+ WaitForSignal(errno::Error),
+}
+
+impl Display for Error {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ use self::Error::*;
+
+ match self {
+ Sigaction(s, e) => write!(f, "sigaction failed for {0:?}: {1}", s, e),
+ HasDefaultSignalHandler(s, e) => {
+ write!(f, "failed to check the signal handler for {0:?}: {1}", s, e)
+ }
+ RegisterSignalHandler(s, e) => write!(
+ f,
+ "failed to register a signal handler for {0:?}: {1}",
+ s, e
+ ),
+ HandlerAlreadySet(s) => write!(f, "signal handler already set for {0:?}", s),
+ AlreadyWaiting => write!(f, "already waiting for interrupt."),
+ WaitForSignal(e) => write!(f, "wait_for_signal failed: {0}", e),
+ }
+ }
+}
+
+pub type Result<T> = result::Result<T, Error>;
+
+/// The interface used by Scoped Signal handler.
+///
+/// # Safety
+/// The implementation of handle_signal needs to be async signal-safe.
+///
+/// NOTE: panics are caught when possible because a panic inside ffi is undefined behavior.
+pub unsafe trait SignalHandler {
+ /// A function that is called to handle the passed signal.
+ fn handle_signal(signal: Signal);
+}
+
+/// Wrap the handler with an extern "C" function.
+extern "C" fn call_handler<H: SignalHandler>(signum: c_int) {
+ // Make an effort to surface an error.
+ if catch_unwind(|| H::handle_signal(Signal::try_from(signum).unwrap())).is_err() {
+ // Note the following cannot be used:
+ // eprintln! - uses std::io which has locks that may be held.
+ // format! - uses the allocator which enforces mutual exclusion.
+
+ // Get the debug representation of signum.
+ let signal: Signal;
+ let signal_debug: &dyn fmt::Debug = match Signal::try_from(signum) {
+ Ok(s) => {
+ signal = s;
+ &signal as &dyn fmt::Debug
+ }
+ Err(_) => &signum as &dyn fmt::Debug,
+ };
+
+ // Buffer the output, so a single call to write can be used.
+ // The message accounts for 29 chars, that leaves 35 for the string representation of the
+ // signal which is more than enough.
+ let mut buffer = [0u8; 64];
+ let mut cursor = Cursor::new(buffer.as_mut());
+ if writeln!(cursor, "signal handler got error for: {:?}", signal_debug).is_ok() {
+ let len = cursor.position() as usize;
+ // Safe in the sense that buffer is owned and the length is checked. This may print in
+ // the middle of an existing write, but that is considered better than dropping the
+ // error.
+ unsafe {
+ libc::write(
+ STDERR_FILENO,
+ cursor.get_ref().as_ptr() as *const c_void,
+ len,
+ )
+ };
+ } else {
+ // This should never happen, but write an error message just in case.
+ const ERROR_DROPPED: &str = "Error dropped by signal handler.";
+ let bytes = ERROR_DROPPED.as_bytes();
+ unsafe { libc::write(STDERR_FILENO, bytes.as_ptr() as *const c_void, bytes.len()) };
+ }
+ }
+}
+
+/// Represents a signal handler that is registered with a set of signals that unregistered when the
+/// struct goes out of scope. Prefer a signalfd based solution before using this.
+pub struct ScopedSignalHandler {
+ signals: Vec<Signal>,
+}
+
+impl ScopedSignalHandler {
+ /// Attempts to register `handler` with the provided `signals`. It will fail if there is already
+ /// an existing handler on any of `signals`.
+ ///
+ /// # Safety
+ /// This is safe if H::handle_signal is async-signal safe.
+ pub fn new<H: SignalHandler>(signals: &[Signal]) -> Result<Self> {
+ let mut scoped_handler = ScopedSignalHandler {
+ signals: Vec::with_capacity(signals.len()),
+ };
+ for &signal in signals {
+ if !has_default_signal_handler((signal).into())
+ .map_err(|err| Error::HasDefaultSignalHandler(signal, err))?
+ {
+ return Err(Error::HandlerAlreadySet(signal));
+ }
+ // Requires an async-safe callback.
+ unsafe {
+ register_signal_handler((signal).into(), call_handler::<H>)
+ .map_err(|err| Error::RegisterSignalHandler(signal, err))?
+ };
+ scoped_handler.signals.push(signal);
+ }
+ Ok(scoped_handler)
+ }
+}
+
+/// Clears the signal handler for any of the associated signals.
+impl Drop for ScopedSignalHandler {
+ fn drop(&mut self) {
+ for signal in &self.signals {
+ if let Err(err) = clear_signal_handler((*signal).into()) {
+ eprintln!("Error: failed to clear signal handler: {:?}", err);
+ }
+ }
+ }
+}
+
+/// A signal handler that does nothing.
+///
+/// This is useful in cases where wait_for_signal is used since it will never trigger if the signal
+/// is blocked and the default handler may have undesired effects like terminating the process.
+pub struct EmptySignalHandler;
+/// # Safety
+/// Safe because handle_signal is async-signal safe.
+unsafe impl SignalHandler for EmptySignalHandler {
+ fn handle_signal(_: Signal) {}
+}
+
+/// Blocks until SIGINT is received, which often happens because Ctrl-C was pressed in an
+/// interactive terminal.
+///
+/// Note: if you are using a multi-threaded application you need to block SIGINT on all other
+/// threads or they may receive the signal instead of the desired thread.
+pub fn wait_for_interrupt() -> Result<()> {
+ // Register a signal handler if there is not one already so the thread is not killed.
+ let ret = ScopedSignalHandler::new::<EmptySignalHandler>(&[Signal::Interrupt]);
+ if !matches!(&ret, Ok(_) | Err(Error::HandlerAlreadySet(_))) {
+ ret?;
+ }
+
+ match wait_for_signal(&[Signal::Interrupt.into()], None) {
+ Ok(_) => Ok(()),
+ Err(err) => Err(Error::WaitForSignal(err)),
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ use std::fs::File;
+ use std::io::{BufRead, BufReader};
+ use std::mem::zeroed;
+ use std::ptr::{null, null_mut};
+ use std::sync::atomic::{AtomicI32, AtomicUsize, Ordering};
+ use std::sync::{Arc, Mutex, MutexGuard, Once};
+ use std::thread::{sleep, spawn};
+ use std::time::{Duration, Instant};
+
+ use libc::sigaction;
+
+ use crate::{gettid, kill, Pid};
+
+ const TEST_SIGNAL: Signal = Signal::User1;
+ const TEST_SIGNALS: &[Signal] = &[Signal::User1, Signal::User2];
+
+ static TEST_SIGNAL_COUNTER: AtomicUsize = AtomicUsize::new(0);
+
+ /// Only allows one test case to execute at a time.
+ fn get_mutex() -> MutexGuard<'static, ()> {
+ static INIT: Once = Once::new();
+ static mut VAL: Option<Arc<Mutex<()>>> = None;
+
+ INIT.call_once(|| {
+ let val = Some(Arc::new(Mutex::new(())));
+ // Safe because the mutation is protected by the Once.
+ unsafe { VAL = val }
+ });
+
+ // Safe mutation only happens in the Once.
+ unsafe { VAL.as_ref() }.unwrap().lock().unwrap()
+ }
+
+ fn reset_counter() {
+ TEST_SIGNAL_COUNTER.swap(0, Ordering::SeqCst);
+ }
+
+ fn get_sigaction(signal: Signal) -> Result<sigaction> {
+ // Safe because sigaction is owned and expected to be initialized ot zeros.
+ let mut sigact: sigaction = unsafe { zeroed() };
+
+ if unsafe { sigaction(signal.into(), null(), &mut sigact) } < 0 {
+ Err(Error::Sigaction(signal, errno::Error::last()))
+ } else {
+ Ok(sigact)
+ }
+ }
+
+ /// Safety:
+ /// This is only safe if the signal handler set in sigaction is safe.
+ unsafe fn restore_sigaction(signal: Signal, sigact: sigaction) -> Result<sigaction> {
+ if sigaction(signal.into(), &sigact, null_mut()) < 0 {
+ Err(Error::Sigaction(signal, errno::Error::last()))
+ } else {
+ Ok(sigact)
+ }
+ }
+
+ /// Safety:
+ /// Safe if the signal handler for Signal::User1 is safe.
+ unsafe fn send_test_signal() {
+ kill(gettid(), Signal::User1.into()).unwrap()
+ }
+
+ macro_rules! assert_counter_eq {
+ ($compare_to:expr) => {{
+ let expected: usize = $compare_to;
+ let got: usize = TEST_SIGNAL_COUNTER.load(Ordering::SeqCst);
+ if got != expected {
+ panic!(
+ "wrong signal counter value: got {}; expected {}",
+ got, expected
+ );
+ }
+ }};
+ }
+
+ struct TestHandler;
+
+ /// # Safety
+ /// Safe because handle_signal is async-signal safe.
+ unsafe impl SignalHandler for TestHandler {
+ fn handle_signal(signal: Signal) {
+ if TEST_SIGNAL == signal {
+ TEST_SIGNAL_COUNTER.fetch_add(1, Ordering::SeqCst);
+ }
+ }
+ }
+
+ #[test]
+ fn scopedsignalhandler_success() {
+ // Prevent other test cases from running concurrently since the signal
+ // handlers are shared for the process.
+ let _guard = get_mutex();
+
+ reset_counter();
+ assert_counter_eq!(0);
+
+ assert!(has_default_signal_handler(TEST_SIGNAL.into()).unwrap());
+ let handler = ScopedSignalHandler::new::<TestHandler>(&[TEST_SIGNAL]).unwrap();
+ assert!(!has_default_signal_handler(TEST_SIGNAL.into()).unwrap());
+
+ // Safe because test_handler is safe.
+ unsafe { send_test_signal() };
+
+ // Give the handler time to run in case it is on a different thread.
+ for _ in 1..40 {
+ if TEST_SIGNAL_COUNTER.load(Ordering::SeqCst) > 0 {
+ break;
+ }
+ sleep(Duration::from_millis(250));
+ }
+
+ assert_counter_eq!(1);
+
+ drop(handler);
+ assert!(has_default_signal_handler(TEST_SIGNAL.into()).unwrap());
+ }
+
+ #[test]
+ fn scopedsignalhandler_handleralreadyset() {
+ // Prevent other test cases from running concurrently since the signal
+ // handlers are shared for the process.
+ let _guard = get_mutex();
+
+ reset_counter();
+ assert_counter_eq!(0);
+
+ assert!(has_default_signal_handler(TEST_SIGNAL.into()).unwrap());
+ // Safe because TestHandler is async-signal safe.
+ let handler = ScopedSignalHandler::new::<TestHandler>(&[TEST_SIGNAL]).unwrap();
+ assert!(!has_default_signal_handler(TEST_SIGNAL.into()).unwrap());
+
+ // Safe because TestHandler is async-signal safe.
+ assert!(matches!(
+ ScopedSignalHandler::new::<TestHandler>(&TEST_SIGNALS),
+ Err(Error::HandlerAlreadySet(Signal::User1))
+ ));
+
+ assert_counter_eq!(0);
+ drop(handler);
+ assert!(has_default_signal_handler(TEST_SIGNAL.into()).unwrap());
+ }
+
+ /// Stores the thread used by WaitForInterruptHandler.
+ static WAIT_FOR_INTERRUPT_THREAD_ID: AtomicI32 = AtomicI32::new(0);
+ /// Forwards SIGINT to the appropriate thread.
+ struct WaitForInterruptHandler;
+
+ /// # Safety
+ /// Safe because handle_signal is async-signal safe.
+ unsafe impl SignalHandler for WaitForInterruptHandler {
+ fn handle_signal(_: Signal) {
+ let tid = WAIT_FOR_INTERRUPT_THREAD_ID.load(Ordering::SeqCst);
+ // If the thread ID is set and executed on the wrong thread, forward the signal.
+ if tid != 0 && gettid() != tid {
+ // Safe because the handler is safe and the target thread id is expecting the signal.
+ unsafe { kill(tid, Signal::Interrupt.into()) }.unwrap();
+ }
+ }
+ }
+
+ /// Query /proc/${tid}/status for its State and check if it is either S (sleeping) or in
+ /// D (disk sleep).
+ fn thread_is_sleeping(tid: Pid) -> result::Result<bool, errno::Error> {
+ const PREFIX: &str = "State:";
+ let mut status_reader = BufReader::new(File::open(format!("/proc/{}/status", tid))?);
+ let mut line = String::new();
+ loop {
+ let count = status_reader.read_line(&mut line)?;
+ if count == 0 {
+ return Err(errno::Error::new(libc::EIO));
+ }
+ if line.starts_with(PREFIX) {
+ return Ok(matches!(
+ line[PREFIX.len()..].trim_start().chars().next(),
+ Some('S') | Some('D')
+ ));
+ }
+ line.clear();
+ }
+ }
+
+ /// Wait for a process to block either in a sleeping or disk sleep state.
+ fn wait_for_thread_to_sleep(tid: Pid, timeout: Duration) -> result::Result<(), errno::Error> {
+ let start = Instant::now();
+ loop {
+ if thread_is_sleeping(tid)? {
+ return Ok(());
+ }
+ if start.elapsed() > timeout {
+ return Err(errno::Error::new(libc::EAGAIN));
+ }
+ sleep(Duration::from_millis(50));
+ }
+ }
+
+ #[test]
+ fn waitforinterrupt_success() {
+ // Prevent other test cases from running concurrently since the signal
+ // handlers are shared for the process.
+ let _guard = get_mutex();
+
+ let to_restore = get_sigaction(Signal::Interrupt).unwrap();
+ clear_signal_handler(Signal::Interrupt.into()).unwrap();
+ // Safe because TestHandler is async-signal safe.
+ let handler =
+ ScopedSignalHandler::new::<WaitForInterruptHandler>(&[Signal::Interrupt]).unwrap();
+
+ let tid = gettid();
+ WAIT_FOR_INTERRUPT_THREAD_ID.store(tid, Ordering::SeqCst);
+
+ let join_handle = spawn(move || -> result::Result<(), errno::Error> {
+ // Wait unitl the thread is ready to receive the signal.
+ wait_for_thread_to_sleep(tid, Duration::from_secs(10)).unwrap();
+
+ // Safe because the SIGINT handler is safe.
+ unsafe { kill(tid, Signal::Interrupt.into()) }
+ });
+ let wait_ret = wait_for_interrupt();
+ let join_ret = join_handle.join();
+
+ drop(handler);
+ // Safe because we are restoring the previous SIGINT handler.
+ unsafe { restore_sigaction(Signal::Interrupt, to_restore) }.unwrap();
+
+ wait_ret.unwrap();
+ join_ret.unwrap().unwrap();
+ }
+}
diff --git a/sys_util/src/shm.rs b/sys_util/src/shm.rs
index d9a583990..927734aef 100644
--- a/sys_util/src/shm.rs
+++ b/sys_util/src/shm.rs
@@ -5,19 +5,21 @@
use std::ffi::{CStr, CString};
use std::fs::{read_link, File};
use std::io::{self, Read, Seek, SeekFrom, Write};
-use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
+use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
use libc::{
- self, c_char, c_int, c_long, c_uint, close, fcntl, ftruncate64, off64_t, syscall, EINVAL,
- F_ADD_SEALS, F_GET_SEALS, F_SEAL_GROW, F_SEAL_SEAL, F_SEAL_SHRINK, F_SEAL_WRITE,
- MFD_ALLOW_SEALING,
+ self, c_char, c_int, c_long, c_uint, close, fcntl, ftruncate64, off64_t, syscall,
+ SYS_memfd_create, EINVAL, F_ADD_SEALS, F_GET_SEALS, F_SEAL_GROW, F_SEAL_SEAL, F_SEAL_SHRINK,
+ F_SEAL_WRITE, MFD_ALLOW_SEALING,
};
-use syscall_defines::linux::LinuxSyscall::SYS_memfd_create;
+use serde::{Deserialize, Serialize};
use crate::{errno, errno_result, Result};
/// A shared memory file descriptor and its size.
+#[derive(Serialize, Deserialize)]
pub struct SharedMemory {
+ #[serde(with = "crate::with_as_descriptor")]
fd: File,
size: u64,
}
@@ -267,9 +269,15 @@ impl AsRawFd for &SharedMemory {
}
}
-impl Into<File> for SharedMemory {
- fn into(self) -> File {
- self.fd
+impl IntoRawFd for SharedMemory {
+ fn into_raw_fd(self) -> RawFd {
+ self.fd.into_raw_fd()
+ }
+}
+
+impl From<SharedMemory> for File {
+ fn from(s: SharedMemory) -> File {
+ s.fd
}
}
diff --git a/sys_util/src/signal.rs b/sys_util/src/signal.rs
index 8f5bc6a92..1191dccb5 100644
--- a/sys_util/src/signal.rs
+++ b/sys_util/src/signal.rs
@@ -4,20 +4,27 @@
use libc::{
c_int, pthread_kill, pthread_sigmask, pthread_t, sigaction, sigaddset, sigemptyset, siginfo_t,
- sigismember, sigpending, sigset_t, sigtimedwait, timespec, EAGAIN, EINTR, EINVAL, SA_RESTART,
- SIG_BLOCK, SIG_UNBLOCK,
+ sigismember, sigpending, sigset_t, sigtimedwait, sigwait, timespec, waitpid, EAGAIN, EINTR,
+ EINVAL, SA_RESTART, SIG_BLOCK, SIG_DFL, SIG_UNBLOCK, WNOHANG,
};
use std::cmp::Ordering;
+use std::convert::TryFrom;
use std::fmt::{self, Display};
use std::io;
use std::mem;
use std::os::unix::thread::JoinHandleExt;
+use std::process::Child;
use std::ptr::{null, null_mut};
use std::result;
use std::thread::JoinHandle;
+use std::time::{Duration, Instant};
-use crate::{errno, errno_result};
+use crate::{duration_to_timespec, errno, errno_result, getsid, Pid, Result};
+use std::ops::{Deref, DerefMut};
+
+const POLL_RATE: Duration = Duration::from_millis(50);
+const DEFAULT_KILL_TIMEOUT: Duration = Duration::from_secs(5);
#[derive(Debug)]
pub enum Error {
@@ -39,6 +46,20 @@ pub enum Error {
ClearGetPending(errno::Error),
/// Failed to check if given signal is in the set of pending signals.
ClearCheckPending(errno::Error),
+ /// Failed to send signal to pid.
+ Kill(errno::Error),
+ /// Failed to get session id.
+ GetSid(errno::Error),
+ /// Failed to wait for signal.
+ WaitForSignal(errno::Error),
+ /// Failed to wait for pid.
+ WaitPid(errno::Error),
+ /// Timeout reached.
+ TimedOut,
+ /// Failed to convert signum to Signal.
+ UnrecognizedSignum(i32),
+ /// Converted signum greater than SIGRTMAX.
+ RtSignumGreaterThanMax(Signal),
}
impl Display for Error {
@@ -67,7 +88,184 @@ impl Display for Error {
"failed to check whether given signal is in the pending set: {}",
e,
),
+ Kill(e) => write!(f, "failed to send signal: {}", e),
+ GetSid(e) => write!(f, "failed to get session id: {}", e),
+ WaitForSignal(e) => write!(f, "failed to wait for signal: {}", e),
+ WaitPid(e) => write!(f, "failed to wait for process: {}", e),
+ TimedOut => write!(f, "timeout reached."),
+ UnrecognizedSignum(signum) => write!(f, "unrecoginized signal number: {}", signum),
+ RtSignumGreaterThanMax(signal) => {
+ write!(f, "got RT signal greater than max: {:?}", signal)
+ }
+ }
+ }
+}
+
+#[derive(Clone, Copy, Debug, Eq, PartialEq)]
+#[repr(i32)]
+pub enum Signal {
+ Abort = libc::SIGABRT,
+ Alarm = libc::SIGALRM,
+ Bus = libc::SIGBUS,
+ Child = libc::SIGCHLD,
+ Continue = libc::SIGCONT,
+ ExceededFileSize = libc::SIGXFSZ,
+ FloatingPointException = libc::SIGFPE,
+ HangUp = libc::SIGHUP,
+ IllegalInstruction = libc::SIGILL,
+ Interrupt = libc::SIGINT,
+ Io = libc::SIGIO,
+ Kill = libc::SIGKILL,
+ Pipe = libc::SIGPIPE,
+ Power = libc::SIGPWR,
+ Profile = libc::SIGPROF,
+ Quit = libc::SIGQUIT,
+ SegmentationViolation = libc::SIGSEGV,
+ StackFault = libc::SIGSTKFLT,
+ Stop = libc::SIGSTOP,
+ Sys = libc::SIGSYS,
+ Trap = libc::SIGTRAP,
+ Terminate = libc::SIGTERM,
+ TtyIn = libc::SIGTTIN,
+ TtyOut = libc::SIGTTOU,
+ TtyStop = libc::SIGTSTP,
+ Urgent = libc::SIGURG,
+ User1 = libc::SIGUSR1,
+ User2 = libc::SIGUSR2,
+ VtAlarm = libc::SIGVTALRM,
+ Winch = libc::SIGWINCH,
+ Xcpu = libc::SIGXCPU,
+ // Rt signal numbers are be adjusted in the conversion to integer.
+ Rt0 = libc::SIGSYS + 1,
+ Rt1,
+ Rt2,
+ Rt3,
+ Rt4,
+ Rt5,
+ Rt6,
+ Rt7,
+ // Only 8 are guaranteed by POSIX, Linux has 32, but only 29 or 30 are usable.
+ Rt8,
+ Rt9,
+ Rt10,
+ Rt11,
+ Rt12,
+ Rt13,
+ Rt14,
+ Rt15,
+ Rt16,
+ Rt17,
+ Rt18,
+ Rt19,
+ Rt20,
+ Rt21,
+ Rt22,
+ Rt23,
+ Rt24,
+ Rt25,
+ Rt26,
+ Rt27,
+ Rt28,
+ Rt29,
+ Rt30,
+ Rt31,
+}
+
+impl From<Signal> for c_int {
+ fn from(signal: Signal) -> c_int {
+ let num = signal as libc::c_int;
+ if num >= Signal::Rt0 as libc::c_int {
+ return num - (Signal::Rt0 as libc::c_int) + SIGRTMIN();
}
+ num
+ }
+}
+
+impl TryFrom<c_int> for Signal {
+ type Error = Error;
+
+ fn try_from(value: c_int) -> result::Result<Self, Self::Error> {
+ use Signal::*;
+
+ Ok(match value {
+ libc::SIGABRT => Abort,
+ libc::SIGALRM => Alarm,
+ libc::SIGBUS => Bus,
+ libc::SIGCHLD => Child,
+ libc::SIGCONT => Continue,
+ libc::SIGXFSZ => ExceededFileSize,
+ libc::SIGFPE => FloatingPointException,
+ libc::SIGHUP => HangUp,
+ libc::SIGILL => IllegalInstruction,
+ libc::SIGINT => Interrupt,
+ libc::SIGIO => Io,
+ libc::SIGKILL => Kill,
+ libc::SIGPIPE => Pipe,
+ libc::SIGPWR => Power,
+ libc::SIGPROF => Profile,
+ libc::SIGQUIT => Quit,
+ libc::SIGSEGV => SegmentationViolation,
+ libc::SIGSTKFLT => StackFault,
+ libc::SIGSTOP => Stop,
+ libc::SIGSYS => Sys,
+ libc::SIGTRAP => Trap,
+ libc::SIGTERM => Terminate,
+ libc::SIGTTIN => TtyIn,
+ libc::SIGTTOU => TtyOut,
+ libc::SIGTSTP => TtyStop,
+ libc::SIGURG => Urgent,
+ libc::SIGUSR1 => User1,
+ libc::SIGUSR2 => User2,
+ libc::SIGVTALRM => VtAlarm,
+ libc::SIGWINCH => Winch,
+ libc::SIGXCPU => Xcpu,
+ _ => {
+ if value < SIGRTMIN() {
+ return Err(Error::UnrecognizedSignum(value));
+ }
+ let signal = match value - SIGRTMIN() {
+ 0 => Rt0,
+ 1 => Rt1,
+ 2 => Rt2,
+ 3 => Rt3,
+ 4 => Rt4,
+ 5 => Rt5,
+ 6 => Rt6,
+ 7 => Rt7,
+ 8 => Rt8,
+ 9 => Rt9,
+ 10 => Rt10,
+ 11 => Rt11,
+ 12 => Rt12,
+ 13 => Rt13,
+ 14 => Rt14,
+ 15 => Rt15,
+ 16 => Rt16,
+ 17 => Rt17,
+ 18 => Rt18,
+ 19 => Rt19,
+ 20 => Rt20,
+ 21 => Rt21,
+ 22 => Rt22,
+ 23 => Rt23,
+ 24 => Rt24,
+ 25 => Rt25,
+ 26 => Rt26,
+ 27 => Rt27,
+ 28 => Rt28,
+ 29 => Rt29,
+ 30 => Rt30,
+ 31 => Rt31,
+ _ => {
+ return Err(Error::UnrecognizedSignum(value));
+ }
+ };
+ if value > SIGRTMAX() {
+ return Err(Error::RtSignumGreaterThanMax(signal));
+ }
+ signal
+ }
+ })
}
}
@@ -101,7 +299,10 @@ fn valid_rt_signal_num(num: c_int) -> bool {
///
/// This is considered unsafe because the given handler will be called asynchronously, interrupting
/// whatever the thread was doing and therefore must only do async-signal-safe operations.
-pub unsafe fn register_signal_handler(num: c_int, handler: extern "C" fn()) -> errno::Result<()> {
+pub unsafe fn register_signal_handler(
+ num: c_int,
+ handler: extern "C" fn(c_int),
+) -> errno::Result<()> {
let mut sigact: sigaction = mem::zeroed();
sigact.sa_flags = SA_RESTART;
sigact.sa_sigaction = handler as *const () as usize;
@@ -114,6 +315,36 @@ pub unsafe fn register_signal_handler(num: c_int, handler: extern "C" fn()) -> e
Ok(())
}
+/// Resets the signal handler of signum `num` back to the default.
+pub fn clear_signal_handler(num: c_int) -> errno::Result<()> {
+ // Safe because sigaction is owned and expected to be initialized ot zeros.
+ let mut sigact: sigaction = unsafe { mem::zeroed() };
+ sigact.sa_flags = SA_RESTART;
+ sigact.sa_sigaction = SIG_DFL;
+
+ // Safe because sigact is owned, and this is restoring the default signal handler.
+ let ret = unsafe { sigaction(num, &sigact, null_mut()) };
+ if ret < 0 {
+ return errno_result();
+ }
+
+ Ok(())
+}
+
+/// Returns true if the signal handler for signum `num` is the default.
+pub fn has_default_signal_handler(num: c_int) -> errno::Result<bool> {
+ // Safe because sigaction is owned and expected to be initialized ot zeros.
+ let mut sigact: sigaction = unsafe { mem::zeroed() };
+
+ // Safe because sigact is owned, and this is just querying the existing state.
+ let ret = unsafe { sigaction(num, null(), &mut sigact) };
+ if ret < 0 {
+ return errno_result();
+ }
+
+ Ok(sigact.sa_sigaction == SIG_DFL)
+}
+
/// Registers `handler` as the signal handler for the real-time signal with signum `num`.
///
/// The value of `num` must be within [`SIGRTMIN`, `SIGRTMAX`] range.
@@ -124,7 +355,7 @@ pub unsafe fn register_signal_handler(num: c_int, handler: extern "C" fn()) -> e
/// whatever the thread was doing and therefore must only do async-signal-safe operations.
pub unsafe fn register_rt_signal_handler(
num: c_int,
- handler: extern "C" fn(),
+ handler: extern "C" fn(c_int),
) -> errno::Result<()> {
if !valid_rt_signal_num(num) {
return Err(errno::Error::new(EINVAL));
@@ -157,6 +388,34 @@ pub fn create_sigset(signals: &[c_int]) -> errno::Result<sigset_t> {
Ok(sigset)
}
+/// Wait for signal before continuing. The signal number of the consumed signal is returned on
+/// success. EAGAIN means the timeout was reached.
+pub fn wait_for_signal(signals: &[c_int], timeout: Option<Duration>) -> errno::Result<c_int> {
+ let sigset = create_sigset(signals)?;
+
+ match timeout {
+ Some(timeout) => {
+ let ts = duration_to_timespec(timeout);
+ // Safe - return value is checked.
+ let ret = handle_eintr_errno!(unsafe { sigtimedwait(&sigset, null_mut(), &ts) });
+ if ret < 0 {
+ errno_result()
+ } else {
+ Ok(ret)
+ }
+ }
+ None => {
+ let mut ret: c_int = 0;
+ let err = handle_eintr_rc!(unsafe { sigwait(&sigset, &mut ret as *mut c_int) });
+ if err != 0 {
+ Err(errno::Error::new(err))
+ } else {
+ Ok(ret)
+ }
+ }
+ }
+}
+
/// Retrieves the signal mask of the current thread as a vector of c_ints.
pub fn get_blocked_signals() -> SignalResult<Vec<c_int>> {
let mut mask = Vec::new();
@@ -266,6 +525,20 @@ pub fn clear_signal(num: c_int) -> SignalResult<()> {
Ok(())
}
+/// # Safety
+/// This is marked unsafe because it allows signals to be sent to arbitrary PIDs. Sending some
+/// signals may lead to undefined behavior. Also, the return codes of the child processes need to be
+/// reaped to avoid leaking zombie processes.
+pub unsafe fn kill(pid: Pid, signum: c_int) -> Result<()> {
+ let ret = libc::kill(pid, signum);
+
+ if ret != 0 {
+ errno_result()
+ } else {
+ Ok(())
+ }
+}
+
/// Trait for threads that can be signalled via `pthread_kill`.
///
/// Note that this is only useful for signals between SIGRTMIN and SIGRTMAX because these are
@@ -300,3 +573,156 @@ unsafe impl<T> Killable for JoinHandle<T> {
self.as_pthread_t()
}
}
+
+/// Treat some errno's as Ok(()).
+macro_rules! ok_if {
+ ($result:expr, $($errno:pat)|+) => {{
+ let res = $result;
+ match res {
+ Ok(_) => Ok(()),
+ Err(err) => {
+ if matches!(err.errno(), $($errno)|+) {
+ Ok(())
+ } else {
+ Err(err)
+ }
+ }
+ }
+ }}
+}
+
+/// Terminates and reaps a child process. If the child process is a group leader, its children will
+/// be terminated and reaped as well. After the given timeout, the child process and any relevant
+/// children are killed (i.e. sent SIGKILL).
+pub fn kill_tree(child: &mut Child, terminate_timeout: Duration) -> SignalResult<()> {
+ let target = {
+ let pid = child.id() as Pid;
+ if getsid(Some(pid)).map_err(Error::GetSid)? == pid {
+ -pid
+ } else {
+ pid
+ }
+ };
+
+ // Safe because target is a child process (or group) and behavior of SIGTERM is defined.
+ ok_if!(unsafe { kill(target, libc::SIGTERM) }, libc::ESRCH).map_err(Error::Kill)?;
+
+ // Reap the direct child first in case it waits for its descendants, afterward reap any
+ // remaining group members.
+ let start = Instant::now();
+ let mut child_running = true;
+ loop {
+ // Wait for the direct child to exit before reaping any process group members.
+ if child_running {
+ if child
+ .try_wait()
+ .map_err(|e| Error::WaitPid(errno::Error::from(e)))?
+ .is_some()
+ {
+ child_running = false;
+ // Skip the timeout check because waitpid(..., WNOHANG) will not block.
+ continue;
+ }
+ } else {
+ // Safe because target is a child process (or group), WNOHANG is used, and the return
+ // value is checked.
+ let ret = unsafe { waitpid(target, null_mut(), WNOHANG) };
+ match ret {
+ -1 => {
+ let err = errno::Error::last();
+ if err.errno() == libc::ECHILD {
+ // No group members to wait on.
+ break;
+ }
+ return Err(Error::WaitPid(err));
+ }
+ 0 => {}
+ // If a process was reaped, skip the timeout check in case there are more.
+ _ => continue,
+ };
+ }
+
+ // Check for a timeout.
+ let elapsed = start.elapsed();
+ if elapsed > terminate_timeout {
+ // Safe because target is a child process (or group) and behavior of SIGKILL is defined.
+ ok_if!(unsafe { kill(target, libc::SIGKILL) }, libc::ESRCH).map_err(Error::Kill)?;
+ return Err(Error::TimedOut);
+ }
+
+ // Wait a SIGCHLD or until either the remaining time or a poll interval elapses.
+ ok_if!(
+ wait_for_signal(
+ &[libc::SIGCHLD],
+ Some(POLL_RATE.min(terminate_timeout - elapsed))
+ ),
+ libc::EAGAIN | libc::EINTR
+ )
+ .map_err(Error::WaitForSignal)?
+ }
+
+ Ok(())
+}
+
+/// Wraps a Child process, and calls kill_tree for its process group to clean
+/// it up when dropped.
+pub struct KillOnDrop {
+ process: Child,
+ timeout: Duration,
+}
+
+impl KillOnDrop {
+ /// Get the timeout. See timeout_mut() for more details.
+ pub fn timeout(&self) -> Duration {
+ self.timeout
+ }
+
+ /// Change the timeout for how long child processes are waited for before
+ /// the process group is forcibly killed.
+ pub fn timeout_mut(&mut self) -> &mut Duration {
+ &mut self.timeout
+ }
+}
+
+impl From<Child> for KillOnDrop {
+ fn from(process: Child) -> Self {
+ KillOnDrop {
+ process,
+ timeout: DEFAULT_KILL_TIMEOUT,
+ }
+ }
+}
+
+impl AsRef<Child> for KillOnDrop {
+ fn as_ref(&self) -> &Child {
+ &self.process
+ }
+}
+
+impl AsMut<Child> for KillOnDrop {
+ fn as_mut(&mut self) -> &mut Child {
+ &mut self.process
+ }
+}
+
+impl Deref for KillOnDrop {
+ type Target = Child;
+
+ fn deref(&self) -> &Self::Target {
+ &self.process
+ }
+}
+
+impl DerefMut for KillOnDrop {
+ fn deref_mut(&mut self) -> &mut Self::Target {
+ &mut self.process
+ }
+}
+
+impl Drop for KillOnDrop {
+ fn drop(&mut self) {
+ if let Err(err) = kill_tree(&mut self.process, self.timeout) {
+ eprintln!("failed to kill child process group: {}", err);
+ }
+ }
+}
diff --git a/sys_util/src/sock_ctrl_msg.rs b/sys_util/src/sock_ctrl_msg.rs
index a9da8bac4..4bdfdc71b 100644
--- a/sys_util/src/sock_ctrl_msg.rs
+++ b/sys_util/src/sock_ctrl_msg.rs
@@ -105,11 +105,11 @@ impl CmsgBuffer {
}
}
-fn raw_sendmsg<D: IntoIobuf>(fd: RawFd, out_data: &[D], out_fds: &[RawFd]) -> Result<usize> {
+fn raw_sendmsg<D: AsIobuf>(fd: RawFd, out_data: &[D], out_fds: &[RawFd]) -> Result<usize> {
let cmsg_capacity = CMSG_SPACE!(size_of::<RawFd>() * out_fds.len());
let mut cmsg_buffer = CmsgBuffer::with_capacity(cmsg_capacity);
- let iovec = IntoIobuf::as_iobufs(out_data);
+ let iovec = AsIobuf::as_iobuf_slice(out_data);
let mut msg = msghdr {
msg_name: null_mut(),
@@ -155,19 +155,26 @@ fn raw_sendmsg<D: IntoIobuf>(fd: RawFd, out_data: &[D], out_fds: &[RawFd]) -> Re
}
fn raw_recvmsg(fd: RawFd, in_data: &mut [u8], in_fds: &mut [RawFd]) -> Result<(usize, usize)> {
- let cmsg_capacity = CMSG_SPACE!(size_of::<RawFd>() * in_fds.len());
- let mut cmsg_buffer = CmsgBuffer::with_capacity(cmsg_capacity);
-
- let mut iovec = iovec {
+ let iovec = iovec {
iov_base: in_data.as_mut_ptr() as *mut c_void,
iov_len: in_data.len(),
};
+ raw_recvmsg_iovecs(fd, &mut [iovec], in_fds)
+}
+
+fn raw_recvmsg_iovecs(
+ fd: RawFd,
+ iovecs: &mut [iovec],
+ in_fds: &mut [RawFd],
+) -> Result<(usize, usize)> {
+ let cmsg_capacity = CMSG_SPACE!(size_of::<RawFd>() * in_fds.len());
+ let mut cmsg_buffer = CmsgBuffer::with_capacity(cmsg_capacity);
let mut msg = msghdr {
msg_name: null_mut(),
msg_namelen: 0,
- msg_iov: &mut iovec as *mut iovec,
- msg_iovlen: 1,
+ msg_iov: iovecs.as_mut_ptr() as *mut iovec,
+ msg_iovlen: iovecs.len(),
msg_control: null_mut(),
msg_controllen: 0,
msg_flags: 0,
@@ -232,7 +239,7 @@ pub trait ScmSocket {
///
/// * `buf` - A buffer of data to send on the `socket`.
/// * `fd` - A file descriptors to be sent.
- fn send_with_fd<D: IntoIobuf>(&self, buf: &[D], fd: RawFd) -> Result<usize> {
+ fn send_with_fd<D: AsIobuf>(&self, buf: &[D], fd: RawFd) -> Result<usize> {
self.send_with_fds(buf, &[fd])
}
@@ -244,10 +251,35 @@ pub trait ScmSocket {
///
/// * `buf` - A buffer of data to send on the `socket`.
/// * `fds` - A list of file descriptors to be sent.
- fn send_with_fds<D: IntoIobuf>(&self, buf: &[D], fd: &[RawFd]) -> Result<usize> {
+ fn send_with_fds<D: AsIobuf>(&self, buf: &[D], fd: &[RawFd]) -> Result<usize> {
raw_sendmsg(self.socket_fd(), buf, fd)
}
+ /// Sends the given data and file descriptor over the socket.
+ ///
+ /// On success, returns the number of bytes sent.
+ ///
+ /// # Arguments
+ ///
+ /// * `bufs` - A slice of slices of data to send on the `socket`.
+ /// * `fd` - A file descriptors to be sent.
+ fn send_bufs_with_fd(&self, bufs: &[&[u8]], fd: RawFd) -> Result<usize> {
+ self.send_bufs_with_fds(bufs, &[fd])
+ }
+
+ /// Sends the given data and file descriptors over the socket.
+ ///
+ /// On success, returns the number of bytes sent.
+ ///
+ /// # Arguments
+ ///
+ /// * `bufs` - A slice of slices of data to send on the `socket`.
+ /// * `fds` - A list of file descriptors to be sent.
+ fn send_bufs_with_fds(&self, bufs: &[&[u8]], fd: &[RawFd]) -> Result<usize> {
+ let slices: Vec<IoSlice> = bufs.iter().map(|&b| IoSlice::new(b)).collect();
+ raw_sendmsg(self.socket_fd(), &slices, fd)
+ }
+
/// Receives data and potentially a file descriptor from the socket.
///
/// On success, returns the number of bytes and an optional file descriptor.
@@ -284,6 +316,27 @@ pub trait ScmSocket {
fn recv_with_fds(&self, buf: &mut [u8], fds: &mut [RawFd]) -> Result<(usize, usize)> {
raw_recvmsg(self.socket_fd(), buf, fds)
}
+
+ /// Receives data and file descriptors from the socket.
+ ///
+ /// On success, returns the number of bytes and file descriptors received as a tuple
+ /// `(bytes count, files count)`.
+ ///
+ /// # Arguments
+ ///
+ /// * `ioves` - A slice of iovecs to store received data.
+ /// * `fds` - A slice of `RawFd`s to put the received file descriptors into. On success, the
+ /// number of valid file descriptors is indicated by the second element of the
+ /// returned tuple. The caller owns these file descriptors, but they will not be
+ /// closed on drop like a `File`-like type would be. It is recommended that each valid
+ /// file descriptor gets wrapped in a drop type that closes it after this returns.
+ fn recv_iovecs_with_fds(
+ &self,
+ iovecs: &mut [iovec],
+ fds: &mut [RawFd],
+ ) -> Result<(usize, usize)> {
+ raw_recvmsg_iovecs(self.socket_fd(), iovecs, fds)
+ }
}
impl ScmSocket for UnixDatagram {
@@ -309,25 +362,26 @@ impl ScmSocket for UnixSeqpacket {
///
/// This trait is unsafe because interfaces that use this trait depend on the base pointer and size
/// being accurate.
-pub unsafe trait IntoIobuf: Sized {
+pub unsafe trait AsIobuf: Sized {
/// Returns a `iovec` that describes a contiguous region of memory.
- fn into_iobuf(&self) -> iovec;
+ fn as_iobuf(&self) -> iovec;
/// Returns a slice of `iovec`s that each describe a contiguous region of memory.
- fn as_iobufs(bufs: &[Self]) -> &[iovec];
+ #[allow(clippy::wrong_self_convention)]
+ fn as_iobuf_slice(bufs: &[Self]) -> &[iovec];
}
// Safe because there are no other mutable references to the memory described by `IoSlice` and it is
// guaranteed to be ABI-compatible with `iovec`.
-unsafe impl<'a> IntoIobuf for IoSlice<'a> {
- fn into_iobuf(&self) -> iovec {
+unsafe impl<'a> AsIobuf for IoSlice<'a> {
+ fn as_iobuf(&self) -> iovec {
iovec {
iov_base: self.as_ptr() as *mut c_void,
iov_len: self.len(),
}
}
- fn as_iobufs(bufs: &[Self]) -> &[iovec] {
+ fn as_iobuf_slice(bufs: &[Self]) -> &[iovec] {
// Safe because `IoSlice` is guaranteed to be ABI-compatible with `iovec`.
unsafe { slice::from_raw_parts(bufs.as_ptr() as *const iovec, bufs.len()) }
}
@@ -335,15 +389,15 @@ unsafe impl<'a> IntoIobuf for IoSlice<'a> {
// Safe because there are no other references to the memory described by `IoSliceMut` and it is
// guaranteed to be ABI-compatible with `iovec`.
-unsafe impl<'a> IntoIobuf for IoSliceMut<'a> {
- fn into_iobuf(&self) -> iovec {
+unsafe impl<'a> AsIobuf for IoSliceMut<'a> {
+ fn as_iobuf(&self) -> iovec {
iovec {
iov_base: self.as_ptr() as *mut c_void,
iov_len: self.len(),
}
}
- fn as_iobufs(bufs: &[Self]) -> &[iovec] {
+ fn as_iobuf_slice(bufs: &[Self]) -> &[iovec] {
// Safe because `IoSliceMut` is guaranteed to be ABI-compatible with `iovec`.
unsafe { slice::from_raw_parts(bufs.as_ptr() as *const iovec, bufs.len()) }
}
@@ -351,12 +405,12 @@ unsafe impl<'a> IntoIobuf for IoSliceMut<'a> {
// Safe because volatile slices are only ever accessed with other volatile interfaces and the
// pointer and size are guaranteed to be accurate.
-unsafe impl<'a> IntoIobuf for VolatileSlice<'a> {
- fn into_iobuf(&self) -> iovec {
+unsafe impl<'a> AsIobuf for VolatileSlice<'a> {
+ fn as_iobuf(&self) -> iovec {
*self.as_iobuf()
}
- fn as_iobufs(bufs: &[Self]) -> &[iovec] {
+ fn as_iobuf_slice(bufs: &[Self]) -> &[iovec] {
VolatileSlice::as_iobufs(bufs)
}
}
@@ -416,7 +470,8 @@ mod tests {
fn send_recv_no_fd() {
let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair");
- let ioslice = IoSlice::new([1u8, 1, 2, 21, 34, 55].as_ref());
+ let send_buf = [1u8, 1, 2, 21, 34, 55];
+ let ioslice = IoSlice::new(&send_buf);
let write_count = s1
.send_with_fds(&[ioslice], &[])
.expect("failed to send data");
@@ -432,6 +487,19 @@ mod tests {
assert_eq!(read_count, 6);
assert_eq!(file_count, 0);
assert_eq!(buf, [1, 1, 2, 21, 34, 55]);
+
+ let write_count = s1
+ .send_bufs_with_fds(&[&send_buf[..]], &[])
+ .expect("failed to send data");
+
+ assert_eq!(write_count, 6);
+ let (read_count, file_count) = s2
+ .recv_with_fds(&mut buf[..], &mut files)
+ .expect("failed to recv data");
+
+ assert_eq!(read_count, 6);
+ assert_eq!(file_count, 0);
+ assert_eq!(buf, [1, 1, 2, 21, 34, 55]);
}
#[test]
diff --git a/sys_util/src/vsock.rs b/sys_util/src/vsock.rs
new file mode 100644
index 000000000..a51dfdbfc
--- /dev/null
+++ b/sys_util/src/vsock.rs
@@ -0,0 +1,495 @@
+// Copyright 2021 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+/// Support for virtual sockets.
+use std::fmt;
+use std::io;
+use std::mem::{self, size_of};
+use std::num::ParseIntError;
+use std::os::raw::{c_uchar, c_uint, c_ushort};
+use std::os::unix::io::{AsRawFd, IntoRawFd, RawFd};
+use std::result;
+use std::str::FromStr;
+
+use libc::{
+ self, c_void, sa_family_t, size_t, sockaddr, socklen_t, F_GETFL, F_SETFL, O_NONBLOCK,
+ VMADDR_CID_ANY, VMADDR_CID_HOST, VMADDR_CID_HYPERVISOR,
+};
+
+// The domain for vsock sockets.
+const AF_VSOCK: sa_family_t = 40;
+
+// Vsock loopback address.
+const VMADDR_CID_LOCAL: c_uint = 1;
+
+/// Vsock equivalent of binding on port 0. Binds to a random port.
+pub const VMADDR_PORT_ANY: c_uint = c_uint::max_value();
+
+// The number of bytes of padding to be added to the sockaddr_vm struct. Taken directly
+// from linux/vm_sockets.h.
+const PADDING: usize = size_of::<sockaddr>()
+ - size_of::<sa_family_t>()
+ - size_of::<c_ushort>()
+ - (2 * size_of::<c_uint>());
+
+#[repr(C)]
+#[derive(Default)]
+struct sockaddr_vm {
+ svm_family: sa_family_t,
+ svm_reserved1: c_ushort,
+ svm_port: c_uint,
+ svm_cid: c_uint,
+ svm_zero: [c_uchar; PADDING],
+}
+
+#[derive(Debug)]
+pub struct AddrParseError;
+
+impl fmt::Display for AddrParseError {
+ fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
+ write!(fmt, "failed to parse vsock address")
+ }
+}
+
+/// The vsock equivalent of an IP address.
+#[derive(Debug, Copy, Clone, Hash, Eq, PartialEq)]
+pub enum VsockCid {
+ /// Vsock equivalent of INADDR_ANY. Indicates the context id of the current endpoint.
+ Any,
+ /// An address that refers to the bare-metal machine that serves as the hypervisor.
+ Hypervisor,
+ /// The loopback address.
+ Local,
+ /// The parent machine. It may not be the hypervisor for nested VMs.
+ Host,
+ /// An assigned CID that serves as the address for VSOCK.
+ Cid(c_uint),
+}
+
+impl fmt::Display for VsockCid {
+ fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
+ match &self {
+ VsockCid::Any => write!(fmt, "Any"),
+ VsockCid::Hypervisor => write!(fmt, "Hypervisor"),
+ VsockCid::Local => write!(fmt, "Local"),
+ VsockCid::Host => write!(fmt, "Host"),
+ VsockCid::Cid(c) => write!(fmt, "'{}'", c),
+ }
+ }
+}
+
+impl From<c_uint> for VsockCid {
+ fn from(c: c_uint) -> Self {
+ match c {
+ VMADDR_CID_ANY => VsockCid::Any,
+ VMADDR_CID_HYPERVISOR => VsockCid::Hypervisor,
+ VMADDR_CID_LOCAL => VsockCid::Local,
+ VMADDR_CID_HOST => VsockCid::Host,
+ _ => VsockCid::Cid(c),
+ }
+ }
+}
+
+impl FromStr for VsockCid {
+ type Err = ParseIntError;
+
+ fn from_str(s: &str) -> Result<Self, Self::Err> {
+ let c: c_uint = s.parse()?;
+ Ok(c.into())
+ }
+}
+
+impl From<VsockCid> for c_uint {
+ fn from(cid: VsockCid) -> c_uint {
+ match cid {
+ VsockCid::Any => VMADDR_CID_ANY,
+ VsockCid::Hypervisor => VMADDR_CID_HYPERVISOR,
+ VsockCid::Local => VMADDR_CID_LOCAL,
+ VsockCid::Host => VMADDR_CID_HOST,
+ VsockCid::Cid(c) => c,
+ }
+ }
+}
+
+/// An address associated with a virtual socket.
+#[derive(Debug, Copy, Clone, Hash, Eq, PartialEq)]
+pub struct SocketAddr {
+ pub cid: VsockCid,
+ pub port: c_uint,
+}
+
+pub trait ToSocketAddr {
+ fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError>;
+}
+
+impl ToSocketAddr for SocketAddr {
+ fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> {
+ Ok(*self)
+ }
+}
+
+impl ToSocketAddr for str {
+ fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> {
+ self.parse()
+ }
+}
+
+impl ToSocketAddr for (VsockCid, c_uint) {
+ fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> {
+ let (cid, port) = *self;
+ Ok(SocketAddr { cid, port })
+ }
+}
+
+impl<'a, T: ToSocketAddr + ?Sized> ToSocketAddr for &'a T {
+ fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> {
+ (**self).to_socket_addr()
+ }
+}
+
+impl FromStr for SocketAddr {
+ type Err = AddrParseError;
+
+ /// Parse a vsock SocketAddr from a string. vsock socket addresses are of the form
+ /// "vsock:cid:port".
+ fn from_str(s: &str) -> Result<SocketAddr, AddrParseError> {
+ let components: Vec<&str> = s.split(':').collect();
+ if components.len() != 3 || components[0] != "vsock" {
+ return Err(AddrParseError);
+ }
+
+ Ok(SocketAddr {
+ cid: components[1].parse().map_err(|_| AddrParseError)?,
+ port: components[2].parse().map_err(|_| AddrParseError)?,
+ })
+ }
+}
+
+impl fmt::Display for SocketAddr {
+ fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
+ write!(fmt, "{}:{}", self.cid, self.port)
+ }
+}
+
+/// Sets `fd` to be blocking or nonblocking. `fd` must be a valid fd of a type that accepts the
+/// `O_NONBLOCK` flag. This includes regular files, pipes, and sockets.
+unsafe fn set_nonblocking(fd: RawFd, nonblocking: bool) -> io::Result<()> {
+ let flags = libc::fcntl(fd, F_GETFL, 0);
+ if flags < 0 {
+ return Err(io::Error::last_os_error());
+ }
+
+ let flags = if nonblocking {
+ flags | O_NONBLOCK
+ } else {
+ flags & !O_NONBLOCK
+ };
+
+ let ret = libc::fcntl(fd, F_SETFL, flags);
+ if ret < 0 {
+ return Err(io::Error::last_os_error());
+ }
+
+ Ok(())
+}
+
+/// A virtual socket.
+///
+/// Do not use this class unless you need to change socket options or query the
+/// state of the socket prior to calling listen or connect. Instead use either VsockStream or
+/// VsockListener.
+#[derive(Debug)]
+pub struct VsockSocket {
+ fd: RawFd,
+}
+
+impl VsockSocket {
+ pub fn new() -> io::Result<Self> {
+ let fd = unsafe { libc::socket(libc::AF_VSOCK, libc::SOCK_STREAM | libc::SOCK_CLOEXEC, 0) };
+ if fd < 0 {
+ Err(io::Error::last_os_error())
+ } else {
+ Ok(VsockSocket { fd })
+ }
+ }
+
+ pub fn bind<A: ToSocketAddr>(&mut self, addr: A) -> io::Result<()> {
+ let sockaddr = addr
+ .to_socket_addr()
+ .map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?;
+
+ // The compiler should optimize this out since these are both compile-time constants.
+ assert_eq!(size_of::<sockaddr_vm>(), size_of::<sockaddr>());
+
+ let svm = sockaddr_vm {
+ svm_family: AF_VSOCK,
+ svm_cid: sockaddr.cid.into(),
+ svm_port: sockaddr.port,
+ ..Default::default()
+ };
+
+ // Safe because this doesn't modify any memory and we check the return value.
+ let ret = unsafe {
+ libc::bind(
+ self.fd,
+ &svm as *const sockaddr_vm as *const sockaddr,
+ size_of::<sockaddr_vm>() as socklen_t,
+ )
+ };
+ if ret < 0 {
+ let bind_err = io::Error::last_os_error();
+ Err(bind_err)
+ } else {
+ Ok(())
+ }
+ }
+
+ pub fn connect<A: ToSocketAddr>(self, addr: A) -> io::Result<VsockStream> {
+ let sockaddr = addr
+ .to_socket_addr()
+ .map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?;
+
+ let svm = sockaddr_vm {
+ svm_family: AF_VSOCK,
+ svm_cid: sockaddr.cid.into(),
+ svm_port: sockaddr.port,
+ ..Default::default()
+ };
+
+ // Safe because this just connects a vsock socket, and the return value is checked.
+ let ret = unsafe {
+ libc::connect(
+ self.fd,
+ &svm as *const sockaddr_vm as *const sockaddr,
+ size_of::<sockaddr_vm>() as socklen_t,
+ )
+ };
+ if ret < 0 {
+ let connect_err = io::Error::last_os_error();
+ Err(connect_err)
+ } else {
+ Ok(VsockStream { sock: self })
+ }
+ }
+
+ pub fn listen(self) -> io::Result<VsockListener> {
+ // Safe because this doesn't modify any memory and we check the return value.
+ let ret = unsafe { libc::listen(self.fd, 1) };
+ if ret < 0 {
+ let listen_err = io::Error::last_os_error();
+ return Err(listen_err);
+ }
+ Ok(VsockListener { sock: self })
+ }
+
+ /// Returns the port that this socket is bound to. This can only succeed after bind is called.
+ pub fn local_port(&self) -> io::Result<u32> {
+ let mut svm: sockaddr_vm = Default::default();
+
+ // Safe because we give a valid pointer for addrlen and check the length.
+ let mut addrlen = size_of::<sockaddr_vm>() as socklen_t;
+ let ret = unsafe {
+ // Get the socket address that was actually bound.
+ libc::getsockname(
+ self.fd,
+ &mut svm as *mut sockaddr_vm as *mut sockaddr,
+ &mut addrlen as *mut socklen_t,
+ )
+ };
+ if ret < 0 {
+ let getsockname_err = io::Error::last_os_error();
+ Err(getsockname_err)
+ } else {
+ // If this doesn't match, it's not safe to get the port out of the sockaddr.
+ assert_eq!(addrlen as usize, size_of::<sockaddr_vm>());
+
+ Ok(svm.svm_port)
+ }
+ }
+
+ pub fn try_clone(&self) -> io::Result<Self> {
+ // Safe because this doesn't modify any memory and we check the return value.
+ let dup_fd = unsafe { libc::fcntl(self.fd, libc::F_DUPFD_CLOEXEC, 0) };
+ if dup_fd < 0 {
+ Err(io::Error::last_os_error())
+ } else {
+ Ok(Self { fd: dup_fd })
+ }
+ }
+
+ pub fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> {
+ // Safe because the fd is valid and owned by this stream.
+ unsafe { set_nonblocking(self.fd, nonblocking) }
+ }
+}
+
+impl IntoRawFd for VsockSocket {
+ fn into_raw_fd(self) -> RawFd {
+ let fd = self.fd;
+ mem::forget(self);
+ fd
+ }
+}
+
+impl AsRawFd for VsockSocket {
+ fn as_raw_fd(&self) -> RawFd {
+ self.fd
+ }
+}
+
+impl Drop for VsockSocket {
+ fn drop(&mut self) {
+ // Safe because this doesn't modify any memory and we are the only
+ // owner of the file descriptor.
+ unsafe { libc::close(self.fd) };
+ }
+}
+
+/// A virtual stream socket.
+#[derive(Debug)]
+pub struct VsockStream {
+ sock: VsockSocket,
+}
+
+impl VsockStream {
+ pub fn connect<A: ToSocketAddr>(addr: A) -> io::Result<VsockStream> {
+ let sock = VsockSocket::new()?;
+ sock.connect(addr)
+ }
+
+ /// Returns the port that this stream is bound to.
+ pub fn local_port(&self) -> io::Result<u32> {
+ self.sock.local_port()
+ }
+
+ pub fn try_clone(&self) -> io::Result<VsockStream> {
+ self.sock.try_clone().map(|f| VsockStream { sock: f })
+ }
+
+ pub fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> {
+ self.sock.set_nonblocking(nonblocking)
+ }
+}
+
+impl io::Read for VsockStream {
+ fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
+ // Safe because this will only modify the contents of |buf| and we check the return value.
+ let ret = unsafe {
+ libc::read(
+ self.sock.as_raw_fd(),
+ buf as *mut [u8] as *mut c_void,
+ buf.len() as size_t,
+ )
+ };
+ if ret < 0 {
+ return Err(io::Error::last_os_error());
+ }
+
+ Ok(ret as usize)
+ }
+}
+
+impl io::Write for VsockStream {
+ fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
+ // Safe because this doesn't modify any memory and we check the return value.
+ let ret = unsafe {
+ libc::write(
+ self.sock.as_raw_fd(),
+ buf as *const [u8] as *const c_void,
+ buf.len() as size_t,
+ )
+ };
+ if ret < 0 {
+ return Err(io::Error::last_os_error());
+ }
+
+ Ok(ret as usize)
+ }
+
+ fn flush(&mut self) -> io::Result<()> {
+ // No buffered data so nothing to do.
+ Ok(())
+ }
+}
+
+impl AsRawFd for VsockStream {
+ fn as_raw_fd(&self) -> RawFd {
+ self.sock.as_raw_fd()
+ }
+}
+
+impl IntoRawFd for VsockStream {
+ fn into_raw_fd(self) -> RawFd {
+ self.sock.into_raw_fd()
+ }
+}
+
+/// Represents a virtual socket server.
+#[derive(Debug)]
+pub struct VsockListener {
+ sock: VsockSocket,
+}
+
+impl VsockListener {
+ /// Creates a new `VsockListener` bound to the specified port on the current virtual socket
+ /// endpoint.
+ pub fn bind<A: ToSocketAddr>(addr: A) -> io::Result<VsockListener> {
+ let mut sock = VsockSocket::new()?;
+ sock.bind(addr)?;
+ sock.listen()
+ }
+
+ /// Returns the port that this listener is bound to.
+ pub fn local_port(&self) -> io::Result<u32> {
+ self.sock.local_port()
+ }
+
+ /// Accepts a new incoming connection on this listener. Blocks the calling thread until a
+ /// new connection is established. When established, returns the corresponding `VsockStream`
+ /// and the remote peer's address.
+ pub fn accept(&self) -> io::Result<(VsockStream, SocketAddr)> {
+ let mut svm: sockaddr_vm = Default::default();
+
+ // Safe because this will only modify |svm| and we check the return value.
+ let mut socklen: socklen_t = size_of::<sockaddr_vm>() as socklen_t;
+ let fd = unsafe {
+ libc::accept4(
+ self.sock.as_raw_fd(),
+ &mut svm as *mut sockaddr_vm as *mut sockaddr,
+ &mut socklen as *mut socklen_t,
+ libc::SOCK_CLOEXEC,
+ )
+ };
+ if fd < 0 {
+ return Err(io::Error::last_os_error());
+ }
+
+ if svm.svm_family != AF_VSOCK {
+ return Err(io::Error::new(
+ io::ErrorKind::InvalidData,
+ format!("unexpected address family: {}", svm.svm_family),
+ ));
+ }
+
+ Ok((
+ VsockStream {
+ sock: VsockSocket { fd },
+ },
+ SocketAddr {
+ cid: svm.svm_cid.into(),
+ port: svm.svm_port,
+ },
+ ))
+ }
+
+ pub fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> {
+ self.sock.set_nonblocking(nonblocking)
+ }
+}
+
+impl AsRawFd for VsockListener {
+ fn as_raw_fd(&self) -> RawFd {
+ self.sock.as_raw_fd()
+ }
+}
diff --git a/usb_util/src/device.rs b/usb_util/src/device.rs
index b30d54139..47cad2130 100644
--- a/usb_util/src/device.rs
+++ b/usb_util/src/device.rs
@@ -7,7 +7,7 @@ use crate::{
ControlRequestDataPhaseTransferDirection, ControlRequestRecipient, ControlRequestType,
DeviceDescriptor, DeviceDescriptorTree, Error, Result, StandardControlRequest,
};
-use base::{handle_eintr_errno, IoctlNr};
+use base::{handle_eintr_errno, AsRawDescriptor, IoctlNr, RawDescriptor};
use data_model::vec_with_array_field;
use libc::{EAGAIN, ENODEV, ENOENT};
use std::convert::TryInto;
@@ -320,6 +320,12 @@ impl Device {
}
}
+impl AsRawDescriptor for Device {
+ fn as_raw_descriptor(&self) -> RawDescriptor {
+ self.fd.as_raw_descriptor()
+ }
+}
+
impl Transfer {
fn urb(&self) -> &usb_sys::usbdevfs_urb {
// self.urb is a Vec created with `vec_with_array_field`; the first entry is
diff --git a/vfio_sys/src/lib.rs b/vfio_sys/src/lib.rs
index 667feb553..b85233866 100644
--- a/vfio_sys/src/lib.rs
+++ b/vfio_sys/src/lib.rs
@@ -8,7 +8,9 @@
use base::ioctl_io_nr;
+pub mod plat;
pub mod vfio;
+pub use crate::plat::*;
pub use crate::vfio::*;
ioctl_io_nr!(VFIO_GET_API_VERSION, VFIO_TYPE, VFIO_BASE);
@@ -37,3 +39,9 @@ ioctl_io_nr!(VFIO_IOMMU_MAP_DMA, VFIO_TYPE, VFIO_BASE + 13);
ioctl_io_nr!(VFIO_IOMMU_UNMAP_DMA, VFIO_TYPE, VFIO_BASE + 14);
ioctl_io_nr!(VFIO_IOMMU_ENABLE, VFIO_TYPE, VFIO_BASE + 15);
ioctl_io_nr!(VFIO_IOMMU_DISABLE, VFIO_TYPE, VFIO_BASE + 16);
+
+ioctl_io_nr!(
+ PLAT_IRQ_FORWARD_SET,
+ PLAT_IRQ_FORWARD_TYPE,
+ PLAT_IRQ_FORWARD_BASE
+);
diff --git a/vfio_sys/src/plat.rs b/vfio_sys/src/plat.rs
new file mode 100644
index 000000000..7a2e633d1
--- /dev/null
+++ b/vfio_sys/src/plat.rs
@@ -0,0 +1,133 @@
+// Copyright 2021 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+
+/*
+ * automatically generated by rust-bindgen 0.56.0
+ * bindgen --constified-enum '*' --with-derive-default --no-doc-comments --no-layout-tests
+ */
+
+#[repr(C)]
+#[derive(Default)]
+pub struct __IncompleteArrayField<T>(::std::marker::PhantomData<T>, [T; 0]);
+impl<T> __IncompleteArrayField<T> {
+ #[inline]
+ pub const fn new() -> Self {
+ __IncompleteArrayField(::std::marker::PhantomData, [])
+ }
+ #[inline]
+ pub fn as_ptr(&self) -> *const T {
+ self as *const _ as *const T
+ }
+ #[inline]
+ pub fn as_mut_ptr(&mut self) -> *mut T {
+ self as *mut _ as *mut T
+ }
+ #[inline]
+ pub unsafe fn as_slice(&self, len: usize) -> &[T] {
+ ::std::slice::from_raw_parts(self.as_ptr(), len)
+ }
+ #[inline]
+ pub unsafe fn as_mut_slice(&mut self, len: usize) -> &mut [T] {
+ ::std::slice::from_raw_parts_mut(self.as_mut_ptr(), len)
+ }
+}
+impl<T> ::std::fmt::Debug for __IncompleteArrayField<T> {
+ fn fmt(&self, fmt: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
+ fmt.write_str("__IncompleteArrayField")
+ }
+}
+pub const _IOC_NRBITS: u32 = 8;
+pub const _IOC_TYPEBITS: u32 = 8;
+pub const _IOC_SIZEBITS: u32 = 14;
+pub const _IOC_DIRBITS: u32 = 2;
+pub const _IOC_NRMASK: u32 = 255;
+pub const _IOC_TYPEMASK: u32 = 255;
+pub const _IOC_SIZEMASK: u32 = 16383;
+pub const _IOC_DIRMASK: u32 = 3;
+pub const _IOC_NRSHIFT: u32 = 0;
+pub const _IOC_TYPESHIFT: u32 = 8;
+pub const _IOC_SIZESHIFT: u32 = 16;
+pub const _IOC_DIRSHIFT: u32 = 30;
+pub const _IOC_NONE: u32 = 0;
+pub const _IOC_WRITE: u32 = 1;
+pub const _IOC_READ: u32 = 2;
+pub const IOC_IN: u32 = 1073741824;
+pub const IOC_OUT: u32 = 2147483648;
+pub const IOC_INOUT: u32 = 3221225472;
+pub const IOCSIZE_MASK: u32 = 1073676288;
+pub const IOCSIZE_SHIFT: u32 = 16;
+pub const __BITS_PER_LONG: u32 = 64;
+pub const __FD_SETSIZE: u32 = 1024;
+pub const PLAT_IRQ_FORWARD_API_VERSION: u32 = 0;
+pub const PLAT_IRQ_FORWARD_TYPE: u32 = 59;
+pub const PLAT_IRQ_FORWARD_BASE: u32 = 100;
+pub const PLAT_IRQ_FORWARD_SET_LEVEL_TRIGGER_EVENTFD: u32 = 1;
+pub const PLAT_IRQ_FORWARD_SET_LEVEL_UNMASK_EVENTFD: u32 = 2;
+pub const PLAT_IRQ_FORWARD_SET_EDGE_TRIGGER: u32 = 4;
+pub type __s8 = ::std::os::raw::c_schar;
+pub type __u8 = ::std::os::raw::c_uchar;
+pub type __s16 = ::std::os::raw::c_short;
+pub type __u16 = ::std::os::raw::c_ushort;
+pub type __s32 = ::std::os::raw::c_int;
+pub type __u32 = ::std::os::raw::c_uint;
+pub type __s64 = ::std::os::raw::c_longlong;
+pub type __u64 = ::std::os::raw::c_ulonglong;
+#[repr(C)]
+#[derive(Debug, Default, Copy, Clone)]
+pub struct __kernel_fd_set {
+ pub fds_bits: [::std::os::raw::c_ulong; 16usize],
+}
+pub type __kernel_sighandler_t =
+ ::std::option::Option<unsafe extern "C" fn(arg1: ::std::os::raw::c_int)>;
+pub type __kernel_key_t = ::std::os::raw::c_int;
+pub type __kernel_mqd_t = ::std::os::raw::c_int;
+pub type __kernel_old_uid_t = ::std::os::raw::c_ushort;
+pub type __kernel_old_gid_t = ::std::os::raw::c_ushort;
+pub type __kernel_old_dev_t = ::std::os::raw::c_ulong;
+pub type __kernel_long_t = ::std::os::raw::c_long;
+pub type __kernel_ulong_t = ::std::os::raw::c_ulong;
+pub type __kernel_ino_t = __kernel_ulong_t;
+pub type __kernel_mode_t = ::std::os::raw::c_uint;
+pub type __kernel_pid_t = ::std::os::raw::c_int;
+pub type __kernel_ipc_pid_t = ::std::os::raw::c_int;
+pub type __kernel_uid_t = ::std::os::raw::c_uint;
+pub type __kernel_gid_t = ::std::os::raw::c_uint;
+pub type __kernel_suseconds_t = __kernel_long_t;
+pub type __kernel_daddr_t = ::std::os::raw::c_int;
+pub type __kernel_uid32_t = ::std::os::raw::c_uint;
+pub type __kernel_gid32_t = ::std::os::raw::c_uint;
+pub type __kernel_size_t = __kernel_ulong_t;
+pub type __kernel_ssize_t = __kernel_long_t;
+pub type __kernel_ptrdiff_t = __kernel_long_t;
+#[repr(C)]
+#[derive(Debug, Default, Copy, Clone)]
+pub struct __kernel_fsid_t {
+ pub val: [::std::os::raw::c_int; 2usize],
+}
+pub type __kernel_off_t = __kernel_long_t;
+pub type __kernel_loff_t = ::std::os::raw::c_longlong;
+pub type __kernel_time_t = __kernel_long_t;
+pub type __kernel_clock_t = __kernel_long_t;
+pub type __kernel_timer_t = ::std::os::raw::c_int;
+pub type __kernel_clockid_t = ::std::os::raw::c_int;
+pub type __kernel_caddr_t = *mut ::std::os::raw::c_char;
+pub type __kernel_uid16_t = ::std::os::raw::c_ushort;
+pub type __kernel_gid16_t = ::std::os::raw::c_ushort;
+pub type __le16 = __u16;
+pub type __be16 = __u16;
+pub type __le32 = __u32;
+pub type __be32 = __u32;
+pub type __le64 = __u64;
+pub type __be64 = __u64;
+pub type __sum16 = __u16;
+pub type __wsum = __u32;
+#[repr(C)]
+#[derive(Debug, Default)]
+pub struct plat_irq_forward_set {
+ pub argsz: __u32,
+ pub action_flags: __u32,
+ pub irq_number_host: __u32,
+ pub count: __u32,
+ pub eventfd: __IncompleteArrayField<__u8>,
+}
diff --git a/vhost/src/lib.rs b/vhost/src/lib.rs
index 1af9af6a9..618c5ac47 100644
--- a/vhost/src/lib.rs
+++ b/vhost/src/lib.rs
@@ -132,7 +132,7 @@ pub trait Vhost: AsRawDescriptor + std::marker::Sized {
let _ = self
.mem()
- .with_regions::<_, ()>(|index, guest_addr, size, host_addr, _| {
+ .with_regions::<_, ()>(|index, guest_addr, size, host_addr, _, _| {
vhost_regions[index] = virtio_sys::vhost_memory_region {
guest_phys_addr: guest_addr.offset() as u64,
memory_size: size as u64,
@@ -341,7 +341,7 @@ mod tests {
use crate::net::fakes::FakeNet;
use net_util::fakes::FakeTap;
- use std::result;
+ use std::{path::PathBuf, result};
use vm_memory::{GuestAddress, GuestMemory, GuestMemoryError};
fn create_guest_memory() -> result::Result<GuestMemory, GuestMemoryError> {
@@ -361,7 +361,7 @@ mod tests {
fn create_fake_vhost_net() -> FakeNet<FakeTap> {
let gm = create_guest_memory().unwrap();
- FakeNet::<FakeTap>::new(&gm).unwrap()
+ FakeNet::<FakeTap>::new(&PathBuf::from(""), &gm).unwrap()
}
#[test]
diff --git a/vhost/src/net.rs b/vhost/src/net.rs
index df6de95ce..d0e57e710 100644
--- a/vhost/src/net.rs
+++ b/vhost/src/net.rs
@@ -3,17 +3,18 @@
// found in the LICENSE file.
use net_util::TapT;
-use std::fs::{File, OpenOptions};
use std::marker::PhantomData;
use std::os::unix::fs::OpenOptionsExt;
+use std::{
+ fs::{File, OpenOptions},
+ path::PathBuf,
+};
use base::{ioctl_with_ref, AsRawDescriptor, RawDescriptor};
use vm_memory::GuestMemory;
use super::{ioctl_result, Error, Result, Vhost};
-static DEVICE: &str = "/dev/vhost-net";
-
/// Handle to run VHOST_NET ioctls.
///
/// This provides a simple wrapper around a VHOST_NET file descriptor and
@@ -28,7 +29,7 @@ pub struct Net<T> {
pub trait NetT<T: TapT>: Vhost + AsRawDescriptor + Send + Sized {
/// Create a new NetT instance
- fn new(mem: &GuestMemory) -> Result<Self>;
+ fn new(vhost_net_device_path: &PathBuf, mem: &GuestMemory) -> Result<Self>;
/// Set the tap file descriptor that will serve as the VHOST_NET backend.
/// This will start the vhost worker for the given queue.
@@ -47,13 +48,13 @@ where
///
/// # Arguments
/// * `mem` - Guest memory mapping.
- fn new(mem: &GuestMemory) -> Result<Net<T>> {
+ fn new(vhost_net_device_path: &PathBuf, mem: &GuestMemory) -> Result<Net<T>> {
Ok(Net::<T> {
descriptor: OpenOptions::new()
.read(true)
.write(true)
.custom_flags(libc::O_CLOEXEC | libc::O_NONBLOCK)
- .open(DEVICE)
+ .open(vhost_net_device_path)
.map_err(Error::VhostOpen)?,
mem: mem.clone(),
phantom: PhantomData,
@@ -117,7 +118,7 @@ pub mod fakes {
where
T: TapT,
{
- fn new(mem: &GuestMemory) -> Result<FakeNet<T>> {
+ fn new(_vhost_net_device_path: &PathBuf, mem: &GuestMemory) -> Result<FakeNet<T>> {
Ok(FakeNet::<T> {
descriptor: OpenOptions::new()
.read(true)
diff --git a/vhost/src/vsock.rs b/vhost/src/vsock.rs
index 95fbb72e5..fb9795137 100644
--- a/vhost/src/vsock.rs
+++ b/vhost/src/vsock.rs
@@ -2,8 +2,11 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
-use std::fs::{File, OpenOptions};
use std::os::unix::fs::OpenOptionsExt;
+use std::{
+ fs::{File, OpenOptions},
+ path::PathBuf,
+};
use base::{ioctl_with_ref, AsRawDescriptor, RawDescriptor};
use virtio_sys::{VHOST_VSOCK_SET_GUEST_CID, VHOST_VSOCK_SET_RUNNING};
@@ -11,8 +14,6 @@ use vm_memory::GuestMemory;
use super::{ioctl_result, Error, Result, Vhost};
-static DEVICE: &str = "/dev/vhost-vsock";
-
/// Handle for running VHOST_VSOCK ioctls.
pub struct Vsock {
descriptor: File,
@@ -21,13 +22,13 @@ pub struct Vsock {
impl Vsock {
/// Open a handle to a new VHOST_VSOCK instance.
- pub fn new(mem: &GuestMemory) -> Result<Vsock> {
+ pub fn new(vhost_vsock_device_path: &PathBuf, mem: &GuestMemory) -> Result<Vsock> {
Ok(Vsock {
descriptor: OpenOptions::new()
.read(true)
.write(true)
.custom_flags(libc::O_CLOEXEC | libc::O_NONBLOCK)
- .open(DEVICE)
+ .open(vhost_vsock_device_path)
.map_err(Error::VhostOpen)?,
mem: mem.clone(),
})
diff --git a/vm_control/Android.bp b/vm_control/Android.bp
index 85cc738b0..18325c0b3 100644
--- a/vm_control/Android.bp
+++ b/vm_control/Android.bp
@@ -1,5 +1,4 @@
-// This file is generated by cargo2android.py --run --device --tests --dependencies --global_defaults=crosvm_defaults --add_workspace --features=gdb.
-// NOTE: The --features=gdb should be applied only to the host (not the device) and there are inline changes to achieve this
+// This file is generated by cargo2android.py --run --device --tests --dependencies --global_defaults=crosvm_defaults --add_workspace.
package {
// See: http://go/android-license-faq
@@ -17,17 +16,6 @@ rust_library {
crate_name: "vm_control",
srcs: ["src/lib.rs"],
edition: "2018",
- target: {
- linux_glibc_x86_64: {
- features: [
- "gdb",
- "gdbstub",
- ],
- rustlibs: [
- "libgdbstub",
- ],
- },
- },
rustlibs: [
"libbase_rust",
"libdata_model",
@@ -49,17 +37,6 @@ rust_defaults {
test_suites: ["general-tests"],
auto_gen_config: true,
edition: "2018",
- target: {
- linux_glibc_x86_64: {
- features: [
- "gdb",
- "gdbstub",
- ],
- rustlibs: [
- "libgdbstub",
- ],
- },
- },
rustlibs: [
"libbase_rust",
"libdata_model",
@@ -76,6 +53,7 @@ rust_defaults {
rust_test_host {
name: "vm_control_host_test_src_lib",
defaults: ["vm_control_defaults"],
+ shared_libs: ["libgfxstream_backend"],
test_options: {
unit_test: true,
},
@@ -109,11 +87,8 @@ rust_test {
// ../tempfile/src/lib.rs
// ../vm_memory/src/lib.rs
// async-task-4.0.3 "default,std"
-// async-trait-0.1.48
-// autocfg-1.0.1
+// async-trait-0.1.45
// base-0.1.0
-// cfg-if-0.1.10
-// cfg-if-1.0.0
// downcast-rs-1.2.0 "default,std"
// futures-0.3.13 "alloc,async-await,default,executor,futures-executor,std"
// futures-channel-0.3.13 "alloc,futures-sink,sink,std"
@@ -124,12 +99,8 @@ rust_test {
// futures-sink-0.3.13 "alloc,std"
// futures-task-0.3.13 "alloc,std"
// futures-util-0.3.13 "alloc,async-await,async-await-macro,channel,futures-channel,futures-io,futures-macro,futures-sink,io,memchr,proc-macro-hack,proc-macro-nested,sink,slab,std"
-// gdbstub-0.4.4 "alloc,default,std"
-// libc-0.2.88 "default,std"
-// log-0.4.14
-// managed-0.8.0 "alloc"
+// libc-0.2.87 "default,std"
// memchr-2.3.4 "default,std"
-// num-traits-0.2.14
// paste-1.0.4
// pin-project-lite-0.2.6
// pin-utils-0.1.0
@@ -137,10 +108,10 @@ rust_test {
// proc-macro-nested-0.1.7
// proc-macro2-1.0.24 "default,proc-macro"
// quote-1.0.9 "default,proc-macro"
-// serde-1.0.124 "default,derive,serde_derive,std"
-// serde_derive-1.0.124 "default"
+// serde-1.0.123 "default,derive,serde_derive,std"
+// serde_derive-1.0.123 "default"
// slab-0.4.2
-// syn-1.0.63 "clone-impls,default,derive,full,parsing,printing,proc-macro,quote,visit-mut"
+// syn-1.0.61 "clone-impls,default,derive,full,parsing,printing,proc-macro,quote,visit-mut"
// thiserror-1.0.24
// thiserror-impl-1.0.24
// unicode-xid-0.2.1 "default"
diff --git a/vm_control/Cargo.toml b/vm_control/Cargo.toml
index 3c59d9874..27b1066b1 100644
--- a/vm_control/Cargo.toml
+++ b/vm_control/Cargo.toml
@@ -8,13 +8,13 @@ edition = "2018"
gdb = ["gdbstub"]
[dependencies]
+base = { path = "../base" }
data_model = { path = "../data_model" }
gdbstub = { version = "0.4.0", optional = true }
hypervisor = { path = "../hypervisor" }
libc = "*"
-msg_socket = { path = "../msg_socket" }
resources = { path = "../resources" }
rutabaga_gfx = { path = "../rutabaga_gfx"}
+serde = { version = "1", features = [ "derive" ] }
sync = { path = "../sync" }
-base = { path = "../base" }
vm_memory = { path = "../vm_memory" }
diff --git a/vm_control/src/client.rs b/vm_control/src/client.rs
new file mode 100644
index 000000000..3fff3e32f
--- /dev/null
+++ b/vm_control/src/client.rs
@@ -0,0 +1,211 @@
+// Copyright 2021 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file.
+use crate::*;
+use base::{info, net::UnixSeqpacket, validate_raw_descriptor, RawDescriptor, Tube};
+
+use std::fs::OpenOptions;
+use std::num::ParseIntError;
+use std::path::{Path, PathBuf};
+
+enum ModifyBatError {
+ BatControlErr(BatControlResult),
+}
+
+impl fmt::Display for ModifyBatError {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ use self::ModifyBatError::*;
+
+ match self {
+ BatControlErr(e) => write!(f, "{}", e),
+ }
+ }
+}
+
+pub enum ModifyUsbError {
+ ArgMissing(&'static str),
+ ArgParse(&'static str, String),
+ ArgParseInt(&'static str, String, ParseIntError),
+ FailedDescriptorValidate(base::Error),
+ PathDoesNotExist(PathBuf),
+ SocketFailed,
+ UnexpectedResponse(VmResponse),
+ UnknownCommand(String),
+ UsbControl(UsbControlResult),
+}
+
+impl std::fmt::Display for ModifyUsbError {
+ fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
+ use self::ModifyUsbError::*;
+
+ match self {
+ ArgMissing(a) => write!(f, "argument missing: {}", a),
+ ArgParse(name, value) => {
+ write!(f, "failed to parse argument {} value `{}`", name, value)
+ }
+ ArgParseInt(name, value, e) => write!(
+ f,
+ "failed to parse integer argument {} value `{}`: {}",
+ name, value, e
+ ),
+ FailedDescriptorValidate(e) => write!(f, "failed to validate file descriptor: {}", e),
+ PathDoesNotExist(p) => write!(f, "path `{}` does not exist", p.display()),
+ SocketFailed => write!(f, "socket failed"),
+ UnexpectedResponse(r) => write!(f, "unexpected response: {}", r),
+ UnknownCommand(c) => write!(f, "unknown command: `{}`", c),
+ UsbControl(e) => write!(f, "{}", e),
+ }
+ }
+}
+
+pub type ModifyUsbResult<T> = std::result::Result<T, ModifyUsbError>;
+
+fn raw_descriptor_from_path(path: &Path) -> ModifyUsbResult<RawDescriptor> {
+ if !path.exists() {
+ return Err(ModifyUsbError::PathDoesNotExist(path.to_owned()));
+ }
+ let raw_descriptor = path
+ .file_name()
+ .and_then(|fd_osstr| fd_osstr.to_str())
+ .map_or(
+ Err(ModifyUsbError::ArgParse(
+ "USB_DEVICE_PATH",
+ path.to_string_lossy().into_owned(),
+ )),
+ |fd_str| {
+ fd_str.parse::<libc::c_int>().map_err(|e| {
+ ModifyUsbError::ArgParseInt("USB_DEVICE_PATH", fd_str.to_owned(), e)
+ })
+ },
+ )?;
+ validate_raw_descriptor(raw_descriptor).map_err(ModifyUsbError::FailedDescriptorValidate)
+}
+
+pub type VmsRequestResult = std::result::Result<(), ()>;
+
+pub fn vms_request(request: &VmRequest, socket_path: &Path) -> VmsRequestResult {
+ let response = handle_request(request, socket_path)?;
+ info!("request response was {}", response);
+ Ok(())
+}
+
+pub fn do_usb_attach(
+ socket_path: &Path,
+ bus: u8,
+ addr: u8,
+ vid: u16,
+ pid: u16,
+ dev_path: &Path,
+) -> ModifyUsbResult<UsbControlResult> {
+ let usb_file: File = if dev_path.parent() == Some(Path::new("/proc/self/fd")) {
+ // Special case '/proc/self/fd/*' paths. The FD is already open, just use it.
+ // Safe because we will validate |raw_fd|.
+ unsafe { File::from_raw_descriptor(raw_descriptor_from_path(&dev_path)?) }
+ } else {
+ OpenOptions::new()
+ .read(true)
+ .write(true)
+ .open(&dev_path)
+ .map_err(|_| ModifyUsbError::UsbControl(UsbControlResult::FailedToOpenDevice))?
+ };
+
+ let request = VmRequest::UsbCommand(UsbControlCommand::AttachDevice {
+ bus,
+ addr,
+ vid,
+ pid,
+ file: usb_file,
+ });
+ let response =
+ handle_request(&request, socket_path).map_err(|_| ModifyUsbError::SocketFailed)?;
+ match response {
+ VmResponse::UsbResponse(usb_resp) => Ok(usb_resp),
+ r => Err(ModifyUsbError::UnexpectedResponse(r)),
+ }
+}
+
+pub fn do_usb_detach(socket_path: &Path, port: u8) -> ModifyUsbResult<UsbControlResult> {
+ let request = VmRequest::UsbCommand(UsbControlCommand::DetachDevice { port });
+ let response =
+ handle_request(&request, socket_path).map_err(|_| ModifyUsbError::SocketFailed)?;
+ match response {
+ VmResponse::UsbResponse(usb_resp) => Ok(usb_resp),
+ r => Err(ModifyUsbError::UnexpectedResponse(r)),
+ }
+}
+
+pub fn do_usb_list(socket_path: &Path) -> ModifyUsbResult<UsbControlResult> {
+ let mut ports: [u8; USB_CONTROL_MAX_PORTS] = Default::default();
+ for (index, port) in ports.iter_mut().enumerate() {
+ *port = index as u8
+ }
+ let request = VmRequest::UsbCommand(UsbControlCommand::ListDevice { ports });
+ let response =
+ handle_request(&request, socket_path).map_err(|_| ModifyUsbError::SocketFailed)?;
+ match response {
+ VmResponse::UsbResponse(usb_resp) => Ok(usb_resp),
+ r => Err(ModifyUsbError::UnexpectedResponse(r)),
+ }
+}
+
+pub type DoModifyBatteryResult = std::result::Result<(), ()>;
+
+pub fn do_modify_battery(
+ socket_path: &Path,
+ battery_type: &str,
+ property: &str,
+ target: &str,
+) -> DoModifyBatteryResult {
+ let response = match battery_type.parse::<BatteryType>() {
+ Ok(type_) => match BatControlCommand::new(property.to_string(), target.to_string()) {
+ Ok(cmd) => {
+ let request = VmRequest::BatCommand(type_, cmd);
+ Ok(handle_request(&request, socket_path)?)
+ }
+ Err(e) => Err(ModifyBatError::BatControlErr(e)),
+ },
+ Err(e) => Err(ModifyBatError::BatControlErr(e)),
+ };
+
+ match response {
+ Ok(response) => {
+ println!("{}", response);
+ Ok(())
+ }
+ Err(e) => {
+ println!("error {}", e);
+ Err(())
+ }
+ }
+}
+
+pub type HandleRequestResult = std::result::Result<VmResponse, ()>;
+
+pub fn handle_request(request: &VmRequest, socket_path: &Path) -> HandleRequestResult {
+ match UnixSeqpacket::connect(&socket_path) {
+ Ok(s) => {
+ let socket = Tube::new(s);
+ if let Err(e) = socket.send(request) {
+ error!(
+ "failed to send request to socket at '{:?}': {}",
+ socket_path, e
+ );
+ return Err(());
+ }
+ match socket.recv() {
+ Ok(response) => Ok(response),
+ Err(e) => {
+ error!(
+ "failed to send request to socket at '{:?}': {}",
+ socket_path, e
+ );
+ Err(())
+ }
+ }
+ }
+ Err(e) => {
+ error!("failed to connect to socket at '{:?}': {}", socket_path, e);
+ Err(())
+ }
+ }
+}
diff --git a/vm_control/src/lib.rs b/vm_control/src/lib.rs
index ce236284e..3222edf09 100644
--- a/vm_control/src/lib.rs
+++ b/vm_control/src/lib.rs
@@ -13,37 +13,41 @@
#[cfg(all(target_arch = "x86_64", feature = "gdb"))]
pub mod gdb;
+pub mod client;
+
use std::fmt::{self, Display};
use std::fs::File;
-use std::mem::ManuallyDrop;
use std::os::raw::c_int;
use std::result::Result as StdResult;
use std::str::FromStr;
use std::sync::Arc;
use libc::{EINVAL, EIO, ENODEV};
+use serde::{Deserialize, Serialize};
use base::{
- error, AsRawDescriptor, Error as SysError, Event, ExternalMapping, Fd, FromRawDescriptor,
- IntoRawDescriptor, MappedRegion, MemoryMappingArena, MemoryMappingBuilder, MmapError,
- Protection, RawDescriptor, Result, SafeDescriptor,
+ error, with_as_descriptor, AsRawDescriptor, Error as SysError, Event, ExternalMapping, Fd,
+ FromRawDescriptor, IntoRawDescriptor, MappedRegion, MemoryMappingArena, MemoryMappingBuilder,
+ MemoryMappingBuilderUnix, MmapError, Protection, Result, SafeDescriptor, SharedMemory, Tube,
};
use hypervisor::{IrqRoute, IrqSource, Vm};
-use msg_socket::{MsgError, MsgOnSocket, MsgReceiver, MsgResult, MsgSender, MsgSocket};
use resources::{Alloc, MmioType, SystemAllocator};
-use rutabaga_gfx::{DrmFormat, ImageAllocationInfo, RutabagaGralloc, RutabagaGrallocFlags};
+use rutabaga_gfx::{
+ DrmFormat, ImageAllocationInfo, RutabagaGralloc, RutabagaGrallocFlags, RutabagaHandle,
+ VulkanInfo,
+};
use sync::Mutex;
use vm_memory::GuestAddress;
/// Struct that describes the offset and stride of a plane located in GPU memory.
-#[derive(Clone, Copy, Debug, PartialEq, Default, MsgOnSocket)]
+#[derive(Clone, Copy, Debug, PartialEq, Default, Serialize, Deserialize)]
pub struct GpuMemoryPlaneDesc {
pub stride: u32,
pub offset: u32,
}
/// Struct that describes a GPU memory allocation that consists of up to 3 planes.
-#[derive(Clone, Copy, Debug, Default, MsgOnSocket)]
+#[derive(Clone, Copy, Debug, Default, Serialize, Deserialize)]
pub struct GpuMemoryDesc {
pub planes: [GpuMemoryPlaneDesc; 3],
}
@@ -60,66 +64,6 @@ pub enum VcpuControl {
RunState(VmRunMode),
}
-/// A file descriptor either borrowed or owned by this.
-#[derive(Debug)]
-pub enum MaybeOwnedDescriptor {
- /// Owned by this enum variant, and will be destructed automatically if not moved out.
- Owned(SafeDescriptor),
- /// A file descriptor borrwed by this enum.
- Borrowed(RawDescriptor),
-}
-
-impl AsRawDescriptor for MaybeOwnedDescriptor {
- fn as_raw_descriptor(&self) -> RawDescriptor {
- match self {
- MaybeOwnedDescriptor::Owned(f) => f.as_raw_descriptor(),
- MaybeOwnedDescriptor::Borrowed(descriptor) => *descriptor,
- }
- }
-}
-
-impl AsRawDescriptor for &MaybeOwnedDescriptor {
- fn as_raw_descriptor(&self) -> RawDescriptor {
- match self {
- MaybeOwnedDescriptor::Owned(f) => f.as_raw_descriptor(),
- MaybeOwnedDescriptor::Borrowed(descriptor) => *descriptor,
- }
- }
-}
-
-// When sent, it could be owned or borrowed. On the receiver end, it always owned.
-impl MsgOnSocket for MaybeOwnedDescriptor {
- fn uses_descriptor() -> bool {
- true
- }
- fn fixed_size() -> Option<usize> {
- Some(0)
- }
- fn descriptor_count(&self) -> usize {
- 1usize
- }
- unsafe fn read_from_buffer(
- buffer: &[u8],
- descriptors: &[RawDescriptor],
- ) -> MsgResult<(Self, usize)> {
- let (file, size) = File::read_from_buffer(buffer, descriptors)?;
- let safe_descriptor = SafeDescriptor::from_raw_descriptor(file.into_raw_descriptor());
- Ok((MaybeOwnedDescriptor::Owned(safe_descriptor), size))
- }
- fn write_to_buffer(
- &self,
- _buffer: &mut [u8],
- descriptors: &mut [RawDescriptor],
- ) -> MsgResult<usize> {
- if descriptors.is_empty() {
- return Err(MsgError::WrongDescriptorBufferSize);
- }
-
- descriptors[0] = self.as_raw_descriptor();
- Ok(1)
- }
-}
-
/// Mode of execution for the VM.
#[derive(Debug, Clone, PartialEq)]
pub enum VmRunMode {
@@ -159,7 +103,7 @@ impl Default for VmRunMode {
/// require adding a big dependency for a single const.
pub const USB_CONTROL_MAX_PORTS: usize = 16;
-#[derive(MsgOnSocket, Debug)]
+#[derive(Serialize, Deserialize, Debug)]
pub enum BalloonControlCommand {
/// Set the size of the VM's balloon.
Adjust {
@@ -169,7 +113,7 @@ pub enum BalloonControlCommand {
}
// BalloonStats holds stats returned from the stats_queue.
-#[derive(Default, MsgOnSocket, Debug)]
+#[derive(Default, Serialize, Deserialize, Debug)]
pub struct BalloonStats {
pub swap_in: Option<u64>,
pub swap_out: Option<u64>,
@@ -220,7 +164,7 @@ impl Display for BalloonStats {
}
}
-#[derive(MsgOnSocket, Debug)]
+#[derive(Serialize, Deserialize, Debug)]
pub enum BalloonControlResult {
Stats {
stats: BalloonStats,
@@ -228,7 +172,7 @@ pub enum BalloonControlResult {
},
}
-#[derive(MsgOnSocket, Debug)]
+#[derive(Serialize, Deserialize, Debug)]
pub enum DiskControlCommand {
/// Resize a disk to `new_size` in bytes.
Resize { new_size: u64 },
@@ -244,20 +188,21 @@ impl Display for DiskControlCommand {
}
}
-#[derive(MsgOnSocket, Debug)]
+#[derive(Serialize, Deserialize, Debug)]
pub enum DiskControlResult {
Ok,
Err(SysError),
}
-#[derive(MsgOnSocket, Debug)]
+#[derive(Serialize, Deserialize, Debug)]
pub enum UsbControlCommand {
AttachDevice {
bus: u8,
addr: u8,
vid: u16,
pid: u16,
- descriptor: Option<MaybeOwnedDescriptor>,
+ #[serde(with = "with_as_descriptor")]
+ file: File,
},
DetachDevice {
port: u8,
@@ -267,7 +212,7 @@ pub enum UsbControlCommand {
},
}
-#[derive(MsgOnSocket, Copy, Clone, Debug, Default)]
+#[derive(Serialize, Deserialize, Copy, Clone, Debug, Default)]
pub struct UsbControlAttachedDevice {
pub port: u8,
pub vendor_id: u16,
@@ -275,12 +220,12 @@ pub struct UsbControlAttachedDevice {
}
impl UsbControlAttachedDevice {
- fn valid(self) -> bool {
+ pub fn valid(self) -> bool {
self.port != 0
}
}
-#[derive(MsgOnSocket, Debug)]
+#[derive(Serialize, Deserialize, Debug)]
pub enum UsbControlResult {
Ok { port: u8 },
NoAvailablePort,
@@ -295,7 +240,7 @@ impl Display for UsbControlResult {
use self::UsbControlResult::*;
match self {
- Ok { port } => write!(f, "ok {}", port),
+ UsbControlResult::Ok { port } => write!(f, "ok {}", port),
NoAvailablePort => write!(f, "no_available_port"),
NoSuchDevice => write!(f, "no_such_device"),
NoSuchPort => write!(f, "no_such_port"),
@@ -311,16 +256,27 @@ impl Display for UsbControlResult {
}
}
-#[derive(MsgOnSocket, Debug)]
+#[derive(Serialize, Deserialize)]
pub enum VmMemoryRequest {
/// Register shared memory represented by the given descriptor into guest address space.
/// The response variant is `VmResponse::RegisterMemory`.
- RegisterMemory(MaybeOwnedDescriptor, usize),
+ RegisterMemory(SharedMemory),
/// Similiar to `VmMemoryRequest::RegisterMemory`, but doesn't allocate new address space.
/// Useful for cases where the address space is already allocated (PCI regions).
- RegisterFdAtPciBarOffset(Alloc, MaybeOwnedDescriptor, usize, u64),
+ RegisterFdAtPciBarOffset(Alloc, SafeDescriptor, usize, u64),
/// Similar to RegisterFdAtPciBarOffset, but is for buffers in the current address space.
RegisterHostPointerAtPciBarOffset(Alloc, u64),
+ /// Similiar to `RegisterFdAtPciBarOffset`, but uses Vulkano to map the resource instead of
+ /// the mmap system call.
+ RegisterVulkanMemoryAtPciBarOffset {
+ alloc: Alloc,
+ descriptor: SafeDescriptor,
+ handle_type: u32,
+ memory_idx: u32,
+ physical_device_idx: u32,
+ offset: u64,
+ size: u64,
+ },
/// Unregister the given memory slot that was previously registered with `RegisterMemory*`.
UnregisterMemory(MemSlot),
/// Allocate GPU buffer of a given size/format and register the memory into guest address space.
@@ -332,7 +288,7 @@ pub enum VmMemoryRequest {
},
/// Register mmaped memory into the hypervisor's EPT.
RegisterMmapMemory {
- descriptor: MaybeOwnedDescriptor,
+ descriptor: SafeDescriptor,
size: usize,
offset: u64,
gpa: u64,
@@ -350,16 +306,16 @@ impl VmMemoryRequest {
/// `VmMemoryResponse` with the intended purpose of sending the response back over the socket
/// that received this `VmMemoryResponse`.
pub fn execute(
- &self,
+ self,
vm: &mut impl Vm,
sys_allocator: &mut SystemAllocator,
map_request: Arc<Mutex<Option<ExternalMapping>>>,
gralloc: &mut RutabagaGralloc,
) -> VmMemoryResponse {
use self::VmMemoryRequest::*;
- match *self {
- RegisterMemory(ref descriptor, size) => {
- match register_memory(vm, sys_allocator, descriptor, size, None) {
+ match self {
+ RegisterMemory(ref shm) => {
+ match register_memory(vm, sys_allocator, shm, shm.size() as usize, None) {
Ok((pfn, slot)) => VmMemoryResponse::RegisterMemory { pfn, slot },
Err(e) => VmMemoryResponse::Err(e),
}
@@ -381,7 +337,39 @@ impl VmMemoryRequest {
.ok_or_else(|| VmMemoryResponse::Err(SysError::new(EINVAL)))
.unwrap();
- match register_memory_hva(vm, sys_allocator, Box::new(mem), (alloc, offset)) {
+ match register_host_pointer(vm, sys_allocator, Box::new(mem), (alloc, offset)) {
+ Ok((pfn, slot)) => VmMemoryResponse::RegisterMemory { pfn, slot },
+ Err(e) => VmMemoryResponse::Err(e),
+ }
+ }
+ RegisterVulkanMemoryAtPciBarOffset {
+ alloc,
+ descriptor,
+ handle_type,
+ memory_idx,
+ physical_device_idx,
+ offset,
+ size,
+ } => {
+ let mapped_region = match gralloc.import_and_map(
+ RutabagaHandle {
+ os_handle: descriptor,
+ handle_type,
+ },
+ VulkanInfo {
+ memory_idx,
+ physical_device_idx,
+ },
+ size,
+ ) {
+ Ok(mapped_region) => mapped_region,
+ Err(e) => {
+ error!("gralloc failed to import and map: {}", e);
+ return VmMemoryResponse::Err(SysError::new(EINVAL));
+ }
+ };
+
+ match register_host_pointer(vm, sys_allocator, mapped_region, (alloc, offset)) {
Ok((pfn, slot)) => VmMemoryResponse::RegisterMemory { pfn, slot },
Err(e) => VmMemoryResponse::Err(e),
}
@@ -438,11 +426,11 @@ impl VmMemoryRequest {
Ok((pfn, slot)) => VmMemoryResponse::AllocateAndRegisterGpuMemory {
// Safe because ownership is transferred to SafeDescriptor via
// into_raw_descriptor
- descriptor: MaybeOwnedDescriptor::Owned(unsafe {
+ descriptor: unsafe {
SafeDescriptor::from_raw_descriptor(
handle.os_handle.into_raw_descriptor(),
)
- }),
+ },
pfn,
slot,
desc,
@@ -473,7 +461,7 @@ impl VmMemoryRequest {
}
}
-#[derive(MsgOnSocket, Debug)]
+#[derive(Serialize, Deserialize, Debug)]
pub enum VmMemoryResponse {
/// The request to register memory into guest address space was successfully done at page frame
/// number `pfn` and memory slot number `slot`.
@@ -484,7 +472,7 @@ pub enum VmMemoryResponse {
/// The request to allocate and register GPU memory into guest address space was successfully
/// done at page frame number `pfn` and memory slot number `slot` for buffer with `desc`.
AllocateAndRegisterGpuMemory {
- descriptor: MaybeOwnedDescriptor,
+ descriptor: SafeDescriptor,
pfn: u64,
slot: MemSlot,
desc: GpuMemoryDesc,
@@ -493,10 +481,10 @@ pub enum VmMemoryResponse {
Err(SysError),
}
-#[derive(MsgOnSocket, Debug)]
+#[derive(Serialize, Deserialize, Debug)]
pub enum VmIrqRequest {
/// Allocate one gsi, and associate gsi to irqfd with register_irqfd()
- AllocateOneMsi { irqfd: MaybeOwnedDescriptor },
+ AllocateOneMsi { irqfd: Event },
/// Add one msi route entry into the IRQ chip.
AddMsiRoute {
gsi: u32,
@@ -530,19 +518,7 @@ impl VmIrqRequest {
match *self {
AllocateOneMsi { ref irqfd } => {
if let Some(irq_num) = sys_allocator.allocate_irq() {
- // Because of the limitation of `MaybeOwnedDescriptor` not fitting into
- // `register_irqfd` which expects an `&Event`, we use the unsafe `from_raw_fd`
- // to assume that the descriptor given is an `Event`, and we ignore the
- // ownership question using `ManuallyDrop`. This is safe because `ManuallyDrop`
- // prevents any Drop implementation from triggering on `irqfd` which already has
- // an owner, and the `Event` methods are never called. The underlying descriptor
- // is merely passed to the kernel which doesn't care about ownership and deals
- // with incorrect FDs, in the case of bugs on our part.
- let evt = unsafe {
- ManuallyDrop::new(Event::from_raw_descriptor(irqfd.as_raw_descriptor()))
- };
-
- match set_up_irq(IrqSetup::Event(irq_num, &evt)) {
+ match set_up_irq(IrqSetup::Event(irq_num, &irqfd)) {
Ok(_) => VmIrqResponse::AllocateOneMsi { gsi: irq_num },
Err(e) => VmIrqResponse::Err(e),
}
@@ -571,14 +547,14 @@ impl VmIrqRequest {
}
}
-#[derive(MsgOnSocket, Debug)]
+#[derive(Serialize, Deserialize, Debug)]
pub enum VmIrqResponse {
AllocateOneMsi { gsi: u32 },
Ok,
Err(SysError),
}
-#[derive(MsgOnSocket, Debug)]
+#[derive(Serialize, Deserialize, Debug)]
pub enum VmMsyncRequest {
/// Flush the content of a memory mapping to its backing file.
/// `slot` selects the arena (as returned by `Vm::add_mmap_arena`).
@@ -591,7 +567,7 @@ pub enum VmMsyncRequest {
},
}
-#[derive(MsgOnSocket, Debug)]
+#[derive(Serialize, Deserialize, Debug)]
pub enum VmMsyncResponse {
Ok,
Err(SysError),
@@ -617,7 +593,7 @@ impl VmMsyncRequest {
}
}
-#[derive(MsgOnSocket, Debug)]
+#[derive(Serialize, Deserialize, Debug)]
pub enum BatControlResult {
Ok,
NoBatDevice,
@@ -644,7 +620,7 @@ impl Display for BatControlResult {
}
}
-#[derive(MsgOnSocket, Copy, Clone, Debug, PartialEq)]
+#[derive(Serialize, Deserialize, Copy, Clone, Debug, PartialEq)]
pub enum BatteryType {
Goldfish,
}
@@ -666,7 +642,7 @@ impl FromStr for BatteryType {
}
}
-#[derive(MsgOnSocket, Debug)]
+#[derive(Serialize, Deserialize, Debug)]
pub enum BatProperty {
Status,
Health,
@@ -690,7 +666,7 @@ impl FromStr for BatProperty {
}
}
-#[derive(MsgOnSocket, Debug)]
+#[derive(Serialize, Deserialize, Debug)]
pub enum BatStatus {
Unknown,
Charging,
@@ -733,7 +709,7 @@ impl From<BatStatus> for u32 {
}
}
-#[derive(MsgOnSocket, Debug)]
+#[derive(Serialize, Deserialize, Debug)]
pub enum BatHealth {
Unknown,
Good,
@@ -773,7 +749,7 @@ impl From<BatHealth> for u32 {
}
}
-#[derive(MsgOnSocket, Debug)]
+#[derive(Serialize, Deserialize, Debug)]
pub enum BatControlCommand {
SetStatus(BatStatus),
SetHealth(BatHealth),
@@ -810,10 +786,10 @@ impl BatControlCommand {
/// Used for VM to control battery properties.
pub struct BatControl {
pub type_: BatteryType,
- pub control_socket: BatControlRequestSocket,
+ pub control_tube: Tube,
}
-#[derive(MsgOnSocket, Debug)]
+#[derive(Serialize, Deserialize, Debug)]
pub enum FsMappingRequest {
/// Create an anonymous memory mapping that spans the entire region described by `Alloc`.
AllocateSharedMemoryRegion(Alloc),
@@ -823,7 +799,7 @@ pub enum FsMappingRequest {
/// `AllocateSharedMemoryRegion` request.
slot: u32,
/// The file descriptor that should be mapped.
- fd: MaybeOwnedDescriptor,
+ fd: SafeDescriptor,
/// The size of the mapping.
size: usize,
/// The offset into the file from where the mapping should start.
@@ -918,37 +894,10 @@ impl FsMappingRequest {
}
}
}
-
-pub type BalloonControlRequestSocket = MsgSocket<BalloonControlCommand, BalloonControlResult>;
-pub type BalloonControlResponseSocket = MsgSocket<BalloonControlResult, BalloonControlCommand>;
-
-pub type BatControlRequestSocket = MsgSocket<BatControlCommand, BatControlResult>;
-pub type BatControlResponseSocket = MsgSocket<BatControlResult, BatControlCommand>;
-
-pub type DiskControlRequestSocket = MsgSocket<DiskControlCommand, DiskControlResult>;
-pub type DiskControlResponseSocket = MsgSocket<DiskControlResult, DiskControlCommand>;
-
-pub type FsMappingRequestSocket = MsgSocket<FsMappingRequest, VmResponse>;
-pub type FsMappingResponseSocket = MsgSocket<VmResponse, FsMappingRequest>;
-
-pub type UsbControlSocket = MsgSocket<UsbControlCommand, UsbControlResult>;
-
-pub type VmMemoryControlRequestSocket = MsgSocket<VmMemoryRequest, VmMemoryResponse>;
-pub type VmMemoryControlResponseSocket = MsgSocket<VmMemoryResponse, VmMemoryRequest>;
-
-pub type VmIrqRequestSocket = MsgSocket<VmIrqRequest, VmIrqResponse>;
-pub type VmIrqResponseSocket = MsgSocket<VmIrqResponse, VmIrqRequest>;
-
-pub type VmMsyncRequestSocket = MsgSocket<VmMsyncRequest, VmMsyncResponse>;
-pub type VmMsyncResponseSocket = MsgSocket<VmMsyncResponse, VmMsyncRequest>;
-
-pub type VmControlRequestSocket = MsgSocket<VmRequest, VmResponse>;
-pub type VmControlResponseSocket = MsgSocket<VmResponse, VmRequest>;
-
/// A request to the main process to perform some operation on the VM.
///
/// Unless otherwise noted, each request should expect a `VmResponse::Ok` to be received on success.
-#[derive(MsgOnSocket, Debug)]
+#[derive(Serialize, Deserialize, Debug)]
pub enum VmRequest {
/// Break the VM's run loop and exit.
Exit,
@@ -1005,7 +954,7 @@ fn register_memory(
Ok((addr >> 12, slot))
}
-fn register_memory_hva(
+fn register_host_pointer(
vm: &mut impl Vm,
allocator: &mut SystemAllocator,
mem: Box<dyn MappedRegion>,
@@ -1029,9 +978,9 @@ impl VmRequest {
pub fn execute(
&self,
run_mode: &mut Option<VmRunMode>,
- balloon_host_socket: &BalloonControlRequestSocket,
- disk_host_sockets: &[DiskControlRequestSocket],
- usb_control_socket: &UsbControlSocket,
+ balloon_host_tube: &Tube,
+ disk_host_tubes: &[Tube],
+ usb_control_tube: &Tube,
bat_control: &mut Option<BatControl>,
) -> VmResponse {
match *self {
@@ -1048,14 +997,14 @@ impl VmRequest {
VmResponse::Ok
}
VmRequest::BalloonCommand(BalloonControlCommand::Adjust { num_bytes }) => {
- match balloon_host_socket.send(&BalloonControlCommand::Adjust { num_bytes }) {
+ match balloon_host_tube.send(&BalloonControlCommand::Adjust { num_bytes }) {
Ok(_) => VmResponse::Ok,
Err(_) => VmResponse::Err(SysError::last()),
}
}
VmRequest::BalloonCommand(BalloonControlCommand::Stats) => {
- match balloon_host_socket.send(&BalloonControlCommand::Stats {}) {
- Ok(_) => match balloon_host_socket.recv() {
+ match balloon_host_tube.send(&BalloonControlCommand::Stats {}) {
+ Ok(_) => match balloon_host_tube.recv() {
Ok(BalloonControlResult::Stats {
stats,
balloon_actual,
@@ -1076,7 +1025,7 @@ impl VmRequest {
ref command,
} => {
// Forward the request to the block device process via its control socket.
- if let Some(sock) = disk_host_sockets.get(disk_index) {
+ if let Some(sock) = disk_host_tubes.get(disk_index) {
if let Err(e) = sock.send(command) {
error!("disk socket send failed: {}", e);
VmResponse::Err(SysError::new(EINVAL))
@@ -1095,12 +1044,12 @@ impl VmRequest {
}
}
VmRequest::UsbCommand(ref cmd) => {
- let res = usb_control_socket.send(cmd);
+ let res = usb_control_tube.send(cmd);
if let Err(e) = res {
error!("fail to send command to usb control socket: {}", e);
return VmResponse::Err(SysError::new(EIO));
}
- match usb_control_socket.recv() {
+ match usb_control_tube.recv() {
Ok(response) => VmResponse::UsbResponse(response),
Err(e) => {
error!("fail to recv command from usb control socket: {}", e);
@@ -1116,13 +1065,13 @@ impl VmRequest {
return VmResponse::Err(SysError::new(EINVAL));
}
- let res = battery.control_socket.send(cmd);
+ let res = battery.control_tube.send(cmd);
if let Err(e) = res {
error!("fail to send command to bat control socket: {}", e);
return VmResponse::Err(SysError::new(EIO));
}
- match battery.control_socket.recv() {
+ match battery.control_tube.recv() {
Ok(response) => VmResponse::BatResponse(response),
Err(e) => {
error!("fail to recv command from bat control socket: {}", e);
@@ -1140,7 +1089,7 @@ impl VmRequest {
/// Indication of success or failure of a `VmRequest`.
///
/// Success is usually indicated `VmResponse::Ok` unless there is data associated with the response.
-#[derive(MsgOnSocket, Debug)]
+#[derive(Serialize, Deserialize, Debug)]
pub enum VmResponse {
/// Indicates the request was executed successfully.
Ok,
@@ -1152,7 +1101,7 @@ pub enum VmResponse {
/// The request to allocate and register GPU memory into guest address space was successfully
/// done at page frame number `pfn` and memory slot number `slot` for buffer with `desc`.
AllocateAndRegisterGpuMemory {
- descriptor: MaybeOwnedDescriptor,
+ descriptor: SafeDescriptor,
pfn: u64,
slot: u32,
desc: GpuMemoryDesc,
@@ -1185,7 +1134,7 @@ impl Display for VmResponse {
"gpu memory allocated and registered to page frame number {:#x} and memory slot {}",
pfn, slot
),
- BalloonStats {
+ VmResponse::BalloonStats {
stats,
balloon_actual,
} => write!(
diff --git a/vm_memory/Cargo.toml b/vm_memory/Cargo.toml
index df599aacc..631ad9b84 100644
--- a/vm_memory/Cargo.toml
+++ b/vm_memory/Cargo.toml
@@ -10,6 +10,6 @@ cros_async = { path = "../cros_async" } # provided by ebuild
data_model = { path = "../data_model" } # provided by ebuild
libc = "*"
base = { path = "../base" } # provided by ebuild
-syscall_defines = { path = "../syscall_defines" } # provided by ebuild
+bitflags = "1"
[workspace]
diff --git a/vm_memory/src/guest_memory.rs b/vm_memory/src/guest_memory.rs
index b160968b8..62ba134e7 100644
--- a/vm_memory/src/guest_memory.rs
+++ b/vm_memory/src/guest_memory.rs
@@ -14,17 +14,21 @@ use std::sync::Arc;
use crate::guest_address::GuestAddress;
use base::{pagesize, Error as SysError};
use base::{
- AsRawDescriptor, MappedRegion, MemfdSeals, MemoryMapping, MemoryMappingBuilder,
- MemoryMappingUnix, MmapError, RawDescriptor, SharedMemory, SharedMemoryUnix,
+ AsRawDescriptor, AsRawDescriptors, MappedRegion, MemfdSeals, MemoryMapping,
+ MemoryMappingBuilder, MemoryMappingUnix, MmapError, RawDescriptor, SharedMemory,
+ SharedMemoryUnix,
};
use cros_async::{mem, BackingMemory};
use data_model::volatile_memory::*;
use data_model::DataInit;
+use bitflags::bitflags;
+
#[derive(Debug)]
pub enum Error {
DescriptorChainOverflow,
InvalidGuestAddress(GuestAddress),
+ InvalidOffset(u64),
MemoryAccess(GuestAddress, MmapError),
MemoryMappingFailed(MmapError),
MemoryRegionOverlap,
@@ -51,6 +55,7 @@ impl Display for Error {
"the combined length of all the buffers in a DescriptorChain is too large"
),
InvalidGuestAddress(addr) => write!(f, "invalid guest address {}", addr),
+ InvalidOffset(addr) => write!(f, "invalid offset {}", addr),
MemoryAccess(addr, e) => {
write!(f, "invalid guest memory access at addr={}: {}", addr, e)
}
@@ -82,10 +87,17 @@ impl Display for Error {
}
}
+bitflags! {
+ pub struct MemoryPolicy: u32 {
+ const USE_HUGEPAGES = 1;
+ }
+}
+
struct MemoryRegion {
mapping: MemoryMapping,
guest_base: GuestAddress,
- memfd_offset: u64,
+ shm_offset: u64,
+ shm: Arc<SharedMemory>,
}
impl MemoryRegion {
@@ -108,24 +120,20 @@ impl MemoryRegion {
#[derive(Clone)]
pub struct GuestMemory {
regions: Arc<[MemoryRegion]>,
- shm: Arc<SharedMemory>,
-}
-
-impl AsRawDescriptor for GuestMemory {
- fn as_raw_descriptor(&self) -> RawDescriptor {
- self.shm.as_raw_descriptor()
- }
}
-impl AsRef<SharedMemory> for GuestMemory {
- fn as_ref(&self) -> &SharedMemory {
- &self.shm
+impl AsRawDescriptors for GuestMemory {
+ fn as_raw_descriptors(&self) -> Vec<RawDescriptor> {
+ self.regions
+ .iter()
+ .map(|r| r.shm.as_raw_descriptor())
+ .collect()
}
}
impl GuestMemory {
/// Creates backing shm for GuestMemory regions
- fn create_memfd(ranges: &[(GuestAddress, u64)]) -> Result<SharedMemory> {
+ fn create_shm(ranges: &[(GuestAddress, u64)]) -> Result<SharedMemory> {
let mut aligned_size = 0;
let pg_size = pagesize();
for range in ranges {
@@ -154,7 +162,7 @@ impl GuestMemory {
pub fn new(ranges: &[(GuestAddress, u64)]) -> Result<GuestMemory> {
// Create shm
- let shm = GuestMemory::create_memfd(ranges)?;
+ let shm = Arc::new(GuestMemory::create_shm(ranges)?);
// Create memory regions
let mut regions = Vec::<MemoryRegion>::new();
let mut offset = 0;
@@ -173,14 +181,15 @@ impl GuestMemory {
let size =
usize::try_from(range.1).map_err(|_| Error::MemoryRegionTooLarge(range.1))?;
let mapping = MemoryMappingBuilder::new(size)
- .from_descriptor(&shm)
+ .from_shared_memory(shm.as_ref())
.offset(offset)
.build()
.map_err(Error::MemoryMappingFailed)?;
regions.push(MemoryRegion {
mapping,
guest_base: range.0,
- memfd_offset: offset,
+ shm_offset: offset,
+ shm: Arc::clone(&shm),
});
offset += size as u64;
@@ -188,7 +197,6 @@ impl GuestMemory {
Ok(GuestMemory {
regions: Arc::from(regions),
- shm: Arc::new(shm),
})
}
@@ -252,13 +260,27 @@ impl GuestMemory {
/// Madvise away the address range in the host that is associated with the given guest range.
pub fn remove_range(&self, addr: GuestAddress, count: u64) -> Result<()> {
- self.do_in_region(addr, move |mapping, offset| {
+ self.do_in_region(addr, move |mapping, offset, _| {
mapping
.remove_range(offset, count as usize)
.map_err(|e| Error::MemoryAccess(addr, e))
})
}
+ /// Handles guest memory policy hints/advices.
+ pub fn set_memory_policy(&self, mem_policy: MemoryPolicy) {
+ if mem_policy.contains(MemoryPolicy::USE_HUGEPAGES) {
+ for (_, region) in self.regions.iter().enumerate() {
+ let ret = region.mapping.use_hugepages();
+
+ match ret {
+ Err(err) => println!("Failed to enable HUGEPAGE for mapping {}", err),
+ Ok(_) => (),
+ }
+ }
+ }
+ }
+
/// Perform the specified action on each region's addresses.
///
/// Callback is called with arguments:
@@ -266,10 +288,11 @@ impl GuestMemory {
/// * guest_addr : GuestAddress
/// * size: usize
/// * host_addr: usize
- /// * memfd_offset: usize
+ /// * shm: SharedMemory backing for the given region
+ /// * shm_offset: usize
pub fn with_regions<F, E>(&self, mut cb: F) -> result::Result<(), E>
where
- F: FnMut(usize, GuestAddress, usize, usize, u64) -> result::Result<(), E>,
+ F: FnMut(usize, GuestAddress, usize, usize, &SharedMemory, u64) -> result::Result<(), E>,
{
for (index, region) in self.regions.iter().enumerate() {
cb(
@@ -277,7 +300,8 @@ impl GuestMemory {
region.start(),
region.mapping.size(),
region.mapping.as_ptr() as usize,
- region.memfd_offset,
+ region.shm.as_ref(),
+ region.shm_offset,
)?;
}
Ok(())
@@ -303,7 +327,7 @@ impl GuestMemory {
/// # }
/// ```
pub fn write_at_addr(&self, buf: &[u8], guest_addr: GuestAddress) -> Result<usize> {
- self.do_in_region(guest_addr, move |mapping, offset| {
+ self.do_in_region(guest_addr, move |mapping, offset, _| {
mapping
.write_slice(buf, offset)
.map_err(|e| Error::MemoryAccess(guest_addr, e))
@@ -362,7 +386,7 @@ impl GuestMemory {
/// # }
/// ```
pub fn read_at_addr(&self, buf: &mut [u8], guest_addr: GuestAddress) -> Result<usize> {
- self.do_in_region(guest_addr, move |mapping, offset| {
+ self.do_in_region(guest_addr, move |mapping, offset, _| {
mapping
.read_slice(buf, offset)
.map_err(|e| Error::MemoryAccess(guest_addr, e))
@@ -422,7 +446,7 @@ impl GuestMemory {
/// # }
/// ```
pub fn read_obj_from_addr<T: DataInit>(&self, guest_addr: GuestAddress) -> Result<T> {
- self.do_in_region(guest_addr, |mapping, offset| {
+ self.do_in_region(guest_addr, |mapping, offset, _| {
mapping
.read_obj(offset)
.map_err(|e| Error::MemoryAccess(guest_addr, e))
@@ -446,7 +470,7 @@ impl GuestMemory {
/// # }
/// ```
pub fn write_obj_at_addr<T: DataInit>(&self, val: T, guest_addr: GuestAddress) -> Result<()> {
- self.do_in_region(guest_addr, move |mapping, offset| {
+ self.do_in_region(guest_addr, move |mapping, offset, _| {
mapping
.write_obj(val, offset)
.map_err(|e| Error::MemoryAccess(guest_addr, e))
@@ -543,7 +567,7 @@ impl GuestMemory {
src: &dyn AsRawDescriptor,
count: usize,
) -> Result<()> {
- self.do_in_region(guest_addr, move |mapping, offset| {
+ self.do_in_region(guest_addr, move |mapping, offset, _| {
mapping
.read_to_memory(offset, src, count)
.map_err(|e| Error::MemoryAccess(guest_addr, e))
@@ -581,7 +605,7 @@ impl GuestMemory {
dst: &dyn AsRawDescriptor,
count: usize,
) -> Result<()> {
- self.do_in_region(guest_addr, move |mapping, offset| {
+ self.do_in_region(guest_addr, move |mapping, offset, _| {
mapping
.write_from_memory(offset, dst, count)
.map_err(|e| Error::MemoryAccess(guest_addr, e))
@@ -609,16 +633,42 @@ impl GuestMemory {
/// # }
/// ```
pub fn get_host_address(&self, guest_addr: GuestAddress) -> Result<*const u8> {
- self.do_in_region(guest_addr, |mapping, offset| {
+ self.do_in_region(guest_addr, |mapping, offset, _| {
// This is safe; `do_in_region` already checks that offset is in
// bounds.
Ok(unsafe { mapping.as_ptr().add(offset) } as *const u8)
})
}
+ /// Returns a reference to the SharedMemory region that backs the given address.
+ pub fn shm_region(&self, guest_addr: GuestAddress) -> Result<&SharedMemory> {
+ self.regions
+ .iter()
+ .find(|region| region.contains(guest_addr))
+ .ok_or(Error::InvalidGuestAddress(guest_addr))
+ .map(|region| region.shm.as_ref())
+ }
+
+ /// Returns the region that contains the memory at `offset` from the base of guest memory.
+ pub fn offset_region(&self, offset: u64) -> Result<&SharedMemory> {
+ self.shm_region(
+ self.checked_offset(self.regions[0].guest_base, offset)
+ .ok_or(Error::InvalidOffset(offset))?,
+ )
+ }
+
+ /// Loops over all guest memory regions of `self`, and performs the callback function `F` in
+ /// the target region that contains `guest_addr`. The callback function `F` takes in:
+ ///
+ /// (i) the memory mapping associated with the target region.
+ /// (ii) the relative offset from the start of the target region to `guest_addr`.
+ /// (iii) the absolute offset from the start of the memory mapping to the target region.
+ ///
+ /// If no target region is found, an error is returned. The callback function `F` may return
+ /// an Ok(`T`) on success or a `GuestMemoryError` on failure.
pub fn do_in_region<F, T>(&self, guest_addr: GuestAddress, cb: F) -> Result<T>
where
- F: FnOnce(&MemoryMapping, usize) -> Result<T>,
+ F: FnOnce(&MemoryMapping, usize, u64) -> Result<T>,
{
self.regions
.iter()
@@ -628,11 +678,12 @@ impl GuestMemory {
cb(
&region.mapping,
guest_addr.offset_from(region.start()) as usize,
+ region.shm_offset,
)
})
}
- /// Convert a GuestAddress into an offset within self.shm.
+ /// Convert a GuestAddress into an offset within the associated shm region.
///
/// Due to potential gaps within GuestMemory, it is helpful to know the
/// offset within the shm where a given address is found. This offset
@@ -660,7 +711,7 @@ impl GuestMemory {
.iter()
.find(|region| region.contains(guest_addr))
.ok_or(Error::InvalidGuestAddress(guest_addr))
- .map(|region| region.memfd_offset + guest_addr.offset_from(region.start()))
+ .map(|region| region.shm_offset + guest_addr.offset_from(region.start()))
}
}
@@ -800,7 +851,7 @@ mod tests {
// Get the base address of the mapping for a GuestAddress.
fn get_mapping(mem: &GuestMemory, addr: GuestAddress) -> Result<*const u8> {
- mem.do_in_region(addr, |mapping, _| Ok(mapping.as_ptr() as *const u8))
+ mem.do_in_region(addr, |mapping, _, _| Ok(mapping.as_ptr() as *const u8))
}
#[test]
@@ -823,7 +874,7 @@ mod tests {
}
#[test]
- fn memfd_offset() {
+ fn shm_offset() {
if !kernel_has_memfd() {
return;
}
@@ -839,10 +890,10 @@ mod tests {
gm.write_obj_at_addr(0x0420u16, GuestAddress(0x10000))
.unwrap();
- let _ = gm.with_regions::<_, ()>(|index, _, size, _, memfd_offset| {
+ let _ = gm.with_regions::<_, ()>(|index, _, size, _, shm, shm_offset| {
let mmap = MemoryMappingBuilder::new(size)
- .from_descriptor(gm.as_ref())
- .offset(memfd_offset)
+ .from_shared_memory(shm)
+ .offset(shm_offset)
.build()
.unwrap();
diff --git a/x86_64/Android.bp b/x86_64/Android.bp
index e7dc658ed..8adb07a85 100644
--- a/x86_64/Android.bp
+++ b/x86_64/Android.bp
@@ -1,5 +1,4 @@
-// This file is generated by cargo2android.py --run --device --tests --dependencies --global_defaults=crosvm_defaults --add_workspace --features=gdb.
-// NOTE: The --features=gdb should be applied only to the host (not the device) and there are inline changes to achieve this
+// This file is generated by cargo2android.py --run --device --tests --dependencies --global_defaults=crosvm_defaults --add_workspace.
package {
// See: http://go/android-license-faq
@@ -18,19 +17,6 @@ rust_library {
crate_name: "x86_64",
srcs: ["src/lib.rs"],
edition: "2018",
- target: {
- linux_glibc_x86_64: {
- features: [
- "gdb",
- "gdbstub",
- "msg_socket",
- ],
- rustlibs: [
- "libgdbstub",
- "libmsg_socket",
- ],
- },
- },
rustlibs: [
"libacpi_tables",
"libarch",
@@ -68,18 +54,6 @@ rust_defaults {
test_suites: ["general-tests"],
auto_gen_config: true,
edition: "2018",
- target: {
- linux_glibc_x86_64: {
- features: [
- "gdb",
- "gdbstub",
- "msg_socket",
- ],
- rustlibs: [
- "libgdbstub",
- ],
- },
- },
rustlibs: [
"libacpi_tables",
"libarch",
@@ -133,7 +107,7 @@ rust_test {
// ../../vm_tools/p9/src/lib.rs
// ../../vm_tools/p9/wire_format_derive/wire_format_derive.rs
// ../acpi_tables/src/lib.rs
-// ../arch/src/lib.rs "gdb,gdbstub"
+// ../arch/src/lib.rs
// ../assertions/src/lib.rs
// ../base/src/lib.rs
// ../bit_field/bit_field_derive/bit_field_derive.rs
@@ -172,11 +146,10 @@ rust_test {
// ../vm_control/src/lib.rs
// ../vm_memory/src/lib.rs
// async-task-4.0.3 "default,std"
-// async-trait-0.1.48
+// async-trait-0.1.45
// autocfg-1.0.1
// base-0.1.0
// bitflags-1.2.1 "default"
-// cfg-if-0.1.10
// cfg-if-1.0.0
// downcast-rs-1.2.0 "default,std"
// futures-0.3.13 "alloc,async-await,default,executor,futures-executor,std"
@@ -188,15 +161,12 @@ rust_test {
// futures-sink-0.3.13 "alloc,std"
// futures-task-0.3.13 "alloc,std"
// futures-util-0.3.13 "alloc,async-await,async-await-macro,channel,futures-channel,futures-io,futures-macro,futures-sink,io,memchr,proc-macro-hack,proc-macro-nested,sink,slab,std"
-// gdbstub-0.4.4 "alloc,default,std"
// getrandom-0.2.2 "std"
// intrusive-collections-0.9.0 "alloc,default"
-// libc-0.2.88 "default,std"
+// libc-0.2.87 "default,std"
// log-0.4.14
-// managed-0.8.0 "alloc"
// memchr-2.3.4 "default,std"
// memoffset-0.5.6 "default"
-// num-traits-0.2.14
// paste-1.0.4
// pin-project-lite-0.2.6
// pin-utils-0.1.0
@@ -212,10 +182,10 @@ rust_test {
// rand_core-0.6.2 "alloc,getrandom,std"
// remain-0.2.2
// remove_dir_all-0.5.3
-// serde-1.0.124 "default,derive,serde_derive,std"
-// serde_derive-1.0.124 "default"
+// serde-1.0.123 "default,derive,serde_derive,std"
+// serde_derive-1.0.123 "default"
// slab-0.4.2
-// syn-1.0.63 "clone-impls,default,derive,full,parsing,printing,proc-macro,quote,visit-mut"
+// syn-1.0.61 "clone-impls,default,derive,full,parsing,printing,proc-macro,quote,visit-mut"
// tempfile-3.2.0
// thiserror-1.0.24
// thiserror-impl-1.0.24
diff --git a/x86_64/Cargo.toml b/x86_64/Cargo.toml
index fcfb0f15f..07fac9dc9 100644
--- a/x86_64/Cargo.toml
+++ b/x86_64/Cargo.toml
@@ -5,7 +5,7 @@ authors = ["The Chromium OS Authors"]
edition = "2018"
[features]
-gdb = ["gdbstub", "msg_socket", "arch/gdb"]
+gdb = ["gdbstub", "arch/gdb"]
[dependencies]
arch = { path = "../arch" }
@@ -18,7 +18,6 @@ kernel_cmdline = { path = "../kernel_cmdline" }
kernel_loader = { path = "../kernel_loader" }
libc = "*"
minijail = { path = "../../minijail/rust/minijail" } # ignored by ebuild
-msg_socket = { path = "../msg_socket", optional = true }
remain = "*"
resources = { path = "../resources" }
sync = { path = "../sync" }
@@ -26,6 +25,3 @@ base = { path = "../base" }
acpi_tables = {path = "../acpi_tables" }
vm_control = { path = "../vm_control" }
vm_memory = { path = "../vm_memory" }
-
-[dev-dependencies]
-msg_socket = { path = "../msg_socket"}
diff --git a/x86_64/src/cpuid.rs b/x86_64/src/cpuid.rs
index d965a2f51..6b31d6012 100644
--- a/x86_64/src/cpuid.rs
+++ b/x86_64/src/cpuid.rs
@@ -50,7 +50,7 @@ fn filter_cpuid(
cpuid: &mut hypervisor::CpuId,
irq_chip: &dyn IrqChipX86_64,
no_smt: bool,
-) -> Result<()> {
+) {
let entries = &mut cpuid.cpu_id_entries;
for entry in entries {
@@ -142,8 +142,6 @@ fn filter_cpuid(
_ => (),
}
}
-
- Ok(())
}
/// Sets up the cpuid entries for the given vcpu. Can fail if there are too many CPUs specified or
@@ -167,7 +165,7 @@ pub fn setup_cpuid(
.get_supported_cpuid()
.map_err(Error::GetSupportedCpusFailed)?;
- filter_cpuid(vcpu_id, nrcpus, &mut cpuid, irq_chip, no_smt)?;
+ filter_cpuid(vcpu_id, nrcpus, &mut cpuid, irq_chip, no_smt);
vcpu.set_cpuid(&cpuid)
.map_err(Error::SetSupportedCpusFailed)
@@ -211,7 +209,7 @@ mod tests {
edx: 0,
..Default::default()
});
- assert_eq!(Ok(()), filter_cpuid(1, 2, &mut cpuid, &irq_chip, false));
+ filter_cpuid(1, 2, &mut cpuid, &irq_chip, false);
let entries = &mut cpuid.cpu_id_entries;
assert_eq!(entries[0].function, 0);
diff --git a/x86_64/src/lib.rs b/x86_64/src/lib.rs
index da1052d78..2de945ae8 100644
--- a/x86_64/src/lib.rs
+++ b/x86_64/src/lib.rs
@@ -349,13 +349,23 @@ fn arch_memory_regions(size: u64, bios_size: Option<u64>) -> Vec<(GuestAddress,
impl arch::LinuxArch for X8664arch {
type Error = Error;
- fn build_vm<V, Vcpu, I, FD, FV, FI, E1, E2, E3>(
+ fn guest_memory_layout(
+ components: &VmComponents,
+ ) -> std::result::Result<Vec<(GuestAddress, u64)>, Self::Error> {
+ let bios_size = match &components.vm_image {
+ VmImage::Bios(bios_file) => Some(bios_file.metadata().map_err(Error::LoadBios)?.len()),
+ VmImage::Kernel(_) => None,
+ };
+ Ok(arch_memory_regions(components.memory_size, bios_size))
+ }
+
+ fn build_vm<V, Vcpu, I, FD, FI, E1, E2>(
mut components: VmComponents,
serial_parameters: &BTreeMap<(SerialHardware, u8), SerialParameters>,
serial_jail: Option<Minijail>,
battery: (&Option<BatteryType>, Option<Minijail>),
+ mut vm: V,
create_devices: FD,
- create_vm: FV,
create_irq_chip: FI,
) -> std::result::Result<RunnableLinuxVm<V, Vcpu, I>, Self::Error>
where
@@ -368,28 +378,18 @@ impl arch::LinuxArch for X8664arch {
&mut SystemAllocator,
&Event,
) -> std::result::Result<Vec<(Box<dyn PciDevice>, Option<Minijail>)>, E1>,
- FV: FnOnce(GuestMemory) -> std::result::Result<V, E2>,
- FI: FnOnce(&V, /* vcpu_count: */ usize) -> std::result::Result<I, E3>,
+ FI: FnOnce(&V, /* vcpu_count: */ usize) -> std::result::Result<I, E2>,
E1: StdError + 'static,
E2: StdError + 'static,
- E3: StdError + 'static,
{
if components.protected_vm != ProtectionType::Unprotected {
return Err(Error::UnsupportedProtectionType);
}
- let bios_size = match components.vm_image {
- VmImage::Bios(ref mut bios_file) => {
- Some(bios_file.metadata().map_err(Error::LoadBios)?.len())
- }
- VmImage::Kernel(_) => None,
- };
- let has_bios = bios_size.is_some();
- let mem = Self::setup_memory(components.memory_size, bios_size)?;
+ let mem = vm.get_memory().clone();
let mut resources = Self::get_resource_allocator(&mem);
let vcpu_count = components.vcpu_count;
- let mut vm = create_vm(mem.clone()).map_err(|e| Error::CreateVm(Box::new(e)))?;
let mut irq_chip =
create_irq_chip(&vm, vcpu_count).map_err(|e| Error::CreateIrqChip(Box::new(e)))?;
@@ -464,7 +464,8 @@ impl arch::LinuxArch for X8664arch {
// Note that this puts the mptable at 0x9FC00 in guest physical memory.
mptable::setup_mptable(&mem, vcpu_count as u8, pci_irqs).map_err(Error::SetupMptable)?;
- smbios::setup_smbios(&mem).map_err(Error::SetupSmbios)?;
+ smbios::setup_smbios(&mem, components.dmi_path).map_err(Error::SetupSmbios)?;
+
// TODO (tjeznach) Write RSDP to bootconfig before writing to memory
acpi::create_acpi_tables(&mem, vcpu_count as u8, X86_64_SCI_IRQ, acpi_dev_resource);
@@ -523,7 +524,7 @@ impl arch::LinuxArch for X8664arch {
vcpu_affinity: components.vcpu_affinity,
no_smt: components.no_smt,
irq_chip,
- has_bios,
+ has_bios: matches!(components.vm_image, VmImage::Bios(_)),
io_bus,
mmio_bus,
pid_debug_label_map,
@@ -929,15 +930,6 @@ impl X8664arch {
Ok(())
}
- /// This creates a GuestMemory object for this VM
- ///
- /// * `mem_size` - Desired physical memory size in bytes for this VM
- fn setup_memory(mem_size: u64, bios_size: Option<u64>) -> Result<GuestMemory> {
- let arch_mem_regions = arch_memory_regions(mem_size, bios_size);
- let mem = GuestMemory::new(&arch_mem_regions).map_err(Error::SetupGuestMemory)?;
- Ok(mem)
- }
-
/// This returns the start address of high mmio
///
/// # Arguments
@@ -1092,7 +1084,7 @@ impl X8664arch {
let bat_control = if let Some(battery_type) = battery.0 {
match battery_type {
BatteryType::Goldfish => {
- let control_socket = arch::add_goldfish_battery(
+ let control_tube = arch::add_goldfish_battery(
&mut amls,
battery.1,
mmio_bus,
@@ -1103,7 +1095,7 @@ impl X8664arch {
.map_err(Error::CreateBatDevices)?;
Some(BatControl {
type_: BatteryType::Goldfish,
- control_socket,
+ control_tube,
})
}
}
diff --git a/x86_64/src/smbios.rs b/x86_64/src/smbios.rs
index 1d6622ed0..6fd8774e1 100644
--- a/x86_64/src/smbios.rs
+++ b/x86_64/src/smbios.rs
@@ -7,6 +7,10 @@ use std::mem;
use std::result;
use std::slice;
+use std::fs::OpenOptions;
+use std::io::prelude::*;
+use std::path::{Path, PathBuf};
+
use data_model::DataInit;
use vm_memory::{GuestAddress, GuestMemory};
@@ -22,6 +26,12 @@ pub enum Error {
WriteSmbiosEp,
/// Failure to write additional data to memory
WriteData,
+ /// Failure while reading SMBIOS data file
+ IoFailed,
+ /// Incorrect or not readable host SMBIOS data
+ InvalidInput,
+ /// Invalid table entry point checksum
+ InvalidChecksum,
}
impl std::error::Error for Error {}
@@ -36,6 +46,9 @@ impl Display for Error {
Clear => "Failure while zeroing out the memory for the SMBIOS table",
WriteSmbiosEp => "Failure to write SMBIOS entrypoint structure",
WriteData => "Failure to write additional data to memory",
+ IoFailed => "Failure while reading SMBIOS data file",
+ InvalidInput => "Failure to read host SMBIOS data",
+ InvalidChecksum => "Failure to verify host SMBIOS entry checksum",
};
write!(f, "SMBIOS error: {}", description)
@@ -46,10 +59,14 @@ pub type Result<T> = result::Result<T, Error>;
const SMBIOS_START: u64 = 0xf0000; // First possible location per the spec.
+// Constants sourced from SMBIOS Spec 2.3.1.
+const SM2_MAGIC_IDENT: &[u8; 4usize] = b"_SM_";
+
// Constants sourced from SMBIOS Spec 3.2.0.
const SM3_MAGIC_IDENT: &[u8; 5usize] = b"_SM3_";
const BIOS_INFORMATION: u8 = 0;
const SYSTEM_INFORMATION: u8 = 1;
+const END_OF_TABLE: u8 = 127;
const PCI_SUPPORTED: u64 = 1 << 7;
const IS_VIRTUAL_MACHINE: u8 = 1 << 4;
@@ -65,6 +82,47 @@ fn compute_checksum<T: Copy>(v: &T) -> u8 {
#[repr(packed)]
#[derive(Default, Copy)]
+pub struct Smbios23Intermediate {
+ pub signature: [u8; 5usize],
+ pub checksum: u8,
+ pub length: u16,
+ pub address: u32,
+ pub count: u16,
+ pub revision: u8,
+}
+
+unsafe impl data_model::DataInit for Smbios23Intermediate {}
+
+impl Clone for Smbios23Intermediate {
+ fn clone(&self) -> Self {
+ *self
+ }
+}
+
+#[repr(packed)]
+#[derive(Default, Copy)]
+pub struct Smbios23Entrypoint {
+ pub signature: [u8; 4usize],
+ pub checksum: u8,
+ pub length: u8,
+ pub majorver: u8,
+ pub minorver: u8,
+ pub max_size: u16,
+ pub revision: u8,
+ pub reserved: [u8; 5usize],
+ pub dmi: Smbios23Intermediate,
+}
+
+unsafe impl data_model::DataInit for Smbios23Entrypoint {}
+
+impl Clone for Smbios23Entrypoint {
+ fn clone(&self) -> Self {
+ *self
+ }
+}
+
+#[repr(packed)]
+#[derive(Default, Copy)]
pub struct Smbios30Entrypoint {
pub signature: [u8; 5usize],
pub checksum: u8,
@@ -154,7 +212,83 @@ fn write_string(mem: &GuestMemory, val: &str, mut curptr: GuestAddress) -> Resul
Ok(curptr)
}
-pub fn setup_smbios(mem: &GuestMemory) -> Result<()> {
+fn setup_smbios_from_file(mem: &GuestMemory, path: &Path) -> Result<()> {
+ let mut sme_path = PathBuf::from(path);
+ sme_path.push("smbios_entry_point");
+ let mut sme = Vec::new();
+ OpenOptions::new()
+ .read(true)
+ .open(&sme_path)
+ .map_err(|_| Error::IoFailed)?
+ .read_to_end(&mut sme)
+ .map_err(|_| Error::IoFailed)?;
+
+ let mut dmi_path = PathBuf::from(path);
+ dmi_path.push("DMI");
+ let mut dmi = Vec::new();
+ OpenOptions::new()
+ .read(true)
+ .open(&dmi_path)
+ .map_err(|_| Error::IoFailed)?
+ .read_to_end(&mut dmi)
+ .map_err(|_| Error::IoFailed)?;
+
+ // Try SMBIOS 3.0 format.
+ if sme.len() == mem::size_of::<Smbios30Entrypoint>() && sme.starts_with(SM3_MAGIC_IDENT) {
+ let mut smbios_ep = Smbios30Entrypoint::default();
+ smbios_ep.as_mut_slice().copy_from_slice(&sme);
+
+ let physptr = GuestAddress(SMBIOS_START)
+ .checked_add(mem::size_of::<Smbios30Entrypoint>() as u64)
+ .ok_or(Error::NotEnoughMemory)?;
+
+ mem.write_at_addr(&dmi, physptr)
+ .map_err(|_| Error::NotEnoughMemory)?;
+
+ // Update EP DMI location
+ smbios_ep.physptr = physptr.offset();
+ smbios_ep.checksum = 0;
+ smbios_ep.checksum = compute_checksum(&smbios_ep);
+
+ mem.write_obj_at_addr(smbios_ep, GuestAddress(SMBIOS_START))
+ .map_err(|_| Error::NotEnoughMemory)?;
+
+ return Ok(());
+ }
+
+ // Try SMBIOS 2.3 format.
+ if sme.len() == mem::size_of::<Smbios23Entrypoint>() && sme.starts_with(SM2_MAGIC_IDENT) {
+ let mut smbios_ep = Smbios23Entrypoint::default();
+ smbios_ep.as_mut_slice().copy_from_slice(&sme);
+
+ let physptr = GuestAddress(SMBIOS_START)
+ .checked_add(mem::size_of::<Smbios23Entrypoint>() as u64)
+ .ok_or(Error::NotEnoughMemory)?;
+
+ mem.write_at_addr(&dmi, physptr)
+ .map_err(|_| Error::NotEnoughMemory)?;
+
+ // Update EP DMI location
+ smbios_ep.dmi.address = physptr.offset() as u32;
+ smbios_ep.dmi.checksum = 0;
+ smbios_ep.dmi.checksum = compute_checksum(&smbios_ep.dmi);
+ smbios_ep.checksum = 0;
+ smbios_ep.checksum = compute_checksum(&smbios_ep);
+
+ mem.write_obj_at_addr(smbios_ep, GuestAddress(SMBIOS_START))
+ .map_err(|_| Error::WriteSmbiosEp)?;
+
+ return Ok(());
+ }
+
+ Err(Error::InvalidInput)
+}
+
+pub fn setup_smbios(mem: &GuestMemory, dmi_path: Option<PathBuf>) -> Result<()> {
+ if let Some(dmi_path) = dmi_path {
+ return setup_smbios_from_file(mem, &dmi_path);
+ }
+
let physptr = GuestAddress(SMBIOS_START)
.checked_add(mem::size_of::<Smbios30Entrypoint>() as u64)
.ok_or(Error::NotEnoughMemory)?;
@@ -196,6 +330,18 @@ pub fn setup_smbios(mem: &GuestMemory) -> Result<()> {
}
{
+ handle += 1;
+ let smbios_sysinfo = SmbiosSysInfo {
+ typ: END_OF_TABLE,
+ length: mem::size_of::<SmbiosSysInfo>() as u8,
+ handle,
+ ..Default::default()
+ };
+ curptr = write_and_incr(mem, smbios_sysinfo, curptr)?;
+ curptr = write_and_incr(mem, 0_u8, curptr)?;
+ }
+
+ {
let mut smbios_ep = Smbios30Entrypoint::default();
smbios_ep.signature = *SM3_MAGIC_IDENT;
smbios_ep.length = mem::size_of::<Smbios30Entrypoint>() as u8;
@@ -221,6 +367,11 @@ mod tests {
#[test]
fn struct_size() {
assert_eq!(
+ mem::size_of::<Smbios23Entrypoint>(),
+ 0x1fusize,
+ concat!("Size of: ", stringify!(Smbios23Entrypoint))
+ );
+ assert_eq!(
mem::size_of::<Smbios30Entrypoint>(),
0x18usize,
concat!("Size of: ", stringify!(Smbios30Entrypoint))
@@ -241,7 +392,8 @@ mod tests {
fn entrypoint_checksum() {
let mem = GuestMemory::new(&[(GuestAddress(SMBIOS_START), 4096)]).unwrap();
- setup_smbios(&mem).unwrap();
+ // Use default 3.0 SMBIOS format.
+ setup_smbios(&mem, None).unwrap();
let smbios_ep: Smbios30Entrypoint =
mem.read_obj_from_addr(GuestAddress(SMBIOS_START)).unwrap();
diff --git a/x86_64/src/test_integration.rs b/x86_64/src/test_integration.rs
index 6545c85a2..6ee54e1a3 100644
--- a/x86_64/src/test_integration.rs
+++ b/x86_64/src/test_integration.rs
@@ -12,13 +12,13 @@ use super::cpuid::setup_cpuid;
use super::interrupts::set_lint;
use super::regs::{setup_fpu, setup_msrs, setup_regs, setup_sregs};
use super::X8664arch;
-use super::{acpi, bootparam, mptable, smbios};
+use super::{acpi, arch_memory_regions, bootparam, mptable, smbios};
use super::{
BOOT_STACK_POINTER, END_ADDR_BEFORE_32BITS, KERNEL_64BIT_ENTRY_OFFSET, KERNEL_START_OFFSET,
X86_64_SCI_IRQ, ZERO_PAGE_OFFSET,
};
-use base::Event;
+use base::{Event, Tube};
use std::collections::BTreeMap;
use std::ffi::CString;
@@ -28,14 +28,9 @@ use sync::Mutex;
use devices::PciConfigIo;
-use vm_control::{
- DiskControlCommand, DiskControlResult, VmIrqRequest, VmIrqRequestSocket, VmIrqResponse,
- VmIrqResponseSocket, VmMemoryControlResponseSocket, VmMemoryRequest, VmMemoryResponse,
-};
-
-enum TaggedControlSocket {
- VmMemory(VmMemoryControlResponseSocket),
- VmIrq(VmIrqResponseSocket),
+enum TaggedControlTube {
+ VmMemory(Tube),
+ VmIrq(Tube),
}
#[test]
@@ -64,8 +59,8 @@ fn simple_kvm_split_irqchip_test() {
let vm = KvmVm::new(&kvm, guest_mem).expect("failed to create kvm vm");
(kvm, vm)
},
- |vm, vcpu_count, device_socket| {
- KvmSplitIrqChip::new(vm, vcpu_count, device_socket)
+ |vm, vcpu_count, device_tube| {
+ KvmSplitIrqChip::new(vm, vcpu_count, device_tube, None)
.expect("failed to create KvmSplitIrqChip")
},
);
@@ -82,7 +77,7 @@ where
Vcpu: VcpuX86_64 + 'static,
I: IrqChipX86_64 + 'static,
FV: FnOnce(GuestMemory) -> (H, V),
- FI: FnOnce(V, /* vcpu_count: */ usize, VmIrqRequestSocket) -> I,
+ FI: FnOnce(V, /* vcpu_count: */ usize, Tube) -> I,
{
/*
0x0000000000000000: 67 89 18 mov dword ptr [eax], ebx
@@ -100,38 +95,32 @@ where
let write_addr = GuestAddress(0x4000);
// guest mem is 400 pages
- let guest_mem = X8664arch::setup_memory(memory_size, None).unwrap();
- // let guest_mem = GuestMemory::new(&[(GuestAddress(0), memory_size)]).unwrap();
+ let arch_mem_regions = arch_memory_regions(memory_size, None);
+ let guest_mem = GuestMemory::new(&arch_mem_regions).unwrap();
+
let mut resources = X8664arch::get_resource_allocator(&guest_mem);
let (hyp, mut vm) = create_vm(guest_mem.clone());
- let (irqchip_socket, device_socket) =
- msg_socket::pair::<VmIrqResponse, VmIrqRequest>().expect("failed to create irq socket");
+ let (irqchip_tube, device_tube) = Tube::pair().expect("failed to create irq tube");
- let mut irq_chip = create_irq_chip(
- vm.try_clone().expect("failed to clone vm"),
- 1,
- device_socket,
- );
+ let mut irq_chip = create_irq_chip(vm.try_clone().expect("failed to clone vm"), 1, device_tube);
let mut mmio_bus = devices::Bus::new();
let exit_evt = Event::new().unwrap();
- let mut control_sockets = vec![TaggedControlSocket::VmIrq(irqchip_socket)];
+ let mut control_tubes = vec![TaggedControlTube::VmIrq(irqchip_tube)];
// Create one control socket per disk.
- let mut disk_device_sockets = Vec::new();
- let mut disk_host_sockets = Vec::new();
+ let mut disk_device_tubes = Vec::new();
+ let mut disk_host_tubes = Vec::new();
let disk_count = 0;
for _ in 0..disk_count {
- let (disk_host_socket, disk_device_socket) =
- msg_socket::pair::<DiskControlCommand, DiskControlResult>().unwrap();
- disk_host_sockets.push(disk_host_socket);
- disk_device_sockets.push(disk_device_socket);
+ let (disk_host_tube, disk_device_tube) = Tube::pair().unwrap();
+ disk_host_tubes.push(disk_host_tube);
+ disk_device_tubes.push(disk_device_tube);
}
- let (gpu_host_socket, _gpu_device_socket) =
- msg_socket::pair::<VmMemoryResponse, VmMemoryRequest>().unwrap();
+ let (gpu_host_tube, _gpu_device_tube) = Tube::pair().unwrap();
- control_sockets.push(TaggedControlSocket::VmMemory(gpu_host_socket));
+ control_tubes.push(TaggedControlTube::VmMemory(gpu_host_tube));
let devices = vec![];
@@ -212,7 +201,7 @@ where
// Note that this puts the mptable at 0x9FC00 in guest physical memory.
mptable::setup_mptable(&guest_mem, 1, pci_irqs).expect("failed to setup mptable");
- smbios::setup_smbios(&guest_mem).expect("failed to setup smbios");
+ smbios::setup_smbios(&guest_mem, None).expect("failed to setup smbios");
acpi::create_acpi_tables(&guest_mem, 1, X86_64_SCI_IRQ, acpi_dev_resource.0);