summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJorge E. Moreira <jemoreira@google.com>2021-04-15 18:16:05 +0000
committerAutomerger Merge Worker <android-build-automerger-merge-worker@system.gserviceaccount.com>2021-04-15 18:16:05 +0000
commit63e5d173880667d287719756463561e5493c6a5b (patch)
tree56be29b9f4f9f6308a7ce7adb8a77cdf115f1bc6
parent9dc6172cfed4ab25b52a45f3e74dc992b7fb485d (diff)
parent86868ea15c099f99ed1ddf83d870e1b025d9d53b (diff)
downloadvmm_vhost-63e5d173880667d287719756463561e5493c6a5b.tar.gz
Merge remote-tracking branch 'aosp/upstream-main' am: 86868ea15c
Original change: https://android-review.googlesource.com/c/platform/external/rust/crates/vmm_vhost/+/1676994 Change-Id: I1896033874d9f7edc287a4dbf3d3294251eed104
-rw-r--r--.cargo/config5
-rw-r--r--.gitignore6
-rw-r--r--.gitmodules3
-rw-r--r--CODEOWNERS2
-rw-r--r--Cargo.toml30
-rw-r--r--LICENSE202
-rw-r--r--LICENSE-BSD-3-Clause27
-rw-r--r--LICENSE-BSD-Chromium27
-rw-r--r--OWNERS3
-rw-r--r--PRESUBMIT.cfg6
-rw-r--r--README.md32
-rw-r--r--coverage_config_aarch64.json1
-rw-r--r--coverage_config_x86_64.json1
-rw-r--r--docs/vhost_architecture.drawio171
-rw-r--r--docs/vhost_architecture.pngbin0 -> 146074 bytes
m---------rust-vmm-ci0
-rw-r--r--src/backend.rs506
-rw-r--r--src/lib.rs162
-rw-r--r--src/vhost_kern/mod.rs283
-rw-r--r--src/vhost_kern/vhost_binding.rs406
-rw-r--r--src/vhost_kern/vsock.rs184
-rw-r--r--src/vhost_user/connection.rs858
-rw-r--r--src/vhost_user/dummy_slave.rs259
-rw-r--r--src/vhost_user/master.rs1071
-rw-r--r--src/vhost_user/master_req_handler.rs477
-rw-r--r--src/vhost_user/message.rs1042
-rw-r--r--src/vhost_user/mod.rs456
-rw-r--r--src/vhost_user/slave.rs86
-rw-r--r--src/vhost_user/slave_fs_cache.rs226
-rw-r--r--src/vhost_user/slave_req_handler.rs828
-rw-r--r--src/vsock.rs30
31 files changed, 7390 insertions, 0 deletions
diff --git a/.cargo/config b/.cargo/config
new file mode 100644
index 0000000..bf8523e
--- /dev/null
+++ b/.cargo/config
@@ -0,0 +1,5 @@
+# This workaround is needed because the linker is unable to find __addtf3,
+# __multf3 and __subtf3.
+# Related issue: https://github.com/rust-lang/compiler-builtins/issues/201
+[target.aarch64-unknown-linux-musl]
+rustflags = [ "-C", "target-feature=+crt-static", "-C", "link-arg=-lgcc"]
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..f738aa8
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,6 @@
+/build
+/kcov_build
+/target
+.idea
+**/*.rs.bk
+Cargo.lock
diff --git a/.gitmodules b/.gitmodules
new file mode 100644
index 0000000..bda97eb
--- /dev/null
+++ b/.gitmodules
@@ -0,0 +1,3 @@
+[submodule "rust-vmm-ci"]
+ path = rust-vmm-ci
+ url = https://github.com/rust-vmm/rust-vmm-ci.git
diff --git a/CODEOWNERS b/CODEOWNERS
new file mode 100644
index 0000000..7174a1b
--- /dev/null
+++ b/CODEOWNERS
@@ -0,0 +1,2 @@
+# Add the list of code owners here (using their GitHub username)
+* gatekeeper-PullAssigner @jiangliu @eryugey @sboeuf @slp
diff --git a/Cargo.toml b/Cargo.toml
new file mode 100644
index 0000000..917ea25
--- /dev/null
+++ b/Cargo.toml
@@ -0,0 +1,30 @@
+[package]
+name = "vmm_vhost"
+version = "0.1.0"
+keywords = ["vhost", "vhost-user", "virtio", "vdpa"]
+description = "a pure rust library for vdpa, vhost and vhost-user"
+authors = ["Liu Jiang <gerry@linux.alibaba.com>"]
+repository = "https://github.com/rust-vmm/vhost"
+documentation = "https://docs.rs/vhost"
+readme = "README.md"
+license = "Apache-2.0 or BSD-3-Clause"
+edition = "2018"
+
+[features]
+default = []
+vhost-vsock = []
+vhost-kern = ["vm-memory"]
+vhost-user = []
+vhost-user-master = ["vhost-user"]
+vhost-user-slave = ["vhost-user"]
+
+[dependencies]
+bitflags = ">=1.0.1"
+libc = ">=0.2.39"
+
+sys_util = { path = "../../../platform/crosvm/sys_util" } # provided by ebuild
+tempfile = { path = "../../../platform/crosvm/tempfile" } # provided by ebuild
+vm-memory = { version = "0.2.0", optional = true }
+
+[dev-dependencies]
+vm-memory = { version = "0.2.0", features=["backend-mmap"] }
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..d645695
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,202 @@
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/LICENSE-BSD-3-Clause b/LICENSE-BSD-3-Clause
new file mode 100644
index 0000000..1ff0cd7
--- /dev/null
+++ b/LICENSE-BSD-3-Clause
@@ -0,0 +1,27 @@
+// Copyright (C) 2019 Alibaba Cloud. All rights reserved.
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+// * Neither the name of Alibaba Inc. nor the names of its contributors
+// may be used to endorse or promote products derived from this software
+// without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/LICENSE-BSD-Chromium b/LICENSE-BSD-Chromium
new file mode 100644
index 0000000..8bafca3
--- /dev/null
+++ b/LICENSE-BSD-Chromium
@@ -0,0 +1,27 @@
+// Copyright 2017 The Chromium OS Authors. All rights reserved.
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/OWNERS b/OWNERS
new file mode 100644
index 0000000..c48e4ef
--- /dev/null
+++ b/OWNERS
@@ -0,0 +1,3 @@
+jemoreira@google.com
+chirantan@google.com
+dgreid@google.com
diff --git a/PRESUBMIT.cfg b/PRESUBMIT.cfg
new file mode 100644
index 0000000..e52636e
--- /dev/null
+++ b/PRESUBMIT.cfg
@@ -0,0 +1,6 @@
+[Hook Overrides]
+cargo_clippy_check: true
+
+[Hook Overrides Options]
+cargo_clippy_check:
+ --project=.
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..b0f4dfa
--- /dev/null
+++ b/README.md
@@ -0,0 +1,32 @@
+# vHost
+A pure rust library for vDPA, vhost and vhost-user.
+
+The `vhost` crate aims to help implementing dataplane for virtio backend drivers. It supports three different types of dataplane drivers:
+- vhost: the dataplane is implemented by linux kernel
+- vhost-user: the dataplane is implemented by dedicated vhost-user servers
+- vDPA(vhost DataPath Accelerator): the dataplane is implemented by hardwares
+
+The main relationship among Traits and Structs exported by the `vhost` crate is as below:
+
+![vhost Architecture](/docs/vhost_architecture.png)
+## Kernel-based vHost Backend Drivers
+The vhost drivers in Linux provide in-kernel virtio device emulation. Normally
+the hypervisor userspace process emulates I/O accesses from the guest.
+Vhost puts virtio emulation code into the kernel, taking hypervisor userspace
+out of the picture. This allows device emulation code to directly call into
+kernel subsystems instead of performing system calls from userspace.
+The hypervisor relies on ioctl based interfaces to control those in-kernel
+vhost drivers, such as vhost-net, vhost-scsi and vhost-vsock etc.
+
+## vHost-user Backend Drivers
+The [vhost-user protocol](https://qemu.readthedocs.io/en/latest/interop/vhost-user.html#communication) aims to implement vhost backend drivers in
+userspace, which complements the ioctl interface used to control the vhost
+implementation in the Linux kernel. It implements the control plane needed
+to establish virtqueue sharing with a user space process on the same host.
+It uses communication over a Unix domain socket to share file descriptors in
+the ancillary data of the message.
+
+The protocol defines two sides of the communication, master and slave.
+Master is the application that shares its virtqueues, slave is the consumer
+of the virtqueues. Master and slave can be either a client (i.e. connecting)
+or server (listening) in the socket communication.
diff --git a/coverage_config_aarch64.json b/coverage_config_aarch64.json
new file mode 100644
index 0000000..67543fc
--- /dev/null
+++ b/coverage_config_aarch64.json
@@ -0,0 +1 @@
+{"coverage_score": 39.8, "exclude_path": "", "crate_features": "vhost-vsock,vhost-kern,vhost-user-master,vhost-user-slave"} \ No newline at end of file
diff --git a/coverage_config_x86_64.json b/coverage_config_x86_64.json
new file mode 100644
index 0000000..2b2c164
--- /dev/null
+++ b/coverage_config_x86_64.json
@@ -0,0 +1 @@
+{"coverage_score": 81.2, "exclude_path": "src/vhost_kern/", "crate_features": "vhost-user-master,vhost-user-slave"}
diff --git a/docs/vhost_architecture.drawio b/docs/vhost_architecture.drawio
new file mode 100644
index 0000000..5008d28
--- /dev/null
+++ b/docs/vhost_architecture.drawio
@@ -0,0 +1,171 @@
+<mxfile host="65bd71144e" modified="2021-02-22T05:37:26.833Z" agent="5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Code/1.53.0 Chrome/87.0.4280.141 Electron/11.2.1 Safari/537.36" etag="HWRXqybJYJqQhnlJWfmB" version="14.2.4" type="embed">
+ <diagram id="xCgrIAQPDQM0eynUYBOE" name="Page-1">
+ <mxGraphModel dx="3446" dy="1284" grid="1" gridSize="10" guides="1" tooltips="1" connect="1" arrows="1" fold="1" page="1" pageScale="1" pageWidth="850" pageHeight="1100" math="0" shadow="0">
+ <root>
+ <mxCell id="0"/>
+ <mxCell id="1" parent="0"/>
+ <mxCell id="46" value="&lt;br&gt;&lt;br&gt;&lt;br&gt;&lt;br&gt;&lt;br&gt;&lt;br&gt;&lt;br&gt;&lt;br&gt;&lt;br&gt;&lt;br&gt;&lt;br&gt;&lt;br&gt;&lt;br&gt;&lt;br&gt;&lt;br&gt;&lt;br&gt;" style="rounded=0;whiteSpace=wrap;html=1;labelBackgroundColor=none;sketch=0;fontSize=25;fontColor=#FF00FF;fillColor=none;strokeColor=#4D4D4D;strokeWidth=5;" vertex="1" parent="1">
+ <mxGeometry x="1620" y="27" width="450" height="990" as="geometry"/>
+ </mxCell>
+ <mxCell id="47" value="" style="shape=hexagon;perimeter=hexagonPerimeter2;whiteSpace=wrap;html=1;fixedSize=1;rounded=0;labelBackgroundColor=none;sketch=0;fillColor=none;fontSize=25;dashed=1;strokeWidth=6;fontColor=#FF00FF;strokeColor=#FF00FF;" vertex="1" parent="1">
+ <mxGeometry x="790" y="237" width="1260" height="750" as="geometry"/>
+ </mxCell>
+ <mxCell id="44" value="" style="rounded=0;whiteSpace=wrap;html=1;labelBackgroundColor=none;sketch=0;fontSize=25;fontColor=#FF00FF;fillColor=none;strokeColor=#4D4D4D;strokeWidth=5;" vertex="1" parent="1">
+ <mxGeometry x="-10" y="37" width="1250" height="670" as="geometry"/>
+ </mxCell>
+ <mxCell id="2" value="&lt;pre style=&quot;font-family: &amp;quot;jetbrains mono&amp;quot;, monospace; font-size: 16.5pt;&quot;&gt;MasterReqHandler&lt;/pre&gt;" style="rounded=0;whiteSpace=wrap;html=1;fontStyle=1;labelBackgroundColor=none;fontColor=#FF00FF;strokeColor=#FF00FF;" parent="1" vertex="1">
+ <mxGeometry x="830" y="477" width="220" height="50" as="geometry"/>
+ </mxCell>
+ <mxCell id="4" value="&lt;pre style=&quot;font-size: 16.5pt; font-weight: 700; font-family: &amp;quot;jetbrains mono&amp;quot;, monospace;&quot;&gt;VhostUserMasterReqHandler&lt;/pre&gt;" style="rounded=1;whiteSpace=wrap;html=1;labelBackgroundColor=none;fontColor=#FF00FF;strokeColor=#FF00FF;" parent="1" vertex="1">
+ <mxGeometry x="840" y="597" width="360" height="60" as="geometry"/>
+ </mxCell>
+ <mxCell id="6" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0;exitY=0.5;exitDx=0;exitDy=0;entryX=1;entryY=0.5;entryDx=0;entryDy=0;labelBackgroundColor=none;fontColor=#FF00FF;strokeColor=#FF00FF;" edge="1" parent="1" source="5" target="2">
+ <mxGeometry relative="1" as="geometry">
+ <Array as="points">
+ <mxPoint x="1280" y="792"/>
+ <mxPoint x="1280" y="502"/>
+ </Array>
+ </mxGeometry>
+ </mxCell>
+ <mxCell id="5" value="&lt;pre style=&quot;font-family: &amp;quot;jetbrains mono&amp;quot;, monospace; font-size: 16.5pt;&quot;&gt;&lt;pre style=&quot;font-family: &amp;quot;jetbrains mono&amp;quot; , monospace ; font-size: 16.5pt&quot;&gt;SlaveFsCacheReq&lt;/pre&gt;&lt;/pre&gt;" style="rounded=0;whiteSpace=wrap;html=1;fontStyle=1;labelBackgroundColor=none;fontColor=#FF00FF;strokeColor=#FF00FF;" parent="1" vertex="1">
+ <mxGeometry x="1715" y="767" width="220" height="50" as="geometry"/>
+ </mxCell>
+ <mxCell id="7" value="&lt;pre style=&quot;font-size: 16.5pt; font-weight: 700; font-family: &amp;quot;jetbrains mono&amp;quot;, monospace;&quot;&gt;VhostUserMasterReqHandlerMut&lt;/pre&gt;" style="rounded=1;whiteSpace=wrap;html=1;labelBackgroundColor=none;fontColor=#FF00FF;strokeColor=#FF00FF;" vertex="1" parent="1">
+ <mxGeometry x="1630" y="657" width="390" height="60" as="geometry"/>
+ </mxCell>
+ <mxCell id="8" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.5;exitY=1;exitDx=0;exitDy=0;labelBackgroundColor=none;fontColor=#FF00FF;strokeColor=#FF00FF;" edge="1" parent="1" source="2" target="4">
+ <mxGeometry relative="1" as="geometry">
+ <mxPoint x="950" y="657" as="sourcePoint"/>
+ <mxPoint x="680" y="717" as="targetPoint"/>
+ </mxGeometry>
+ </mxCell>
+ <mxCell id="10" value="&lt;pre style=&quot;font-family: &amp;quot;jetbrains mono&amp;quot;, monospace; font-size: 16.5pt;&quot;&gt;&lt;pre style=&quot;font-family: &amp;quot;jetbrains mono&amp;quot; , monospace ; font-size: 16.5pt&quot;&gt;SlaveListener&lt;/pre&gt;&lt;/pre&gt;" style="rounded=0;whiteSpace=wrap;html=1;fontStyle=1;labelBackgroundColor=none;fontColor=#FF00FF;strokeColor=#FF00FF;" vertex="1" parent="1">
+ <mxGeometry x="1360" y="472" width="190" height="50" as="geometry"/>
+ </mxCell>
+ <mxCell id="11" value="&lt;pre style=&quot;font-family: &amp;quot;jetbrains mono&amp;quot;, monospace; font-size: 16.5pt;&quot;&gt;&lt;pre style=&quot;font-family: &amp;quot;jetbrains mono&amp;quot; , monospace ; font-size: 16.5pt&quot;&gt;&lt;pre style=&quot;font-family: &amp;quot;jetbrains mono&amp;quot; , monospace ; font-size: 16.5pt&quot;&gt;SlaveReqHandler&lt;/pre&gt;&lt;/pre&gt;&lt;/pre&gt;" style="rounded=0;whiteSpace=wrap;html=1;fontStyle=1;labelBackgroundColor=none;fontColor=#FF00FF;strokeColor=#FF00FF;" vertex="1" parent="1">
+ <mxGeometry x="1712" y="387" width="210" height="50" as="geometry"/>
+ </mxCell>
+ <mxCell id="14" value="&lt;pre style=&quot;font-size: 16.5pt; font-weight: 700; font-family: &amp;quot;jetbrains mono&amp;quot;, monospace;&quot;&gt;&lt;pre style=&quot;font-family: &amp;quot;jetbrains mono&amp;quot;, monospace; font-size: 16.5pt;&quot;&gt;VhostUserSlaveReqHandler&lt;/pre&gt;&lt;/pre&gt;" style="rounded=1;whiteSpace=wrap;html=1;labelBackgroundColor=none;fontColor=#FF00FF;strokeColor=#FF00FF;" vertex="1" parent="1">
+ <mxGeometry x="1652" y="537" width="330" height="60" as="geometry"/>
+ </mxCell>
+ <mxCell id="15" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.5;exitY=1;exitDx=0;exitDy=0;entryX=0.5;entryY=0;entryDx=0;entryDy=0;labelBackgroundColor=none;fontColor=#FF00FF;strokeColor=#FF00FF;" edge="1" parent="1" source="11" target="14">
+ <mxGeometry relative="1" as="geometry">
+ <mxPoint x="1202" y="567" as="sourcePoint"/>
+ <mxPoint x="1202" y="667" as="targetPoint"/>
+ </mxGeometry>
+ </mxCell>
+ <mxCell id="16" value="&lt;pre style=&quot;font-size: 16.5pt; font-weight: 700; font-family: &amp;quot;jetbrains mono&amp;quot;, monospace;&quot;&gt;VhostBackend&lt;/pre&gt;" style="rounded=1;whiteSpace=wrap;html=1;labelBackgroundColor=none;fontColor=#00994D;strokeColor=#009900;" vertex="1" parent="1">
+ <mxGeometry x="390" y="197" width="250" height="60" as="geometry"/>
+ </mxCell>
+ <mxCell id="17" value="&lt;pre style=&quot;font-family: &amp;quot;jetbrains mono&amp;quot;, monospace; font-size: 16.5pt;&quot;&gt;VhostKernBackend&lt;/pre&gt;" style="rounded=0;whiteSpace=wrap;html=1;fontStyle=1;labelBackgroundColor=none;strokeColor=#0000CC;fontColor=#0000CC;" vertex="1" parent="1">
+ <mxGeometry x="530" y="387" width="220" height="50" as="geometry"/>
+ </mxCell>
+ <mxCell id="18" value="&lt;pre style=&quot;font-family: &amp;quot;jetbrains mono&amp;quot;, monospace; font-size: 16.5pt;&quot;&gt;VhostVdpaBackend&lt;/pre&gt;" style="rounded=0;whiteSpace=wrap;html=1;fontStyle=1;labelBackgroundColor=none;fontColor=#808080;strokeColor=#808080;" vertex="1" parent="1">
+ <mxGeometry x="270" y="387" width="220" height="50" as="geometry"/>
+ </mxCell>
+ <mxCell id="19" value="&lt;pre style=&quot;font-family: &amp;quot;jetbrains mono&amp;quot; , monospace ; font-size: 16.5pt&quot;&gt;Master&lt;/pre&gt;" style="rounded=0;whiteSpace=wrap;html=1;fontStyle=1;labelBackgroundColor=none;fontColor=#FF00FF;strokeColor=#FF00FF;" vertex="1" parent="1">
+ <mxGeometry x="820" y="387" width="220" height="50" as="geometry"/>
+ </mxCell>
+ <mxCell id="20" value="&lt;pre style=&quot;font-family: &amp;quot;jetbrains mono&amp;quot; , monospace ; font-size: 16.5pt&quot;&gt;VhostSoftBackend&lt;/pre&gt;" style="rounded=0;whiteSpace=wrap;html=1;fontStyle=1;labelBackgroundColor=none;fontColor=#808080;strokeColor=#808080;" vertex="1" parent="1">
+ <mxGeometry x="10" y="387" width="220" height="50" as="geometry"/>
+ </mxCell>
+ <mxCell id="21" value="Handle virtque in VMM" style="shape=process;whiteSpace=wrap;html=1;backgroundOutline=1;rounded=0;labelBackgroundColor=none;sketch=0;fontSize=25;fontColor=#808080;strokeColor=#808080;" vertex="1" parent="1">
+ <mxGeometry x="10" y="557" width="220" height="120" as="geometry"/>
+ </mxCell>
+ <mxCell id="23" value="Handle virtque in hardware" style="shape=process;whiteSpace=wrap;html=1;backgroundOutline=1;rounded=0;labelBackgroundColor=none;sketch=0;fontSize=25;fontColor=#808080;strokeColor=#808080;" vertex="1" parent="1">
+ <mxGeometry x="270" y="807" width="220" height="120" as="geometry"/>
+ </mxCell>
+ <mxCell id="24" value="Handle virtque in kernel" style="shape=process;whiteSpace=wrap;html=1;backgroundOutline=1;rounded=0;labelBackgroundColor=none;sketch=0;fontSize=25;strokeColor=#0000CC;fontColor=#0000CC;" vertex="1" parent="1">
+ <mxGeometry x="530" y="807" width="220" height="120" as="geometry"/>
+ </mxCell>
+ <mxCell id="25" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.5;exitY=0;exitDx=0;exitDy=0;labelBackgroundColor=none;entryX=0.5;entryY=1;entryDx=0;entryDy=0;strokeColor=#0000CC;" edge="1" parent="1" source="24" target="17">
+ <mxGeometry relative="1" as="geometry">
+ <mxPoint x="930" y="647" as="sourcePoint"/>
+ <mxPoint x="930" y="747" as="targetPoint"/>
+ </mxGeometry>
+ </mxCell>
+ <mxCell id="26" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;labelBackgroundColor=none;entryX=0;entryY=0.5;entryDx=0;entryDy=0;fontColor=#FF00FF;strokeColor=#FF00FF;" edge="1" parent="1" source="19" target="11">
+ <mxGeometry relative="1" as="geometry">
+ <mxPoint x="840" y="917" as="sourcePoint"/>
+ <mxPoint x="840" y="1017" as="targetPoint"/>
+ </mxGeometry>
+ </mxCell>
+ <mxCell id="27" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.5;exitY=0;exitDx=0;exitDy=0;labelBackgroundColor=none;entryX=0.5;entryY=1;entryDx=0;entryDy=0;fontColor=#808080;strokeColor=#808080;" edge="1" parent="1" source="23" target="18">
+ <mxGeometry relative="1" as="geometry">
+ <mxPoint x="420" y="807" as="sourcePoint"/>
+ <mxPoint x="420" y="907" as="targetPoint"/>
+ </mxGeometry>
+ </mxCell>
+ <mxCell id="28" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.5;exitY=0;exitDx=0;exitDy=0;labelBackgroundColor=none;entryX=0.5;entryY=1;entryDx=0;entryDy=0;fontColor=#808080;strokeColor=#808080;" edge="1" parent="1" source="21" target="20">
+ <mxGeometry relative="1" as="geometry">
+ <mxPoint x="240" y="857" as="sourcePoint"/>
+ <mxPoint x="240" y="957" as="targetPoint"/>
+ </mxGeometry>
+ </mxCell>
+ <mxCell id="30" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.5;exitY=0;exitDx=0;exitDy=0;labelBackgroundColor=none;entryX=0.5;entryY=1;entryDx=0;entryDy=0;strokeColor=#00994D;" edge="1" parent="1" source="20" target="16">
+ <mxGeometry relative="1" as="geometry">
+ <mxPoint x="910" y="647" as="sourcePoint"/>
+ <mxPoint x="910" y="747" as="targetPoint"/>
+ </mxGeometry>
+ </mxCell>
+ <mxCell id="31" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.5;exitY=0;exitDx=0;exitDy=0;labelBackgroundColor=none;strokeColor=#00994D;entryX=0.5;entryY=1;entryDx=0;entryDy=0;" edge="1" parent="1" source="18" target="16">
+ <mxGeometry relative="1" as="geometry">
+ <mxPoint x="1000" y="177" as="sourcePoint"/>
+ <mxPoint x="530" y="227" as="targetPoint"/>
+ </mxGeometry>
+ </mxCell>
+ <mxCell id="32" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.5;exitY=0;exitDx=0;exitDy=0;labelBackgroundColor=none;entryX=0.5;entryY=1;entryDx=0;entryDy=0;strokeColor=#00994D;" edge="1" parent="1" source="17" target="16">
+ <mxGeometry relative="1" as="geometry">
+ <mxPoint x="1010" y="127" as="sourcePoint"/>
+ <mxPoint x="1505" y="-73" as="targetPoint"/>
+ </mxGeometry>
+ </mxCell>
+ <mxCell id="35" value="&lt;pre style=&quot;font-family: &amp;quot;jetbrains mono&amp;quot; , monospace ; font-size: 16.5pt&quot;&gt;&lt;pre style=&quot;font-family: &amp;quot;jetbrains mono&amp;quot; , monospace ; font-size: 16.5pt&quot;&gt;Endpoint&lt;/pre&gt;&lt;/pre&gt;" style="rounded=0;whiteSpace=wrap;html=1;fontStyle=1;labelBackgroundColor=none;fontColor=#FF00FF;strokeColor=#FF00FF;" vertex="1" parent="1">
+ <mxGeometry x="1360" y="552" width="190" height="50" as="geometry"/>
+ </mxCell>
+ <mxCell id="36" value="&lt;pre style=&quot;font-family: &amp;quot;jetbrains mono&amp;quot; , monospace ; font-size: 16.5pt&quot;&gt;&lt;pre style=&quot;font-family: &amp;quot;jetbrains mono&amp;quot; , monospace ; font-size: 16.5pt&quot;&gt;Message&lt;/pre&gt;&lt;/pre&gt;" style="rounded=0;whiteSpace=wrap;html=1;fontStyle=1;labelBackgroundColor=none;fontColor=#FF00FF;strokeColor=#FF00FF;" vertex="1" parent="1">
+ <mxGeometry x="1360" y="632" width="190" height="50" as="geometry"/>
+ </mxCell>
+ <mxCell id="37" value="&lt;pre style=&quot;font-size: 16.5pt ; font-weight: 700 ; font-family: &amp;quot;jetbrains mono&amp;quot; , monospace&quot;&gt;&lt;pre style=&quot;font-family: &amp;quot;jetbrains mono&amp;quot; , monospace ; font-size: 16.5pt&quot;&gt;VhostUserMaster&lt;/pre&gt;&lt;/pre&gt;" style="rounded=1;whiteSpace=wrap;html=1;labelBackgroundColor=none;strokeColor=#FF33FF;fontColor=#FF00FF;" vertex="1" parent="1">
+ <mxGeometry x="980" y="257" width="230" height="60" as="geometry"/>
+ </mxCell>
+ <mxCell id="38" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.5;exitY=0;exitDx=0;exitDy=0;labelBackgroundColor=none;entryX=0.5;entryY=1;entryDx=0;entryDy=0;strokeColor=#00994D;" edge="1" parent="1" source="19" target="16">
+ <mxGeometry relative="1" as="geometry">
+ <mxPoint x="1030" y="527" as="sourcePoint"/>
+ <mxPoint x="515" y="257" as="targetPoint"/>
+ </mxGeometry>
+ </mxCell>
+ <mxCell id="39" value="Handle virtque in remote process" style="shape=process;whiteSpace=wrap;html=1;backgroundOutline=1;rounded=0;labelBackgroundColor=none;sketch=0;fontSize=25;fontColor=#FF00FF;strokeColor=#FF00FF;" vertex="1" parent="1">
+ <mxGeometry x="850" y="807" width="220" height="120" as="geometry"/>
+ </mxCell>
+ <mxCell id="41" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.5;exitY=0;exitDx=0;exitDy=0;entryX=0.5;entryY=1;entryDx=0;entryDy=0;labelBackgroundColor=none;fontColor=#FF00FF;strokeColor=#FF00FF;" edge="1" parent="1" source="5" target="7">
+ <mxGeometry relative="1" as="geometry">
+ <mxPoint x="1860" y="187" as="sourcePoint"/>
+ <mxPoint x="1860" y="267" as="targetPoint"/>
+ </mxGeometry>
+ </mxCell>
+ <mxCell id="43" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.5;exitY=0;exitDx=0;exitDy=0;labelBackgroundColor=none;entryX=0.5;entryY=1;entryDx=0;entryDy=0;strokeColor=#FF00FF;" edge="1" parent="1" source="19" target="37">
+ <mxGeometry relative="1" as="geometry">
+ <mxPoint x="1430" y="187" as="sourcePoint"/>
+ <mxPoint x="2102" y="187" as="targetPoint"/>
+ </mxGeometry>
+ </mxCell>
+ <mxCell id="49" value="&lt;pre style=&quot;font-size: 16.5pt ; font-weight: 700 ; font-family: &amp;#34;jetbrains mono&amp;#34; , monospace&quot;&gt;&lt;pre style=&quot;font-family: &amp;#34;jetbrains mono&amp;#34; , monospace ; font-size: 16.5pt&quot;&gt;Trait&lt;/pre&gt;&lt;/pre&gt;" style="rounded=1;whiteSpace=wrap;html=1;labelBackgroundColor=none;strokeColor=#FF33FF;fontColor=#FF00FF;" vertex="1" parent="1">
+ <mxGeometry x="60" y="1017" width="130" height="60" as="geometry"/>
+ </mxCell>
+ <mxCell id="51" value="Vhost-user protocol" style="rounded=1;whiteSpace=wrap;html=1;dashed=1;labelBackgroundColor=none;sketch=0;strokeWidth=5;fontSize=67;fontColor=#FF00FF;fillColor=none;strokeColor=none;" vertex="1" parent="1">
+ <mxGeometry x="1220" y="817" width="330" height="150" as="geometry"/>
+ </mxCell>
+ <mxCell id="52" value="Vhost-user server" style="rounded=1;whiteSpace=wrap;html=1;dashed=1;labelBackgroundColor=none;sketch=0;strokeWidth=5;fontSize=67;fillColor=none;strokeColor=none;fontColor=#4D4D4D;" vertex="1" parent="1">
+ <mxGeometry x="1680" y="57" width="330" height="150" as="geometry"/>
+ </mxCell>
+ <mxCell id="53" value="VMM" style="rounded=1;whiteSpace=wrap;html=1;dashed=1;labelBackgroundColor=none;sketch=0;strokeWidth=5;fontSize=67;fillColor=none;strokeColor=none;fontColor=#4D4D4D;" vertex="1" parent="1">
+ <mxGeometry x="20" y="47" width="240" height="150" as="geometry"/>
+ </mxCell>
+ <mxCell id="54" value="&lt;pre style=&quot;font-family: &amp;#34;jetbrains mono&amp;#34; , monospace ; font-size: 16.5pt&quot;&gt;Struct&lt;/pre&gt;" style="rounded=0;whiteSpace=wrap;html=1;fontStyle=1;labelBackgroundColor=none;strokeColor=#0000CC;fontColor=#0000CC;" vertex="1" parent="1">
+ <mxGeometry x="240" y="1022" width="140" height="55" as="geometry"/>
+ </mxCell>
+ </root>
+ </mxGraphModel>
+ </diagram>
+</mxfile> \ No newline at end of file
diff --git a/docs/vhost_architecture.png b/docs/vhost_architecture.png
new file mode 100644
index 0000000..4d1e2bc
--- /dev/null
+++ b/docs/vhost_architecture.png
Binary files differ
diff --git a/rust-vmm-ci b/rust-vmm-ci
new file mode 160000
+Subproject ebc701641fa57f78d03f3f5ecac617b7bf7470b
diff --git a/src/backend.rs b/src/backend.rs
new file mode 100644
index 0000000..1ae306f
--- /dev/null
+++ b/src/backend.rs
@@ -0,0 +1,506 @@
+// Copyright (C) 2019-2021 Alibaba Cloud. All rights reserved.
+// SPDX-License-Identifier: Apache-2.0 or BSD-3-Clause
+//
+// Portions Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+//
+// Portions Copyright 2017 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE-BSD-Google file.
+
+//! Common traits and structs for vhost-kern and vhost-user backend drivers.
+
+use std::cell::RefCell;
+use std::os::unix::io::RawFd;
+use std::sync::RwLock;
+
+use sys_util::EventFd;
+
+use super::Result;
+
+/// Maximum number of memory regions supported.
+pub const VHOST_MAX_MEMORY_REGIONS: usize = 255;
+
+/// Vring configuration data.
+pub struct VringConfigData {
+ /// Maximum queue size supported by the driver.
+ pub queue_max_size: u16,
+ /// Actual queue size negotiated by the driver.
+ pub queue_size: u16,
+ /// Bitmask of vring flags.
+ pub flags: u32,
+ /// Descriptor table address.
+ pub desc_table_addr: u64,
+ /// Used ring buffer address.
+ pub used_ring_addr: u64,
+ /// Available ring buffer address.
+ pub avail_ring_addr: u64,
+ /// Optional address for logging.
+ pub log_addr: Option<u64>,
+}
+
+impl VringConfigData {
+ /// Check whether the log (flag, address) pair is valid.
+ pub fn is_log_addr_valid(&self) -> bool {
+ if self.flags & 0x1 != 0 && self.log_addr.is_none() {
+ return false;
+ }
+
+ true
+ }
+
+ /// Get the log address, default to zero if not available.
+ pub fn get_log_addr(&self) -> u64 {
+ if self.flags & 0x1 != 0 && self.log_addr.is_some() {
+ self.log_addr.unwrap()
+ } else {
+ 0
+ }
+ }
+}
+
+/// Memory region configuration data.
+#[derive(Default, Clone, Copy)]
+pub struct VhostUserMemoryRegionInfo {
+ /// Guest physical address of the memory region.
+ pub guest_phys_addr: u64,
+ /// Size of the memory region.
+ pub memory_size: u64,
+ /// Virtual address in the current process.
+ pub userspace_addr: u64,
+ /// Optional offset where region starts in the mapped memory.
+ pub mmap_offset: u64,
+ /// Optional file descriptor for mmap.
+ pub mmap_handle: RawFd,
+}
+
+/// An interface for setting up vhost-based backend drivers with interior mutability.
+///
+/// Vhost devices are subset of virtio devices, which improve virtio device's performance by
+/// delegating data plane operations to dedicated IO service processes. Vhost devices use the
+/// same virtqueue layout as virtio devices to allow vhost devices to be mapped directly to
+/// virtio devices.
+///
+/// The purpose of vhost is to implement a subset of a virtio device's functionality outside the
+/// VMM process. Typically fast paths for IO operations are delegated to the dedicated IO service
+/// processes, and slow path for device configuration are still handled by the VMM process. It may
+/// also be used to control access permissions of virtio backend devices.
+pub trait VhostBackend: std::marker::Sized {
+ /// Get a bitmask of supported virtio/vhost features.
+ fn get_features(&self) -> Result<u64>;
+
+ /// Inform the vhost subsystem which features to enable.
+ /// This should be a subset of supported features from get_features().
+ ///
+ /// # Arguments
+ /// * `features` - Bitmask of features to set.
+ fn set_features(&self, features: u64) -> Result<()>;
+
+ /// Set the current process as the owner of the vhost backend.
+ /// This must be run before any other vhost commands.
+ fn set_owner(&self) -> Result<()>;
+
+ /// Used to be sent to request disabling all rings
+ /// This is no longer used.
+ fn reset_owner(&self) -> Result<()>;
+
+ /// Set the guest memory mappings for vhost to use.
+ fn set_mem_table(&self, regions: &[VhostUserMemoryRegionInfo]) -> Result<()>;
+
+ /// Set base address for page modification logging.
+ fn set_log_base(&self, base: u64, fd: Option<RawFd>) -> Result<()>;
+
+ /// Specify an eventfd file descriptor to signal on log write.
+ fn set_log_fd(&self, fd: RawFd) -> Result<()>;
+
+ /// Set the number of descriptors in the vring.
+ ///
+ /// # Arguments
+ /// * `queue_index` - Index of the queue to set descriptor count for.
+ /// * `num` - Number of descriptors in the queue.
+ fn set_vring_num(&self, queue_index: usize, num: u16) -> Result<()>;
+
+ /// Set the addresses for a given vring.
+ ///
+ /// # Arguments
+ /// * `queue_index` - Index of the queue to set addresses for.
+ /// * `config_data` - Configuration data for a vring.
+ fn set_vring_addr(&self, queue_index: usize, config_data: &VringConfigData) -> Result<()>;
+
+ /// Set the first index to look for available descriptors.
+ ///
+ /// # Arguments
+ /// * `queue_index` - Index of the queue to modify.
+ /// * `num` - Index where available descriptors start.
+ fn set_vring_base(&self, queue_index: usize, base: u16) -> Result<()>;
+
+ /// Get the available vring base offset.
+ fn get_vring_base(&self, queue_index: usize) -> Result<u32>;
+
+ /// Set the eventfd to trigger when buffers have been used by the host.
+ ///
+ /// # Arguments
+ /// * `queue_index` - Index of the queue to modify.
+ /// * `fd` - EventFd to trigger.
+ fn set_vring_call(&self, queue_index: usize, fd: &EventFd) -> Result<()>;
+
+ /// Set the eventfd that will be signaled by the guest when buffers are
+ /// available for the host to process.
+ ///
+ /// # Arguments
+ /// * `queue_index` - Index of the queue to modify.
+ /// * `fd` - EventFd that will be signaled from guest.
+ fn set_vring_kick(&self, queue_index: usize, fd: &EventFd) -> Result<()>;
+
+ /// Set the eventfd that will be signaled by the guest when error happens.
+ ///
+ /// # Arguments
+ /// * `queue_index` - Index of the queue to modify.
+ /// * `fd` - EventFd that will be signaled from guest.
+ fn set_vring_err(&self, queue_index: usize, fd: &EventFd) -> Result<()>;
+}
+
+/// An interface for setting up vhost-based backend drivers.
+///
+/// Vhost devices are subset of virtio devices, which improve virtio device's performance by
+/// delegating data plane operations to dedicated IO service processes. Vhost devices use the
+/// same virtqueue layout as virtio devices to allow vhost devices to be mapped directly to
+/// virtio devices.
+///
+/// The purpose of vhost is to implement a subset of a virtio device's functionality outside the
+/// VMM process. Typically fast paths for IO operations are delegated to the dedicated IO service
+/// processes, and slow path for device configuration are still handled by the VMM process. It may
+/// also be used to control access permissions of virtio backend devices.
+pub trait VhostBackendMut: std::marker::Sized {
+ /// Get a bitmask of supported virtio/vhost features.
+ fn get_features(&mut self) -> Result<u64>;
+
+ /// Inform the vhost subsystem which features to enable.
+ /// This should be a subset of supported features from get_features().
+ ///
+ /// # Arguments
+ /// * `features` - Bitmask of features to set.
+ fn set_features(&mut self, features: u64) -> Result<()>;
+
+ /// Set the current process as the owner of the vhost backend.
+ /// This must be run before any other vhost commands.
+ fn set_owner(&mut self) -> Result<()>;
+
+ /// Used to be sent to request disabling all rings
+ /// This is no longer used.
+ fn reset_owner(&mut self) -> Result<()>;
+
+ /// Set the guest memory mappings for vhost to use.
+ fn set_mem_table(&mut self, regions: &[VhostUserMemoryRegionInfo]) -> Result<()>;
+
+ /// Set base address for page modification logging.
+ fn set_log_base(&mut self, base: u64, fd: Option<RawFd>) -> Result<()>;
+
+ /// Specify an eventfd file descriptor to signal on log write.
+ fn set_log_fd(&mut self, fd: RawFd) -> Result<()>;
+
+ /// Set the number of descriptors in the vring.
+ ///
+ /// # Arguments
+ /// * `queue_index` - Index of the queue to set descriptor count for.
+ /// * `num` - Number of descriptors in the queue.
+ fn set_vring_num(&mut self, queue_index: usize, num: u16) -> Result<()>;
+
+ /// Set the addresses for a given vring.
+ ///
+ /// # Arguments
+ /// * `queue_index` - Index of the queue to set addresses for.
+ /// * `config_data` - Configuration data for a vring.
+ fn set_vring_addr(&mut self, queue_index: usize, config_data: &VringConfigData) -> Result<()>;
+
+ /// Set the first index to look for available descriptors.
+ ///
+ /// # Arguments
+ /// * `queue_index` - Index of the queue to modify.
+ /// * `num` - Index where available descriptors start.
+ fn set_vring_base(&mut self, queue_index: usize, base: u16) -> Result<()>;
+
+ /// Get the available vring base offset.
+ fn get_vring_base(&mut self, queue_index: usize) -> Result<u32>;
+
+ /// Set the eventfd to trigger when buffers have been used by the host.
+ ///
+ /// # Arguments
+ /// * `queue_index` - Index of the queue to modify.
+ /// * `fd` - EventFd to trigger.
+ fn set_vring_call(&mut self, queue_index: usize, fd: &EventFd) -> Result<()>;
+
+ /// Set the eventfd that will be signaled by the guest when buffers are
+ /// available for the host to process.
+ ///
+ /// # Arguments
+ /// * `queue_index` - Index of the queue to modify.
+ /// * `fd` - EventFd that will be signaled from guest.
+ fn set_vring_kick(&mut self, queue_index: usize, fd: &EventFd) -> Result<()>;
+
+ /// Set the eventfd that will be signaled by the guest when error happens.
+ ///
+ /// # Arguments
+ /// * `queue_index` - Index of the queue to modify.
+ /// * `fd` - EventFd that will be signaled from guest.
+ fn set_vring_err(&mut self, queue_index: usize, fd: &EventFd) -> Result<()>;
+}
+
+impl<T: VhostBackendMut> VhostBackend for RwLock<T> {
+ fn get_features(&self) -> Result<u64> {
+ self.write().unwrap().get_features()
+ }
+
+ fn set_features(&self, features: u64) -> Result<()> {
+ self.write().unwrap().set_features(features)
+ }
+
+ fn set_owner(&self) -> Result<()> {
+ self.write().unwrap().set_owner()
+ }
+
+ fn reset_owner(&self) -> Result<()> {
+ self.write().unwrap().reset_owner()
+ }
+
+ fn set_mem_table(&self, regions: &[VhostUserMemoryRegionInfo]) -> Result<()> {
+ self.write().unwrap().set_mem_table(regions)
+ }
+
+ fn set_log_base(&self, base: u64, fd: Option<RawFd>) -> Result<()> {
+ self.write().unwrap().set_log_base(base, fd)
+ }
+
+ fn set_log_fd(&self, fd: RawFd) -> Result<()> {
+ self.write().unwrap().set_log_fd(fd)
+ }
+
+ fn set_vring_num(&self, queue_index: usize, num: u16) -> Result<()> {
+ self.write().unwrap().set_vring_num(queue_index, num)
+ }
+
+ fn set_vring_addr(&self, queue_index: usize, config_data: &VringConfigData) -> Result<()> {
+ self.write()
+ .unwrap()
+ .set_vring_addr(queue_index, config_data)
+ }
+
+ fn set_vring_base(&self, queue_index: usize, base: u16) -> Result<()> {
+ self.write().unwrap().set_vring_base(queue_index, base)
+ }
+
+ fn get_vring_base(&self, queue_index: usize) -> Result<u32> {
+ self.write().unwrap().get_vring_base(queue_index)
+ }
+
+ fn set_vring_call(&self, queue_index: usize, fd: &EventFd) -> Result<()> {
+ self.write().unwrap().set_vring_call(queue_index, fd)
+ }
+
+ fn set_vring_kick(&self, queue_index: usize, fd: &EventFd) -> Result<()> {
+ self.write().unwrap().set_vring_kick(queue_index, fd)
+ }
+
+ fn set_vring_err(&self, queue_index: usize, fd: &EventFd) -> Result<()> {
+ self.write().unwrap().set_vring_err(queue_index, fd)
+ }
+}
+
+impl<T: VhostBackendMut> VhostBackend for RefCell<T> {
+ fn get_features(&self) -> Result<u64> {
+ self.borrow_mut().get_features()
+ }
+
+ fn set_features(&self, features: u64) -> Result<()> {
+ self.borrow_mut().set_features(features)
+ }
+
+ fn set_owner(&self) -> Result<()> {
+ self.borrow_mut().set_owner()
+ }
+
+ fn reset_owner(&self) -> Result<()> {
+ self.borrow_mut().reset_owner()
+ }
+
+ fn set_mem_table(&self, regions: &[VhostUserMemoryRegionInfo]) -> Result<()> {
+ self.borrow_mut().set_mem_table(regions)
+ }
+
+ fn set_log_base(&self, base: u64, fd: Option<RawFd>) -> Result<()> {
+ self.borrow_mut().set_log_base(base, fd)
+ }
+
+ fn set_log_fd(&self, fd: RawFd) -> Result<()> {
+ self.borrow_mut().set_log_fd(fd)
+ }
+
+ fn set_vring_num(&self, queue_index: usize, num: u16) -> Result<()> {
+ self.borrow_mut().set_vring_num(queue_index, num)
+ }
+
+ fn set_vring_addr(&self, queue_index: usize, config_data: &VringConfigData) -> Result<()> {
+ self.borrow_mut().set_vring_addr(queue_index, config_data)
+ }
+
+ fn set_vring_base(&self, queue_index: usize, base: u16) -> Result<()> {
+ self.borrow_mut().set_vring_base(queue_index, base)
+ }
+
+ fn get_vring_base(&self, queue_index: usize) -> Result<u32> {
+ self.borrow_mut().get_vring_base(queue_index)
+ }
+
+ fn set_vring_call(&self, queue_index: usize, fd: &EventFd) -> Result<()> {
+ self.borrow_mut().set_vring_call(queue_index, fd)
+ }
+
+ fn set_vring_kick(&self, queue_index: usize, fd: &EventFd) -> Result<()> {
+ self.borrow_mut().set_vring_kick(queue_index, fd)
+ }
+
+ fn set_vring_err(&self, queue_index: usize, fd: &EventFd) -> Result<()> {
+ self.borrow_mut().set_vring_err(queue_index, fd)
+ }
+}
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ struct MockBackend {}
+
+ impl VhostBackendMut for MockBackend {
+ fn get_features(&mut self) -> Result<u64> {
+ Ok(0x1)
+ }
+
+ fn set_features(&mut self, features: u64) -> Result<()> {
+ assert_eq!(features, 0x1);
+ Ok(())
+ }
+
+ fn set_owner(&mut self) -> Result<()> {
+ Ok(())
+ }
+
+ fn reset_owner(&mut self) -> Result<()> {
+ Ok(())
+ }
+
+ fn set_mem_table(&mut self, _regions: &[VhostUserMemoryRegionInfo]) -> Result<()> {
+ Ok(())
+ }
+
+ fn set_log_base(&mut self, base: u64, fd: Option<RawFd>) -> Result<()> {
+ assert_eq!(base, 0x100);
+ assert_eq!(fd, Some(100));
+ Ok(())
+ }
+
+ fn set_log_fd(&mut self, fd: RawFd) -> Result<()> {
+ assert_eq!(fd, 100);
+ Ok(())
+ }
+
+ fn set_vring_num(&mut self, queue_index: usize, num: u16) -> Result<()> {
+ assert_eq!(queue_index, 1);
+ assert_eq!(num, 256);
+ Ok(())
+ }
+
+ fn set_vring_addr(
+ &mut self,
+ queue_index: usize,
+ _config_data: &VringConfigData,
+ ) -> Result<()> {
+ assert_eq!(queue_index, 1);
+ Ok(())
+ }
+
+ fn set_vring_base(&mut self, queue_index: usize, base: u16) -> Result<()> {
+ assert_eq!(queue_index, 1);
+ assert_eq!(base, 2);
+ Ok(())
+ }
+
+ fn get_vring_base(&mut self, queue_index: usize) -> Result<u32> {
+ assert_eq!(queue_index, 1);
+ Ok(2)
+ }
+
+ fn set_vring_call(&mut self, queue_index: usize, _fd: &EventFd) -> Result<()> {
+ assert_eq!(queue_index, 1);
+ Ok(())
+ }
+
+ fn set_vring_kick(&mut self, queue_index: usize, _fd: &EventFd) -> Result<()> {
+ assert_eq!(queue_index, 1);
+ Ok(())
+ }
+
+ fn set_vring_err(&mut self, queue_index: usize, _fd: &EventFd) -> Result<()> {
+ assert_eq!(queue_index, 1);
+ Ok(())
+ }
+ }
+
+ #[test]
+ fn test_vring_backend_mut() {
+ let b = RwLock::new(MockBackend {});
+
+ assert_eq!(b.get_features().unwrap(), 0x1);
+ b.set_features(0x1).unwrap();
+ b.set_owner().unwrap();
+ b.reset_owner().unwrap();
+ b.set_mem_table(&[]).unwrap();
+ b.set_log_base(0x100, Some(100)).unwrap();
+ b.set_log_fd(100).unwrap();
+ b.set_vring_num(1, 256).unwrap();
+
+ let config = VringConfigData {
+ queue_max_size: 0x1000,
+ queue_size: 0x2000,
+ flags: 0x0,
+ desc_table_addr: 0x4000,
+ used_ring_addr: 0x5000,
+ avail_ring_addr: 0x6000,
+ log_addr: None,
+ };
+ b.set_vring_addr(1, &config).unwrap();
+
+ b.set_vring_base(1, 2).unwrap();
+ assert_eq!(b.get_vring_base(1).unwrap(), 2);
+
+ let eventfd = EventFd::new().unwrap();
+ b.set_vring_call(1, &eventfd).unwrap();
+ b.set_vring_kick(1, &eventfd).unwrap();
+ b.set_vring_err(1, &eventfd).unwrap();
+ }
+
+ #[test]
+ fn test_vring_config_data() {
+ let mut config = VringConfigData {
+ queue_max_size: 0x1000,
+ queue_size: 0x2000,
+ flags: 0x0,
+ desc_table_addr: 0x4000,
+ used_ring_addr: 0x5000,
+ avail_ring_addr: 0x6000,
+ log_addr: None,
+ };
+
+ assert_eq!(config.is_log_addr_valid(), true);
+ assert_eq!(config.get_log_addr(), 0);
+
+ config.flags = 0x1;
+ assert_eq!(config.is_log_addr_valid(), false);
+ assert_eq!(config.get_log_addr(), 0);
+
+ config.log_addr = Some(0x7000);
+ assert_eq!(config.is_log_addr_valid(), true);
+ assert_eq!(config.get_log_addr(), 0x7000);
+
+ config.flags = 0x0;
+ assert_eq!(config.is_log_addr_valid(), true);
+ assert_eq!(config.get_log_addr(), 0);
+ }
+}
diff --git a/src/lib.rs b/src/lib.rs
new file mode 100644
index 0000000..a755f4f
--- /dev/null
+++ b/src/lib.rs
@@ -0,0 +1,162 @@
+// Copyright (C) 2019 Alibaba Cloud. All rights reserved.
+// SPDX-License-Identifier: Apache-2.0 or BSD-3-Clause
+
+//! Virtio Vhost Backend Drivers
+//!
+//! Virtio devices use virtqueues to transport data efficiently. The first generation of virtqueue
+//! is a set of three different single-producer, single-consumer ring structures designed to store
+//! generic scatter-gather I/O. The virtio specification 1.1 introduces an alternative compact
+//! virtqueue layout named "Packed Virtqueue", which is more friendly to memory cache system and
+//! hardware implemented virtio devices. The packed virtqueue uses read-write memory, that means
+//! the memory will be both read and written by both host and guest. The new Packed Virtqueue is
+//! preferred for performance.
+//!
+//! Vhost is a mechanism to improve performance of Virtio devices by delegate data plane operations
+//! to dedicated IO service processes. Only the configuration, I/O submission notification, and I/O
+//! completion interruption are piped through the hypervisor.
+//! It uses the same virtqueue layout as Virtio to allow Vhost devices to be mapped directly to
+//! Virtio devices. This allows a Vhost device to be accessed directly by a guest OS inside a
+//! hypervisor process with an existing Virtio (PCI) driver.
+//!
+//! The initial vhost implementation is a part of the Linux kernel and uses ioctl interface to
+//! communicate with userspace applications. Dedicated kernel worker threads are created to handle
+//! IO requests from the guest.
+//!
+//! Later Vhost-user protocol is introduced to complement the ioctl interface used to control the
+//! vhost implementation in the Linux kernel. It implements the control plane needed to establish
+//! virtqueues sharing with a user space process on the same host. It uses communication over a
+//! Unix domain socket to share file descriptors in the ancillary data of the message.
+//! The protocol defines 2 sides of the communication, master and slave. Master is the application
+//! that shares its virtqueues. Slave is the consumer of the virtqueues. Master and slave can be
+//! either a client (i.e. connecting) or server (listening) in the socket communication.
+
+#![deny(missing_docs)]
+
+#[cfg_attr(feature = "vhost-user", macro_use)]
+extern crate bitflags;
+#[cfg_attr(feature = "vhost-kern", macro_use)]
+extern crate sys_util;
+
+mod backend;
+pub use backend::*;
+
+#[cfg(feature = "vhost-kern")]
+pub mod vhost_kern;
+#[cfg(feature = "vhost-user")]
+pub mod vhost_user;
+#[cfg(feature = "vhost-vsock")]
+pub mod vsock;
+
+/// Error codes for vhost operations
+#[derive(Debug)]
+pub enum Error {
+ /// Invalid operations.
+ InvalidOperation,
+ /// Invalid guest memory.
+ InvalidGuestMemory,
+ /// Invalid guest memory region.
+ InvalidGuestMemoryRegion,
+ /// Invalid queue.
+ InvalidQueue,
+ /// Invalid descriptor table address.
+ DescriptorTableAddress,
+ /// Invalid used address.
+ UsedAddress,
+ /// Invalid available address.
+ AvailAddress,
+ /// Invalid log address.
+ LogAddress,
+ #[cfg(feature = "vhost-kern")]
+ /// Error opening the vhost backend driver.
+ VhostOpen(std::io::Error),
+ #[cfg(feature = "vhost-kern")]
+ /// Error while running ioctl.
+ IoctlError(std::io::Error),
+ /// Error from IO subsystem.
+ IOError(std::io::Error),
+ #[cfg(feature = "vhost-user")]
+ /// Error from the vhost-user subsystem.
+ VhostUserProtocol(vhost_user::Error),
+}
+
+impl std::fmt::Display for Error {
+ fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
+ match self {
+ Error::InvalidOperation => write!(f, "invalid vhost operations"),
+ Error::InvalidGuestMemory => write!(f, "invalid guest memory object"),
+ Error::InvalidGuestMemoryRegion => write!(f, "invalid guest memory region"),
+ Error::InvalidQueue => write!(f, "invalid virtque"),
+ Error::DescriptorTableAddress => write!(f, "invalid virtque descriptor talbe address"),
+ Error::UsedAddress => write!(f, "invalid virtque used talbe address"),
+ Error::AvailAddress => write!(f, "invalid virtque available table address"),
+ Error::LogAddress => write!(f, "invalid virtque log address"),
+ Error::IOError(e) => write!(f, "IO error: {}", e),
+ #[cfg(feature = "vhost-kern")]
+ Error::VhostOpen(e) => write!(f, "failure in opening vhost file: {}", e),
+ #[cfg(feature = "vhost-kern")]
+ Error::IoctlError(e) => write!(f, "failure in vhost ioctl: {}", e),
+ #[cfg(feature = "vhost-user")]
+ Error::VhostUserProtocol(e) => write!(f, "vhost-user: {}", e),
+ }
+ }
+}
+
+impl std::error::Error for Error {}
+
+#[cfg(feature = "vhost-user")]
+impl std::convert::From<vhost_user::Error> for Error {
+ fn from(err: vhost_user::Error) -> Self {
+ Error::VhostUserProtocol(err)
+ }
+}
+
+/// Result of vhost operations
+pub type Result<T> = std::result::Result<T, Error>;
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_error() {
+ assert_eq!(
+ format!("{}", Error::AvailAddress),
+ "invalid virtque available table address"
+ );
+ assert_eq!(
+ format!("{}", Error::InvalidOperation),
+ "invalid vhost operations"
+ );
+ assert_eq!(
+ format!("{}", Error::InvalidGuestMemory),
+ "invalid guest memory object"
+ );
+ assert_eq!(
+ format!("{}", Error::InvalidGuestMemoryRegion),
+ "invalid guest memory region"
+ );
+ assert_eq!(format!("{}", Error::InvalidQueue), "invalid virtque");
+ assert_eq!(
+ format!("{}", Error::DescriptorTableAddress),
+ "invalid virtque descriptor talbe address"
+ );
+ assert_eq!(
+ format!("{}", Error::UsedAddress),
+ "invalid virtque used talbe address"
+ );
+ assert_eq!(
+ format!("{}", Error::LogAddress),
+ "invalid virtque log address"
+ );
+
+ assert_eq!(format!("{:?}", Error::AvailAddress), "AvailAddress");
+ }
+
+ #[cfg(feature = "vhost-user")]
+ #[test]
+ fn test_convert_from_vhost_user_error() {
+ let e: Error = vhost_user::Error::OversizedMsg.into();
+
+ assert_eq!(format!("{}", e), "vhost-user: oversized message");
+ }
+}
diff --git a/src/vhost_kern/mod.rs b/src/vhost_kern/mod.rs
new file mode 100644
index 0000000..5daca51
--- /dev/null
+++ b/src/vhost_kern/mod.rs
@@ -0,0 +1,283 @@
+// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved.
+// SPDX-License-Identifier: Apache-2.0 or BSD-3-Clause
+//
+// Portions Copyright 2017 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE-BSD-Google file.
+
+//! Traits and structs to control Linux in-kernel vhost drivers.
+//!
+//! The initial vhost implementation is a part of the Linux kernel and uses ioctl interface to
+//! communicate with userspace applications. This sub module provides ioctl based interfaces to
+//! control the in-kernel net, scsi, vsock vhost drivers.
+
+use std::os::unix::io::{AsRawFd, RawFd};
+
+use sys_util::ioctl::{ioctl, ioctl_with_mut_ref, ioctl_with_ptr, ioctl_with_ref};
+use sys_util::EventFd;
+use vm_memory::{Address, GuestAddress, GuestAddressSpace, GuestMemory, GuestUsize};
+
+use super::{
+ Error, Result, VhostBackend, VhostUserMemoryRegionInfo, VringConfigData,
+ VHOST_MAX_MEMORY_REGIONS,
+};
+
+pub mod vhost_binding;
+use self::vhost_binding::*;
+
+#[cfg(feature = "vhost-vsock")]
+pub mod vsock;
+
+#[inline]
+fn ioctl_result<T>(rc: i32, res: T) -> Result<T> {
+ if rc < 0 {
+ Err(Error::IoctlError(std::io::Error::last_os_error()))
+ } else {
+ Ok(res)
+ }
+}
+
+/// Represent an in-kernel vhost device backend.
+pub trait VhostKernBackend: AsRawFd {
+ /// Associated type to access guest memory.
+ type AS: GuestAddressSpace;
+
+ /// Get the object to access the guest's memory.
+ fn mem(&self) -> &Self::AS;
+
+ /// Check whether the ring configuration is valid.
+ fn is_valid(&self, config_data: &VringConfigData) -> bool {
+ let queue_size = config_data.queue_size;
+ if queue_size > config_data.queue_max_size
+ || queue_size == 0
+ || (queue_size & (queue_size - 1)) != 0
+ {
+ return false;
+ }
+
+ let m = self.mem().memory();
+ let desc_table_size = 16 * u64::from(queue_size) as GuestUsize;
+ let avail_ring_size = 6 + 2 * u64::from(queue_size) as GuestUsize;
+ let used_ring_size = 6 + 8 * u64::from(queue_size) as GuestUsize;
+ if GuestAddress(config_data.desc_table_addr)
+ .checked_add(desc_table_size)
+ .map_or(true, |v| !m.address_in_range(v))
+ {
+ return false;
+ }
+ if GuestAddress(config_data.avail_ring_addr)
+ .checked_add(avail_ring_size)
+ .map_or(true, |v| !m.address_in_range(v))
+ {
+ return false;
+ }
+ if GuestAddress(config_data.used_ring_addr)
+ .checked_add(used_ring_size)
+ .map_or(true, |v| !m.address_in_range(v))
+ {
+ return false;
+ }
+
+ config_data.is_log_addr_valid()
+ }
+}
+
+impl<T: VhostKernBackend> VhostBackend for T {
+ /// Get a bitmask of supported virtio/vhost features.
+ fn get_features(&self) -> Result<u64> {
+ let mut avail_features: u64 = 0;
+ // This ioctl is called on a valid vhost fd and has its return value checked.
+ let ret = unsafe { ioctl_with_mut_ref(self, VHOST_GET_FEATURES(), &mut avail_features) };
+ ioctl_result(ret, avail_features)
+ }
+
+ /// Inform the vhost subsystem which features to enable. This should be a subset of
+ /// supported features from VHOST_GET_FEATURES.
+ ///
+ /// # Arguments
+ /// * `features` - Bitmask of features to set.
+ fn set_features(&self, features: u64) -> Result<()> {
+ // This ioctl is called on a valid vhost fd and has its return value checked.
+ let ret = unsafe { ioctl_with_ref(self, VHOST_SET_FEATURES(), &features) };
+ ioctl_result(ret, ())
+ }
+
+ /// Set the current process as the owner of this file descriptor.
+ /// This must be run before any other vhost ioctls.
+ fn set_owner(&self) -> Result<()> {
+ // This ioctl is called on a valid vhost fd and has its return value checked.
+ let ret = unsafe { ioctl(self, VHOST_SET_OWNER()) };
+ ioctl_result(ret, ())
+ }
+
+ fn reset_owner(&self) -> Result<()> {
+ // This ioctl is called on a valid vhost fd and has its return value checked.
+ let ret = unsafe { ioctl(self, VHOST_RESET_OWNER()) };
+ ioctl_result(ret, ())
+ }
+
+ /// Set the guest memory mappings for vhost to use.
+ fn set_mem_table(&self, regions: &[VhostUserMemoryRegionInfo]) -> Result<()> {
+ if regions.is_empty() || regions.len() > VHOST_MAX_MEMORY_REGIONS {
+ return Err(Error::InvalidGuestMemory);
+ }
+
+ let mut vhost_memory = VhostMemory::new(regions.len() as u16);
+ for (index, region) in regions.iter().enumerate() {
+ vhost_memory.set_region(
+ index as u32,
+ &vhost_memory_region {
+ guest_phys_addr: region.guest_phys_addr,
+ memory_size: region.memory_size,
+ userspace_addr: region.userspace_addr,
+ flags_padding: 0u64,
+ },
+ )?;
+ }
+
+ // This ioctl is called with a pointer that is valid for the lifetime
+ // of this function. The kernel will make its own copy of the memory
+ // tables. As always, check the return value.
+ let ret = unsafe { ioctl_with_ptr(self, VHOST_SET_MEM_TABLE(), vhost_memory.as_ptr()) };
+ ioctl_result(ret, ())
+ }
+
+ /// Set base address for page modification logging.
+ ///
+ /// # Arguments
+ /// * `base` - Base address for page modification logging.
+ fn set_log_base(&self, base: u64, fd: Option<RawFd>) -> Result<()> {
+ if fd.is_some() {
+ return Err(Error::LogAddress);
+ }
+
+ // This ioctl is called on a valid vhost fd and has its return value checked.
+ let ret = unsafe { ioctl_with_ref(self, VHOST_SET_LOG_BASE(), &base) };
+ ioctl_result(ret, ())
+ }
+
+ /// Specify an eventfd file descriptor to signal on log write.
+ fn set_log_fd(&self, fd: RawFd) -> Result<()> {
+ // This ioctl is called on a valid vhost fd and has its return value checked.
+ let val: i32 = fd;
+ let ret = unsafe { ioctl_with_ref(self, VHOST_SET_LOG_FD(), &val) };
+ ioctl_result(ret, ())
+ }
+
+ /// Set the number of descriptors in the vring.
+ ///
+ /// # Arguments
+ /// * `queue_index` - Index of the queue to set descriptor count for.
+ /// * `num` - Number of descriptors in the queue.
+ fn set_vring_num(&self, queue_index: usize, num: u16) -> Result<()> {
+ let vring_state = vhost_vring_state {
+ index: queue_index as u32,
+ num: u32::from(num),
+ };
+
+ // This ioctl is called on a valid vhost fd and has its return value checked.
+ let ret = unsafe { ioctl_with_ref(self, VHOST_SET_VRING_NUM(), &vring_state) };
+ ioctl_result(ret, ())
+ }
+
+ /// Set the addresses for a given vring.
+ ///
+ /// # Arguments
+ /// * `queue_index` - Index of the queue to set addresses for.
+ /// * `config_data` - Vring config data.
+ fn set_vring_addr(&self, queue_index: usize, config_data: &VringConfigData) -> Result<()> {
+ if !self.is_valid(config_data) {
+ return Err(Error::InvalidQueue);
+ }
+
+ let vring_addr = vhost_vring_addr {
+ index: queue_index as u32,
+ flags: config_data.flags,
+ desc_user_addr: config_data.desc_table_addr,
+ used_user_addr: config_data.used_ring_addr,
+ avail_user_addr: config_data.avail_ring_addr,
+ log_guest_addr: config_data.get_log_addr(),
+ };
+
+ // This ioctl is called on a valid vhost fd and has its
+ // return value checked.
+ let ret = unsafe { ioctl_with_ref(self, VHOST_SET_VRING_ADDR(), &vring_addr) };
+ ioctl_result(ret, ())
+ }
+
+ /// Set the first index to look for available descriptors.
+ ///
+ /// # Arguments
+ /// * `queue_index` - Index of the queue to modify.
+ /// * `num` - Index where available descriptors start.
+ fn set_vring_base(&self, queue_index: usize, base: u16) -> Result<()> {
+ let vring_state = vhost_vring_state {
+ index: queue_index as u32,
+ num: u32::from(base),
+ };
+
+ // This ioctl is called on a valid vhost fd and has its return value checked.
+ let ret = unsafe { ioctl_with_ref(self, VHOST_SET_VRING_BASE(), &vring_state) };
+ ioctl_result(ret, ())
+ }
+
+ /// Get a bitmask of supported virtio/vhost features.
+ fn get_vring_base(&self, queue_index: usize) -> Result<u32> {
+ let vring_state = vhost_vring_state {
+ index: queue_index as u32,
+ num: 0,
+ };
+ // This ioctl is called on a valid vhost fd and has its return value checked.
+ let ret = unsafe { ioctl_with_ref(self, VHOST_GET_VRING_BASE(), &vring_state) };
+ ioctl_result(ret, vring_state.num)
+ }
+
+ /// Set the eventfd to trigger when buffers have been used by the host.
+ ///
+ /// # Arguments
+ /// * `queue_index` - Index of the queue to modify.
+ /// * `fd` - EventFd to trigger.
+ fn set_vring_call(&self, queue_index: usize, fd: &EventFd) -> Result<()> {
+ let vring_file = vhost_vring_file {
+ index: queue_index as u32,
+ fd: fd.as_raw_fd(),
+ };
+
+ // This ioctl is called on a valid vhost fd and has its return value checked.
+ let ret = unsafe { ioctl_with_ref(self, VHOST_SET_VRING_CALL(), &vring_file) };
+ ioctl_result(ret, ())
+ }
+
+ /// Set the eventfd that will be signaled by the guest when buffers are
+ /// available for the host to process.
+ ///
+ /// # Arguments
+ /// * `queue_index` - Index of the queue to modify.
+ /// * `fd` - EventFd that will be signaled from guest.
+ fn set_vring_kick(&self, queue_index: usize, fd: &EventFd) -> Result<()> {
+ let vring_file = vhost_vring_file {
+ index: queue_index as u32,
+ fd: fd.as_raw_fd(),
+ };
+
+ // This ioctl is called on a valid vhost fd and has its return value checked.
+ let ret = unsafe { ioctl_with_ref(self, VHOST_SET_VRING_KICK(), &vring_file) };
+ ioctl_result(ret, ())
+ }
+
+ /// Set the eventfd to signal an error from the vhost backend.
+ ///
+ /// # Arguments
+ /// * `queue_index` - Index of the queue to modify.
+ /// * `fd` - EventFd that will be signaled from the backend.
+ fn set_vring_err(&self, queue_index: usize, fd: &EventFd) -> Result<()> {
+ let vring_file = vhost_vring_file {
+ index: queue_index as u32,
+ fd: fd.as_raw_fd(),
+ };
+
+ // This ioctl is called on a valid vhost fd and has its return value checked.
+ let ret = unsafe { ioctl_with_ref(self, VHOST_SET_VRING_ERR(), &vring_file) };
+ ioctl_result(ret, ())
+ }
+}
diff --git a/src/vhost_kern/vhost_binding.rs b/src/vhost_kern/vhost_binding.rs
new file mode 100644
index 0000000..57ae698
--- /dev/null
+++ b/src/vhost_kern/vhost_binding.rs
@@ -0,0 +1,406 @@
+// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved.
+// SPDX-License-Identifier: Apache-2.0 or BSD-3-Clause
+//
+// Portions Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+//
+// Portions Copyright 2017 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE-BSD-Google file.
+
+/* Auto-generated by bindgen then manually edited for simplicity */
+
+#![allow(non_upper_case_globals)]
+#![allow(non_camel_case_types)]
+#![allow(non_snake_case)]
+#![allow(missing_docs)]
+#![allow(clippy::missing_safety_doc)]
+
+use crate::{Error, Result};
+use std::os::raw;
+
+pub const VHOST: raw::c_uint = 0xaf;
+pub const VHOST_VRING_F_LOG: raw::c_uint = 0;
+pub const VHOST_ACCESS_RO: raw::c_uint = 1;
+pub const VHOST_ACCESS_WO: raw::c_uint = 2;
+pub const VHOST_ACCESS_RW: raw::c_uint = 3;
+pub const VHOST_IOTLB_MISS: raw::c_uint = 1;
+pub const VHOST_IOTLB_UPDATE: raw::c_uint = 2;
+pub const VHOST_IOTLB_INVALIDATE: raw::c_uint = 3;
+pub const VHOST_IOTLB_ACCESS_FAIL: raw::c_uint = 4;
+pub const VHOST_IOTLB_MSG: raw::c_uint = 1;
+pub const VHOST_PAGE_SIZE: raw::c_uint = 4096;
+pub const VHOST_VIRTIO: raw::c_uint = 175;
+pub const VHOST_VRING_LITTLE_ENDIAN: raw::c_uint = 0;
+pub const VHOST_VRING_BIG_ENDIAN: raw::c_uint = 1;
+pub const VHOST_F_LOG_ALL: raw::c_uint = 26;
+pub const VHOST_NET_F_VIRTIO_NET_HDR: raw::c_uint = 27;
+pub const VHOST_SCSI_ABI_VERSION: raw::c_uint = 1;
+
+ioctl_ior_nr!(VHOST_GET_FEATURES, VHOST, 0x00, raw::c_ulonglong);
+ioctl_iow_nr!(VHOST_SET_FEATURES, VHOST, 0x00, raw::c_ulonglong);
+ioctl_io_nr!(VHOST_SET_OWNER, VHOST, 0x01);
+ioctl_io_nr!(VHOST_RESET_OWNER, VHOST, 0x02);
+ioctl_iow_nr!(VHOST_SET_MEM_TABLE, VHOST, 0x03, vhost_memory);
+ioctl_iow_nr!(VHOST_SET_LOG_BASE, VHOST, 0x04, raw::c_ulonglong);
+ioctl_iow_nr!(VHOST_SET_LOG_FD, VHOST, 0x07, raw::c_int);
+ioctl_iow_nr!(VHOST_SET_VRING_NUM, VHOST, 0x10, vhost_vring_state);
+ioctl_iow_nr!(VHOST_SET_VRING_ADDR, VHOST, 0x11, vhost_vring_addr);
+ioctl_iow_nr!(VHOST_SET_VRING_BASE, VHOST, 0x12, vhost_vring_state);
+ioctl_iowr_nr!(VHOST_GET_VRING_BASE, VHOST, 0x12, vhost_vring_state);
+ioctl_iow_nr!(VHOST_SET_VRING_KICK, VHOST, 0x20, vhost_vring_file);
+ioctl_iow_nr!(VHOST_SET_VRING_CALL, VHOST, 0x21, vhost_vring_file);
+ioctl_iow_nr!(VHOST_SET_VRING_ERR, VHOST, 0x22, vhost_vring_file);
+ioctl_iow_nr!(VHOST_NET_SET_BACKEND, VHOST, 0x30, vhost_vring_file);
+ioctl_iow_nr!(VHOST_SCSI_SET_ENDPOINT, VHOST, 0x40, vhost_scsi_target);
+ioctl_iow_nr!(VHOST_SCSI_CLEAR_ENDPOINT, VHOST, 0x41, vhost_scsi_target);
+ioctl_iow_nr!(VHOST_SCSI_GET_ABI_VERSION, VHOST, 0x42, raw::c_int);
+ioctl_iow_nr!(VHOST_SCSI_SET_EVENTS_MISSED, VHOST, 0x43, raw::c_uint);
+ioctl_iow_nr!(VHOST_SCSI_GET_EVENTS_MISSED, VHOST, 0x44, raw::c_uint);
+ioctl_iow_nr!(VHOST_VSOCK_SET_GUEST_CID, VHOST, 0x60, raw::c_ulonglong);
+ioctl_iow_nr!(VHOST_VSOCK_SET_RUNNING, VHOST, 0x61, raw::c_int);
+
+#[repr(C)]
+#[derive(Default)]
+pub struct __IncompleteArrayField<T>(::std::marker::PhantomData<T>);
+
+impl<T> __IncompleteArrayField<T> {
+ #[inline]
+ pub fn new() -> Self {
+ __IncompleteArrayField(::std::marker::PhantomData)
+ }
+
+ #[inline]
+ #[allow(clippy::trivially_copy_pass_by_ref)]
+ #[allow(clippy::useless_transmute)]
+ pub unsafe fn as_ptr(&self) -> *const T {
+ ::std::mem::transmute(self)
+ }
+
+ #[inline]
+ #[allow(clippy::useless_transmute)]
+ pub unsafe fn as_mut_ptr(&mut self) -> *mut T {
+ ::std::mem::transmute(self)
+ }
+
+ #[inline]
+ pub unsafe fn as_slice(&self, len: usize) -> &[T] {
+ ::std::slice::from_raw_parts(self.as_ptr(), len)
+ }
+
+ #[inline]
+ pub unsafe fn as_mut_slice(&mut self, len: usize) -> &mut [T] {
+ ::std::slice::from_raw_parts_mut(self.as_mut_ptr(), len)
+ }
+}
+
+impl<T> ::std::fmt::Debug for __IncompleteArrayField<T> {
+ fn fmt(&self, fmt: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
+ fmt.write_str("__IncompleteArrayField")
+ }
+}
+
+impl<T> ::std::clone::Clone for __IncompleteArrayField<T> {
+ #[inline]
+ fn clone(&self) -> Self {
+ Self::new()
+ }
+}
+
+impl<T> ::std::marker::Copy for __IncompleteArrayField<T> {}
+
+#[repr(C)]
+#[derive(Debug, Default, Copy, Clone)]
+pub struct vhost_vring_state {
+ pub index: raw::c_uint,
+ pub num: raw::c_uint,
+}
+
+#[repr(C)]
+#[derive(Debug, Default, Copy, Clone)]
+pub struct vhost_vring_file {
+ pub index: raw::c_uint,
+ pub fd: raw::c_int,
+}
+
+#[repr(C)]
+#[derive(Debug, Default, Copy, Clone)]
+pub struct vhost_vring_addr {
+ pub index: raw::c_uint,
+ pub flags: raw::c_uint,
+ pub desc_user_addr: raw::c_ulonglong,
+ pub used_user_addr: raw::c_ulonglong,
+ pub avail_user_addr: raw::c_ulonglong,
+ pub log_guest_addr: raw::c_ulonglong,
+}
+
+#[repr(C)]
+#[derive(Debug, Default, Copy, Clone)]
+pub struct vhost_iotlb_msg {
+ pub iova: raw::c_ulonglong,
+ pub size: raw::c_ulonglong,
+ pub uaddr: raw::c_ulonglong,
+ pub perm: raw::c_uchar,
+ pub type_: raw::c_uchar,
+}
+
+#[repr(C)]
+#[derive(Copy, Clone)]
+pub struct vhost_msg {
+ pub type_: raw::c_int,
+ pub __bindgen_anon_1: vhost_msg__bindgen_ty_1,
+}
+
+impl Default for vhost_msg {
+ fn default() -> Self {
+ unsafe { ::std::mem::zeroed() }
+ }
+}
+
+#[repr(C)]
+#[derive(Copy, Clone)]
+pub union vhost_msg__bindgen_ty_1 {
+ pub iotlb: vhost_iotlb_msg,
+ pub padding: [raw::c_uchar; 64usize],
+ _bindgen_union_align: [u64; 8usize],
+}
+
+impl Default for vhost_msg__bindgen_ty_1 {
+ fn default() -> Self {
+ unsafe { ::std::mem::zeroed() }
+ }
+}
+
+#[repr(C)]
+#[derive(Debug, Default, Copy, Clone)]
+pub struct vhost_memory_region {
+ pub guest_phys_addr: raw::c_ulonglong,
+ pub memory_size: raw::c_ulonglong,
+ pub userspace_addr: raw::c_ulonglong,
+ pub flags_padding: raw::c_ulonglong,
+}
+
+#[repr(C)]
+#[derive(Debug, Default, Clone)]
+pub struct vhost_memory {
+ pub nregions: raw::c_uint,
+ pub padding: raw::c_uint,
+ pub regions: __IncompleteArrayField<vhost_memory_region>,
+ __force_alignment: [u64; 0],
+}
+
+#[repr(C)]
+#[derive(Copy, Clone)]
+pub struct vhost_scsi_target {
+ pub abi_version: raw::c_int,
+ pub vhost_wwpn: [raw::c_char; 224usize],
+ pub vhost_tpgt: raw::c_ushort,
+ pub reserved: raw::c_ushort,
+}
+
+impl Default for vhost_scsi_target {
+ fn default() -> Self {
+ unsafe { ::std::mem::zeroed() }
+ }
+}
+
+/// Helper to support vhost::set_mem_table()
+pub struct VhostMemory {
+ buf: Vec<vhost_memory>,
+}
+
+impl VhostMemory {
+ // Limit number of regions to u16 to simplify error handling
+ pub fn new(entries: u16) -> Self {
+ let size = std::mem::size_of::<vhost_memory_region>() * entries as usize;
+ let count = (size + 2 * std::mem::size_of::<vhost_memory>() - 1)
+ / std::mem::size_of::<vhost_memory>();
+ let mut buf: Vec<vhost_memory> = vec![Default::default(); count];
+ buf[0].nregions = u32::from(entries);
+ VhostMemory { buf }
+ }
+
+ pub fn as_ptr(&self) -> *const char {
+ &self.buf[0] as *const vhost_memory as *const char
+ }
+
+ pub fn get_header(&self) -> &vhost_memory {
+ &self.buf[0]
+ }
+
+ pub fn get_region(&self, index: u32) -> Option<&vhost_memory_region> {
+ if index >= self.buf[0].nregions {
+ return None;
+ }
+ // Safe because we have allocated enough space nregions
+ let regions = unsafe { self.buf[0].regions.as_slice(self.buf[0].nregions as usize) };
+ Some(&regions[index as usize])
+ }
+
+ pub fn set_region(&mut self, index: u32, region: &vhost_memory_region) -> Result<()> {
+ if index >= self.buf[0].nregions {
+ return Err(Error::InvalidGuestMemory);
+ }
+ // Safe because we have allocated enough space nregions and checked the index.
+ let regions = unsafe { self.buf[0].regions.as_mut_slice(index as usize + 1) };
+ regions[index as usize] = *region;
+ Ok(())
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn bindgen_test_layout_vhost_vring_state() {
+ assert_eq!(
+ ::std::mem::size_of::<vhost_vring_state>(),
+ 8usize,
+ concat!("Size of: ", stringify!(vhost_vring_state))
+ );
+ assert_eq!(
+ ::std::mem::align_of::<vhost_vring_state>(),
+ 4usize,
+ concat!("Alignment of ", stringify!(vhost_vring_state))
+ );
+ }
+
+ #[test]
+ fn bindgen_test_layout_vhost_vring_file() {
+ assert_eq!(
+ ::std::mem::size_of::<vhost_vring_file>(),
+ 8usize,
+ concat!("Size of: ", stringify!(vhost_vring_file))
+ );
+ assert_eq!(
+ ::std::mem::align_of::<vhost_vring_file>(),
+ 4usize,
+ concat!("Alignment of ", stringify!(vhost_vring_file))
+ );
+ }
+
+ #[test]
+ fn bindgen_test_layout_vhost_vring_addr() {
+ assert_eq!(
+ ::std::mem::size_of::<vhost_vring_addr>(),
+ 40usize,
+ concat!("Size of: ", stringify!(vhost_vring_addr))
+ );
+ assert_eq!(
+ ::std::mem::align_of::<vhost_vring_addr>(),
+ 8usize,
+ concat!("Alignment of ", stringify!(vhost_vring_addr))
+ );
+ }
+
+ #[test]
+ fn bindgen_test_layout_vhost_msg__bindgen_ty_1() {
+ assert_eq!(
+ ::std::mem::size_of::<vhost_msg__bindgen_ty_1>(),
+ 64usize,
+ concat!("Size of: ", stringify!(vhost_msg__bindgen_ty_1))
+ );
+ assert_eq!(
+ ::std::mem::align_of::<vhost_msg__bindgen_ty_1>(),
+ 8usize,
+ concat!("Alignment of ", stringify!(vhost_msg__bindgen_ty_1))
+ );
+ }
+
+ #[test]
+ fn bindgen_test_layout_vhost_msg() {
+ assert_eq!(
+ ::std::mem::size_of::<vhost_msg>(),
+ 72usize,
+ concat!("Size of: ", stringify!(vhost_msg))
+ );
+ assert_eq!(
+ ::std::mem::align_of::<vhost_msg>(),
+ 8usize,
+ concat!("Alignment of ", stringify!(vhost_msg))
+ );
+ }
+
+ #[test]
+ fn bindgen_test_layout_vhost_memory_region() {
+ assert_eq!(
+ ::std::mem::size_of::<vhost_memory_region>(),
+ 32usize,
+ concat!("Size of: ", stringify!(vhost_memory_region))
+ );
+ assert_eq!(
+ ::std::mem::align_of::<vhost_memory_region>(),
+ 8usize,
+ concat!("Alignment of ", stringify!(vhost_memory_region))
+ );
+ }
+
+ #[test]
+ fn bindgen_test_layout_vhost_memory() {
+ assert_eq!(
+ ::std::mem::size_of::<vhost_memory>(),
+ 8usize,
+ concat!("Size of: ", stringify!(vhost_memory))
+ );
+ assert_eq!(
+ ::std::mem::align_of::<vhost_memory>(),
+ 8usize,
+ concat!("Alignment of ", stringify!(vhost_memory))
+ );
+ }
+
+ #[test]
+ fn bindgen_test_layout_vhost_iotlb_msg() {
+ assert_eq!(
+ ::std::mem::size_of::<vhost_iotlb_msg>(),
+ 32usize,
+ concat!("Size of: ", stringify!(vhost_iotlb_msg))
+ );
+ assert_eq!(
+ ::std::mem::align_of::<vhost_iotlb_msg>(),
+ 8usize,
+ concat!("Alignment of ", stringify!(vhost_iotlb_msg))
+ );
+ }
+
+ #[test]
+ fn bindgen_test_layout_vhost_scsi_target() {
+ assert_eq!(
+ ::std::mem::size_of::<vhost_scsi_target>(),
+ 232usize,
+ concat!("Size of: ", stringify!(vhost_scsi_target))
+ );
+ assert_eq!(
+ ::std::mem::align_of::<vhost_scsi_target>(),
+ 4usize,
+ concat!("Alignment of ", stringify!(vhost_scsi_target))
+ );
+ }
+
+ #[test]
+ fn test_vhostmemory() {
+ let mut obj = VhostMemory::new(2);
+ let region = vhost_memory_region {
+ guest_phys_addr: 0x1000u64,
+ memory_size: 0x2000u64,
+ userspace_addr: 0x300000u64,
+ flags_padding: 0u64,
+ };
+ assert!(obj.get_region(2).is_none());
+
+ {
+ let header = obj.get_header();
+ assert_eq!(header.nregions, 2u32);
+ }
+ {
+ assert!(obj.set_region(0, &region).is_ok());
+ assert!(obj.set_region(1, &region).is_ok());
+ assert!(obj.set_region(2, &region).is_err());
+ }
+
+ let region1 = obj.get_region(1).unwrap();
+ assert_eq!(region1.guest_phys_addr, 0x1000u64);
+ assert_eq!(region1.memory_size, 0x2000u64);
+ assert_eq!(region1.userspace_addr, 0x300000u64);
+ }
+}
diff --git a/src/vhost_kern/vsock.rs b/src/vhost_kern/vsock.rs
new file mode 100644
index 0000000..388d500
--- /dev/null
+++ b/src/vhost_kern/vsock.rs
@@ -0,0 +1,184 @@
+// Copyright (C) 2019 Alibaba Cloud. All rights reserved.
+// SPDX-License-Identifier: Apache-2.0 or BSD-3-Clause
+//
+// Copyright 2017 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE-BSD-Google file.
+
+//! Kernel-based vhost-vsock backend.
+
+use std::fs::{File, OpenOptions};
+use std::os::unix::fs::OpenOptionsExt;
+use std::os::unix::io::{AsRawFd, RawFd};
+
+use sys_util::ioctl_with_ref;
+use vm_memory::GuestAddressSpace;
+
+use super::vhost_binding::{VHOST_VSOCK_SET_GUEST_CID, VHOST_VSOCK_SET_RUNNING};
+use super::{ioctl_result, Error, Result, VhostKernBackend};
+use crate::vsock::VhostVsock;
+
+const VHOST_PATH: &str = "/dev/vhost-vsock";
+
+/// Handle for running VHOST_VSOCK ioctls.
+pub struct Vsock<AS: GuestAddressSpace> {
+ fd: File,
+ mem: AS,
+}
+
+impl<AS: GuestAddressSpace> Vsock<AS> {
+ /// Open a handle to a new VHOST-VSOCK instance.
+ pub fn new(mem: AS) -> Result<Self> {
+ Ok(Vsock {
+ fd: OpenOptions::new()
+ .read(true)
+ .write(true)
+ .custom_flags(libc::O_CLOEXEC | libc::O_NONBLOCK)
+ .open(VHOST_PATH)
+ .map_err(Error::VhostOpen)?,
+ mem,
+ })
+ }
+
+ fn set_running(&self, running: bool) -> Result<()> {
+ let on: ::std::os::raw::c_int = if running { 1 } else { 0 };
+ let ret = unsafe { ioctl_with_ref(&self.fd, VHOST_VSOCK_SET_RUNNING(), &on) };
+ ioctl_result(ret, ())
+ }
+}
+
+impl<AS: GuestAddressSpace> VhostVsock for Vsock<AS> {
+ fn set_guest_cid(&self, cid: u64) -> Result<()> {
+ let ret = unsafe { ioctl_with_ref(&self.fd, VHOST_VSOCK_SET_GUEST_CID(), &cid) };
+ ioctl_result(ret, ())
+ }
+
+ fn start(&self) -> Result<()> {
+ self.set_running(true)
+ }
+
+ fn stop(&self) -> Result<()> {
+ self.set_running(false)
+ }
+}
+
+impl<AS: GuestAddressSpace> VhostKernBackend for Vsock<AS> {
+ type AS = AS;
+
+ fn mem(&self) -> &Self::AS {
+ &self.mem
+ }
+}
+
+impl<AS: GuestAddressSpace> AsRawFd for Vsock<AS> {
+ fn as_raw_fd(&self) -> RawFd {
+ self.fd.as_raw_fd()
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use sys_util::EventFd;
+ use vm_memory::{GuestAddress, GuestMemory, GuestMemoryMmap};
+
+ use super::*;
+ use crate::{VhostBackend, VhostUserMemoryRegionInfo, VringConfigData};
+
+ // Ignore all tests because /dev/vhost-vsock is unavailable in Chrome OS chroot.
+ #[test]
+ #[ignore]
+ fn test_vsock_new_device() {
+ let m = GuestMemoryMmap::from_ranges(&[(GuestAddress(0), 0x10_0000)]).unwrap();
+ let vsock = Vsock::new(&m).unwrap();
+
+ assert!(vsock.as_raw_fd() >= 0);
+ assert!(vsock.mem().find_region(GuestAddress(0x100)).is_some());
+ assert!(vsock.mem().find_region(GuestAddress(0x10_0000)).is_none());
+ }
+
+ #[test]
+ #[ignore]
+ fn test_vsock_is_valid() {
+ let m = GuestMemoryMmap::from_ranges(&[(GuestAddress(0), 0x10_0000)]).unwrap();
+ let vsock = Vsock::new(&m).unwrap();
+
+ let mut config = VringConfigData {
+ queue_max_size: 32,
+ queue_size: 32,
+ flags: 0,
+ desc_table_addr: 0x1000,
+ used_ring_addr: 0x2000,
+ avail_ring_addr: 0x3000,
+ log_addr: None,
+ };
+ assert_eq!(vsock.is_valid(&config), true);
+
+ config.queue_size = 0;
+ assert_eq!(vsock.is_valid(&config), false);
+ config.queue_size = 31;
+ assert_eq!(vsock.is_valid(&config), false);
+ config.queue_size = 33;
+ assert_eq!(vsock.is_valid(&config), false);
+ }
+
+ #[test]
+ #[ignore]
+ fn test_vsock_ioctls() {
+ let m = GuestMemoryMmap::from_ranges(&[(GuestAddress(0), 0x10_0000)]).unwrap();
+ let vsock = Vsock::new(&m).unwrap();
+
+ let features = vsock.get_features().unwrap();
+ vsock.set_features(features).unwrap();
+
+ vsock.set_owner().unwrap();
+
+ vsock.set_mem_table(&[]).unwrap_err();
+
+ /*
+ let region = VhostUserMemoryRegionInfo {
+ guest_phys_addr: 0x0,
+ memory_size: 0x10_0000,
+ userspace_addr: 0,
+ mmap_offset: 0,
+ mmap_handle: -1,
+ };
+ vsock.set_mem_table(&[region]).unwrap_err();
+ */
+
+ let region = VhostUserMemoryRegionInfo {
+ guest_phys_addr: 0x0,
+ memory_size: 0x10_0000,
+ userspace_addr: m.get_host_address(GuestAddress(0x0)).unwrap() as u64,
+ mmap_offset: 0,
+ mmap_handle: -1,
+ };
+ vsock.set_mem_table(&[region]).unwrap();
+
+ vsock.set_log_base(0x4000, Some(1)).unwrap_err();
+ vsock.set_log_base(0x4000, None).unwrap();
+
+ let eventfd = EventFd::new().unwrap();
+ vsock.set_log_fd(eventfd.as_raw_fd()).unwrap();
+
+ vsock.set_vring_num(0, 32).unwrap();
+
+ let config = VringConfigData {
+ queue_max_size: 32,
+ queue_size: 32,
+ flags: 0,
+ desc_table_addr: 0x1000,
+ used_ring_addr: 0x2000,
+ avail_ring_addr: 0x3000,
+ log_addr: None,
+ };
+ vsock.set_vring_addr(0, &config).unwrap();
+ vsock.set_vring_base(0, 1).unwrap();
+ vsock.set_vring_call(0, &eventfd).unwrap();
+ vsock.set_vring_kick(0, &eventfd).unwrap();
+ vsock.set_vring_err(0, &eventfd).unwrap();
+ assert_eq!(vsock.get_vring_base(0).unwrap(), 1);
+ vsock.set_guest_cid(0xdead).unwrap();
+ //vsock.start().unwrap();
+ //vsock.stop().unwrap();
+ }
+}
diff --git a/src/vhost_user/connection.rs b/src/vhost_user/connection.rs
new file mode 100644
index 0000000..f92db45
--- /dev/null
+++ b/src/vhost_user/connection.rs
@@ -0,0 +1,858 @@
+// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved.
+// SPDX-License-Identifier: Apache-2.0
+
+//! Structs for Unix Domain Socket listener and endpoint.
+
+#![allow(dead_code)]
+
+use std::io::ErrorKind;
+use std::marker::PhantomData;
+use std::os::unix::io::{AsRawFd, RawFd};
+use std::os::unix::net::{UnixListener, UnixStream};
+use std::path::{Path, PathBuf};
+use std::{mem, slice};
+
+use libc::{c_void, iovec};
+use sys_util::ScmSocket;
+
+use super::message::*;
+use super::{Error, Result};
+
+/// Unix domain socket listener for accepting incoming connections.
+pub struct Listener {
+ fd: UnixListener,
+ path: PathBuf,
+}
+
+impl Listener {
+ /// Create a unix domain socket listener.
+ ///
+ /// # Return:
+ /// * - the new Listener object on success.
+ /// * - SocketError: failed to create listener socket.
+ pub fn new<P: AsRef<Path>>(path: P, unlink: bool) -> Result<Self> {
+ if unlink {
+ let _ = std::fs::remove_file(&path);
+ }
+ let fd = UnixListener::bind(&path).map_err(Error::SocketError)?;
+ Ok(Listener {
+ fd,
+ path: path.as_ref().to_owned(),
+ })
+ }
+
+ /// Accept an incoming connection.
+ ///
+ /// # Return:
+ /// * - Some(UnixStream): new UnixStream object if new incoming connection is available.
+ /// * - None: no incoming connection available.
+ /// * - SocketError: errors from accept().
+ pub fn accept(&self) -> Result<Option<UnixStream>> {
+ loop {
+ match self.fd.accept() {
+ Ok((socket, _addr)) => return Ok(Some(socket)),
+ Err(e) => {
+ match e.kind() {
+ // No incoming connection available.
+ ErrorKind::WouldBlock => return Ok(None),
+ // New connection closed by peer.
+ ErrorKind::ConnectionAborted => return Ok(None),
+ // Interrupted by signals, retry
+ ErrorKind::Interrupted => continue,
+ _ => return Err(Error::SocketError(e)),
+ }
+ }
+ }
+ }
+ }
+
+ /// Change blocking status on the listener.
+ ///
+ /// # Return:
+ /// * - () on success.
+ /// * - SocketError: failure from set_nonblocking().
+ pub fn set_nonblocking(&self, block: bool) -> Result<()> {
+ self.fd.set_nonblocking(block).map_err(Error::SocketError)
+ }
+}
+
+impl AsRawFd for Listener {
+ fn as_raw_fd(&self) -> RawFd {
+ self.fd.as_raw_fd()
+ }
+}
+
+impl Drop for Listener {
+ fn drop(&mut self) {
+ let _ = std::fs::remove_file(&self.path);
+ }
+}
+
+/// Unix domain socket endpoint for vhost-user connection.
+pub(super) struct Endpoint<R: Req> {
+ sock: UnixStream,
+ _r: PhantomData<R>,
+}
+
+impl<R: Req> Endpoint<R> {
+ /// Create a new stream by connecting to server at `str`.
+ ///
+ /// # Return:
+ /// * - the new Endpoint object on success.
+ /// * - SocketConnect: failed to connect to peer.
+ pub fn connect<P: AsRef<Path>>(path: P) -> Result<Self> {
+ let sock = UnixStream::connect(path).map_err(Error::SocketConnect)?;
+ Ok(Self::from_stream(sock))
+ }
+
+ /// Create an endpoint from a stream object.
+ pub fn from_stream(sock: UnixStream) -> Self {
+ Endpoint {
+ sock,
+ _r: PhantomData,
+ }
+ }
+
+ /// Sends bytes from scatter-gather vectors over the socket with optional attached file
+ /// descriptors.
+ ///
+ /// # Return:
+ /// * - number of bytes sent on success
+ /// * - SocketRetry: temporary error caused by signals or short of resources.
+ /// * - SocketBroken: the underline socket is broken.
+ /// * - SocketError: other socket related errors.
+ pub fn send_iovec(&mut self, iovs: &[&[u8]], fds: Option<&[RawFd]>) -> Result<usize> {
+ let rfds = match fds {
+ Some(rfds) => rfds,
+ _ => &[],
+ };
+ self.sock.send_bufs_with_fds(iovs, rfds).map_err(Into::into)
+ }
+
+ /// Sends all bytes from scatter-gather vectors over the socket with optional attached file
+ /// descriptors. Will loop until all data has been transfered.
+ ///
+ /// # Return:
+ /// * - number of bytes sent on success
+ /// * - SocketBroken: the underline socket is broken.
+ /// * - SocketError: other socket related errors.
+ pub fn send_iovec_all(&mut self, iovs: &[&[u8]], fds: Option<&[RawFd]>) -> Result<usize> {
+ let mut data_sent = 0;
+ let mut data_total = 0;
+ let iov_lens: Vec<usize> = iovs.iter().map(|iov| iov.len()).collect();
+ for len in &iov_lens {
+ data_total += len;
+ }
+
+ while (data_total - data_sent) > 0 {
+ let (nr_skip, offset) = get_sub_iovs_offset(&iov_lens, data_sent);
+ let iov = &iovs[nr_skip][offset..];
+
+ let data = &[&[iov], &iovs[(nr_skip + 1)..]].concat();
+ let sfds = if data_sent == 0 { fds } else { None };
+
+ let sent = self.send_iovec(data, sfds);
+ match sent {
+ Ok(0) => return Ok(data_sent),
+ Ok(n) => data_sent += n,
+ Err(e) => match e {
+ Error::SocketRetry(_) => {}
+ _ => return Err(e),
+ },
+ }
+ }
+ Ok(data_sent)
+ }
+
+ /// Sends bytes from a slice over the socket with optional attached file descriptors.
+ ///
+ /// # Return:
+ /// * - number of bytes sent on success
+ /// * - SocketRetry: temporary error caused by signals or short of resources.
+ /// * - SocketBroken: the underline socket is broken.
+ /// * - SocketError: other socket related errors.
+ pub fn send_slice(&mut self, data: &[u8], fds: Option<&[RawFd]>) -> Result<usize> {
+ self.send_iovec(&[data], fds)
+ }
+
+ /// Sends a header-only message with optional attached file descriptors.
+ ///
+ /// # Return:
+ /// * - number of bytes sent on success
+ /// * - SocketRetry: temporary error caused by signals or short of resources.
+ /// * - SocketBroken: the underline socket is broken.
+ /// * - SocketError: other socket related errors.
+ /// * - PartialMessage: received a partial message.
+ pub fn send_header(
+ &mut self,
+ hdr: &VhostUserMsgHeader<R>,
+ fds: Option<&[RawFd]>,
+ ) -> Result<()> {
+ // Safe because there can't be other mutable referance to hdr.
+ let iovs = unsafe {
+ [slice::from_raw_parts(
+ hdr as *const VhostUserMsgHeader<R> as *const u8,
+ mem::size_of::<VhostUserMsgHeader<R>>(),
+ )]
+ };
+ let bytes = self.send_iovec_all(&iovs[..], fds)?;
+ if bytes != mem::size_of::<VhostUserMsgHeader<R>>() {
+ return Err(Error::PartialMessage);
+ }
+ Ok(())
+ }
+
+ /// Send a message with header and body. Optional file descriptors may be attached to
+ /// the message.
+ ///
+ /// # Return:
+ /// * - number of bytes sent on success
+ /// * - SocketRetry: temporary error caused by signals or short of resources.
+ /// * - SocketBroken: the underline socket is broken.
+ /// * - SocketError: other socket related errors.
+ /// * - PartialMessage: received a partial message.
+ pub fn send_message<T: Sized>(
+ &mut self,
+ hdr: &VhostUserMsgHeader<R>,
+ body: &T,
+ fds: Option<&[RawFd]>,
+ ) -> Result<()> {
+ if mem::size_of::<T>() > MAX_MSG_SIZE {
+ return Err(Error::OversizedMsg);
+ }
+ // Safe because there can't be other mutable referance to hdr and body.
+ let iovs = unsafe {
+ [
+ slice::from_raw_parts(
+ hdr as *const VhostUserMsgHeader<R> as *const u8,
+ mem::size_of::<VhostUserMsgHeader<R>>(),
+ ),
+ slice::from_raw_parts(body as *const T as *const u8, mem::size_of::<T>()),
+ ]
+ };
+ let bytes = self.send_iovec_all(&iovs[..], fds)?;
+ if bytes != mem::size_of::<VhostUserMsgHeader<R>>() + mem::size_of::<T>() {
+ return Err(Error::PartialMessage);
+ }
+ Ok(())
+ }
+
+ /// Send a message with header, body and payload. Optional file descriptors
+ /// may also be attached to the message.
+ ///
+ /// # Return:
+ /// * - number of bytes sent on success
+ /// * - SocketRetry: temporary error caused by signals or short of resources.
+ /// * - SocketBroken: the underline socket is broken.
+ /// * - SocketError: other socket related errors.
+ /// * - OversizedMsg: message size is too big.
+ /// * - PartialMessage: received a partial message.
+ /// * - IncorrectFds: wrong number of attached fds.
+ pub fn send_message_with_payload<T: Sized>(
+ &mut self,
+ hdr: &VhostUserMsgHeader<R>,
+ body: &T,
+ payload: &[u8],
+ fds: Option<&[RawFd]>,
+ ) -> Result<()> {
+ let len = payload.len();
+ if mem::size_of::<T>() > MAX_MSG_SIZE {
+ return Err(Error::OversizedMsg);
+ }
+ if len > MAX_MSG_SIZE - mem::size_of::<T>() {
+ return Err(Error::OversizedMsg);
+ }
+ if let Some(fd_arr) = fds {
+ if fd_arr.len() > MAX_ATTACHED_FD_ENTRIES {
+ return Err(Error::IncorrectFds);
+ }
+ }
+
+ // Safe because there can't be other mutable reference to hdr, body and payload.
+ let iovs = unsafe {
+ [
+ slice::from_raw_parts(
+ hdr as *const VhostUserMsgHeader<R> as *const u8,
+ mem::size_of::<VhostUserMsgHeader<R>>(),
+ ),
+ slice::from_raw_parts(body as *const T as *const u8, mem::size_of::<T>()),
+ slice::from_raw_parts(payload.as_ptr() as *const u8, len),
+ ]
+ };
+ let total = mem::size_of::<VhostUserMsgHeader<R>>() + mem::size_of::<T>() + len;
+ let len = self.send_iovec_all(&iovs, fds)?;
+ if len != total {
+ return Err(Error::PartialMessage);
+ }
+ Ok(())
+ }
+
+ /// Reads bytes from the socket into the given scatter/gather vectors.
+ ///
+ /// # Return:
+ /// * - (number of bytes received, buf) on success
+ /// * - SocketRetry: temporary error caused by signals or short of resources.
+ /// * - SocketBroken: the underline socket is broken.
+ /// * - SocketError: other socket related errors.
+ pub fn recv_data(&mut self, len: usize) -> Result<(usize, Vec<u8>)> {
+ let mut rbuf = vec![0u8; len];
+ let (bytes, _) = self.sock.recv_with_fds(&mut rbuf[..], &mut [])?;
+ Ok((bytes, rbuf))
+ }
+
+ /// Reads bytes from the socket into the given scatter/gather vectors with optional attached
+ /// file descriptors.
+ ///
+ /// The underlying communication channel is a Unix domain socket in STREAM mode. It's a little
+ /// tricky to pass file descriptors through such a communication channel. Let's assume that a
+ /// sender sending a message with some file descriptors attached. To successfully receive those
+ /// attached file descriptors, the receiver must obey following rules:
+ /// 1) file descriptors are attached to a message.
+ /// 2) message(packet) boundaries must be respected on the receive side.
+ /// In other words, recvmsg() operations must not cross the packet boundary, otherwise the
+ /// attached file descriptors will get lost.
+ ///
+ /// # Return:
+ /// * - (number of bytes received, [received fds]) on success
+ /// * - SocketRetry: temporary error caused by signals or short of resources.
+ /// * - SocketBroken: the underline socket is broken.
+ /// * - SocketError: other socket related errors.
+ pub fn recv_into_iovec(&mut self, iovs: &mut [iovec]) -> Result<(usize, Option<Vec<RawFd>>)> {
+ let mut fd_array = vec![0; MAX_ATTACHED_FD_ENTRIES];
+ let (bytes, fds) = self.sock.recv_iovecs_with_fds(iovs, &mut fd_array)?;
+ let rfds = match fds {
+ 0 => None,
+ n => {
+ let mut fds = Vec::with_capacity(n);
+ fds.extend_from_slice(&fd_array[0..n]);
+ Some(fds)
+ }
+ };
+
+ Ok((bytes, rfds))
+ }
+
+ /// Reads all bytes from the socket into the given scatter/gather vectors with optional
+ /// attached file descriptors. Will loop until all data has been transfered.
+ ///
+ /// The underlying communication channel is a Unix domain socket in STREAM mode. It's a little
+ /// tricky to pass file descriptors through such a communication channel. Let's assume that a
+ /// sender sending a message with some file descriptors attached. To successfully receive those
+ /// attached file descriptors, the receiver must obey following rules:
+ /// 1) file descriptors are attached to a message.
+ /// 2) message(packet) boundaries must be respected on the receive side.
+ /// In other words, recvmsg() operations must not cross the packet boundary, otherwise the
+ /// attached file descriptors will get lost.
+ ///
+ /// # Return:
+ /// * - (number of bytes received, [received fds]) on success
+ /// * - SocketBroken: the underline socket is broken.
+ /// * - SocketError: other socket related errors.
+ pub fn recv_into_iovec_all(
+ &mut self,
+ iovs: &mut [iovec],
+ ) -> Result<(usize, Option<Vec<RawFd>>)> {
+ let mut data_read = 0;
+ let mut data_total = 0;
+ let mut rfds = None;
+ let iov_lens: Vec<usize> = iovs.iter().map(|iov| iov.iov_len).collect();
+ for len in &iov_lens {
+ data_total += len;
+ }
+
+ while (data_total - data_read) > 0 {
+ let (nr_skip, offset) = get_sub_iovs_offset(&iov_lens, data_read);
+ let iov = &mut iovs[nr_skip];
+
+ let mut data = [
+ &[iovec {
+ iov_base: (iov.iov_base as usize + offset) as *mut c_void,
+ iov_len: iov.iov_len - offset,
+ }],
+ &iovs[(nr_skip + 1)..],
+ ]
+ .concat();
+
+ let res = self.recv_into_iovec(&mut data);
+ match res {
+ Ok((0, _)) => return Ok((data_read, rfds)),
+ Ok((n, fds)) => {
+ if data_read == 0 {
+ rfds = fds;
+ }
+ data_read += n;
+ }
+ Err(e) => match e {
+ Error::SocketRetry(_) => {}
+ _ => return Err(e),
+ },
+ }
+ }
+ Ok((data_read, rfds))
+ }
+
+ /// Reads bytes from the socket into a new buffer with optional attached
+ /// file descriptors. Received file descriptors are set close-on-exec.
+ ///
+ /// # Return:
+ /// * - (number of bytes received, buf, [received fds]) on success.
+ /// * - SocketRetry: temporary error caused by signals or short of resources.
+ /// * - SocketBroken: the underline socket is broken.
+ /// * - SocketError: other socket related errors.
+ pub fn recv_into_buf(
+ &mut self,
+ buf_size: usize,
+ ) -> Result<(usize, Vec<u8>, Option<Vec<RawFd>>)> {
+ let mut buf = vec![0u8; buf_size];
+ let (bytes, rfds) = {
+ let mut iovs = [iovec {
+ iov_base: buf.as_mut_ptr() as *mut c_void,
+ iov_len: buf_size,
+ }];
+ self.recv_into_iovec(&mut iovs)?
+ };
+ Ok((bytes, buf, rfds))
+ }
+
+ /// Receive a header-only message with optional attached file descriptors.
+ /// Note, only the first MAX_ATTACHED_FD_ENTRIES file descriptors will be
+ /// accepted and all other file descriptor will be discard silently.
+ ///
+ /// # Return:
+ /// * - (message header, [received fds]) on success.
+ /// * - SocketRetry: temporary error caused by signals or short of resources.
+ /// * - SocketBroken: the underline socket is broken.
+ /// * - SocketError: other socket related errors.
+ /// * - PartialMessage: received a partial message.
+ /// * - InvalidMessage: received a invalid message.
+ pub fn recv_header(&mut self) -> Result<(VhostUserMsgHeader<R>, Option<Vec<RawFd>>)> {
+ let mut hdr = VhostUserMsgHeader::default();
+ let mut iovs = [iovec {
+ iov_base: (&mut hdr as *mut VhostUserMsgHeader<R>) as *mut c_void,
+ iov_len: mem::size_of::<VhostUserMsgHeader<R>>(),
+ }];
+ let (bytes, rfds) = self.recv_into_iovec_all(&mut iovs[..])?;
+
+ if bytes != mem::size_of::<VhostUserMsgHeader<R>>() {
+ return Err(Error::PartialMessage);
+ } else if !hdr.is_valid() {
+ return Err(Error::InvalidMessage);
+ }
+
+ Ok((hdr, rfds))
+ }
+
+ /// Receive a message with optional attached file descriptors.
+ /// Note, only the first MAX_ATTACHED_FD_ENTRIES file descriptors will be
+ /// accepted and all other file descriptor will be discard silently.
+ ///
+ /// # Return:
+ /// * - (message header, message body, [received fds]) on success.
+ /// * - SocketRetry: temporary error caused by signals or short of resources.
+ /// * - SocketBroken: the underline socket is broken.
+ /// * - SocketError: other socket related errors.
+ /// * - PartialMessage: received a partial message.
+ /// * - InvalidMessage: received a invalid message.
+ pub fn recv_body<T: Sized + Default + VhostUserMsgValidator>(
+ &mut self,
+ ) -> Result<(VhostUserMsgHeader<R>, T, Option<Vec<RawFd>>)> {
+ let mut hdr = VhostUserMsgHeader::default();
+ let mut body: T = Default::default();
+ let mut iovs = [
+ iovec {
+ iov_base: (&mut hdr as *mut VhostUserMsgHeader<R>) as *mut c_void,
+ iov_len: mem::size_of::<VhostUserMsgHeader<R>>(),
+ },
+ iovec {
+ iov_base: (&mut body as *mut T) as *mut c_void,
+ iov_len: mem::size_of::<T>(),
+ },
+ ];
+ let (bytes, rfds) = self.recv_into_iovec_all(&mut iovs[..])?;
+
+ let total = mem::size_of::<VhostUserMsgHeader<R>>() + mem::size_of::<T>();
+ if bytes != total {
+ return Err(Error::PartialMessage);
+ } else if !hdr.is_valid() || !body.is_valid() {
+ return Err(Error::InvalidMessage);
+ }
+
+ Ok((hdr, body, rfds))
+ }
+
+ /// Receive a message with header and optional content. Callers need to
+ /// pre-allocate a big enough buffer to receive the message body and
+ /// optional payload. If there are attached file descriptor associated
+ /// with the message, the first MAX_ATTACHED_FD_ENTRIES file descriptors
+ /// will be accepted and all other file descriptor will be discard
+ /// silently.
+ ///
+ /// # Return:
+ /// * - (message header, message size, [received fds]) on success.
+ /// * - SocketRetry: temporary error caused by signals or short of resources.
+ /// * - SocketBroken: the underline socket is broken.
+ /// * - SocketError: other socket related errors.
+ /// * - PartialMessage: received a partial message.
+ /// * - InvalidMessage: received a invalid message.
+ pub fn recv_body_into_buf(
+ &mut self,
+ buf: &mut [u8],
+ ) -> Result<(VhostUserMsgHeader<R>, usize, Option<Vec<RawFd>>)> {
+ let mut hdr = VhostUserMsgHeader::default();
+ let mut iovs = [
+ iovec {
+ iov_base: (&mut hdr as *mut VhostUserMsgHeader<R>) as *mut c_void,
+ iov_len: mem::size_of::<VhostUserMsgHeader<R>>(),
+ },
+ iovec {
+ iov_base: buf.as_mut_ptr() as *mut c_void,
+ iov_len: buf.len(),
+ },
+ ];
+ let (bytes, rfds) = self.recv_into_iovec_all(&mut iovs[..])?;
+
+ if bytes < mem::size_of::<VhostUserMsgHeader<R>>() {
+ return Err(Error::PartialMessage);
+ } else if !hdr.is_valid() {
+ return Err(Error::InvalidMessage);
+ }
+
+ Ok((hdr, bytes - mem::size_of::<VhostUserMsgHeader<R>>(), rfds))
+ }
+
+ /// Receive a message with optional payload and attached file descriptors.
+ /// Note, only the first MAX_ATTACHED_FD_ENTRIES file descriptors will be
+ /// accepted and all other file descriptor will be discard silently.
+ ///
+ /// # Return:
+ /// * - (message header, message body, size of payload, [received fds]) on success.
+ /// * - SocketRetry: temporary error caused by signals or short of resources.
+ /// * - SocketBroken: the underline socket is broken.
+ /// * - SocketError: other socket related errors.
+ /// * - PartialMessage: received a partial message.
+ /// * - InvalidMessage: received a invalid message.
+ #[cfg_attr(feature = "cargo-clippy", allow(clippy::type_complexity))]
+ pub fn recv_payload_into_buf<T: Sized + Default + VhostUserMsgValidator>(
+ &mut self,
+ buf: &mut [u8],
+ ) -> Result<(VhostUserMsgHeader<R>, T, usize, Option<Vec<RawFd>>)> {
+ let mut hdr = VhostUserMsgHeader::default();
+ let mut body: T = Default::default();
+ let mut iovs = [
+ iovec {
+ iov_base: (&mut hdr as *mut VhostUserMsgHeader<R>) as *mut c_void,
+ iov_len: mem::size_of::<VhostUserMsgHeader<R>>(),
+ },
+ iovec {
+ iov_base: (&mut body as *mut T) as *mut c_void,
+ iov_len: mem::size_of::<T>(),
+ },
+ iovec {
+ iov_base: buf.as_mut_ptr() as *mut c_void,
+ iov_len: buf.len(),
+ },
+ ];
+ let (bytes, rfds) = self.recv_into_iovec_all(&mut iovs[..])?;
+
+ let total = mem::size_of::<VhostUserMsgHeader<R>>() + mem::size_of::<T>();
+ if bytes < total {
+ return Err(Error::PartialMessage);
+ } else if !hdr.is_valid() || !body.is_valid() {
+ return Err(Error::InvalidMessage);
+ }
+
+ Ok((hdr, body, bytes - total, rfds))
+ }
+
+ /// Close all raw file descriptors.
+ pub fn close_rfds(rfds: Option<Vec<RawFd>>) {
+ if let Some(fds) = rfds {
+ for fd in fds {
+ // safe because the rawfds are valid and we don't care about the result.
+ let _ = unsafe { libc::close(fd) };
+ }
+ }
+ }
+}
+
+impl<T: Req> AsRawFd for Endpoint<T> {
+ fn as_raw_fd(&self) -> RawFd {
+ self.sock.as_raw_fd()
+ }
+}
+
+// Given a slice of sizes and the `skip_size`, return the offset of `skip_size` in the slice.
+// For example:
+// let iov_lens = vec![4, 4, 5];
+// let size = 6;
+// assert_eq!(get_sub_iovs_offset(&iov_len, size), (1, 2));
+fn get_sub_iovs_offset(iov_lens: &[usize], skip_size: usize) -> (usize, usize) {
+ let mut size = skip_size;
+ let mut nr_skip = 0;
+
+ for len in iov_lens {
+ if size >= *len {
+ size -= *len;
+ nr_skip += 1;
+ } else {
+ break;
+ }
+ }
+ (nr_skip, size)
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use std::fs::File;
+ use std::io::{Read, Seek, SeekFrom, Write};
+ use std::os::unix::io::FromRawFd;
+ use tempfile::{tempfile, Builder, TempDir};
+
+ fn temp_dir() -> TempDir {
+ Builder::new().prefix("/tmp/vhost_test").tempdir().unwrap()
+ }
+
+ #[test]
+ fn create_listener() {
+ let dir = temp_dir();
+ let mut path = dir.path().to_owned();
+ path.push("sock");
+ let listener = Listener::new(&path, true).unwrap();
+
+ assert!(listener.as_raw_fd() > 0);
+ }
+
+ #[test]
+ fn accept_connection() {
+ let dir = temp_dir();
+ let mut path = dir.path().to_owned();
+ path.push("sock");
+ let listener = Listener::new(&path, true).unwrap();
+ listener.set_nonblocking(true).unwrap();
+
+ // accept on a fd without incoming connection
+ let conn = listener.accept().unwrap();
+ assert!(conn.is_none());
+ }
+
+ #[test]
+ fn send_data() {
+ let dir = temp_dir();
+ let mut path = dir.path().to_owned();
+ path.push("sock");
+ let listener = Listener::new(&path, true).unwrap();
+ listener.set_nonblocking(true).unwrap();
+ let mut master = Endpoint::<MasterReq>::connect(&path).unwrap();
+ let sock = listener.accept().unwrap().unwrap();
+ let mut slave = Endpoint::<MasterReq>::from_stream(sock);
+
+ let buf1 = vec![0x1, 0x2, 0x3, 0x4];
+ let mut len = master.send_slice(&buf1[..], None).unwrap();
+ assert_eq!(len, 4);
+ let (bytes, buf2, _) = slave.recv_into_buf(0x1000).unwrap();
+ assert_eq!(bytes, 4);
+ assert_eq!(&buf1[..], &buf2[..bytes]);
+
+ len = master.send_slice(&buf1[..], None).unwrap();
+ assert_eq!(len, 4);
+ let (bytes, buf2, _) = slave.recv_into_buf(0x2).unwrap();
+ assert_eq!(bytes, 2);
+ assert_eq!(&buf1[..2], &buf2[..]);
+ let (bytes, buf2, _) = slave.recv_into_buf(0x2).unwrap();
+ assert_eq!(bytes, 2);
+ assert_eq!(&buf1[2..], &buf2[..]);
+ }
+
+ #[test]
+ fn send_fd() {
+ let dir = temp_dir();
+ let mut path = dir.path().to_owned();
+ path.push("sock");
+ let listener = Listener::new(&path, true).unwrap();
+ listener.set_nonblocking(true).unwrap();
+ let mut master = Endpoint::<MasterReq>::connect(&path).unwrap();
+ let sock = listener.accept().unwrap().unwrap();
+ let mut slave = Endpoint::<MasterReq>::from_stream(sock);
+
+ let mut fd = tempfile().unwrap();
+ write!(fd, "test").unwrap();
+
+ // Normal case for sending/receiving file descriptors
+ let buf1 = vec![0x1, 0x2, 0x3, 0x4];
+ let len = master
+ .send_slice(&buf1[..], Some(&[fd.as_raw_fd()]))
+ .unwrap();
+ assert_eq!(len, 4);
+
+ let (bytes, buf2, rfds) = slave.recv_into_buf(4).unwrap();
+ assert_eq!(bytes, 4);
+ assert_eq!(&buf1[..], &buf2[..]);
+ assert!(rfds.is_some());
+ let fds = rfds.unwrap();
+ {
+ assert_eq!(fds.len(), 1);
+ let mut file = unsafe { File::from_raw_fd(fds[0]) };
+ let mut content = String::new();
+ file.seek(SeekFrom::Start(0)).unwrap();
+ file.read_to_string(&mut content).unwrap();
+ assert_eq!(content, "test");
+ }
+
+ // Following communication pattern should work:
+ // Sending side: data(header, body) with fds
+ // Receiving side: data(header) with fds, data(body)
+ let len = master
+ .send_slice(
+ &buf1[..],
+ Some(&[fd.as_raw_fd(), fd.as_raw_fd(), fd.as_raw_fd()]),
+ )
+ .unwrap();
+ assert_eq!(len, 4);
+
+ let (bytes, buf2, rfds) = slave.recv_into_buf(0x2).unwrap();
+ assert_eq!(bytes, 2);
+ assert_eq!(&buf1[..2], &buf2[..]);
+ assert!(rfds.is_some());
+ let fds = rfds.unwrap();
+ {
+ assert_eq!(fds.len(), 3);
+ let mut file = unsafe { File::from_raw_fd(fds[1]) };
+ let mut content = String::new();
+ file.seek(SeekFrom::Start(0)).unwrap();
+ file.read_to_string(&mut content).unwrap();
+ assert_eq!(content, "test");
+ }
+ let (bytes, buf2, rfds) = slave.recv_into_buf(0x2).unwrap();
+ assert_eq!(bytes, 2);
+ assert_eq!(&buf1[2..], &buf2[..]);
+ assert!(rfds.is_none());
+
+ // Following communication pattern should not work:
+ // Sending side: data(header, body) with fds
+ // Receiving side: data(header), data(body) with fds
+ let len = master
+ .send_slice(
+ &buf1[..],
+ Some(&[fd.as_raw_fd(), fd.as_raw_fd(), fd.as_raw_fd()]),
+ )
+ .unwrap();
+ assert_eq!(len, 4);
+
+ let (bytes, buf4) = slave.recv_data(2).unwrap();
+ assert_eq!(bytes, 2);
+ assert_eq!(&buf1[..2], &buf4[..]);
+ let (bytes, buf2, rfds) = slave.recv_into_buf(0x2).unwrap();
+ assert_eq!(bytes, 2);
+ assert_eq!(&buf1[2..], &buf2[..]);
+ assert!(rfds.is_none());
+
+ // Following communication pattern should work:
+ // Sending side: data, data with fds
+ // Receiving side: data, data with fds
+ let len = master.send_slice(&buf1[..], None).unwrap();
+ assert_eq!(len, 4);
+ let len = master
+ .send_slice(
+ &buf1[..],
+ Some(&[fd.as_raw_fd(), fd.as_raw_fd(), fd.as_raw_fd()]),
+ )
+ .unwrap();
+ assert_eq!(len, 4);
+
+ let (bytes, buf2, rfds) = slave.recv_into_buf(0x4).unwrap();
+ assert_eq!(bytes, 4);
+ assert_eq!(&buf1[..], &buf2[..]);
+ assert!(rfds.is_none());
+
+ let (bytes, buf2, rfds) = slave.recv_into_buf(0x2).unwrap();
+ assert_eq!(bytes, 2);
+ assert_eq!(&buf1[..2], &buf2[..]);
+ assert!(rfds.is_some());
+ let fds = rfds.unwrap();
+ {
+ assert_eq!(fds.len(), 3);
+ let mut file = unsafe { File::from_raw_fd(fds[1]) };
+ let mut content = String::new();
+ file.seek(SeekFrom::Start(0)).unwrap();
+ file.read_to_string(&mut content).unwrap();
+ assert_eq!(content, "test");
+ }
+ let (bytes, buf2, rfds) = slave.recv_into_buf(0x2).unwrap();
+ assert_eq!(bytes, 2);
+ assert_eq!(&buf1[2..], &buf2[..]);
+ assert!(rfds.is_none());
+
+ // Following communication pattern should not work:
+ // Sending side: data1, data2 with fds
+ // Receiving side: data + partial of data2, left of data2 with fds
+ let len = master.send_slice(&buf1[..], None).unwrap();
+ assert_eq!(len, 4);
+ let len = master
+ .send_slice(
+ &buf1[..],
+ Some(&[fd.as_raw_fd(), fd.as_raw_fd(), fd.as_raw_fd()]),
+ )
+ .unwrap();
+ assert_eq!(len, 4);
+
+ let (bytes, _) = slave.recv_data(5).unwrap();
+ assert_eq!(bytes, 5);
+
+ let (bytes, _, rfds) = slave.recv_into_buf(0x4).unwrap();
+ assert_eq!(bytes, 3);
+ assert!(rfds.is_none());
+
+ // If the target fd array is too small, extra file descriptors will get lost.
+ let len = master
+ .send_slice(
+ &buf1[..],
+ Some(&[fd.as_raw_fd(), fd.as_raw_fd(), fd.as_raw_fd()]),
+ )
+ .unwrap();
+ assert_eq!(len, 4);
+
+ let (bytes, _, rfds) = slave.recv_into_buf(0x4).unwrap();
+ assert_eq!(bytes, 4);
+ assert!(rfds.is_some());
+
+ Endpoint::<MasterReq>::close_rfds(rfds);
+ Endpoint::<MasterReq>::close_rfds(None);
+ }
+
+ #[test]
+ fn send_recv() {
+ let dir = temp_dir();
+ let mut path = dir.path().to_owned();
+ path.push("sock");
+ let listener = Listener::new(&path, true).unwrap();
+ listener.set_nonblocking(true).unwrap();
+ let mut master = Endpoint::<MasterReq>::connect(&path).unwrap();
+ let sock = listener.accept().unwrap().unwrap();
+ let mut slave = Endpoint::<MasterReq>::from_stream(sock);
+
+ let mut hdr1 =
+ VhostUserMsgHeader::new(MasterReq::GET_FEATURES, 0, mem::size_of::<u64>() as u32);
+ hdr1.set_need_reply(true);
+ let features1 = 0x1u64;
+ master.send_message(&hdr1, &features1, None).unwrap();
+
+ let mut features2 = 0u64;
+ let slice = unsafe {
+ slice::from_raw_parts_mut(
+ (&mut features2 as *mut u64) as *mut u8,
+ mem::size_of::<u64>(),
+ )
+ };
+ let (hdr2, bytes, rfds) = slave.recv_body_into_buf(slice).unwrap();
+ assert_eq!(hdr1, hdr2);
+ assert_eq!(bytes, 8);
+ assert_eq!(features1, features2);
+ assert!(rfds.is_none());
+
+ master.send_header(&hdr1, None).unwrap();
+ let (hdr2, rfds) = slave.recv_header().unwrap();
+ assert_eq!(hdr1, hdr2);
+ assert!(rfds.is_none());
+ }
+}
diff --git a/src/vhost_user/dummy_slave.rs b/src/vhost_user/dummy_slave.rs
new file mode 100644
index 0000000..b2b83d2
--- /dev/null
+++ b/src/vhost_user/dummy_slave.rs
@@ -0,0 +1,259 @@
+// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved.
+// SPDX-License-Identifier: Apache-2.0
+
+use std::os::unix::io::RawFd;
+
+use super::message::*;
+use super::*;
+
+pub const MAX_QUEUE_NUM: usize = 2;
+pub const MAX_VRING_NUM: usize = 256;
+pub const MAX_MEM_SLOTS: usize = 32;
+pub const VIRTIO_FEATURES: u64 = 0x40000003;
+
+#[derive(Default)]
+pub struct DummySlaveReqHandler {
+ pub owned: bool,
+ pub features_acked: bool,
+ pub acked_features: u64,
+ pub acked_protocol_features: u64,
+ pub queue_num: usize,
+ pub vring_num: [u32; MAX_QUEUE_NUM],
+ pub vring_base: [u32; MAX_QUEUE_NUM],
+ pub call_fd: [Option<RawFd>; MAX_QUEUE_NUM],
+ pub kick_fd: [Option<RawFd>; MAX_QUEUE_NUM],
+ pub err_fd: [Option<RawFd>; MAX_QUEUE_NUM],
+ pub vring_started: [bool; MAX_QUEUE_NUM],
+ pub vring_enabled: [bool; MAX_QUEUE_NUM],
+}
+
+impl DummySlaveReqHandler {
+ pub fn new() -> Self {
+ DummySlaveReqHandler {
+ queue_num: MAX_QUEUE_NUM,
+ ..Default::default()
+ }
+ }
+}
+
+impl VhostUserSlaveReqHandlerMut for DummySlaveReqHandler {
+ fn set_owner(&mut self) -> Result<()> {
+ if self.owned {
+ return Err(Error::InvalidOperation);
+ }
+ self.owned = true;
+ Ok(())
+ }
+
+ fn reset_owner(&mut self) -> Result<()> {
+ self.owned = false;
+ self.features_acked = false;
+ self.acked_features = 0;
+ self.acked_protocol_features = 0;
+ Ok(())
+ }
+
+ fn get_features(&mut self) -> Result<u64> {
+ Ok(VIRTIO_FEATURES)
+ }
+
+ fn set_features(&mut self, features: u64) -> Result<()> {
+ if !self.owned || self.features_acked {
+ return Err(Error::InvalidOperation);
+ } else if (features & !VIRTIO_FEATURES) != 0 {
+ return Err(Error::InvalidParam);
+ }
+
+ self.acked_features = features;
+ self.features_acked = true;
+
+ // If VHOST_USER_F_PROTOCOL_FEATURES has not been negotiated,
+ // the ring is initialized in an enabled state.
+ // If VHOST_USER_F_PROTOCOL_FEATURES has been negotiated,
+ // the ring is initialized in a disabled state. Client must not
+ // pass data to/from the backend until ring is enabled by
+ // VHOST_USER_SET_VRING_ENABLE with parameter 1, or after it has
+ // been disabled by VHOST_USER_SET_VRING_ENABLE with parameter 0.
+ let vring_enabled =
+ self.acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() == 0;
+ for enabled in &mut self.vring_enabled {
+ *enabled = vring_enabled;
+ }
+
+ Ok(())
+ }
+
+ fn set_mem_table(&mut self, _ctx: &[VhostUserMemoryRegion], _fds: &[RawFd]) -> Result<()> {
+ Ok(())
+ }
+
+ fn set_vring_num(&mut self, index: u32, num: u32) -> Result<()> {
+ if index as usize >= self.queue_num || num == 0 || num as usize > MAX_VRING_NUM {
+ return Err(Error::InvalidParam);
+ }
+ self.vring_num[index as usize] = num;
+ Ok(())
+ }
+
+ fn set_vring_addr(
+ &mut self,
+ index: u32,
+ _flags: VhostUserVringAddrFlags,
+ _descriptor: u64,
+ _used: u64,
+ _available: u64,
+ _log: u64,
+ ) -> Result<()> {
+ if index as usize >= self.queue_num {
+ return Err(Error::InvalidParam);
+ }
+ Ok(())
+ }
+
+ fn set_vring_base(&mut self, index: u32, base: u32) -> Result<()> {
+ if index as usize >= self.queue_num || base as usize >= MAX_VRING_NUM {
+ return Err(Error::InvalidParam);
+ }
+ self.vring_base[index as usize] = base;
+ Ok(())
+ }
+
+ fn get_vring_base(&mut self, index: u32) -> Result<VhostUserVringState> {
+ if index as usize >= self.queue_num {
+ return Err(Error::InvalidParam);
+ }
+ // Quotation from vhost-user spec:
+ // Client must start ring upon receiving a kick (that is, detecting
+ // that file descriptor is readable) on the descriptor specified by
+ // VHOST_USER_SET_VRING_KICK, and stop ring upon receiving
+ // VHOST_USER_GET_VRING_BASE.
+ self.vring_started[index as usize] = false;
+ Ok(VhostUserVringState::new(
+ index,
+ self.vring_base[index as usize],
+ ))
+ }
+
+ fn set_vring_kick(&mut self, index: u8, fd: Option<RawFd>) -> Result<()> {
+ if index as usize >= self.queue_num || index as usize > self.queue_num {
+ return Err(Error::InvalidParam);
+ }
+ if self.kick_fd[index as usize].is_some() {
+ // Close file descriptor set by previous operations.
+ let _ = unsafe { libc::close(self.kick_fd[index as usize].unwrap()) };
+ }
+ self.kick_fd[index as usize] = fd;
+
+ // Quotation from vhost-user spec:
+ // Client must start ring upon receiving a kick (that is, detecting
+ // that file descriptor is readable) on the descriptor specified by
+ // VHOST_USER_SET_VRING_KICK, and stop ring upon receiving
+ // VHOST_USER_GET_VRING_BASE.
+ //
+ // So we should add fd to event monitor(select, poll, epoll) here.
+ self.vring_started[index as usize] = true;
+ Ok(())
+ }
+
+ fn set_vring_call(&mut self, index: u8, fd: Option<RawFd>) -> Result<()> {
+ if index as usize >= self.queue_num || index as usize > self.queue_num {
+ return Err(Error::InvalidParam);
+ }
+ if self.call_fd[index as usize].is_some() {
+ // Close file descriptor set by previous operations.
+ let _ = unsafe { libc::close(self.call_fd[index as usize].unwrap()) };
+ }
+ self.call_fd[index as usize] = fd;
+ Ok(())
+ }
+
+ fn set_vring_err(&mut self, index: u8, fd: Option<RawFd>) -> Result<()> {
+ if index as usize >= self.queue_num || index as usize > self.queue_num {
+ return Err(Error::InvalidParam);
+ }
+ if self.err_fd[index as usize].is_some() {
+ // Close file descriptor set by previous operations.
+ let _ = unsafe { libc::close(self.err_fd[index as usize].unwrap()) };
+ }
+ self.err_fd[index as usize] = fd;
+ Ok(())
+ }
+
+ fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures> {
+ Ok(VhostUserProtocolFeatures::all())
+ }
+
+ fn set_protocol_features(&mut self, features: u64) -> Result<()> {
+ // Note: slave that reported VHOST_USER_F_PROTOCOL_FEATURES must
+ // support this message even before VHOST_USER_SET_FEATURES was
+ // called.
+ // What happens if the master calls set_features() with
+ // VHOST_USER_F_PROTOCOL_FEATURES cleared after calling this
+ // interface?
+ self.acked_protocol_features = features;
+ Ok(())
+ }
+
+ fn get_queue_num(&mut self) -> Result<u64> {
+ Ok(MAX_QUEUE_NUM as u64)
+ }
+
+ fn set_vring_enable(&mut self, index: u32, enable: bool) -> Result<()> {
+ // This request should be handled only when VHOST_USER_F_PROTOCOL_FEATURES
+ // has been negotiated.
+ if self.acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() == 0 {
+ return Err(Error::InvalidOperation);
+ } else if index as usize >= self.queue_num || index as usize > self.queue_num {
+ return Err(Error::InvalidParam);
+ }
+
+ // Slave must not pass data to/from the backend until ring is
+ // enabled by VHOST_USER_SET_VRING_ENABLE with parameter 1,
+ // or after it has been disabled by VHOST_USER_SET_VRING_ENABLE
+ // with parameter 0.
+ self.vring_enabled[index as usize] = enable;
+ Ok(())
+ }
+
+ fn get_config(
+ &mut self,
+ offset: u32,
+ size: u32,
+ _flags: VhostUserConfigFlags,
+ ) -> Result<Vec<u8>> {
+ if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 {
+ return Err(Error::InvalidOperation);
+ } else if !(VHOST_USER_CONFIG_OFFSET..VHOST_USER_CONFIG_SIZE).contains(&offset)
+ || size > VHOST_USER_CONFIG_SIZE - VHOST_USER_CONFIG_OFFSET
+ || size + offset > VHOST_USER_CONFIG_SIZE
+ {
+ return Err(Error::InvalidParam);
+ }
+ Ok(vec![0xa5; size as usize])
+ }
+
+ fn set_config(&mut self, offset: u32, buf: &[u8], _flags: VhostUserConfigFlags) -> Result<()> {
+ let size = buf.len() as u32;
+ if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 {
+ return Err(Error::InvalidOperation);
+ } else if !(VHOST_USER_CONFIG_OFFSET..VHOST_USER_CONFIG_SIZE).contains(&offset)
+ || size > VHOST_USER_CONFIG_SIZE - VHOST_USER_CONFIG_OFFSET
+ || size + offset > VHOST_USER_CONFIG_SIZE
+ {
+ return Err(Error::InvalidParam);
+ }
+ Ok(())
+ }
+
+ fn get_max_mem_slots(&mut self) -> Result<u64> {
+ Ok(MAX_MEM_SLOTS as u64)
+ }
+
+ fn add_mem_region(&mut self, _region: &VhostUserSingleMemoryRegion, _fd: RawFd) -> Result<()> {
+ Ok(())
+ }
+
+ fn remove_mem_region(&mut self, _region: &VhostUserSingleMemoryRegion) -> Result<()> {
+ Ok(())
+ }
+}
diff --git a/src/vhost_user/master.rs b/src/vhost_user/master.rs
new file mode 100644
index 0000000..cc79871
--- /dev/null
+++ b/src/vhost_user/master.rs
@@ -0,0 +1,1071 @@
+// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved.
+// SPDX-License-Identifier: Apache-2.0
+
+//! Traits and Struct for vhost-user master.
+
+use std::mem;
+use std::os::unix::io::{AsRawFd, RawFd};
+use std::os::unix::net::UnixStream;
+use std::path::Path;
+use std::sync::{Arc, Mutex, MutexGuard};
+
+use sys_util::EventFd;
+
+use super::connection::Endpoint;
+use super::message::*;
+use super::{Error as VhostUserError, Result as VhostUserResult};
+use crate::backend::{VhostBackend, VhostUserMemoryRegionInfo, VringConfigData};
+use crate::{Error, Result};
+
+/// Trait for vhost-user master to provide extra methods not covered by the VhostBackend yet.
+pub trait VhostUserMaster: VhostBackend {
+ /// Get the protocol feature bitmask from the underlying vhost implementation.
+ fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures>;
+
+ /// Enable protocol features in the underlying vhost implementation.
+ fn set_protocol_features(&mut self, features: VhostUserProtocolFeatures) -> Result<()>;
+
+ /// Query how many queues the backend supports.
+ fn get_queue_num(&mut self) -> Result<u64>;
+
+ /// Signal slave to enable or disable corresponding vring.
+ ///
+ /// Slave must not pass data to/from the backend until ring is enabled by
+ /// VHOST_USER_SET_VRING_ENABLE with parameter 1, or after it has been
+ /// disabled by VHOST_USER_SET_VRING_ENABLE with parameter 0.
+ fn set_vring_enable(&mut self, queue_index: usize, enable: bool) -> Result<()>;
+
+ /// Fetch the contents of the virtio device configuration space.
+ fn get_config(
+ &mut self,
+ offset: u32,
+ size: u32,
+ flags: VhostUserConfigFlags,
+ buf: &[u8],
+ ) -> Result<(VhostUserConfig, VhostUserConfigPayload)>;
+
+ /// Change the virtio device configuration space. It also can be used for live migration on the
+ /// destination host to set readonly configuration space fields.
+ fn set_config(&mut self, offset: u32, flags: VhostUserConfigFlags, buf: &[u8]) -> Result<()>;
+
+ /// Setup slave communication channel.
+ fn set_slave_request_fd(&mut self, fd: RawFd) -> Result<()>;
+
+ /// Query the maximum amount of memory slots supported by the backend.
+ fn get_max_mem_slots(&mut self) -> Result<u64>;
+
+ /// Add a new guest memory mapping for vhost to use.
+ fn add_mem_region(&mut self, region: &VhostUserMemoryRegionInfo) -> Result<()>;
+
+ /// Remove a guest memory mapping from vhost.
+ fn remove_mem_region(&mut self, region: &VhostUserMemoryRegionInfo) -> Result<()>;
+}
+
+fn error_code<T>(err: VhostUserError) -> Result<T> {
+ Err(Error::VhostUserProtocol(err))
+}
+
+/// Struct for the vhost-user master endpoint.
+#[derive(Clone)]
+pub struct Master {
+ node: Arc<Mutex<MasterInternal>>,
+}
+
+impl Master {
+ /// Create a new instance.
+ fn new(ep: Endpoint<MasterReq>, max_queue_num: u64) -> Self {
+ Master {
+ node: Arc::new(Mutex::new(MasterInternal {
+ main_sock: ep,
+ virtio_features: 0,
+ acked_virtio_features: 0,
+ protocol_features: 0,
+ acked_protocol_features: 0,
+ protocol_features_ready: false,
+ max_queue_num,
+ error: None,
+ })),
+ }
+ }
+
+ fn node(&self) -> MutexGuard<MasterInternal> {
+ self.node.lock().unwrap()
+ }
+
+ /// Create a new instance from a Unix stream socket.
+ pub fn from_stream(sock: UnixStream, max_queue_num: u64) -> Self {
+ Self::new(Endpoint::<MasterReq>::from_stream(sock), max_queue_num)
+ }
+
+ /// Create a new vhost-user master endpoint.
+ ///
+ /// Will retry as the backend may not be ready to accept the connection.
+ ///
+ /// # Arguments
+ /// * `path` - path of Unix domain socket listener to connect to
+ pub fn connect<P: AsRef<Path>>(path: P, max_queue_num: u64) -> Result<Self> {
+ let mut retry_count = 5;
+ let endpoint = loop {
+ match Endpoint::<MasterReq>::connect(&path) {
+ Ok(endpoint) => break Ok(endpoint),
+ Err(e) => match &e {
+ VhostUserError::SocketConnect(why) => {
+ if why.kind() == std::io::ErrorKind::ConnectionRefused && retry_count > 0 {
+ std::thread::sleep(std::time::Duration::from_millis(100));
+ retry_count -= 1;
+ continue;
+ } else {
+ break Err(e);
+ }
+ }
+ _ => break Err(e),
+ },
+ }
+ }?;
+
+ Ok(Self::new(endpoint, max_queue_num))
+ }
+}
+
+impl VhostBackend for Master {
+ /// Get from the underlying vhost implementation the feature bitmask.
+ fn get_features(&self) -> Result<u64> {
+ let mut node = self.node();
+ let hdr = node.send_request_header(MasterReq::GET_FEATURES, None)?;
+ let val = node.recv_reply::<VhostUserU64>(&hdr)?;
+ node.virtio_features = val.value;
+ Ok(node.virtio_features)
+ }
+
+ /// Enable features in the underlying vhost implementation using a bitmask.
+ fn set_features(&self, features: u64) -> Result<()> {
+ let mut node = self.node();
+ let val = VhostUserU64::new(features);
+ let _ = node.send_request_with_body(MasterReq::SET_FEATURES, &val, None)?;
+ // Don't wait for ACK here because the protocol feature negotiation process hasn't been
+ // completed yet.
+ node.acked_virtio_features = features & node.virtio_features;
+ Ok(())
+ }
+
+ /// Set the current Master as an owner of the session.
+ fn set_owner(&self) -> Result<()> {
+ // We unwrap() the return value to assert that we are not expecting threads to ever fail
+ // while holding the lock.
+ let mut node = self.node();
+ let _ = node.send_request_header(MasterReq::SET_OWNER, None)?;
+ // Don't wait for ACK here because the protocol feature negotiation process hasn't been
+ // completed yet.
+ Ok(())
+ }
+
+ fn reset_owner(&self) -> Result<()> {
+ let mut node = self.node();
+ let _ = node.send_request_header(MasterReq::RESET_OWNER, None)?;
+ // Don't wait for ACK here because the protocol feature negotiation process hasn't been
+ // completed yet.
+ Ok(())
+ }
+
+ /// Set the memory map regions on the slave so it can translate the vring
+ /// addresses. In the ancillary data there is an array of file descriptors
+ fn set_mem_table(&self, regions: &[VhostUserMemoryRegionInfo]) -> Result<()> {
+ if regions.is_empty() || regions.len() > MAX_ATTACHED_FD_ENTRIES {
+ return error_code(VhostUserError::InvalidParam);
+ }
+
+ let mut ctx = VhostUserMemoryContext::new();
+ for region in regions.iter() {
+ if region.memory_size == 0 || region.mmap_handle < 0 {
+ return error_code(VhostUserError::InvalidParam);
+ }
+ let reg = VhostUserMemoryRegion {
+ guest_phys_addr: region.guest_phys_addr,
+ memory_size: region.memory_size,
+ user_addr: region.userspace_addr,
+ mmap_offset: region.mmap_offset,
+ };
+ ctx.append(&reg, region.mmap_handle);
+ }
+
+ let mut node = self.node();
+ let body = VhostUserMemory::new(ctx.regions.len() as u32);
+ let (_, payload, _) = unsafe { ctx.regions.align_to::<u8>() };
+ let hdr = node.send_request_with_payload(
+ MasterReq::SET_MEM_TABLE,
+ &body,
+ payload,
+ Some(ctx.fds.as_slice()),
+ )?;
+ node.wait_for_ack(&hdr).map_err(|e| e.into())
+ }
+
+ // Clippy doesn't seem to know that if let with && is still experimental
+ #[allow(clippy::unnecessary_unwrap)]
+ fn set_log_base(&self, base: u64, fd: Option<RawFd>) -> Result<()> {
+ let mut node = self.node();
+ let val = VhostUserU64::new(base);
+
+ if node.acked_protocol_features & VhostUserProtocolFeatures::LOG_SHMFD.bits() != 0
+ && fd.is_some()
+ {
+ let fds = [fd.unwrap()];
+ let _ = node.send_request_with_body(MasterReq::SET_LOG_BASE, &val, Some(&fds))?;
+ } else {
+ let _ = node.send_request_with_body(MasterReq::SET_LOG_BASE, &val, None)?;
+ }
+ Ok(())
+ }
+
+ fn set_log_fd(&self, fd: RawFd) -> Result<()> {
+ let mut node = self.node();
+ let fds = [fd];
+ node.send_request_header(MasterReq::SET_LOG_FD, Some(&fds))?;
+ Ok(())
+ }
+
+ /// Set the size of the queue.
+ fn set_vring_num(&self, queue_index: usize, num: u16) -> Result<()> {
+ let mut node = self.node();
+ if queue_index as u64 >= node.max_queue_num {
+ return error_code(VhostUserError::InvalidParam);
+ }
+
+ let val = VhostUserVringState::new(queue_index as u32, num.into());
+ let hdr = node.send_request_with_body(MasterReq::SET_VRING_NUM, &val, None)?;
+ node.wait_for_ack(&hdr).map_err(|e| e.into())
+ }
+
+ /// Sets the addresses of the different aspects of the vring.
+ fn set_vring_addr(&self, queue_index: usize, config_data: &VringConfigData) -> Result<()> {
+ let mut node = self.node();
+ if queue_index as u64 >= node.max_queue_num
+ || config_data.flags & !(VhostUserVringAddrFlags::all().bits()) != 0
+ {
+ return error_code(VhostUserError::InvalidParam);
+ }
+
+ let val = VhostUserVringAddr::from_config_data(queue_index as u32, config_data);
+ let hdr = node.send_request_with_body(MasterReq::SET_VRING_ADDR, &val, None)?;
+ node.wait_for_ack(&hdr).map_err(|e| e.into())
+ }
+
+ /// Sets the base offset in the available vring.
+ fn set_vring_base(&self, queue_index: usize, base: u16) -> Result<()> {
+ let mut node = self.node();
+ if queue_index as u64 >= node.max_queue_num {
+ return error_code(VhostUserError::InvalidParam);
+ }
+
+ let val = VhostUserVringState::new(queue_index as u32, base.into());
+ let hdr = node.send_request_with_body(MasterReq::SET_VRING_BASE, &val, None)?;
+ node.wait_for_ack(&hdr).map_err(|e| e.into())
+ }
+
+ fn get_vring_base(&self, queue_index: usize) -> Result<u32> {
+ let mut node = self.node();
+ if queue_index as u64 >= node.max_queue_num {
+ return error_code(VhostUserError::InvalidParam);
+ }
+
+ let req = VhostUserVringState::new(queue_index as u32, 0);
+ let hdr = node.send_request_with_body(MasterReq::GET_VRING_BASE, &req, None)?;
+ let reply = node.recv_reply::<VhostUserVringState>(&hdr)?;
+ Ok(reply.num)
+ }
+
+ /// Set the event file descriptor to signal when buffers are used.
+ /// Bits (0-7) of the payload contain the vring index. Bit 8 is the invalid FD flag. This flag
+ /// is set when there is no file descriptor in the ancillary data. This signals that polling
+ /// will be used instead of waiting for the call.
+ fn set_vring_call(&self, queue_index: usize, fd: &EventFd) -> Result<()> {
+ let mut node = self.node();
+ if queue_index as u64 >= node.max_queue_num {
+ return error_code(VhostUserError::InvalidParam);
+ }
+ node.send_fd_for_vring(MasterReq::SET_VRING_CALL, queue_index, fd.as_raw_fd())?;
+ Ok(())
+ }
+
+ /// Set the event file descriptor for adding buffers to the vring.
+ /// Bits (0-7) of the payload contain the vring index. Bit 8 is the invalid FD flag. This flag
+ /// is set when there is no file descriptor in the ancillary data. This signals that polling
+ /// should be used instead of waiting for a kick.
+ fn set_vring_kick(&self, queue_index: usize, fd: &EventFd) -> Result<()> {
+ let mut node = self.node();
+ if queue_index as u64 >= node.max_queue_num {
+ return error_code(VhostUserError::InvalidParam);
+ }
+ node.send_fd_for_vring(MasterReq::SET_VRING_KICK, queue_index, fd.as_raw_fd())?;
+ Ok(())
+ }
+
+ /// Set the event file descriptor to signal when error occurs.
+ /// Bits (0-7) of the payload contain the vring index. Bit 8 is the invalid FD flag. This flag
+ /// is set when there is no file descriptor in the ancillary data.
+ fn set_vring_err(&self, queue_index: usize, fd: &EventFd) -> Result<()> {
+ let mut node = self.node();
+ if queue_index as u64 >= node.max_queue_num {
+ return error_code(VhostUserError::InvalidParam);
+ }
+ node.send_fd_for_vring(MasterReq::SET_VRING_ERR, queue_index, fd.as_raw_fd())?;
+ Ok(())
+ }
+}
+
+impl VhostUserMaster for Master {
+ fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures> {
+ let mut node = self.node();
+ let flag = VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits();
+ if node.virtio_features & flag == 0 || node.acked_virtio_features & flag == 0 {
+ return error_code(VhostUserError::InvalidOperation);
+ }
+ let hdr = node.send_request_header(MasterReq::GET_PROTOCOL_FEATURES, None)?;
+ let val = node.recv_reply::<VhostUserU64>(&hdr)?;
+ node.protocol_features = val.value;
+ // Should we support forward compatibility?
+ // If so just mask out unrecognized flags instead of return errors.
+ match VhostUserProtocolFeatures::from_bits(node.protocol_features) {
+ Some(val) => Ok(val),
+ None => error_code(VhostUserError::InvalidMessage),
+ }
+ }
+
+ fn set_protocol_features(&mut self, features: VhostUserProtocolFeatures) -> Result<()> {
+ let mut node = self.node();
+ let flag = VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits();
+ if node.virtio_features & flag == 0 || node.acked_virtio_features & flag == 0 {
+ return error_code(VhostUserError::InvalidOperation);
+ }
+ let val = VhostUserU64::new(features.bits());
+ let _ = node.send_request_with_body(MasterReq::SET_PROTOCOL_FEATURES, &val, None)?;
+ // Don't wait for ACK here because the protocol feature negotiation process hasn't been
+ // completed yet.
+ node.acked_protocol_features = features.bits();
+ node.protocol_features_ready = true;
+ Ok(())
+ }
+
+ fn get_queue_num(&mut self) -> Result<u64> {
+ let mut node = self.node();
+ if !node.is_feature_mq_available() {
+ return error_code(VhostUserError::InvalidOperation);
+ }
+
+ let hdr = node.send_request_header(MasterReq::GET_QUEUE_NUM, None)?;
+ let val = node.recv_reply::<VhostUserU64>(&hdr)?;
+ if val.value > VHOST_USER_MAX_VRINGS {
+ return error_code(VhostUserError::InvalidMessage);
+ }
+ node.max_queue_num = val.value;
+ Ok(node.max_queue_num)
+ }
+
+ fn set_vring_enable(&mut self, queue_index: usize, enable: bool) -> Result<()> {
+ let mut node = self.node();
+ // set_vring_enable() is supported only when PROTOCOL_FEATURES has been enabled.
+ if node.acked_virtio_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() == 0 {
+ return error_code(VhostUserError::InvalidOperation);
+ } else if queue_index as u64 >= node.max_queue_num {
+ return error_code(VhostUserError::InvalidParam);
+ }
+
+ let flag = if enable { 1 } else { 0 };
+ let val = VhostUserVringState::new(queue_index as u32, flag);
+ let hdr = node.send_request_with_body(MasterReq::SET_VRING_ENABLE, &val, None)?;
+ node.wait_for_ack(&hdr).map_err(|e| e.into())
+ }
+
+ fn get_config(
+ &mut self,
+ offset: u32,
+ size: u32,
+ flags: VhostUserConfigFlags,
+ buf: &[u8],
+ ) -> Result<(VhostUserConfig, VhostUserConfigPayload)> {
+ let body = VhostUserConfig::new(offset, size, flags);
+ if !body.is_valid() {
+ return error_code(VhostUserError::InvalidParam);
+ }
+
+ let mut node = self.node();
+ // depends on VhostUserProtocolFeatures::CONFIG
+ if node.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 {
+ return error_code(VhostUserError::InvalidOperation);
+ }
+
+ // vhost-user spec states that:
+ // "Master payload: virtio device config space"
+ // "Slave payload: virtio device config space"
+ let hdr = node.send_request_with_payload(MasterReq::GET_CONFIG, &body, buf, None)?;
+ let (body_reply, buf_reply, rfds) =
+ node.recv_reply_with_payload::<VhostUserConfig>(&hdr)?;
+ if rfds.is_some() {
+ Endpoint::<MasterReq>::close_rfds(rfds);
+ return error_code(VhostUserError::InvalidMessage);
+ } else if body_reply.size == 0 {
+ return error_code(VhostUserError::SlaveInternalError);
+ } else if body_reply.size != body.size
+ || body_reply.size as usize != buf.len()
+ || body_reply.offset != body.offset
+ {
+ return error_code(VhostUserError::InvalidMessage);
+ }
+
+ Ok((body_reply, buf_reply))
+ }
+
+ fn set_config(&mut self, offset: u32, flags: VhostUserConfigFlags, buf: &[u8]) -> Result<()> {
+ if buf.len() > MAX_MSG_SIZE {
+ return error_code(VhostUserError::InvalidParam);
+ }
+ let body = VhostUserConfig::new(offset, buf.len() as u32, flags);
+ if !body.is_valid() {
+ return error_code(VhostUserError::InvalidParam);
+ }
+
+ let mut node = self.node();
+ // depends on VhostUserProtocolFeatures::CONFIG
+ if node.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 {
+ return error_code(VhostUserError::InvalidOperation);
+ }
+
+ let hdr = node.send_request_with_payload(MasterReq::SET_CONFIG, &body, buf, None)?;
+ node.wait_for_ack(&hdr).map_err(|e| e.into())
+ }
+
+ fn set_slave_request_fd(&mut self, fd: RawFd) -> Result<()> {
+ let mut node = self.node();
+ if node.acked_protocol_features & VhostUserProtocolFeatures::SLAVE_REQ.bits() == 0 {
+ return error_code(VhostUserError::InvalidOperation);
+ }
+
+ let fds = [fd];
+ node.send_request_header(MasterReq::SET_SLAVE_REQ_FD, Some(&fds))?;
+ Ok(())
+ }
+
+ fn get_max_mem_slots(&mut self) -> Result<u64> {
+ let mut node = self.node();
+ if node.acked_protocol_features & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits() == 0
+ {
+ return error_code(VhostUserError::InvalidOperation);
+ }
+
+ let hdr = node.send_request_header(MasterReq::GET_MAX_MEM_SLOTS, None)?;
+ let val = node.recv_reply::<VhostUserU64>(&hdr)?;
+
+ Ok(val.value)
+ }
+
+ fn add_mem_region(&mut self, region: &VhostUserMemoryRegionInfo) -> Result<()> {
+ let mut node = self.node();
+ if node.acked_protocol_features & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits() == 0
+ {
+ return error_code(VhostUserError::InvalidOperation);
+ }
+ if region.memory_size == 0 || region.mmap_handle < 0 {
+ return error_code(VhostUserError::InvalidParam);
+ }
+
+ let body = VhostUserSingleMemoryRegion::new(
+ region.guest_phys_addr,
+ region.memory_size,
+ region.userspace_addr,
+ region.mmap_offset,
+ );
+ let fds = [region.mmap_handle];
+ let hdr = node.send_request_with_body(MasterReq::ADD_MEM_REG, &body, Some(&fds))?;
+ node.wait_for_ack(&hdr).map_err(|e| e.into())
+ }
+
+ fn remove_mem_region(&mut self, region: &VhostUserMemoryRegionInfo) -> Result<()> {
+ let mut node = self.node();
+ if node.acked_protocol_features & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits() == 0
+ {
+ return error_code(VhostUserError::InvalidOperation);
+ }
+ if region.memory_size == 0 {
+ return error_code(VhostUserError::InvalidParam);
+ }
+
+ let body = VhostUserSingleMemoryRegion::new(
+ region.guest_phys_addr,
+ region.memory_size,
+ region.userspace_addr,
+ region.mmap_offset,
+ );
+ let hdr = node.send_request_with_body(MasterReq::REM_MEM_REG, &body, None)?;
+ node.wait_for_ack(&hdr).map_err(|e| e.into())
+ }
+}
+
+impl AsRawFd for Master {
+ fn as_raw_fd(&self) -> RawFd {
+ let node = self.node();
+ node.main_sock.as_raw_fd()
+ }
+}
+
+/// Context object to pass guest memory configuration to VhostUserMaster::set_mem_table().
+struct VhostUserMemoryContext {
+ regions: VhostUserMemoryPayload,
+ fds: Vec<RawFd>,
+}
+
+impl VhostUserMemoryContext {
+ /// Create a context object.
+ pub fn new() -> Self {
+ VhostUserMemoryContext {
+ regions: VhostUserMemoryPayload::new(),
+ fds: Vec::new(),
+ }
+ }
+
+ /// Append a user memory region and corresponding RawFd into the context object.
+ pub fn append(&mut self, region: &VhostUserMemoryRegion, fd: RawFd) {
+ self.regions.push(*region);
+ self.fds.push(fd);
+ }
+}
+
+struct MasterInternal {
+ // Used to send requests to the slave.
+ main_sock: Endpoint<MasterReq>,
+ // Cached virtio features from the slave.
+ virtio_features: u64,
+ // Cached acked virtio features from the driver.
+ acked_virtio_features: u64,
+ // Cached vhost-user protocol features from the slave.
+ protocol_features: u64,
+ // Cached vhost-user protocol features.
+ acked_protocol_features: u64,
+ // Cached vhost-user protocol features are ready to use.
+ protocol_features_ready: bool,
+ // Cached maxinum number of queues supported from the slave.
+ max_queue_num: u64,
+ // Internal flag to mark failure state.
+ error: Option<i32>,
+}
+
+impl MasterInternal {
+ fn send_request_header(
+ &mut self,
+ code: MasterReq,
+ fds: Option<&[RawFd]>,
+ ) -> VhostUserResult<VhostUserMsgHeader<MasterReq>> {
+ self.check_state()?;
+ let hdr = Self::new_request_header(code, 0);
+ self.main_sock.send_header(&hdr, fds)?;
+ Ok(hdr)
+ }
+
+ fn send_request_with_body<T: Sized>(
+ &mut self,
+ code: MasterReq,
+ msg: &T,
+ fds: Option<&[RawFd]>,
+ ) -> VhostUserResult<VhostUserMsgHeader<MasterReq>> {
+ if mem::size_of::<T>() > MAX_MSG_SIZE {
+ return Err(VhostUserError::InvalidParam);
+ }
+ self.check_state()?;
+
+ let hdr = Self::new_request_header(code, mem::size_of::<T>() as u32);
+ self.main_sock.send_message(&hdr, msg, fds)?;
+ Ok(hdr)
+ }
+
+ fn send_request_with_payload<T: Sized>(
+ &mut self,
+ code: MasterReq,
+ msg: &T,
+ payload: &[u8],
+ fds: Option<&[RawFd]>,
+ ) -> VhostUserResult<VhostUserMsgHeader<MasterReq>> {
+ let len = mem::size_of::<T>() + payload.len();
+ if len > MAX_MSG_SIZE {
+ return Err(VhostUserError::InvalidParam);
+ }
+ if let Some(ref fd_arr) = fds {
+ if fd_arr.len() > MAX_ATTACHED_FD_ENTRIES {
+ return Err(VhostUserError::InvalidParam);
+ }
+ }
+ self.check_state()?;
+
+ let hdr = Self::new_request_header(code, len as u32);
+ self.main_sock
+ .send_message_with_payload(&hdr, msg, payload, fds)?;
+ Ok(hdr)
+ }
+
+ fn send_fd_for_vring(
+ &mut self,
+ code: MasterReq,
+ queue_index: usize,
+ fd: RawFd,
+ ) -> VhostUserResult<VhostUserMsgHeader<MasterReq>> {
+ if queue_index as u64 >= self.max_queue_num {
+ return Err(VhostUserError::InvalidParam);
+ }
+ self.check_state()?;
+
+ // Bits (0-7) of the payload contain the vring index. Bit 8 is the invalid FD flag.
+ // This flag is set when there is no file descriptor in the ancillary data. This signals
+ // that polling will be used instead of waiting for the call.
+ let msg = VhostUserU64::new(queue_index as u64);
+ let hdr = Self::new_request_header(code, mem::size_of::<VhostUserU64>() as u32);
+ self.main_sock.send_message(&hdr, &msg, Some(&[fd]))?;
+ Ok(hdr)
+ }
+
+ fn recv_reply<T: Sized + Default + VhostUserMsgValidator>(
+ &mut self,
+ hdr: &VhostUserMsgHeader<MasterReq>,
+ ) -> VhostUserResult<T> {
+ if mem::size_of::<T>() > MAX_MSG_SIZE || hdr.is_reply() {
+ return Err(VhostUserError::InvalidParam);
+ }
+ self.check_state()?;
+
+ let (reply, body, rfds) = self.main_sock.recv_body::<T>()?;
+ if !reply.is_reply_for(&hdr) || rfds.is_some() || !body.is_valid() {
+ Endpoint::<MasterReq>::close_rfds(rfds);
+ return Err(VhostUserError::InvalidMessage);
+ }
+ Ok(body)
+ }
+
+ fn recv_reply_with_payload<T: Sized + Default + VhostUserMsgValidator>(
+ &mut self,
+ hdr: &VhostUserMsgHeader<MasterReq>,
+ ) -> VhostUserResult<(T, Vec<u8>, Option<Vec<RawFd>>)> {
+ if mem::size_of::<T>() > MAX_MSG_SIZE
+ || hdr.get_size() as usize <= mem::size_of::<T>()
+ || hdr.get_size() as usize > MAX_MSG_SIZE
+ || hdr.is_reply()
+ {
+ return Err(VhostUserError::InvalidParam);
+ }
+ self.check_state()?;
+
+ let mut buf: Vec<u8> = vec![0; hdr.get_size() as usize - mem::size_of::<T>()];
+ let (reply, body, bytes, rfds) = self.main_sock.recv_payload_into_buf::<T>(&mut buf)?;
+ if !reply.is_reply_for(hdr)
+ || reply.get_size() as usize != mem::size_of::<T>() + bytes
+ || rfds.is_some()
+ || !body.is_valid()
+ {
+ Endpoint::<MasterReq>::close_rfds(rfds);
+ return Err(VhostUserError::InvalidMessage);
+ } else if bytes != buf.len() {
+ return Err(VhostUserError::InvalidMessage);
+ }
+ Ok((body, buf, rfds))
+ }
+
+ fn wait_for_ack(&mut self, hdr: &VhostUserMsgHeader<MasterReq>) -> VhostUserResult<()> {
+ if self.acked_protocol_features & VhostUserProtocolFeatures::REPLY_ACK.bits() == 0
+ || !hdr.is_need_reply()
+ {
+ return Ok(());
+ }
+ self.check_state()?;
+
+ let (reply, body, rfds) = self.main_sock.recv_body::<VhostUserU64>()?;
+ if !reply.is_reply_for(&hdr) || rfds.is_some() || !body.is_valid() {
+ Endpoint::<MasterReq>::close_rfds(rfds);
+ return Err(VhostUserError::InvalidMessage);
+ }
+ if body.value != 0 {
+ return Err(VhostUserError::SlaveInternalError);
+ }
+ Ok(())
+ }
+
+ fn is_feature_mq_available(&self) -> bool {
+ self.acked_protocol_features & VhostUserProtocolFeatures::MQ.bits() != 0
+ }
+
+ fn check_state(&self) -> VhostUserResult<()> {
+ match self.error {
+ Some(e) => Err(VhostUserError::SocketBroken(
+ std::io::Error::from_raw_os_error(e),
+ )),
+ None => Ok(()),
+ }
+ }
+
+ #[inline]
+ fn new_request_header(request: MasterReq, size: u32) -> VhostUserMsgHeader<MasterReq> {
+ // TODO: handle NEED_REPLY flag
+ VhostUserMsgHeader::new(request, 0x1, size)
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::super::connection::Listener;
+ use super::*;
+ use tempfile::{Builder, TempDir};
+
+ fn temp_dir() -> TempDir {
+ Builder::new().prefix("/tmp/vhost_test").tempdir().unwrap()
+ }
+
+ fn create_pair<P: AsRef<Path>>(path: P) -> (Master, Endpoint<MasterReq>) {
+ let listener = Listener::new(&path, true).unwrap();
+ listener.set_nonblocking(true).unwrap();
+ let master = Master::connect(path, 2).unwrap();
+ let slave = listener.accept().unwrap().unwrap();
+ (master, Endpoint::from_stream(slave))
+ }
+
+ #[test]
+ fn create_master() {
+ let dir = temp_dir();
+ let mut path = dir.path().to_owned();
+ path.push("sock");
+ let listener = Listener::new(&path, true).unwrap();
+ listener.set_nonblocking(true).unwrap();
+
+ let master = Master::connect(&path, 1).unwrap();
+ let mut slave = Endpoint::<MasterReq>::from_stream(listener.accept().unwrap().unwrap());
+
+ assert!(master.as_raw_fd() > 0);
+ // Send two messages continuously
+ master.set_owner().unwrap();
+ master.reset_owner().unwrap();
+
+ let (hdr, rfds) = slave.recv_header().unwrap();
+ assert_eq!(hdr.get_code(), MasterReq::SET_OWNER);
+ assert_eq!(hdr.get_size(), 0);
+ assert_eq!(hdr.get_version(), 0x1);
+ assert!(rfds.is_none());
+
+ let (hdr, rfds) = slave.recv_header().unwrap();
+ assert_eq!(hdr.get_code(), MasterReq::RESET_OWNER);
+ assert_eq!(hdr.get_size(), 0);
+ assert_eq!(hdr.get_version(), 0x1);
+ assert!(rfds.is_none());
+ }
+
+ #[test]
+ fn test_create_failure() {
+ let dir = temp_dir();
+ let mut path = dir.path().to_owned();
+ path.push("sock");
+ let _ = Listener::new(&path, true).unwrap();
+ let _ = Listener::new(&path, false).is_err();
+ assert!(Master::connect(&path, 1).is_err());
+
+ let listener = Listener::new(&path, true).unwrap();
+ assert!(Listener::new(&path, false).is_err());
+ listener.set_nonblocking(true).unwrap();
+
+ let _master = Master::connect(&path, 1).unwrap();
+ let _slave = listener.accept().unwrap().unwrap();
+ }
+
+ #[test]
+ fn test_features() {
+ let dir = temp_dir();
+ let mut path = dir.path().to_owned();
+ path.push("sock");
+ let (master, mut peer) = create_pair(&path);
+
+ master.set_owner().unwrap();
+ let (hdr, rfds) = peer.recv_header().unwrap();
+ assert_eq!(hdr.get_code(), MasterReq::SET_OWNER);
+ assert_eq!(hdr.get_size(), 0);
+ assert_eq!(hdr.get_version(), 0x1);
+ assert!(rfds.is_none());
+
+ let hdr = VhostUserMsgHeader::new(MasterReq::GET_FEATURES, 0x4, 8);
+ let msg = VhostUserU64::new(0x15);
+ peer.send_message(&hdr, &msg, None).unwrap();
+ let features = master.get_features().unwrap();
+ assert_eq!(features, 0x15u64);
+ let (_hdr, rfds) = peer.recv_header().unwrap();
+ assert!(rfds.is_none());
+
+ let hdr = VhostUserMsgHeader::new(MasterReq::SET_FEATURES, 0x4, 8);
+ let msg = VhostUserU64::new(0x15);
+ peer.send_message(&hdr, &msg, None).unwrap();
+ master.set_features(0x15).unwrap();
+ let (_hdr, msg, rfds) = peer.recv_body::<VhostUserU64>().unwrap();
+ assert!(rfds.is_none());
+ let val = msg.value;
+ assert_eq!(val, 0x15);
+
+ let hdr = VhostUserMsgHeader::new(MasterReq::GET_FEATURES, 0x4, 8);
+ let msg = 0x15u32;
+ peer.send_message(&hdr, &msg, None).unwrap();
+ assert!(master.get_features().is_err());
+ }
+
+ #[test]
+ fn test_protocol_features() {
+ let dir = temp_dir();
+ let mut path = dir.path().to_owned();
+ path.push("sock");
+ let (mut master, mut peer) = create_pair(&path);
+
+ master.set_owner().unwrap();
+ let (hdr, rfds) = peer.recv_header().unwrap();
+ assert_eq!(hdr.get_code(), MasterReq::SET_OWNER);
+ assert!(rfds.is_none());
+
+ assert!(master.get_protocol_features().is_err());
+ assert!(master
+ .set_protocol_features(VhostUserProtocolFeatures::all())
+ .is_err());
+
+ let vfeatures = 0x15 | VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits();
+ let hdr = VhostUserMsgHeader::new(MasterReq::GET_FEATURES, 0x4, 8);
+ let msg = VhostUserU64::new(vfeatures);
+ peer.send_message(&hdr, &msg, None).unwrap();
+ let features = master.get_features().unwrap();
+ assert_eq!(features, vfeatures);
+ let (_hdr, rfds) = peer.recv_header().unwrap();
+ assert!(rfds.is_none());
+
+ master.set_features(vfeatures).unwrap();
+ let (_hdr, msg, rfds) = peer.recv_body::<VhostUserU64>().unwrap();
+ assert!(rfds.is_none());
+ let val = msg.value;
+ assert_eq!(val, vfeatures);
+
+ let pfeatures = VhostUserProtocolFeatures::all();
+ let hdr = VhostUserMsgHeader::new(MasterReq::GET_PROTOCOL_FEATURES, 0x4, 8);
+ let msg = VhostUserU64::new(pfeatures.bits());
+ peer.send_message(&hdr, &msg, None).unwrap();
+ let features = master.get_protocol_features().unwrap();
+ assert_eq!(features, pfeatures);
+ let (_hdr, rfds) = peer.recv_header().unwrap();
+ assert!(rfds.is_none());
+
+ master.set_protocol_features(pfeatures).unwrap();
+ let (_hdr, msg, rfds) = peer.recv_body::<VhostUserU64>().unwrap();
+ assert!(rfds.is_none());
+ let val = msg.value;
+ assert_eq!(val, pfeatures.bits());
+
+ let hdr = VhostUserMsgHeader::new(MasterReq::SET_PROTOCOL_FEATURES, 0x4, 8);
+ let msg = VhostUserU64::new(pfeatures.bits());
+ peer.send_message(&hdr, &msg, None).unwrap();
+ assert!(master.get_protocol_features().is_err());
+ }
+
+ #[test]
+ fn test_master_set_config_negative() {
+ let dir = temp_dir();
+ let mut path = dir.path().to_owned();
+ path.push("sock");
+ let (mut master, _peer) = create_pair(&path);
+ let buf = vec![0x0; MAX_MSG_SIZE + 1];
+
+ master
+ .set_config(0x100, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .unwrap_err();
+
+ {
+ let mut node = master.node();
+ node.virtio_features = 0xffff_ffff;
+ node.acked_virtio_features = 0xffff_ffff;
+ node.protocol_features = 0xffff_ffff;
+ node.acked_protocol_features = 0xffff_ffff;
+ }
+
+ master
+ .set_config(0x100, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .unwrap();
+ master
+ .set_config(0x0, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .unwrap_err();
+ master
+ .set_config(0x1000, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .unwrap_err();
+ master
+ .set_config(
+ 0x100,
+ unsafe { VhostUserConfigFlags::from_bits_unchecked(0xffff_ffff) },
+ &buf[0..4],
+ )
+ .unwrap_err();
+ master
+ .set_config(0x100, VhostUserConfigFlags::WRITABLE, &buf)
+ .unwrap_err();
+ master
+ .set_config(0x100, VhostUserConfigFlags::WRITABLE, &[])
+ .unwrap_err();
+ }
+
+ fn create_pair2() -> (Master, Endpoint<MasterReq>) {
+ let dir = temp_dir();
+ let mut path = dir.path().to_owned();
+ path.push("sock");
+ let (master, peer) = create_pair(&path);
+
+ {
+ let mut node = master.node();
+ node.virtio_features = 0xffff_ffff;
+ node.acked_virtio_features = 0xffff_ffff;
+ node.protocol_features = 0xffff_ffff;
+ node.acked_protocol_features = 0xffff_ffff;
+ }
+
+ (master, peer)
+ }
+
+ #[test]
+ fn test_master_get_config_negative0() {
+ let (mut master, mut peer) = create_pair2();
+ let buf = vec![0x0; MAX_MSG_SIZE + 1];
+
+ let mut hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16);
+ let msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty());
+ peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
+ .unwrap();
+ assert!(master
+ .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .is_ok());
+
+ hdr.set_code(MasterReq::GET_FEATURES);
+ peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
+ .unwrap();
+ assert!(master
+ .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .is_err());
+ hdr.set_code(MasterReq::GET_CONFIG);
+ }
+
+ #[test]
+ fn test_master_get_config_negative1() {
+ let (mut master, mut peer) = create_pair2();
+ let buf = vec![0x0; MAX_MSG_SIZE + 1];
+
+ let mut hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16);
+ let msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty());
+ peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
+ .unwrap();
+ assert!(master
+ .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .is_ok());
+
+ hdr.set_reply(false);
+ peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
+ .unwrap();
+ assert!(master
+ .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .is_err());
+ }
+
+ #[test]
+ fn test_master_get_config_negative2() {
+ let (mut master, mut peer) = create_pair2();
+ let buf = vec![0x0; MAX_MSG_SIZE + 1];
+
+ let hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16);
+ let msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty());
+ peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
+ .unwrap();
+ assert!(master
+ .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .is_ok());
+ }
+
+ #[test]
+ fn test_master_get_config_negative3() {
+ let (mut master, mut peer) = create_pair2();
+ let buf = vec![0x0; MAX_MSG_SIZE + 1];
+
+ let hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16);
+ let mut msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty());
+ peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
+ .unwrap();
+ assert!(master
+ .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .is_ok());
+
+ msg.offset = 0;
+ peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
+ .unwrap();
+ assert!(master
+ .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .is_err());
+ }
+
+ #[test]
+ fn test_master_get_config_negative4() {
+ let (mut master, mut peer) = create_pair2();
+ let buf = vec![0x0; MAX_MSG_SIZE + 1];
+
+ let hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16);
+ let mut msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty());
+ peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
+ .unwrap();
+ assert!(master
+ .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .is_ok());
+
+ msg.offset = 0x101;
+ peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
+ .unwrap();
+ assert!(master
+ .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .is_err());
+ }
+
+ #[test]
+ fn test_master_get_config_negative5() {
+ let (mut master, mut peer) = create_pair2();
+ let buf = vec![0x0; MAX_MSG_SIZE + 1];
+
+ let hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16);
+ let mut msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty());
+ peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
+ .unwrap();
+ assert!(master
+ .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .is_ok());
+
+ msg.offset = (MAX_MSG_SIZE + 1) as u32;
+ peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
+ .unwrap();
+ assert!(master
+ .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .is_err());
+ }
+
+ #[test]
+ fn test_master_get_config_negative6() {
+ let (mut master, mut peer) = create_pair2();
+ let buf = vec![0x0; MAX_MSG_SIZE + 1];
+
+ let hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16);
+ let mut msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty());
+ peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
+ .unwrap();
+ assert!(master
+ .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .is_ok());
+
+ msg.size = 6;
+ peer.send_message_with_payload(&hdr, &msg, &buf[0..6], None)
+ .unwrap();
+ assert!(master
+ .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .is_err());
+ }
+
+ #[test]
+ fn test_maset_set_mem_table_failure() {
+ let (master, _peer) = create_pair2();
+
+ master.set_mem_table(&[]).unwrap_err();
+ let tables = vec![VhostUserMemoryRegionInfo::default(); MAX_ATTACHED_FD_ENTRIES + 1];
+ master.set_mem_table(&tables).unwrap_err();
+ }
+}
diff --git a/src/vhost_user/master_req_handler.rs b/src/vhost_user/master_req_handler.rs
new file mode 100644
index 0000000..8cba188
--- /dev/null
+++ b/src/vhost_user/master_req_handler.rs
@@ -0,0 +1,477 @@
+// Copyright (C) 2019-2021 Alibaba Cloud. All rights reserved.
+// SPDX-License-Identifier: Apache-2.0
+
+use std::mem;
+use std::os::unix::io::{AsRawFd, RawFd};
+use std::os::unix::net::UnixStream;
+use std::sync::{Arc, Mutex};
+
+use super::connection::Endpoint;
+use super::message::*;
+use super::{Error, HandlerResult, Result};
+
+/// Define services provided by masters for the slave communication channel.
+///
+/// The vhost-user specification defines a slave communication channel, by which slaves could
+/// request services from masters. The [VhostUserMasterReqHandler] trait defines services provided
+/// by masters, and it's used both on the master side and slave side.
+/// - on the slave side, a stub forwarder implementing [VhostUserMasterReqHandler] will proxy
+/// service requests to masters. The [SlaveFsCacheReq] is an example stub forwarder.
+/// - on the master side, the [MasterReqHandler] will forward service requests to a handler
+/// implementing [VhostUserMasterReqHandler].
+///
+/// The [VhostUserMasterReqHandler] trait is design with interior mutability to improve performance
+/// for multi-threading.
+///
+/// [VhostUserMasterReqHandler]: trait.VhostUserMasterReqHandler.html
+/// [MasterReqHandler]: struct.MasterReqHandler.html
+/// [SlaveFsCacheReq]: struct.SlaveFsCacheReq.html
+pub trait VhostUserMasterReqHandler {
+ /// Handle device configuration change notifications.
+ fn handle_config_change(&self) -> HandlerResult<u64> {
+ Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
+ }
+
+ /// Handle virtio-fs map file requests.
+ fn fs_slave_map(&self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> {
+ // Safe because we have just received the rawfd from kernel.
+ unsafe { libc::close(fd) };
+ Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
+ }
+
+ /// Handle virtio-fs unmap file requests.
+ fn fs_slave_unmap(&self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
+ Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
+ }
+
+ /// Handle virtio-fs sync file requests.
+ fn fs_slave_sync(&self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
+ Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
+ }
+
+ /// Handle virtio-fs file IO requests.
+ fn fs_slave_io(&self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> {
+ // Safe because we have just received the rawfd from kernel.
+ unsafe { libc::close(fd) };
+ Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
+ }
+
+ // fn handle_iotlb_msg(&mut self, iotlb: VhostUserIotlb);
+ // fn handle_vring_host_notifier(&mut self, area: VhostUserVringArea, fd: RawFd);
+}
+
+/// A helper trait mirroring [VhostUserMasterReqHandler] but without interior mutability.
+///
+/// [VhostUserMasterReqHandler]: trait.VhostUserMasterReqHandler.html
+pub trait VhostUserMasterReqHandlerMut {
+ /// Handle device configuration change notifications.
+ fn handle_config_change(&mut self) -> HandlerResult<u64> {
+ Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
+ }
+
+ /// Handle virtio-fs map file requests.
+ fn fs_slave_map(&mut self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> {
+ // Safe because we have just received the rawfd from kernel.
+ unsafe { libc::close(fd) };
+ Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
+ }
+
+ /// Handle virtio-fs unmap file requests.
+ fn fs_slave_unmap(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
+ Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
+ }
+
+ /// Handle virtio-fs sync file requests.
+ fn fs_slave_sync(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
+ Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
+ }
+
+ /// Handle virtio-fs file IO requests.
+ fn fs_slave_io(&mut self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> {
+ // Safe because we have just received the rawfd from kernel.
+ unsafe { libc::close(fd) };
+ Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
+ }
+
+ // fn handle_iotlb_msg(&mut self, iotlb: VhostUserIotlb);
+ // fn handle_vring_host_notifier(&mut self, area: VhostUserVringArea, fd: RawFd);
+}
+
+impl<S: VhostUserMasterReqHandlerMut> VhostUserMasterReqHandler for Mutex<S> {
+ fn handle_config_change(&self) -> HandlerResult<u64> {
+ self.lock().unwrap().handle_config_change()
+ }
+
+ fn fs_slave_map(&self, fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> {
+ self.lock().unwrap().fs_slave_map(fs, fd)
+ }
+
+ fn fs_slave_unmap(&self, fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
+ self.lock().unwrap().fs_slave_unmap(fs)
+ }
+
+ fn fs_slave_sync(&self, fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
+ self.lock().unwrap().fs_slave_sync(fs)
+ }
+
+ fn fs_slave_io(&self, fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> {
+ self.lock().unwrap().fs_slave_io(fs, fd)
+ }
+}
+
+/// Server to handle service requests from slaves from the slave communication channel.
+///
+/// The [MasterReqHandler] acts as a server on the master side, to handle service requests from
+/// slaves on the slave communication channel. It's actually a proxy invoking the registered
+/// handler implementing [VhostUserMasterReqHandler] to do the real work.
+///
+/// [MasterReqHandler]: struct.MasterReqHandler.html
+/// [VhostUserMasterReqHandler]: trait.VhostUserMasterReqHandler.html
+pub struct MasterReqHandler<S: VhostUserMasterReqHandler> {
+ // underlying Unix domain socket for communication
+ sub_sock: Endpoint<SlaveReq>,
+ tx_sock: UnixStream,
+ // Protocol feature VHOST_USER_PROTOCOL_F_REPLY_ACK has been negotiated.
+ reply_ack_negotiated: bool,
+ // the VirtIO backend device object
+ backend: Arc<S>,
+ // whether the endpoint has encountered any failure
+ error: Option<i32>,
+}
+
+impl<S: VhostUserMasterReqHandler> MasterReqHandler<S> {
+ /// Create a server to handle service requests from slaves on the slave communication channel.
+ ///
+ /// This opens a pair of connected anonymous sockets to form the slave communication channel.
+ /// The socket fd returned by [Self::get_tx_raw_fd()] should be sent to the slave by
+ /// [VhostUserMaster::set_slave_request_fd()].
+ ///
+ /// [Self::get_tx_raw_fd()]: struct.MasterReqHandler.html#method.get_tx_raw_fd
+ /// [VhostUserMaster::set_slave_request_fd()]: trait.VhostUserMaster.html#tymethod.set_slave_request_fd
+ pub fn new(backend: Arc<S>) -> Result<Self> {
+ let (tx, rx) = UnixStream::pair().map_err(Error::SocketError)?;
+
+ Ok(MasterReqHandler {
+ sub_sock: Endpoint::<SlaveReq>::from_stream(rx),
+ tx_sock: tx,
+ reply_ack_negotiated: false,
+ backend,
+ error: None,
+ })
+ }
+
+ /// Get the socket fd for the slave to communication with the master.
+ ///
+ /// The returned fd should be sent to the slave by [VhostUserMaster::set_slave_request_fd()].
+ ///
+ /// [VhostUserMaster::set_slave_request_fd()]: trait.VhostUserMaster.html#tymethod.set_slave_request_fd
+ pub fn get_tx_raw_fd(&self) -> RawFd {
+ self.tx_sock.as_raw_fd()
+ }
+
+ /// Set the negotiation state of the `VHOST_USER_PROTOCOL_F_REPLY_ACK` protocol feature.
+ ///
+ /// When the `VHOST_USER_PROTOCOL_F_REPLY_ACK` protocol feature has been negotiated,
+ /// the "REPLY_ACK" flag will be set in the message header for every slave to master request
+ /// message.
+ pub fn set_reply_ack_flag(&mut self, enable: bool) {
+ self.reply_ack_negotiated = enable;
+ }
+
+ /// Mark endpoint as failed or in normal state.
+ pub fn set_failed(&mut self, error: i32) {
+ if error == 0 {
+ self.error = None;
+ } else {
+ self.error = Some(error);
+ }
+ }
+
+ /// Main entrance to server slave request from the slave communication channel.
+ ///
+ /// The caller needs to:
+ /// - serialize calls to this function
+ /// - decide what to do when errer happens
+ /// - optional recover from failure
+ pub fn handle_request(&mut self) -> Result<u64> {
+ // Return error if the endpoint is already in failed state.
+ self.check_state()?;
+
+ // The underlying communication channel is a Unix domain socket in
+ // stream mode, and recvmsg() is a little tricky here. To successfully
+ // receive attached file descriptors, we need to receive messages and
+ // corresponding attached file descriptors in this way:
+ // . recv messsage header and optional attached file
+ // . validate message header
+ // . recv optional message body and payload according size field in
+ // message header
+ // . validate message body and optional payload
+ let (hdr, rfds) = self.sub_sock.recv_header()?;
+ let rfds = self.check_attached_rfds(&hdr, rfds)?;
+ let (size, buf) = match hdr.get_size() {
+ 0 => (0, vec![0u8; 0]),
+ len => {
+ if len as usize > MAX_MSG_SIZE {
+ return Err(Error::InvalidMessage);
+ }
+ let (size2, rbuf) = self.sub_sock.recv_data(len as usize)?;
+ if size2 != len as usize {
+ return Err(Error::InvalidMessage);
+ }
+ (size2, rbuf)
+ }
+ };
+
+ let res = match hdr.get_code() {
+ SlaveReq::CONFIG_CHANGE_MSG => {
+ self.check_msg_size(&hdr, size, 0)?;
+ self.backend
+ .handle_config_change()
+ .map_err(Error::ReqHandlerError)
+ }
+ SlaveReq::FS_MAP => {
+ let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?;
+ // check_attached_rfds() has validated rfds
+ self.backend
+ .fs_slave_map(&msg, rfds.unwrap()[0])
+ .map_err(Error::ReqHandlerError)
+ }
+ SlaveReq::FS_UNMAP => {
+ let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?;
+ self.backend
+ .fs_slave_unmap(&msg)
+ .map_err(Error::ReqHandlerError)
+ }
+ SlaveReq::FS_SYNC => {
+ let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?;
+ self.backend
+ .fs_slave_sync(&msg)
+ .map_err(Error::ReqHandlerError)
+ }
+ SlaveReq::FS_IO => {
+ let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?;
+ // check_attached_rfds() has validated rfds
+ self.backend
+ .fs_slave_io(&msg, rfds.unwrap()[0])
+ .map_err(Error::ReqHandlerError)
+ }
+ _ => Err(Error::InvalidMessage),
+ };
+
+ self.send_ack_message(&hdr, &res)?;
+
+ res
+ }
+
+ fn check_state(&self) -> Result<()> {
+ match self.error {
+ Some(e) => Err(Error::SocketBroken(std::io::Error::from_raw_os_error(e))),
+ None => Ok(()),
+ }
+ }
+
+ fn check_msg_size(
+ &self,
+ hdr: &VhostUserMsgHeader<SlaveReq>,
+ size: usize,
+ expected: usize,
+ ) -> Result<()> {
+ if hdr.get_size() as usize != expected
+ || hdr.is_reply()
+ || hdr.get_version() != 0x1
+ || size != expected
+ {
+ return Err(Error::InvalidMessage);
+ }
+ Ok(())
+ }
+
+ fn check_attached_rfds(
+ &self,
+ hdr: &VhostUserMsgHeader<SlaveReq>,
+ rfds: Option<Vec<RawFd>>,
+ ) -> Result<Option<Vec<RawFd>>> {
+ match hdr.get_code() {
+ SlaveReq::FS_MAP | SlaveReq::FS_IO => {
+ // Expect an fd set with a single fd.
+ match rfds {
+ None => Err(Error::InvalidMessage),
+ Some(fds) => {
+ if fds.len() != 1 {
+ Endpoint::<SlaveReq>::close_rfds(Some(fds));
+ Err(Error::InvalidMessage)
+ } else {
+ Ok(Some(fds))
+ }
+ }
+ }
+ }
+ _ => {
+ if rfds.is_some() {
+ Endpoint::<SlaveReq>::close_rfds(rfds);
+ Err(Error::InvalidMessage)
+ } else {
+ Ok(rfds)
+ }
+ }
+ }
+ }
+
+ fn extract_msg_body<T: Sized + VhostUserMsgValidator>(
+ &self,
+ hdr: &VhostUserMsgHeader<SlaveReq>,
+ size: usize,
+ buf: &[u8],
+ ) -> Result<T> {
+ self.check_msg_size(hdr, size, mem::size_of::<T>())?;
+ let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const T) };
+ if !msg.is_valid() {
+ return Err(Error::InvalidMessage);
+ }
+ Ok(msg)
+ }
+
+ fn new_reply_header<T: Sized>(
+ &self,
+ req: &VhostUserMsgHeader<SlaveReq>,
+ ) -> Result<VhostUserMsgHeader<SlaveReq>> {
+ if mem::size_of::<T>() > MAX_MSG_SIZE {
+ return Err(Error::InvalidParam);
+ }
+ self.check_state()?;
+ Ok(VhostUserMsgHeader::new(
+ req.get_code(),
+ VhostUserHeaderFlag::REPLY.bits(),
+ mem::size_of::<T>() as u32,
+ ))
+ }
+
+ fn send_ack_message(
+ &mut self,
+ req: &VhostUserMsgHeader<SlaveReq>,
+ res: &Result<u64>,
+ ) -> Result<()> {
+ if self.reply_ack_negotiated && req.is_need_reply() {
+ let hdr = self.new_reply_header::<VhostUserU64>(req)?;
+ let def_err = libc::EINVAL;
+ let val = match res {
+ Ok(n) => *n,
+ Err(e) => match &*e {
+ Error::ReqHandlerError(ioerr) => match ioerr.raw_os_error() {
+ Some(rawerr) => -rawerr as u64,
+ None => -def_err as u64,
+ },
+ _ => -def_err as u64,
+ },
+ };
+ let msg = VhostUserU64::new(val);
+ self.sub_sock.send_message(&hdr, &msg, None)?;
+ }
+ Ok(())
+ }
+}
+
+impl<S: VhostUserMasterReqHandler> AsRawFd for MasterReqHandler<S> {
+ fn as_raw_fd(&self) -> RawFd {
+ self.sub_sock.as_raw_fd()
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[cfg(feature = "vhost-user-slave")]
+ use crate::vhost_user::SlaveFsCacheReq;
+ #[cfg(feature = "vhost-user-slave")]
+ use std::os::unix::io::FromRawFd;
+
+ struct MockMasterReqHandler {}
+
+ impl VhostUserMasterReqHandlerMut for MockMasterReqHandler {
+ /// Handle virtio-fs map file requests from the slave.
+ fn fs_slave_map(&mut self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> {
+ // Safe because we have just received the rawfd from kernel.
+ unsafe { libc::close(fd) };
+ Ok(0)
+ }
+
+ /// Handle virtio-fs unmap file requests from the slave.
+ fn fs_slave_unmap(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
+ Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
+ }
+ }
+
+ #[test]
+ fn test_new_master_req_handler() {
+ let backend = Arc::new(Mutex::new(MockMasterReqHandler {}));
+ let mut handler = MasterReqHandler::new(backend).unwrap();
+
+ assert!(handler.get_tx_raw_fd() >= 0);
+ assert!(handler.as_raw_fd() >= 0);
+ handler.check_state().unwrap();
+
+ assert_eq!(handler.error, None);
+ handler.set_failed(libc::EAGAIN);
+ assert_eq!(handler.error, Some(libc::EAGAIN));
+ handler.check_state().unwrap_err();
+ }
+
+ #[cfg(feature = "vhost-user-slave")]
+ #[test]
+ fn test_master_slave_req_handler() {
+ let backend = Arc::new(Mutex::new(MockMasterReqHandler {}));
+ let mut handler = MasterReqHandler::new(backend).unwrap();
+
+ let fd = unsafe { libc::dup(handler.get_tx_raw_fd()) };
+ if fd < 0 {
+ panic!("failed to duplicated tx fd!");
+ }
+ let stream = unsafe { UnixStream::from_raw_fd(fd) };
+ let fs_cache = SlaveFsCacheReq::from_stream(stream);
+
+ std::thread::spawn(move || {
+ let res = handler.handle_request().unwrap();
+ assert_eq!(res, 0);
+ handler.handle_request().unwrap_err();
+ });
+
+ fs_cache
+ .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd)
+ .unwrap();
+ // When REPLY_ACK has not been negotiated, the master has no way to detect failure from
+ // slave side.
+ fs_cache
+ .fs_slave_unmap(&VhostUserFSSlaveMsg::default())
+ .unwrap();
+ }
+
+ #[cfg(feature = "vhost-user-slave")]
+ #[test]
+ fn test_master_slave_req_handler_with_ack() {
+ let backend = Arc::new(Mutex::new(MockMasterReqHandler {}));
+ let mut handler = MasterReqHandler::new(backend).unwrap();
+ handler.set_reply_ack_flag(true);
+
+ let fd = unsafe { libc::dup(handler.get_tx_raw_fd()) };
+ if fd < 0 {
+ panic!("failed to duplicated tx fd!");
+ }
+ let stream = unsafe { UnixStream::from_raw_fd(fd) };
+ let fs_cache = SlaveFsCacheReq::from_stream(stream);
+
+ std::thread::spawn(move || {
+ let res = handler.handle_request().unwrap();
+ assert_eq!(res, 0);
+ handler.handle_request().unwrap_err();
+ });
+
+ fs_cache.set_reply_ack_flag(true);
+ fs_cache
+ .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd)
+ .unwrap();
+ fs_cache
+ .fs_slave_unmap(&VhostUserFSSlaveMsg::default())
+ .unwrap_err();
+ }
+}
diff --git a/src/vhost_user/message.rs b/src/vhost_user/message.rs
new file mode 100644
index 0000000..ea2df4e
--- /dev/null
+++ b/src/vhost_user/message.rs
@@ -0,0 +1,1042 @@
+// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved.
+// SPDX-License-Identifier: Apache-2.0
+
+//! Define communication messages for the vhost-user protocol.
+//!
+//! For message definition, please refer to the [vhost-user spec](https://github.com/qemu/qemu/blob/f7526eece29cd2e36a63b6703508b24453095eb8/docs/interop/vhost-user.txt).
+
+#![allow(dead_code)]
+#![allow(non_camel_case_types)]
+
+use std::fmt::Debug;
+use std::marker::PhantomData;
+
+use crate::VringConfigData;
+
+/// The vhost-user specification uses a field of u32 to store message length.
+/// On the other hand, preallocated buffers are needed to receive messages from the Unix domain
+/// socket. To preallocating a 4GB buffer for each vhost-user message is really just an overhead.
+/// Among all defined vhost-user messages, only the VhostUserConfig and VhostUserMemory has variable
+/// message size. For the VhostUserConfig, a maximum size of 4K is enough because the user
+/// configuration space for virtio devices is (4K - 0x100) bytes at most. For the VhostUserMemory,
+/// 4K should be enough too because it can support 255 memory regions at most.
+pub const MAX_MSG_SIZE: usize = 0x1000;
+
+/// The VhostUserMemory message has variable message size and variable number of attached file
+/// descriptors. Each user memory region entry in the message payload occupies 32 bytes,
+/// so setting maximum number of attached file descriptors based on the maximum message size.
+/// But rust only implements Default and AsMut traits for arrays with 0 - 32 entries, so further
+/// reduce the maximum number...
+// pub const MAX_ATTACHED_FD_ENTRIES: usize = (MAX_MSG_SIZE - 8) / 32;
+pub const MAX_ATTACHED_FD_ENTRIES: usize = 32;
+
+/// Starting position (inclusion) of the device configuration space in virtio devices.
+pub const VHOST_USER_CONFIG_OFFSET: u32 = 0x100;
+
+/// Ending position (exclusion) of the device configuration space in virtio devices.
+pub const VHOST_USER_CONFIG_SIZE: u32 = 0x1000;
+
+/// Maximum number of vrings supported.
+pub const VHOST_USER_MAX_VRINGS: u64 = 0x8000u64;
+
+pub(super) trait Req:
+ Clone + Copy + Debug + PartialEq + Eq + PartialOrd + Ord + Into<u32>
+{
+ fn is_valid(&self) -> bool;
+}
+
+/// Type of requests sending from masters to slaves.
+#[repr(u32)]
+#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
+pub enum MasterReq {
+ /// Null operation.
+ NOOP = 0,
+ /// Get from the underlying vhost implementation the features bit mask.
+ GET_FEATURES = 1,
+ /// Enable features in the underlying vhost implementation using a bit mask.
+ SET_FEATURES = 2,
+ /// Set the current Master as an owner of the session.
+ SET_OWNER = 3,
+ /// No longer used.
+ RESET_OWNER = 4,
+ /// Set the memory map regions on the slave so it can translate the vring addresses.
+ SET_MEM_TABLE = 5,
+ /// Set logging shared memory space.
+ SET_LOG_BASE = 6,
+ /// Set the logging file descriptor, which is passed as ancillary data.
+ SET_LOG_FD = 7,
+ /// Set the size of the queue.
+ SET_VRING_NUM = 8,
+ /// Set the addresses of the different aspects of the vring.
+ SET_VRING_ADDR = 9,
+ /// Set the base offset in the available vring.
+ SET_VRING_BASE = 10,
+ /// Get the available vring base offset.
+ GET_VRING_BASE = 11,
+ /// Set the event file descriptor for adding buffers to the vring.
+ SET_VRING_KICK = 12,
+ /// Set the event file descriptor to signal when buffers are used.
+ SET_VRING_CALL = 13,
+ /// Set the event file descriptor to signal when error occurs.
+ SET_VRING_ERR = 14,
+ /// Get the protocol feature bit mask from the underlying vhost implementation.
+ GET_PROTOCOL_FEATURES = 15,
+ /// Enable protocol features in the underlying vhost implementation.
+ SET_PROTOCOL_FEATURES = 16,
+ /// Query how many queues the backend supports.
+ GET_QUEUE_NUM = 17,
+ /// Signal slave to enable or disable corresponding vring.
+ SET_VRING_ENABLE = 18,
+ /// Ask vhost user backend to broadcast a fake RARP to notify the migration is terminated
+ /// for guest that does not support GUEST_ANNOUNCE.
+ SEND_RARP = 19,
+ /// Set host MTU value exposed to the guest.
+ NET_SET_MTU = 20,
+ /// Set the socket file descriptor for slave initiated requests.
+ SET_SLAVE_REQ_FD = 21,
+ /// Send IOTLB messages with struct vhost_iotlb_msg as payload.
+ IOTLB_MSG = 22,
+ /// Set the endianness of a VQ for legacy devices.
+ SET_VRING_ENDIAN = 23,
+ /// Fetch the contents of the virtio device configuration space.
+ GET_CONFIG = 24,
+ /// Change the contents of the virtio device configuration space.
+ SET_CONFIG = 25,
+ /// Create a session for crypto operation.
+ CREATE_CRYPTO_SESSION = 26,
+ /// Close a session for crypto operation.
+ CLOSE_CRYPTO_SESSION = 27,
+ /// Advise slave that a migration with postcopy enabled is underway.
+ POSTCOPY_ADVISE = 28,
+ /// Advise slave that a transition to postcopy mode has happened.
+ POSTCOPY_LISTEN = 29,
+ /// Advise that postcopy migration has now completed.
+ POSTCOPY_END = 30,
+ /// Get a shared buffer from slave.
+ GET_INFLIGHT_FD = 31,
+ /// Send the shared inflight buffer back to slave.
+ SET_INFLIGHT_FD = 32,
+ /// Sets the GPU protocol socket file descriptor.
+ GPU_SET_SOCKET = 33,
+ /// Ask the vhost user backend to disable all rings and reset all internal
+ /// device state to the initial state.
+ RESET_DEVICE = 34,
+ /// Indicate that a buffer was added to the vring instead of signalling it
+ /// using the vring’s kick file descriptor.
+ VRING_KICK = 35,
+ /// Return a u64 payload containing the maximum number of memory slots.
+ GET_MAX_MEM_SLOTS = 36,
+ /// Update the memory tables by adding the region described.
+ ADD_MEM_REG = 37,
+ /// Update the memory tables by removing the region described.
+ REM_MEM_REG = 38,
+ /// Notify the backend with updated device status as defined in the VIRTIO
+ /// specification.
+ SET_STATUS = 39,
+ /// Query the backend for its device status as defined in the VIRTIO
+ /// specification.
+ GET_STATUS = 40,
+ /// Upper bound of valid commands.
+ MAX_CMD = 41,
+}
+
+impl Into<u32> for MasterReq {
+ fn into(self) -> u32 {
+ self as u32
+ }
+}
+
+impl Req for MasterReq {
+ fn is_valid(&self) -> bool {
+ (*self > MasterReq::NOOP) && (*self < MasterReq::MAX_CMD)
+ }
+}
+
+/// Type of requests sending from slaves to masters.
+#[repr(u32)]
+#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
+pub enum SlaveReq {
+ /// Null operation.
+ NOOP = 0,
+ /// Send IOTLB messages with struct vhost_iotlb_msg as payload.
+ IOTLB_MSG = 1,
+ /// Notify that the virtio device's configuration space has changed.
+ CONFIG_CHANGE_MSG = 2,
+ /// Set host notifier for a specified queue.
+ VRING_HOST_NOTIFIER_MSG = 3,
+ /// Indicate that a buffer was used from the vring.
+ VRING_CALL = 4,
+ /// Indicate that an error occurred on the specific vring.
+ VRING_ERR = 5,
+ /// Virtio-fs draft: map file content into the window.
+ FS_MAP = 6,
+ /// Virtio-fs draft: unmap file content from the window.
+ FS_UNMAP = 7,
+ /// Virtio-fs draft: sync file content.
+ FS_SYNC = 8,
+ /// Virtio-fs draft: perform a read/write from an fd directly to GPA.
+ FS_IO = 9,
+ /// Upper bound of valid commands.
+ MAX_CMD = 10,
+}
+
+impl Into<u32> for SlaveReq {
+ fn into(self) -> u32 {
+ self as u32
+ }
+}
+
+impl Req for SlaveReq {
+ fn is_valid(&self) -> bool {
+ (*self > SlaveReq::NOOP) && (*self < SlaveReq::MAX_CMD)
+ }
+}
+
+/// Vhost message Validator.
+pub trait VhostUserMsgValidator {
+ /// Validate message syntax only.
+ /// It doesn't validate message semantics such as protocol version number and dependency
+ /// on feature flags etc.
+ fn is_valid(&self) -> bool {
+ true
+ }
+}
+
+// Bit mask for common message flags.
+bitflags! {
+ /// Common message flags for vhost-user requests and replies.
+ pub struct VhostUserHeaderFlag: u32 {
+ /// Bits[0..2] is message version number.
+ const VERSION = 0x3;
+ /// Mark message as reply.
+ const REPLY = 0x4;
+ /// Sender anticipates a reply message from the peer.
+ const NEED_REPLY = 0x8;
+ /// All valid bits.
+ const ALL_FLAGS = 0xc;
+ /// All reserved bits.
+ const RESERVED_BITS = !0xf;
+ }
+}
+
+/// Common message header for vhost-user requests and replies.
+/// A vhost-user message consists of 3 header fields and an optional payload. All numbers are in the
+/// machine native byte order.
+#[allow(safe_packed_borrows)]
+#[repr(packed)]
+#[derive(Debug, Clone, Copy, PartialEq)]
+pub(super) struct VhostUserMsgHeader<R: Req> {
+ request: u32,
+ flags: u32,
+ size: u32,
+ _r: PhantomData<R>,
+}
+
+impl<R: Req> VhostUserMsgHeader<R> {
+ /// Create a new instance of `VhostUserMsgHeader`.
+ pub fn new(request: R, flags: u32, size: u32) -> Self {
+ // Default to protocol version 1
+ let fl = (flags & VhostUserHeaderFlag::ALL_FLAGS.bits()) | 0x1;
+ VhostUserMsgHeader {
+ request: request.into(),
+ flags: fl,
+ size,
+ _r: PhantomData,
+ }
+ }
+
+ /// Get message type.
+ pub fn get_code(&self) -> R {
+ // It's safe because R is marked as repr(u32).
+ unsafe { std::mem::transmute_copy::<u32, R>(&self.request) }
+ }
+
+ /// Set message type.
+ pub fn set_code(&mut self, request: R) {
+ self.request = request.into();
+ }
+
+ /// Get message version number.
+ pub fn get_version(&self) -> u32 {
+ self.flags & 0x3
+ }
+
+ /// Set message version number.
+ pub fn set_version(&mut self, ver: u32) {
+ self.flags &= !0x3;
+ self.flags |= ver & 0x3;
+ }
+
+ /// Check whether it's a reply message.
+ pub fn is_reply(&self) -> bool {
+ (self.flags & VhostUserHeaderFlag::REPLY.bits()) != 0
+ }
+
+ /// Mark message as reply.
+ pub fn set_reply(&mut self, is_reply: bool) {
+ if is_reply {
+ self.flags |= VhostUserHeaderFlag::REPLY.bits();
+ } else {
+ self.flags &= !VhostUserHeaderFlag::REPLY.bits();
+ }
+ }
+
+ /// Check whether reply for this message is requested.
+ pub fn is_need_reply(&self) -> bool {
+ (self.flags & VhostUserHeaderFlag::NEED_REPLY.bits()) != 0
+ }
+
+ /// Mark that reply for this message is needed.
+ pub fn set_need_reply(&mut self, need_reply: bool) {
+ if need_reply {
+ self.flags |= VhostUserHeaderFlag::NEED_REPLY.bits();
+ } else {
+ self.flags &= !VhostUserHeaderFlag::NEED_REPLY.bits();
+ }
+ }
+
+ /// Check whether it's the reply message for the request `req`.
+ pub fn is_reply_for(&self, req: &VhostUserMsgHeader<R>) -> bool {
+ self.is_reply() && !req.is_reply() && self.get_code() == req.get_code()
+ }
+
+ /// Get message size.
+ pub fn get_size(&self) -> u32 {
+ self.size
+ }
+
+ /// Set message size.
+ pub fn set_size(&mut self, size: u32) {
+ self.size = size;
+ }
+}
+
+impl<R: Req> Default for VhostUserMsgHeader<R> {
+ fn default() -> Self {
+ VhostUserMsgHeader {
+ request: 0,
+ flags: 0x1,
+ size: 0,
+ _r: PhantomData,
+ }
+ }
+}
+
+impl<T: Req> VhostUserMsgValidator for VhostUserMsgHeader<T> {
+ #[allow(clippy::if_same_then_else)]
+ fn is_valid(&self) -> bool {
+ if !self.get_code().is_valid() {
+ return false;
+ } else if self.size as usize > MAX_MSG_SIZE {
+ return false;
+ } else if self.get_version() != 0x1 {
+ return false;
+ } else if (self.flags & VhostUserHeaderFlag::RESERVED_BITS.bits()) != 0 {
+ return false;
+ }
+ true
+ }
+}
+
+// Bit mask for transport specific flags in VirtIO feature set defined by vhost-user.
+bitflags! {
+ /// Transport specific flags in VirtIO feature set defined by vhost-user.
+ pub struct VhostUserVirtioFeatures: u64 {
+ /// Feature flag for the protocol feature.
+ const PROTOCOL_FEATURES = 0x4000_0000;
+ }
+}
+
+// Bit mask for vhost-user protocol feature flags.
+bitflags! {
+ /// Vhost-user protocol feature flags.
+ pub struct VhostUserProtocolFeatures: u64 {
+ /// Support multiple queues.
+ const MQ = 0x0000_0001;
+ /// Support logging through shared memory fd.
+ const LOG_SHMFD = 0x0000_0002;
+ /// Support broadcasting fake RARP packet.
+ const RARP = 0x0000_0004;
+ /// Support sending reply messages for requests with NEED_REPLY flag set.
+ const REPLY_ACK = 0x0000_0008;
+ /// Support setting MTU for virtio-net devices.
+ const MTU = 0x0000_0010;
+ /// Allow the slave to send requests to the master by an optional communication channel.
+ const SLAVE_REQ = 0x0000_0020;
+ /// Support setting slave endian by SET_VRING_ENDIAN.
+ const CROSS_ENDIAN = 0x0000_0040;
+ /// Support crypto operations.
+ const CRYPTO_SESSION = 0x0000_0080;
+ /// Support sending userfault_fd from slaves to masters.
+ const PAGEFAULT = 0x0000_0100;
+ /// Support Virtio device configuration.
+ const CONFIG = 0x0000_0200;
+ /// Allow the slave to send fds (at most 8 descriptors in each message) to the master.
+ const SLAVE_SEND_FD = 0x0000_0400;
+ /// Allow the slave to register a host notifier.
+ const HOST_NOTIFIER = 0x0000_0800;
+ /// Support inflight shmfd.
+ const INFLIGHT_SHMFD = 0x0000_1000;
+ /// Support resetting the device.
+ const RESET_DEVICE = 0x0000_2000;
+ /// Support inband notifications.
+ const INBAND_NOTIFICATIONS = 0x0000_4000;
+ /// Support configuring memory slots.
+ const CONFIGURE_MEM_SLOTS = 0x0000_8000;
+ /// Support reporting status.
+ const STATUS = 0x0001_0000;
+ }
+}
+
+/// A generic message to encapsulate a 64-bit value.
+#[repr(packed)]
+#[derive(Default)]
+pub struct VhostUserU64 {
+ /// The encapsulated 64-bit common value.
+ pub value: u64,
+}
+
+impl VhostUserU64 {
+ /// Create a new instance.
+ pub fn new(value: u64) -> Self {
+ VhostUserU64 { value }
+ }
+}
+
+impl VhostUserMsgValidator for VhostUserU64 {}
+
+/// Memory region descriptor for the SET_MEM_TABLE request.
+#[repr(packed)]
+#[derive(Default)]
+pub struct VhostUserMemory {
+ /// Number of memory regions in the payload.
+ pub num_regions: u32,
+ /// Padding for alignment.
+ pub padding1: u32,
+}
+
+impl VhostUserMemory {
+ /// Create a new instance.
+ pub fn new(cnt: u32) -> Self {
+ VhostUserMemory {
+ num_regions: cnt,
+ padding1: 0,
+ }
+ }
+}
+
+impl VhostUserMsgValidator for VhostUserMemory {
+ #[allow(clippy::if_same_then_else)]
+ fn is_valid(&self) -> bool {
+ if self.padding1 != 0 {
+ return false;
+ } else if self.num_regions == 0 || self.num_regions > MAX_ATTACHED_FD_ENTRIES as u32 {
+ return false;
+ }
+ true
+ }
+}
+
+/// Memory region descriptors as payload for the SET_MEM_TABLE request.
+#[repr(packed)]
+#[derive(Default, Clone, Copy)]
+pub struct VhostUserMemoryRegion {
+ /// Guest physical address of the memory region.
+ pub guest_phys_addr: u64,
+ /// Size of the memory region.
+ pub memory_size: u64,
+ /// Virtual address in the current process.
+ pub user_addr: u64,
+ /// Offset where region starts in the mapped memory.
+ pub mmap_offset: u64,
+}
+
+impl VhostUserMemoryRegion {
+ /// Create a new instance.
+ pub fn new(guest_phys_addr: u64, memory_size: u64, user_addr: u64, mmap_offset: u64) -> Self {
+ VhostUserMemoryRegion {
+ guest_phys_addr,
+ memory_size,
+ user_addr,
+ mmap_offset,
+ }
+ }
+}
+
+impl VhostUserMsgValidator for VhostUserMemoryRegion {
+ fn is_valid(&self) -> bool {
+ if self.memory_size == 0
+ || self.guest_phys_addr.checked_add(self.memory_size).is_none()
+ || self.user_addr.checked_add(self.memory_size).is_none()
+ || self.mmap_offset.checked_add(self.memory_size).is_none()
+ {
+ return false;
+ }
+ true
+ }
+}
+
+/// Payload of the VhostUserMemory message.
+pub type VhostUserMemoryPayload = Vec<VhostUserMemoryRegion>;
+
+/// Single memory region descriptor as payload for ADD_MEM_REG and REM_MEM_REG
+/// requests.
+#[repr(C)]
+#[derive(Default, Clone, Copy)]
+pub struct VhostUserSingleMemoryRegion {
+ /// Padding for correct alignment
+ padding: u64,
+ /// Guest physical address of the memory region.
+ pub guest_phys_addr: u64,
+ /// Size of the memory region.
+ pub memory_size: u64,
+ /// Virtual address in the current process.
+ pub user_addr: u64,
+ /// Offset where region starts in the mapped memory.
+ pub mmap_offset: u64,
+}
+
+impl VhostUserSingleMemoryRegion {
+ /// Create a new instance.
+ pub fn new(guest_phys_addr: u64, memory_size: u64, user_addr: u64, mmap_offset: u64) -> Self {
+ VhostUserSingleMemoryRegion {
+ padding: 0,
+ guest_phys_addr,
+ memory_size,
+ user_addr,
+ mmap_offset,
+ }
+ }
+}
+
+impl VhostUserMsgValidator for VhostUserSingleMemoryRegion {
+ fn is_valid(&self) -> bool {
+ if self.memory_size == 0
+ || self.guest_phys_addr.checked_add(self.memory_size).is_none()
+ || self.user_addr.checked_add(self.memory_size).is_none()
+ || self.mmap_offset.checked_add(self.memory_size).is_none()
+ {
+ return false;
+ }
+ true
+ }
+}
+
+/// Vring state descriptor.
+#[repr(packed)]
+#[derive(Default)]
+pub struct VhostUserVringState {
+ /// Vring index.
+ pub index: u32,
+ /// A common 32bit value to encapsulate vring state etc.
+ pub num: u32,
+}
+
+impl VhostUserVringState {
+ /// Create a new instance.
+ pub fn new(index: u32, num: u32) -> Self {
+ VhostUserVringState { index, num }
+ }
+}
+
+impl VhostUserMsgValidator for VhostUserVringState {}
+
+// Bit mask for vring address flags.
+bitflags! {
+ /// Flags for vring address.
+ pub struct VhostUserVringAddrFlags: u32 {
+ /// Support log of vring operations.
+ /// Modifications to "used" vring should be logged.
+ const VHOST_VRING_F_LOG = 0x1;
+ }
+}
+
+/// Vring address descriptor.
+#[repr(packed)]
+#[derive(Default)]
+pub struct VhostUserVringAddr {
+ /// Vring index.
+ pub index: u32,
+ /// Vring flags defined by VhostUserVringAddrFlags.
+ pub flags: u32,
+ /// Ring address of the vring descriptor table.
+ pub descriptor: u64,
+ /// Ring address of the vring used ring.
+ pub used: u64,
+ /// Ring address of the vring available ring.
+ pub available: u64,
+ /// Guest address for logging.
+ pub log: u64,
+}
+
+impl VhostUserVringAddr {
+ /// Create a new instance.
+ pub fn new(
+ index: u32,
+ flags: VhostUserVringAddrFlags,
+ descriptor: u64,
+ used: u64,
+ available: u64,
+ log: u64,
+ ) -> Self {
+ VhostUserVringAddr {
+ index,
+ flags: flags.bits(),
+ descriptor,
+ used,
+ available,
+ log,
+ }
+ }
+
+ /// Create a new instance from `VringConfigData`.
+ #[cfg_attr(feature = "cargo-clippy", allow(clippy::identity_conversion))]
+ pub fn from_config_data(index: u32, config_data: &VringConfigData) -> Self {
+ let log_addr = config_data.log_addr.unwrap_or(0);
+ VhostUserVringAddr {
+ index,
+ flags: config_data.flags,
+ descriptor: config_data.desc_table_addr,
+ used: config_data.used_ring_addr,
+ available: config_data.avail_ring_addr,
+ log: log_addr,
+ }
+ }
+}
+
+impl VhostUserMsgValidator for VhostUserVringAddr {
+ #[allow(clippy::if_same_then_else)]
+ fn is_valid(&self) -> bool {
+ if (self.flags & !VhostUserVringAddrFlags::all().bits()) != 0 {
+ return false;
+ } else if self.descriptor & 0xf != 0 {
+ return false;
+ } else if self.available & 0x1 != 0 {
+ return false;
+ } else if self.used & 0x3 != 0 {
+ return false;
+ }
+ true
+ }
+}
+
+// Bit mask for the vhost-user device configuration message.
+bitflags! {
+ /// Flags for the device configuration message.
+ pub struct VhostUserConfigFlags: u32 {
+ /// Vhost master messages used for writeable fields.
+ const WRITABLE = 0x1;
+ /// Vhost master messages used for live migration.
+ const LIVE_MIGRATION = 0x2;
+ }
+}
+
+/// Message to read/write device configuration space.
+#[repr(packed)]
+#[derive(Default)]
+pub struct VhostUserConfig {
+ /// Offset of virtio device's configuration space.
+ pub offset: u32,
+ /// Configuration space access size in bytes.
+ pub size: u32,
+ /// Flags for the device configuration operation.
+ pub flags: u32,
+}
+
+impl VhostUserConfig {
+ /// Create a new instance.
+ pub fn new(offset: u32, size: u32, flags: VhostUserConfigFlags) -> Self {
+ VhostUserConfig {
+ offset,
+ size,
+ flags: flags.bits(),
+ }
+ }
+}
+
+impl VhostUserMsgValidator for VhostUserConfig {
+ #[allow(clippy::if_same_then_else)]
+ fn is_valid(&self) -> bool {
+ if (self.flags & !VhostUserConfigFlags::all().bits()) != 0 {
+ return false;
+ } else if self.offset < 0x100 {
+ return false;
+ } else if self.size == 0
+ || self.size > VHOST_USER_CONFIG_SIZE
+ || self.size + self.offset > VHOST_USER_CONFIG_SIZE
+ {
+ return false;
+ }
+ true
+ }
+}
+
+/// Payload for the VhostUserConfig message.
+pub type VhostUserConfigPayload = Vec<u8>;
+
+/*
+ * TODO: support dirty log, live migration and IOTLB operations.
+#[repr(packed)]
+pub struct VhostUserVringArea {
+ pub index: u32,
+ pub flags: u32,
+ pub size: u64,
+ pub offset: u64,
+}
+
+#[repr(packed)]
+pub struct VhostUserLog {
+ pub size: u64,
+ pub offset: u64,
+}
+
+#[repr(packed)]
+pub struct VhostUserIotlb {
+ pub iova: u64,
+ pub size: u64,
+ pub user_addr: u64,
+ pub permission: u8,
+ pub optype: u8,
+}
+*/
+
+// Bit mask for flags in virtio-fs slave messages
+bitflags! {
+ #[derive(Default)]
+ /// Flags for virtio-fs slave messages.
+ pub struct VhostUserFSSlaveMsgFlags: u64 {
+ /// Empty permission.
+ const EMPTY = 0x0;
+ /// Read permission.
+ const MAP_R = 0x1;
+ /// Write permission.
+ const MAP_W = 0x2;
+ }
+}
+
+/// Max entries in one virtio-fs slave request.
+pub const VHOST_USER_FS_SLAVE_ENTRIES: usize = 8;
+
+/// Slave request message to update the MMIO window.
+#[repr(packed)]
+#[derive(Default)]
+pub struct VhostUserFSSlaveMsg {
+ /// File offset.
+ pub fd_offset: [u64; VHOST_USER_FS_SLAVE_ENTRIES],
+ /// Offset into the DAX window.
+ pub cache_offset: [u64; VHOST_USER_FS_SLAVE_ENTRIES],
+ /// Size of region to map.
+ pub len: [u64; VHOST_USER_FS_SLAVE_ENTRIES],
+ /// Flags for the mmap operation
+ pub flags: [VhostUserFSSlaveMsgFlags; VHOST_USER_FS_SLAVE_ENTRIES],
+}
+
+impl VhostUserMsgValidator for VhostUserFSSlaveMsg {
+ fn is_valid(&self) -> bool {
+ for i in 0..VHOST_USER_FS_SLAVE_ENTRIES {
+ if ({ self.flags[i] }.bits() & !VhostUserFSSlaveMsgFlags::all().bits()) != 0
+ || self.fd_offset[i].checked_add(self.len[i]).is_none()
+ || self.cache_offset[i].checked_add(self.len[i]).is_none()
+ {
+ return false;
+ }
+ }
+ true
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use std::mem;
+
+ #[test]
+ fn check_master_request_code() {
+ let code = MasterReq::NOOP;
+ assert!(!code.is_valid());
+ let code = MasterReq::MAX_CMD;
+ assert!(!code.is_valid());
+ assert!(code > MasterReq::NOOP);
+ let code = MasterReq::GET_FEATURES;
+ assert!(code.is_valid());
+ assert_eq!(code, code.clone());
+ let code: MasterReq = unsafe { std::mem::transmute::<u32, MasterReq>(10000u32) };
+ assert!(!code.is_valid());
+ }
+
+ #[test]
+ fn check_slave_request_code() {
+ let code = SlaveReq::NOOP;
+ assert!(!code.is_valid());
+ let code = SlaveReq::MAX_CMD;
+ assert!(!code.is_valid());
+ assert!(code > SlaveReq::NOOP);
+ let code = SlaveReq::CONFIG_CHANGE_MSG;
+ assert!(code.is_valid());
+ assert_eq!(code, code.clone());
+ let code: SlaveReq = unsafe { std::mem::transmute::<u32, SlaveReq>(10000u32) };
+ assert!(!code.is_valid());
+ }
+
+ #[test]
+ fn msg_header_ops() {
+ let mut hdr = VhostUserMsgHeader::new(MasterReq::GET_FEATURES, 0, 0x100);
+ assert_eq!(hdr.get_code(), MasterReq::GET_FEATURES);
+ hdr.set_code(MasterReq::SET_FEATURES);
+ assert_eq!(hdr.get_code(), MasterReq::SET_FEATURES);
+
+ assert_eq!(hdr.get_version(), 0x1);
+
+ assert_eq!(hdr.is_reply(), false);
+ hdr.set_reply(true);
+ assert_eq!(hdr.is_reply(), true);
+ hdr.set_reply(false);
+
+ assert_eq!(hdr.is_need_reply(), false);
+ hdr.set_need_reply(true);
+ assert_eq!(hdr.is_need_reply(), true);
+ hdr.set_need_reply(false);
+
+ assert_eq!(hdr.get_size(), 0x100);
+ hdr.set_size(0x200);
+ assert_eq!(hdr.get_size(), 0x200);
+
+ assert_eq!(hdr.is_need_reply(), false);
+ assert_eq!(hdr.is_reply(), false);
+ assert_eq!(hdr.get_version(), 0x1);
+
+ // Check message length
+ assert!(hdr.is_valid());
+ hdr.set_size(0x2000);
+ assert!(!hdr.is_valid());
+ hdr.set_size(0x100);
+ assert_eq!(hdr.get_size(), 0x100);
+ assert!(hdr.is_valid());
+ hdr.set_size((MAX_MSG_SIZE - mem::size_of::<VhostUserMsgHeader<MasterReq>>()) as u32);
+ assert!(hdr.is_valid());
+ hdr.set_size(0x0);
+ assert!(hdr.is_valid());
+
+ // Check version
+ hdr.set_version(0x0);
+ assert!(!hdr.is_valid());
+ hdr.set_version(0x2);
+ assert!(!hdr.is_valid());
+ hdr.set_version(0x1);
+ assert!(hdr.is_valid());
+
+ assert_eq!(hdr, hdr.clone());
+ }
+
+ #[test]
+ fn test_vhost_user_message_u64() {
+ let val = VhostUserU64::default();
+ let val1 = VhostUserU64::new(0);
+
+ let a = val.value;
+ let b = val1.value;
+ assert_eq!(a, b);
+ let a = VhostUserU64::new(1).value;
+ assert_eq!(a, 1);
+ }
+
+ #[test]
+ fn check_user_memory() {
+ let mut msg = VhostUserMemory::new(1);
+ assert!(msg.is_valid());
+ msg.num_regions = MAX_ATTACHED_FD_ENTRIES as u32;
+ assert!(msg.is_valid());
+
+ msg.num_regions += 1;
+ assert!(!msg.is_valid());
+ msg.num_regions = 0xFFFFFFFF;
+ assert!(!msg.is_valid());
+ msg.num_regions = MAX_ATTACHED_FD_ENTRIES as u32;
+ msg.padding1 = 1;
+ assert!(!msg.is_valid());
+ }
+
+ #[test]
+ fn check_user_memory_region() {
+ let mut msg = VhostUserMemoryRegion {
+ guest_phys_addr: 0,
+ memory_size: 0x1000,
+ user_addr: 0,
+ mmap_offset: 0,
+ };
+ assert!(msg.is_valid());
+ msg.guest_phys_addr = 0xFFFFFFFFFFFFEFFF;
+ assert!(msg.is_valid());
+ msg.guest_phys_addr = 0xFFFFFFFFFFFFF000;
+ assert!(!msg.is_valid());
+ msg.guest_phys_addr = 0xFFFFFFFFFFFF0000;
+ msg.memory_size = 0;
+ assert!(!msg.is_valid());
+ let a = msg.guest_phys_addr;
+ let b = msg.guest_phys_addr;
+ assert_eq!(a, b);
+
+ let msg = VhostUserMemoryRegion::default();
+ let a = msg.guest_phys_addr;
+ assert_eq!(a, 0);
+ let a = msg.memory_size;
+ assert_eq!(a, 0);
+ let a = msg.user_addr;
+ assert_eq!(a, 0);
+ let a = msg.mmap_offset;
+ assert_eq!(a, 0);
+ }
+
+ #[test]
+ fn test_vhost_user_state() {
+ let state = VhostUserVringState::new(5, 8);
+
+ let a = state.index;
+ assert_eq!(a, 5);
+ let a = state.num;
+ assert_eq!(a, 8);
+ assert_eq!(state.is_valid(), true);
+
+ let state = VhostUserVringState::default();
+ let a = state.index;
+ assert_eq!(a, 0);
+ let a = state.num;
+ assert_eq!(a, 0);
+ assert_eq!(state.is_valid(), true);
+ }
+
+ #[test]
+ fn test_vhost_user_addr() {
+ let mut addr = VhostUserVringAddr::new(
+ 2,
+ VhostUserVringAddrFlags::VHOST_VRING_F_LOG,
+ 0x1000,
+ 0x2000,
+ 0x3000,
+ 0x4000,
+ );
+
+ let a = addr.index;
+ assert_eq!(a, 2);
+ let a = addr.flags;
+ assert_eq!(a, VhostUserVringAddrFlags::VHOST_VRING_F_LOG.bits());
+ let a = addr.descriptor;
+ assert_eq!(a, 0x1000);
+ let a = addr.used;
+ assert_eq!(a, 0x2000);
+ let a = addr.available;
+ assert_eq!(a, 0x3000);
+ let a = addr.log;
+ assert_eq!(a, 0x4000);
+ assert_eq!(addr.is_valid(), true);
+
+ addr.descriptor = 0x1001;
+ assert_eq!(addr.is_valid(), false);
+ addr.descriptor = 0x1000;
+
+ addr.available = 0x3001;
+ assert_eq!(addr.is_valid(), false);
+ addr.available = 0x3000;
+
+ addr.used = 0x2001;
+ assert_eq!(addr.is_valid(), false);
+ addr.used = 0x2000;
+ assert_eq!(addr.is_valid(), true);
+ }
+
+ #[test]
+ fn test_vhost_user_state_from_config() {
+ let config = VringConfigData {
+ queue_max_size: 256,
+ queue_size: 128,
+ flags: VhostUserVringAddrFlags::VHOST_VRING_F_LOG.bits,
+ desc_table_addr: 0x1000,
+ used_ring_addr: 0x2000,
+ avail_ring_addr: 0x3000,
+ log_addr: Some(0x4000),
+ };
+ let addr = VhostUserVringAddr::from_config_data(2, &config);
+
+ let a = addr.index;
+ assert_eq!(a, 2);
+ let a = addr.flags;
+ assert_eq!(a, VhostUserVringAddrFlags::VHOST_VRING_F_LOG.bits());
+ let a = addr.descriptor;
+ assert_eq!(a, 0x1000);
+ let a = addr.used;
+ assert_eq!(a, 0x2000);
+ let a = addr.available;
+ assert_eq!(a, 0x3000);
+ let a = addr.log;
+ assert_eq!(a, 0x4000);
+ assert_eq!(addr.is_valid(), true);
+ }
+
+ #[test]
+ fn check_user_vring_addr() {
+ let mut msg =
+ VhostUserVringAddr::new(0, VhostUserVringAddrFlags::all(), 0x0, 0x0, 0x0, 0x0);
+ assert!(msg.is_valid());
+
+ msg.descriptor = 1;
+ assert!(!msg.is_valid());
+ msg.descriptor = 0;
+
+ msg.available = 1;
+ assert!(!msg.is_valid());
+ msg.available = 0;
+
+ msg.used = 1;
+ assert!(!msg.is_valid());
+ msg.used = 0;
+
+ msg.flags |= 0x80000000;
+ assert!(!msg.is_valid());
+ msg.flags &= !0x80000000;
+ }
+
+ #[test]
+ fn check_user_config_msg() {
+ let mut msg = VhostUserConfig::new(
+ VHOST_USER_CONFIG_OFFSET,
+ VHOST_USER_CONFIG_SIZE - VHOST_USER_CONFIG_OFFSET,
+ VhostUserConfigFlags::WRITABLE,
+ );
+
+ assert!(msg.is_valid());
+ msg.size = 0;
+ assert!(!msg.is_valid());
+ msg.size = 1;
+ assert!(msg.is_valid());
+ msg.offset = 0;
+ assert!(!msg.is_valid());
+ msg.offset = VHOST_USER_CONFIG_SIZE;
+ assert!(!msg.is_valid());
+ msg.offset = VHOST_USER_CONFIG_SIZE - 1;
+ assert!(msg.is_valid());
+ msg.size = 2;
+ assert!(!msg.is_valid());
+ msg.size = 1;
+ msg.flags |= VhostUserConfigFlags::LIVE_MIGRATION.bits();
+ assert!(msg.is_valid());
+ msg.flags |= 0x4;
+ assert!(!msg.is_valid());
+ }
+
+ #[test]
+ fn test_vhost_user_fs_slave() {
+ let mut fs_slave = VhostUserFSSlaveMsg::default();
+
+ assert_eq!(fs_slave.is_valid(), true);
+
+ fs_slave.fd_offset[0] = 0xffff_ffff_ffff_ffff;
+ fs_slave.len[0] = 0x1;
+ assert_eq!(fs_slave.is_valid(), false);
+
+ assert_ne!(
+ VhostUserFSSlaveMsgFlags::MAP_R,
+ VhostUserFSSlaveMsgFlags::MAP_W
+ );
+ assert_eq!(VhostUserFSSlaveMsgFlags::EMPTY.bits(), 0);
+ }
+}
diff --git a/src/vhost_user/mod.rs b/src/vhost_user/mod.rs
new file mode 100644
index 0000000..9ef6453
--- /dev/null
+++ b/src/vhost_user/mod.rs
@@ -0,0 +1,456 @@
+// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved.
+// SPDX-License-Identifier: Apache-2.0
+
+//! The protocol for vhost-user is based on the existing implementation of vhost for the Linux
+//! Kernel. The protocol defines two sides of the communication, master and slave. Master is
+//! the application that shares its virtqueues. Slave is the consumer of the virtqueues.
+//!
+//! The communication channel between the master and the slave includes two sub channels. One is
+//! used to send requests from the master to the slave and optional replies from the slave to the
+//! master. This sub channel is created on master startup by connecting to the slave service
+//! endpoint. The other is used to send requests from the slave to the master and optional replies
+//! from the master to the slave. This sub channel is created by the master issuing a
+//! VHOST_USER_SET_SLAVE_REQ_FD request to the slave with an auxiliary file descriptor.
+//!
+//! Unix domain socket is used as the underlying communication channel because the master needs to
+//! send file descriptors to the slave.
+//!
+//! Most messages that can be sent via the Unix domain socket implementing vhost-user have an
+//! equivalent ioctl to the kernel implementation.
+
+use std::io::Error as IOError;
+
+pub mod message;
+
+mod connection;
+pub use self::connection::Listener;
+
+#[cfg(feature = "vhost-user-master")]
+mod master;
+#[cfg(feature = "vhost-user-master")]
+pub use self::master::{Master, VhostUserMaster};
+#[cfg(feature = "vhost-user")]
+mod master_req_handler;
+#[cfg(feature = "vhost-user")]
+pub use self::master_req_handler::{
+ MasterReqHandler, VhostUserMasterReqHandler, VhostUserMasterReqHandlerMut,
+};
+
+#[cfg(feature = "vhost-user-slave")]
+mod slave;
+#[cfg(feature = "vhost-user-slave")]
+pub use self::slave::SlaveListener;
+#[cfg(feature = "vhost-user-slave")]
+mod slave_req_handler;
+#[cfg(feature = "vhost-user-slave")]
+pub use self::slave_req_handler::{
+ SlaveReqHandler, VhostUserSlaveReqHandler, VhostUserSlaveReqHandlerMut,
+};
+#[cfg(feature = "vhost-user-slave")]
+mod slave_fs_cache;
+#[cfg(feature = "vhost-user-slave")]
+pub use self::slave_fs_cache::SlaveFsCacheReq;
+
+/// Errors for vhost-user operations
+#[derive(Debug)]
+pub enum Error {
+ /// Invalid parameters.
+ InvalidParam,
+ /// Unsupported operations due to that the protocol feature hasn't been negotiated.
+ InvalidOperation,
+ /// Invalid message format, flag or content.
+ InvalidMessage,
+ /// Only part of a message have been sent or received successfully
+ PartialMessage,
+ /// Message is too large
+ OversizedMsg,
+ /// Fd array in question is too big or too small
+ IncorrectFds,
+ /// Can't connect to peer.
+ SocketConnect(std::io::Error),
+ /// Generic socket errors.
+ SocketError(std::io::Error),
+ /// The socket is broken or has been closed.
+ SocketBroken(std::io::Error),
+ /// Should retry the socket operation again.
+ SocketRetry(std::io::Error),
+ /// Failure from the slave side.
+ SlaveInternalError,
+ /// Failure from the master side.
+ MasterInternalError,
+ /// Virtio/protocol features mismatch.
+ FeatureMismatch,
+ /// Error from request handler
+ ReqHandlerError(IOError),
+}
+
+impl std::fmt::Display for Error {
+ fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
+ match self {
+ Error::InvalidParam => write!(f, "invalid parameters"),
+ Error::InvalidOperation => write!(f, "invalid operation"),
+ Error::InvalidMessage => write!(f, "invalid message"),
+ Error::PartialMessage => write!(f, "partial message"),
+ Error::OversizedMsg => write!(f, "oversized message"),
+ Error::IncorrectFds => write!(f, "wrong number of attached fds"),
+ Error::SocketError(e) => write!(f, "socket error: {}", e),
+ Error::SocketConnect(e) => write!(f, "can't connect to peer: {}", e),
+ Error::SocketBroken(e) => write!(f, "socket is broken: {}", e),
+ Error::SocketRetry(e) => write!(f, "temporary socket error: {}", e),
+ Error::SlaveInternalError => write!(f, "slave internal error"),
+ Error::MasterInternalError => write!(f, "Master internal error"),
+ Error::FeatureMismatch => write!(f, "virtio/protocol features mismatch"),
+ Error::ReqHandlerError(e) => write!(f, "handler failed to handle request: {}", e),
+ }
+ }
+}
+
+impl std::error::Error for Error {}
+
+impl Error {
+ /// Determine whether to rebuild the underline communication channel.
+ pub fn should_reconnect(&self) -> bool {
+ match *self {
+ // Should reconnect because it may be caused by temporary network errors.
+ Error::PartialMessage => true,
+ // Should reconnect because the underline socket is broken.
+ Error::SocketBroken(_) => true,
+ // Slave internal error, hope it recovers on reconnect.
+ Error::SlaveInternalError => true,
+ // Master internal error, hope it recovers on reconnect.
+ Error::MasterInternalError => true,
+ // Should just retry the IO operation instead of rebuilding the underline connection.
+ Error::SocketRetry(_) => false,
+ Error::InvalidParam | Error::InvalidOperation => false,
+ Error::InvalidMessage | Error::IncorrectFds | Error::OversizedMsg => false,
+ Error::SocketError(_) | Error::SocketConnect(_) => false,
+ Error::FeatureMismatch => false,
+ Error::ReqHandlerError(_) => false,
+ }
+ }
+}
+
+impl std::convert::From<sys_util::Error> for Error {
+ /// Convert raw socket errors into meaningful vhost-user errors.
+ ///
+ /// The sys_util::Error is a simple wrapper over the raw errno, which doesn't means
+ /// much to the vhost-user connection manager. So convert it into meaningful errors to simplify
+ /// the connection manager logic.
+ ///
+ /// # Return:
+ /// * - Error::SocketRetry: temporary error caused by signals or short of resources.
+ /// * - Error::SocketBroken: the underline socket is broken.
+ /// * - Error::SocketError: other socket related errors.
+ #[allow(unreachable_patterns)] // EWOULDBLOCK equals to EGAIN on linux
+ fn from(err: sys_util::Error) -> Self {
+ match err.errno() {
+ // The socket is marked nonblocking and the requested operation would block.
+ libc::EAGAIN => Error::SocketRetry(IOError::from_raw_os_error(libc::EAGAIN)),
+ // The socket is marked nonblocking and the requested operation would block.
+ libc::EWOULDBLOCK => Error::SocketRetry(IOError::from_raw_os_error(libc::EWOULDBLOCK)),
+ // A signal occurred before any data was transmitted
+ libc::EINTR => Error::SocketRetry(IOError::from_raw_os_error(libc::EINTR)),
+ // The output queue for a network interface was full. This generally indicates
+ // that the interface has stopped sending, but may be caused by transient congestion.
+ libc::ENOBUFS => Error::SocketRetry(IOError::from_raw_os_error(libc::ENOBUFS)),
+ // No memory available.
+ libc::ENOMEM => Error::SocketRetry(IOError::from_raw_os_error(libc::ENOMEM)),
+ // Connection reset by peer.
+ libc::ECONNRESET => Error::SocketBroken(IOError::from_raw_os_error(libc::ECONNRESET)),
+ // The local end has been shut down on a connection oriented socket. In this case the
+ // process will also receive a SIGPIPE unless MSG_NOSIGNAL is set.
+ libc::EPIPE => Error::SocketBroken(IOError::from_raw_os_error(libc::EPIPE)),
+ // Write permission is denied on the destination socket file, or search permission is
+ // denied for one of the directories the path prefix.
+ libc::EACCES => Error::SocketConnect(IOError::from_raw_os_error(libc::EACCES)),
+ // Catch all other errors
+ e => Error::SocketError(IOError::from_raw_os_error(e)),
+ }
+ }
+}
+
+/// Result of vhost-user operations
+pub type Result<T> = std::result::Result<T, Error>;
+
+/// Result of request handler.
+pub type HandlerResult<T> = std::result::Result<T, IOError>;
+
+#[cfg(all(test, feature = "vhost-user-slave"))]
+mod dummy_slave;
+
+#[cfg(all(test, feature = "vhost-user-master", feature = "vhost-user-slave"))]
+mod tests {
+ use std::os::unix::io::AsRawFd;
+ use std::path::Path;
+ use std::sync::{Arc, Barrier, Mutex};
+ use std::thread;
+
+ use super::dummy_slave::{DummySlaveReqHandler, VIRTIO_FEATURES};
+ use super::message::*;
+ use super::*;
+ use crate::backend::VhostBackend;
+ use crate::{VhostUserMemoryRegionInfo, VringConfigData};
+ use tempfile::{tempfile, Builder, TempDir};
+
+ fn temp_dir() -> TempDir {
+ Builder::new().prefix("/tmp/vhost_test").tempdir().unwrap()
+ }
+
+ fn create_slave<P, S>(path: P, backend: Arc<S>) -> (Master, SlaveReqHandler<S>)
+ where
+ P: AsRef<Path>,
+ S: VhostUserSlaveReqHandler,
+ {
+ let listener = Listener::new(&path, true).unwrap();
+ let mut slave_listener = SlaveListener::new(listener, backend).unwrap();
+ let master = Master::connect(&path, 1).unwrap();
+ (master, slave_listener.accept().unwrap().unwrap())
+ }
+
+ #[test]
+ fn create_dummy_slave() {
+ let slave = Arc::new(Mutex::new(DummySlaveReqHandler::new()));
+
+ slave.set_owner().unwrap();
+ assert!(slave.set_owner().is_err());
+ }
+
+ #[test]
+ fn test_set_owner() {
+ let slave_be = Arc::new(Mutex::new(DummySlaveReqHandler::new()));
+ let dir = temp_dir();
+ let mut path = dir.path().to_owned();
+ path.push("sock");
+ let (master, mut slave) = create_slave(&path, slave_be.clone());
+
+ assert_eq!(slave_be.lock().unwrap().owned, false);
+ master.set_owner().unwrap();
+ slave.handle_request().unwrap();
+ assert_eq!(slave_be.lock().unwrap().owned, true);
+ master.set_owner().unwrap();
+ assert!(slave.handle_request().is_err());
+ assert_eq!(slave_be.lock().unwrap().owned, true);
+ }
+
+ #[test]
+ fn test_set_features() {
+ let mbar = Arc::new(Barrier::new(2));
+ let sbar = mbar.clone();
+ let dir = temp_dir();
+ let mut path = dir.path().to_owned();
+ path.push("sock");
+ let slave_be = Arc::new(Mutex::new(DummySlaveReqHandler::new()));
+ let (mut master, mut slave) = create_slave(&path, slave_be.clone());
+
+ thread::spawn(move || {
+ slave.handle_request().unwrap();
+ assert_eq!(slave_be.lock().unwrap().owned, true);
+
+ slave.handle_request().unwrap();
+ slave.handle_request().unwrap();
+ assert_eq!(
+ slave_be.lock().unwrap().acked_features,
+ VIRTIO_FEATURES & !0x1
+ );
+
+ slave.handle_request().unwrap();
+ slave.handle_request().unwrap();
+ assert_eq!(
+ slave_be.lock().unwrap().acked_protocol_features,
+ VhostUserProtocolFeatures::all().bits()
+ );
+
+ sbar.wait();
+ });
+
+ master.set_owner().unwrap();
+
+ // set virtio features
+ let features = master.get_features().unwrap();
+ assert_eq!(features, VIRTIO_FEATURES);
+ master.set_features(VIRTIO_FEATURES & !0x1).unwrap();
+
+ // set vhost protocol features
+ let features = master.get_protocol_features().unwrap();
+ assert_eq!(features.bits(), VhostUserProtocolFeatures::all().bits());
+ master.set_protocol_features(features).unwrap();
+
+ mbar.wait();
+ }
+
+ #[test]
+ fn test_master_slave_process() {
+ let mbar = Arc::new(Barrier::new(2));
+ let sbar = mbar.clone();
+ let dir = temp_dir();
+ let mut path = dir.path().to_owned();
+ path.push("sock");
+ let slave_be = Arc::new(Mutex::new(DummySlaveReqHandler::new()));
+ let (mut master, mut slave) = create_slave(&path, slave_be.clone());
+
+ thread::spawn(move || {
+ // set_own()
+ slave.handle_request().unwrap();
+ assert_eq!(slave_be.lock().unwrap().owned, true);
+
+ // get/set_features()
+ slave.handle_request().unwrap();
+ slave.handle_request().unwrap();
+ assert_eq!(
+ slave_be.lock().unwrap().acked_features,
+ VIRTIO_FEATURES & !0x1
+ );
+
+ slave.handle_request().unwrap();
+ slave.handle_request().unwrap();
+ assert_eq!(
+ slave_be.lock().unwrap().acked_protocol_features,
+ VhostUserProtocolFeatures::all().bits()
+ );
+
+ // get_queue_num()
+ slave.handle_request().unwrap();
+
+ // set_mem_table()
+ slave.handle_request().unwrap();
+
+ // get/set_config()
+ slave.handle_request().unwrap();
+ slave.handle_request().unwrap();
+
+ // set_slave_request_fd
+ slave.handle_request().unwrap();
+
+ // set_vring_enable
+ slave.handle_request().unwrap();
+
+ // set_log_base,set_log_fd()
+ slave.handle_request().unwrap_err();
+ slave.handle_request().unwrap_err();
+
+ // set_vring_xxx
+ slave.handle_request().unwrap();
+ slave.handle_request().unwrap();
+ slave.handle_request().unwrap();
+ slave.handle_request().unwrap();
+ slave.handle_request().unwrap();
+ slave.handle_request().unwrap();
+
+ // get_max_mem_slots()
+ slave.handle_request().unwrap();
+
+ // add_mem_region()
+ slave.handle_request().unwrap();
+
+ // remove_mem_region()
+ slave.handle_request().unwrap();
+
+ sbar.wait();
+ });
+
+ master.set_owner().unwrap();
+
+ // set virtio features
+ let features = master.get_features().unwrap();
+ assert_eq!(features, VIRTIO_FEATURES);
+ master.set_features(VIRTIO_FEATURES & !0x1).unwrap();
+
+ // set vhost protocol features
+ let features = master.get_protocol_features().unwrap();
+ assert_eq!(features.bits(), VhostUserProtocolFeatures::all().bits());
+ master.set_protocol_features(features).unwrap();
+
+ let num = master.get_queue_num().unwrap();
+ assert_eq!(num, 2);
+
+ let eventfd = sys_util::EventFd::new().unwrap();
+ let mem = [VhostUserMemoryRegionInfo {
+ guest_phys_addr: 0,
+ memory_size: 0x10_0000,
+ userspace_addr: 0,
+ mmap_offset: 0,
+ mmap_handle: eventfd.as_raw_fd(),
+ }];
+ master.set_mem_table(&mem).unwrap();
+
+ master
+ .set_config(0x100, VhostUserConfigFlags::WRITABLE, &[0xa5u8])
+ .unwrap();
+ let buf = [0x0u8; 4];
+ let (reply_body, reply_payload) = master
+ .get_config(0x100, 4, VhostUserConfigFlags::empty(), &buf)
+ .unwrap();
+ let offset = reply_body.offset;
+ assert_eq!(offset, 0x100);
+ assert_eq!(reply_payload[0], 0xa5);
+
+ master.set_slave_request_fd(eventfd.as_raw_fd()).unwrap();
+ master.set_vring_enable(0, true).unwrap();
+
+ // unimplemented yet
+ master.set_log_base(0, Some(eventfd.as_raw_fd())).unwrap();
+ master.set_log_fd(eventfd.as_raw_fd()).unwrap();
+
+ master.set_vring_num(0, 256).unwrap();
+ master.set_vring_base(0, 0).unwrap();
+ let config = VringConfigData {
+ queue_max_size: 256,
+ queue_size: 128,
+ flags: VhostUserVringAddrFlags::VHOST_VRING_F_LOG.bits(),
+ desc_table_addr: 0x1000,
+ used_ring_addr: 0x2000,
+ avail_ring_addr: 0x3000,
+ log_addr: Some(0x4000),
+ };
+ master.set_vring_addr(0, &config).unwrap();
+ master.set_vring_call(0, &eventfd).unwrap();
+ master.set_vring_kick(0, &eventfd).unwrap();
+ master.set_vring_err(0, &eventfd).unwrap();
+
+ let max_mem_slots = master.get_max_mem_slots().unwrap();
+ assert_eq!(max_mem_slots, 32);
+
+ let region_file = tempfile().unwrap();
+ let region = VhostUserMemoryRegionInfo {
+ guest_phys_addr: 0x10_0000,
+ memory_size: 0x10_0000,
+ userspace_addr: 0,
+ mmap_offset: 0,
+ mmap_handle: region_file.as_raw_fd(),
+ };
+ master.add_mem_region(&region).unwrap();
+
+ master.remove_mem_region(&region).unwrap();
+
+ mbar.wait();
+ }
+
+ #[test]
+ fn test_error_display() {
+ assert_eq!(format!("{}", Error::InvalidParam), "invalid parameters");
+ assert_eq!(format!("{}", Error::InvalidOperation), "invalid operation");
+ }
+
+ #[test]
+ fn test_should_reconnect() {
+ assert_eq!(Error::PartialMessage.should_reconnect(), true);
+ assert_eq!(Error::SlaveInternalError.should_reconnect(), true);
+ assert_eq!(Error::MasterInternalError.should_reconnect(), true);
+ assert_eq!(Error::InvalidParam.should_reconnect(), false);
+ assert_eq!(Error::InvalidOperation.should_reconnect(), false);
+ assert_eq!(Error::InvalidMessage.should_reconnect(), false);
+ assert_eq!(Error::IncorrectFds.should_reconnect(), false);
+ assert_eq!(Error::OversizedMsg.should_reconnect(), false);
+ assert_eq!(Error::FeatureMismatch.should_reconnect(), false);
+ }
+
+ #[test]
+ fn test_error_from_sys_util_error() {
+ let e: Error = sys_util::Error::new(libc::EAGAIN).into();
+ if let Error::SocketRetry(e1) = e {
+ assert_eq!(e1.raw_os_error().unwrap(), libc::EAGAIN);
+ } else {
+ panic!("invalid error code conversion!");
+ }
+ }
+}
diff --git a/src/vhost_user/slave.rs b/src/vhost_user/slave.rs
new file mode 100644
index 0000000..fb65c41
--- /dev/null
+++ b/src/vhost_user/slave.rs
@@ -0,0 +1,86 @@
+// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved.
+// SPDX-License-Identifier: Apache-2.0
+
+//! Traits and Structs for vhost-user slave.
+
+use std::sync::Arc;
+
+use super::connection::{Endpoint, Listener};
+use super::message::*;
+use super::{Result, SlaveReqHandler, VhostUserSlaveReqHandler};
+
+/// Vhost-user slave side connection listener.
+pub struct SlaveListener<S: VhostUserSlaveReqHandler> {
+ listener: Listener,
+ backend: Option<Arc<S>>,
+}
+
+/// Sets up a listener for incoming master connections, and handles construction
+/// of a Slave on success.
+impl<S: VhostUserSlaveReqHandler> SlaveListener<S> {
+ /// Create a unix domain socket for incoming master connections.
+ pub fn new(listener: Listener, backend: Arc<S>) -> Result<Self> {
+ Ok(SlaveListener {
+ listener,
+ backend: Some(backend),
+ })
+ }
+
+ /// Accept an incoming connection from the master, returning Some(Slave) on
+ /// success, or None if the socket is nonblocking and no incoming connection
+ /// was detected
+ pub fn accept(&mut self) -> Result<Option<SlaveReqHandler<S>>> {
+ if let Some(fd) = self.listener.accept()? {
+ return Ok(Some(SlaveReqHandler::new(
+ Endpoint::<MasterReq>::from_stream(fd),
+ self.backend.take().unwrap(),
+ )));
+ }
+ Ok(None)
+ }
+
+ /// Change blocking status on the listener.
+ pub fn set_nonblocking(&self, block: bool) -> Result<()> {
+ self.listener.set_nonblocking(block)
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use std::sync::Mutex;
+
+ use super::*;
+ use crate::vhost_user::dummy_slave::DummySlaveReqHandler;
+
+ #[test]
+ fn test_slave_listener_set_nonblocking() {
+ let backend = Arc::new(Mutex::new(DummySlaveReqHandler::new()));
+ let listener =
+ Listener::new("/tmp/vhost_user_lib_unit_test_slave_nonblocking", true).unwrap();
+ let slave_listener = SlaveListener::new(listener, backend).unwrap();
+
+ slave_listener.set_nonblocking(true).unwrap();
+ slave_listener.set_nonblocking(false).unwrap();
+ slave_listener.set_nonblocking(false).unwrap();
+ slave_listener.set_nonblocking(true).unwrap();
+ slave_listener.set_nonblocking(true).unwrap();
+ }
+
+ #[cfg(feature = "vhost-user-master")]
+ #[test]
+ fn test_slave_listener_accept() {
+ use super::super::Master;
+
+ let path = "/tmp/vhost_user_lib_unit_test_slave_accept";
+ let backend = Arc::new(Mutex::new(DummySlaveReqHandler::new()));
+ let listener = Listener::new(path, true).unwrap();
+ let mut slave_listener = SlaveListener::new(listener, backend).unwrap();
+
+ slave_listener.set_nonblocking(true).unwrap();
+ assert!(slave_listener.accept().unwrap().is_none());
+ assert!(slave_listener.accept().unwrap().is_none());
+
+ let _master = Master::connect(path, 1).unwrap();
+ let _slave = slave_listener.accept().unwrap().unwrap();
+ }
+}
diff --git a/src/vhost_user/slave_fs_cache.rs b/src/vhost_user/slave_fs_cache.rs
new file mode 100644
index 0000000..a9c4ed2
--- /dev/null
+++ b/src/vhost_user/slave_fs_cache.rs
@@ -0,0 +1,226 @@
+// Copyright (C) 2020 Alibaba Cloud. All rights reserved.
+// SPDX-License-Identifier: Apache-2.0
+
+use std::io;
+use std::mem;
+use std::os::unix::io::RawFd;
+use std::os::unix::net::UnixStream;
+use std::sync::{Arc, Mutex, MutexGuard};
+
+use super::connection::Endpoint;
+use super::message::*;
+use super::{Error, HandlerResult, Result, VhostUserMasterReqHandler};
+
+struct SlaveFsCacheReqInternal {
+ sock: Endpoint<SlaveReq>,
+
+ // Protocol feature VHOST_USER_PROTOCOL_F_REPLY_ACK has been negotiated.
+ reply_ack_negotiated: bool,
+
+ // whether the endpoint has encountered any failure
+ error: Option<i32>,
+}
+
+impl SlaveFsCacheReqInternal {
+ fn check_state(&self) -> Result<u64> {
+ match self.error {
+ Some(e) => Err(Error::SocketBroken(std::io::Error::from_raw_os_error(e))),
+ None => Ok(0),
+ }
+ }
+
+ fn send_message(
+ &mut self,
+ request: SlaveReq,
+ fs: &VhostUserFSSlaveMsg,
+ fds: Option<&[RawFd]>,
+ ) -> Result<u64> {
+ self.check_state()?;
+
+ let len = mem::size_of::<VhostUserFSSlaveMsg>();
+ let mut hdr = VhostUserMsgHeader::new(request, 0, len as u32);
+ if self.reply_ack_negotiated {
+ hdr.set_need_reply(true);
+ }
+ self.sock.send_message(&hdr, fs, fds)?;
+
+ self.wait_for_ack(&hdr)
+ }
+
+ fn wait_for_ack(&mut self, hdr: &VhostUserMsgHeader<SlaveReq>) -> Result<u64> {
+ self.check_state()?;
+ if !self.reply_ack_negotiated {
+ return Ok(0);
+ }
+
+ let (reply, body, rfds) = self.sock.recv_body::<VhostUserU64>()?;
+ if !reply.is_reply_for(&hdr) || rfds.is_some() || !body.is_valid() {
+ Endpoint::<SlaveReq>::close_rfds(rfds);
+ return Err(Error::InvalidMessage);
+ }
+ if body.value != 0 {
+ return Err(Error::MasterInternalError);
+ }
+
+ Ok(body.value)
+ }
+}
+
+/// Request proxy to send vhost-user-fs slave requests to the master through the slave
+/// communication channel.
+///
+/// The [SlaveFsCacheReq] acts as a message proxy to forward vhost-user-fs slave requests to the
+/// master through the vhost-user slave communication channel. The forwarded messages will be
+/// handled by the [MasterReqHandler] server.
+///
+/// [SlaveFsCacheReq]: struct.SlaveFsCacheReq.html
+/// [MasterReqHandler]: struct.MasterReqHandler.html
+#[derive(Clone)]
+pub struct SlaveFsCacheReq {
+ // underlying Unix domain socket for communication
+ node: Arc<Mutex<SlaveFsCacheReqInternal>>,
+}
+
+impl SlaveFsCacheReq {
+ fn new(ep: Endpoint<SlaveReq>) -> Self {
+ SlaveFsCacheReq {
+ node: Arc::new(Mutex::new(SlaveFsCacheReqInternal {
+ sock: ep,
+ reply_ack_negotiated: false,
+ error: None,
+ })),
+ }
+ }
+
+ fn node(&self) -> MutexGuard<SlaveFsCacheReqInternal> {
+ self.node.lock().unwrap()
+ }
+
+ fn send_message(
+ &self,
+ request: SlaveReq,
+ fs: &VhostUserFSSlaveMsg,
+ fds: Option<&[RawFd]>,
+ ) -> io::Result<u64> {
+ self.node()
+ .send_message(request, fs, fds)
+ .map_err(|e| io::Error::new(io::ErrorKind::Other, format!("{}", e)))
+ }
+
+ /// Create a new instance from a `UnixStream` object.
+ pub fn from_stream(sock: UnixStream) -> Self {
+ Self::new(Endpoint::<SlaveReq>::from_stream(sock))
+ }
+
+ /// Set the negotiation state of the `VHOST_USER_PROTOCOL_F_REPLY_ACK` protocol feature.
+ ///
+ /// When the `VHOST_USER_PROTOCOL_F_REPLY_ACK` protocol feature has been negotiated,
+ /// the "REPLY_ACK" flag will be set in the message header for every slave to master request
+ /// message.
+ pub fn set_reply_ack_flag(&self, enable: bool) {
+ self.node().reply_ack_negotiated = enable;
+ }
+
+ /// Mark endpoint as failed with specified error code.
+ pub fn set_failed(&self, error: i32) {
+ self.node().error = Some(error);
+ }
+}
+
+impl VhostUserMasterReqHandler for SlaveFsCacheReq {
+ /// Forward vhost-user-fs map file requests to the slave.
+ fn fs_slave_map(&self, fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> {
+ self.send_message(SlaveReq::FS_MAP, fs, Some(&[fd]))
+ }
+
+ /// Forward vhost-user-fs unmap file requests to the master.
+ fn fs_slave_unmap(&self, fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
+ self.send_message(SlaveReq::FS_UNMAP, fs, None)
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use std::os::unix::io::AsRawFd;
+
+ use super::*;
+
+ #[test]
+ fn test_slave_fs_cache_req_set_failed() {
+ let (p1, _p2) = UnixStream::pair().unwrap();
+ let fs_cache = SlaveFsCacheReq::from_stream(p1);
+
+ assert!(fs_cache.node().error.is_none());
+ fs_cache.set_failed(libc::EAGAIN);
+ assert_eq!(fs_cache.node().error, Some(libc::EAGAIN));
+ }
+
+ #[test]
+ fn test_slave_fs_cache_send_failure() {
+ let (p1, p2) = UnixStream::pair().unwrap();
+ let fd = p2.as_raw_fd();
+ let fs_cache = SlaveFsCacheReq::from_stream(p1);
+
+ fs_cache.set_failed(libc::ECONNRESET);
+ fs_cache
+ .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd)
+ .unwrap_err();
+ fs_cache
+ .fs_slave_unmap(&VhostUserFSSlaveMsg::default())
+ .unwrap_err();
+ fs_cache.node().error = None;
+
+ drop(p2);
+ fs_cache
+ .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd)
+ .unwrap_err();
+ fs_cache
+ .fs_slave_unmap(&VhostUserFSSlaveMsg::default())
+ .unwrap_err();
+ }
+
+ #[test]
+ fn test_slave_fs_cache_recv_negative() {
+ let (p1, p2) = UnixStream::pair().unwrap();
+ let fd = p2.as_raw_fd();
+ let fs_cache = SlaveFsCacheReq::from_stream(p1);
+ let mut master = Endpoint::<SlaveReq>::from_stream(p2);
+
+ let len = mem::size_of::<VhostUserFSSlaveMsg>();
+ let mut hdr = VhostUserMsgHeader::new(
+ SlaveReq::FS_MAP,
+ VhostUserHeaderFlag::REPLY.bits(),
+ len as u32,
+ );
+ let body = VhostUserU64::new(0);
+
+ master.send_message(&hdr, &body, Some(&[fd])).unwrap();
+ fs_cache
+ .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd)
+ .unwrap();
+
+ fs_cache.set_reply_ack_flag(true);
+ fs_cache
+ .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd)
+ .unwrap_err();
+
+ hdr.set_code(SlaveReq::FS_UNMAP);
+ master.send_message(&hdr, &body, None).unwrap();
+ fs_cache
+ .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd)
+ .unwrap_err();
+ hdr.set_code(SlaveReq::FS_MAP);
+
+ let body = VhostUserU64::new(1);
+ master.send_message(&hdr, &body, None).unwrap();
+ fs_cache
+ .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd)
+ .unwrap_err();
+
+ let body = VhostUserU64::new(0);
+ master.send_message(&hdr, &body, None).unwrap();
+ fs_cache
+ .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd)
+ .unwrap();
+ }
+}
diff --git a/src/vhost_user/slave_req_handler.rs b/src/vhost_user/slave_req_handler.rs
new file mode 100644
index 0000000..18459a2
--- /dev/null
+++ b/src/vhost_user/slave_req_handler.rs
@@ -0,0 +1,828 @@
+// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved.
+// SPDX-License-Identifier: Apache-2.0
+
+use std::mem;
+use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
+use std::os::unix::net::UnixStream;
+use std::slice;
+use std::sync::{Arc, Mutex};
+
+use super::connection::Endpoint;
+use super::message::*;
+use super::slave_fs_cache::SlaveFsCacheReq;
+use super::{Error, Result};
+
+/// Services provided to the master by the slave with interior mutability.
+///
+/// The [VhostUserSlaveReqHandler] trait defines the services provided to the master by the slave.
+/// And the [VhostUserSlaveReqHandlerMut] trait is a helper mirroring [VhostUserSlaveReqHandler],
+/// but without interior mutability.
+/// The vhost-user specification defines a master communication channel, by which masters could
+/// request services from slaves. The [VhostUserSlaveReqHandler] trait defines services provided by
+/// slaves, and it's used both on the master side and slave side.
+///
+/// - on the master side, a stub forwarder implementing [VhostUserSlaveReqHandler] will proxy
+/// service requests to slaves.
+/// - on the slave side, the [SlaveReqHandler] will forward service requests to a handler
+/// implementing [VhostUserSlaveReqHandler].
+///
+/// The [VhostUserSlaveReqHandler] trait is design with interior mutability to improve performance
+/// for multi-threading.
+///
+/// [VhostUserSlaveReqHandler]: trait.VhostUserSlaveReqHandler.html
+/// [VhostUserSlaveReqHandlerMut]: trait.VhostUserSlaveReqHandlerMut.html
+/// [SlaveReqHandler]: struct.SlaveReqHandler.html
+#[allow(missing_docs)]
+pub trait VhostUserSlaveReqHandler {
+ fn set_owner(&self) -> Result<()>;
+ fn reset_owner(&self) -> Result<()>;
+ fn get_features(&self) -> Result<u64>;
+ fn set_features(&self, features: u64) -> Result<()>;
+ fn set_mem_table(&self, ctx: &[VhostUserMemoryRegion], fds: &[RawFd]) -> Result<()>;
+ fn set_vring_num(&self, index: u32, num: u32) -> Result<()>;
+ fn set_vring_addr(
+ &self,
+ index: u32,
+ flags: VhostUserVringAddrFlags,
+ descriptor: u64,
+ used: u64,
+ available: u64,
+ log: u64,
+ ) -> Result<()>;
+ fn set_vring_base(&self, index: u32, base: u32) -> Result<()>;
+ fn get_vring_base(&self, index: u32) -> Result<VhostUserVringState>;
+ fn set_vring_kick(&self, index: u8, fd: Option<RawFd>) -> Result<()>;
+ fn set_vring_call(&self, index: u8, fd: Option<RawFd>) -> Result<()>;
+ fn set_vring_err(&self, index: u8, fd: Option<RawFd>) -> Result<()>;
+
+ fn get_protocol_features(&self) -> Result<VhostUserProtocolFeatures>;
+ fn set_protocol_features(&self, features: u64) -> Result<()>;
+ fn get_queue_num(&self) -> Result<u64>;
+ fn set_vring_enable(&self, index: u32, enable: bool) -> Result<()>;
+ fn get_config(&self, offset: u32, size: u32, flags: VhostUserConfigFlags) -> Result<Vec<u8>>;
+ fn set_config(&self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>;
+ fn set_slave_req_fd(&self, _vu_req: SlaveFsCacheReq) {}
+ fn get_max_mem_slots(&self) -> Result<u64>;
+ fn add_mem_region(&self, region: &VhostUserSingleMemoryRegion, fd: RawFd) -> Result<()>;
+ fn remove_mem_region(&self, region: &VhostUserSingleMemoryRegion) -> Result<()>;
+}
+
+/// Services provided to the master by the slave without interior mutability.
+///
+/// This is a helper trait mirroring the [VhostUserSlaveReqHandler] trait.
+#[allow(missing_docs)]
+pub trait VhostUserSlaveReqHandlerMut {
+ fn set_owner(&mut self) -> Result<()>;
+ fn reset_owner(&mut self) -> Result<()>;
+ fn get_features(&mut self) -> Result<u64>;
+ fn set_features(&mut self, features: u64) -> Result<()>;
+ fn set_mem_table(&mut self, ctx: &[VhostUserMemoryRegion], fds: &[RawFd]) -> Result<()>;
+ fn set_vring_num(&mut self, index: u32, num: u32) -> Result<()>;
+ fn set_vring_addr(
+ &mut self,
+ index: u32,
+ flags: VhostUserVringAddrFlags,
+ descriptor: u64,
+ used: u64,
+ available: u64,
+ log: u64,
+ ) -> Result<()>;
+ fn set_vring_base(&mut self, index: u32, base: u32) -> Result<()>;
+ fn get_vring_base(&mut self, index: u32) -> Result<VhostUserVringState>;
+ fn set_vring_kick(&mut self, index: u8, fd: Option<RawFd>) -> Result<()>;
+ fn set_vring_call(&mut self, index: u8, fd: Option<RawFd>) -> Result<()>;
+ fn set_vring_err(&mut self, index: u8, fd: Option<RawFd>) -> Result<()>;
+
+ fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures>;
+ fn set_protocol_features(&mut self, features: u64) -> Result<()>;
+ fn get_queue_num(&mut self) -> Result<u64>;
+ fn set_vring_enable(&mut self, index: u32, enable: bool) -> Result<()>;
+ fn get_config(
+ &mut self,
+ offset: u32,
+ size: u32,
+ flags: VhostUserConfigFlags,
+ ) -> Result<Vec<u8>>;
+ fn set_config(&mut self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>;
+ fn set_slave_req_fd(&mut self, _vu_req: SlaveFsCacheReq) {}
+ fn get_max_mem_slots(&mut self) -> Result<u64>;
+ fn add_mem_region(&mut self, region: &VhostUserSingleMemoryRegion, fd: RawFd) -> Result<()>;
+ fn remove_mem_region(&mut self, region: &VhostUserSingleMemoryRegion) -> Result<()>;
+}
+
+impl<T: VhostUserSlaveReqHandlerMut> VhostUserSlaveReqHandler for Mutex<T> {
+ fn set_owner(&self) -> Result<()> {
+ self.lock().unwrap().set_owner()
+ }
+
+ fn reset_owner(&self) -> Result<()> {
+ self.lock().unwrap().reset_owner()
+ }
+
+ fn get_features(&self) -> Result<u64> {
+ self.lock().unwrap().get_features()
+ }
+
+ fn set_features(&self, features: u64) -> Result<()> {
+ self.lock().unwrap().set_features(features)
+ }
+
+ fn set_mem_table(&self, ctx: &[VhostUserMemoryRegion], fds: &[RawFd]) -> Result<()> {
+ self.lock().unwrap().set_mem_table(ctx, fds)
+ }
+
+ fn set_vring_num(&self, index: u32, num: u32) -> Result<()> {
+ self.lock().unwrap().set_vring_num(index, num)
+ }
+
+ fn set_vring_addr(
+ &self,
+ index: u32,
+ flags: VhostUserVringAddrFlags,
+ descriptor: u64,
+ used: u64,
+ available: u64,
+ log: u64,
+ ) -> Result<()> {
+ self.lock()
+ .unwrap()
+ .set_vring_addr(index, flags, descriptor, used, available, log)
+ }
+
+ fn set_vring_base(&self, index: u32, base: u32) -> Result<()> {
+ self.lock().unwrap().set_vring_base(index, base)
+ }
+
+ fn get_vring_base(&self, index: u32) -> Result<VhostUserVringState> {
+ self.lock().unwrap().get_vring_base(index)
+ }
+
+ fn set_vring_kick(&self, index: u8, fd: Option<RawFd>) -> Result<()> {
+ self.lock().unwrap().set_vring_kick(index, fd)
+ }
+
+ fn set_vring_call(&self, index: u8, fd: Option<RawFd>) -> Result<()> {
+ self.lock().unwrap().set_vring_call(index, fd)
+ }
+
+ fn set_vring_err(&self, index: u8, fd: Option<RawFd>) -> Result<()> {
+ self.lock().unwrap().set_vring_err(index, fd)
+ }
+
+ fn get_protocol_features(&self) -> Result<VhostUserProtocolFeatures> {
+ self.lock().unwrap().get_protocol_features()
+ }
+
+ fn set_protocol_features(&self, features: u64) -> Result<()> {
+ self.lock().unwrap().set_protocol_features(features)
+ }
+
+ fn get_queue_num(&self) -> Result<u64> {
+ self.lock().unwrap().get_queue_num()
+ }
+
+ fn set_vring_enable(&self, index: u32, enable: bool) -> Result<()> {
+ self.lock().unwrap().set_vring_enable(index, enable)
+ }
+
+ fn get_config(&self, offset: u32, size: u32, flags: VhostUserConfigFlags) -> Result<Vec<u8>> {
+ self.lock().unwrap().get_config(offset, size, flags)
+ }
+
+ fn set_config(&self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()> {
+ self.lock().unwrap().set_config(offset, buf, flags)
+ }
+
+ fn set_slave_req_fd(&self, vu_req: SlaveFsCacheReq) {
+ self.lock().unwrap().set_slave_req_fd(vu_req)
+ }
+
+ fn get_max_mem_slots(&self) -> Result<u64> {
+ self.lock().unwrap().get_max_mem_slots()
+ }
+
+ fn add_mem_region(&self, region: &VhostUserSingleMemoryRegion, fd: RawFd) -> Result<()> {
+ self.lock().unwrap().add_mem_region(region, fd)
+ }
+
+ fn remove_mem_region(&self, region: &VhostUserSingleMemoryRegion) -> Result<()> {
+ self.lock().unwrap().remove_mem_region(region)
+ }
+}
+
+/// Server to handle service requests from masters from the master communication channel.
+///
+/// The [SlaveReqHandler] acts as a server on the slave side, to handle service requests from
+/// masters on the master communication channel. It's actually a proxy invoking the registered
+/// handler implementing [VhostUserSlaveReqHandler] to do the real work.
+///
+/// The lifetime of the SlaveReqHandler object should be the same as the underline Unix Domain
+/// Socket, so it gets simpler to recover from disconnect.
+///
+/// [VhostUserSlaveReqHandler]: trait.VhostUserSlaveReqHandler.html
+/// [SlaveReqHandler]: struct.SlaveReqHandler.html
+pub struct SlaveReqHandler<S: VhostUserSlaveReqHandler> {
+ // underlying Unix domain socket for communication
+ main_sock: Endpoint<MasterReq>,
+ // the vhost-user backend device object
+ backend: Arc<S>,
+
+ virtio_features: u64,
+ acked_virtio_features: u64,
+ protocol_features: VhostUserProtocolFeatures,
+ acked_protocol_features: u64,
+
+ // sending ack for messages without payload
+ reply_ack_enabled: bool,
+ // whether the endpoint has encountered any failure
+ error: Option<i32>,
+}
+
+impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
+ /// Create a vhost-user slave endpoint.
+ pub(super) fn new(main_sock: Endpoint<MasterReq>, backend: Arc<S>) -> Self {
+ SlaveReqHandler {
+ main_sock,
+ backend,
+ virtio_features: 0,
+ acked_virtio_features: 0,
+ protocol_features: VhostUserProtocolFeatures::empty(),
+ acked_protocol_features: 0,
+ reply_ack_enabled: false,
+ error: None,
+ }
+ }
+
+ /// Create a new vhost-user slave endpoint.
+ ///
+ /// # Arguments
+ /// * - `path` - path of Unix domain socket listener to connect to
+ /// * - `backend` - handler for requests from the master to the slave
+ pub fn connect(path: &str, backend: Arc<S>) -> Result<Self> {
+ Ok(Self::new(Endpoint::<MasterReq>::connect(path)?, backend))
+ }
+
+ /// Mark endpoint as failed with specified error code.
+ pub fn set_failed(&mut self, error: i32) {
+ self.error = Some(error);
+ }
+
+ /// Main entrance to server slave request from the slave communication channel.
+ ///
+ /// Receive and handle one incoming request message from the master. The caller needs to:
+ /// - serialize calls to this function
+ /// - decide what to do when error happens
+ /// - optional recover from failure
+ pub fn handle_request(&mut self) -> Result<()> {
+ // Return error if the endpoint is already in failed state.
+ self.check_state()?;
+
+ // The underlying communication channel is a Unix domain socket in
+ // stream mode, and recvmsg() is a little tricky here. To successfully
+ // receive attached file descriptors, we need to receive messages and
+ // corresponding attached file descriptors in this way:
+ // . recv messsage header and optional attached file
+ // . validate message header
+ // . recv optional message body and payload according size field in
+ // message header
+ // . validate message body and optional payload
+ let (hdr, rfds) = self.main_sock.recv_header()?;
+ let rfds = self.check_attached_rfds(&hdr, rfds)?;
+ let (size, buf) = match hdr.get_size() {
+ 0 => (0, vec![0u8; 0]),
+ len => {
+ let (size2, rbuf) = self.main_sock.recv_data(len as usize)?;
+ if size2 != len as usize {
+ return Err(Error::InvalidMessage);
+ }
+ (size2, rbuf)
+ }
+ };
+
+ match hdr.get_code() {
+ MasterReq::SET_OWNER => {
+ self.check_request_size(&hdr, size, 0)?;
+ self.backend.set_owner()?;
+ }
+ MasterReq::RESET_OWNER => {
+ self.check_request_size(&hdr, size, 0)?;
+ self.backend.reset_owner()?;
+ }
+ MasterReq::GET_FEATURES => {
+ self.check_request_size(&hdr, size, 0)?;
+ let features = self.backend.get_features()?;
+ let msg = VhostUserU64::new(features);
+ self.send_reply_message(&hdr, &msg)?;
+ self.virtio_features = features;
+ self.update_reply_ack_flag();
+ }
+ MasterReq::SET_FEATURES => {
+ let msg = self.extract_request_body::<VhostUserU64>(&hdr, size, &buf)?;
+ self.backend.set_features(msg.value)?;
+ self.acked_virtio_features = msg.value;
+ self.update_reply_ack_flag();
+ }
+ MasterReq::SET_MEM_TABLE => {
+ let res = self.set_mem_table(&hdr, size, &buf, rfds);
+ self.send_ack_message(&hdr, res)?;
+ }
+ MasterReq::SET_VRING_NUM => {
+ let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
+ let res = self.backend.set_vring_num(msg.index, msg.num);
+ self.send_ack_message(&hdr, res)?;
+ }
+ MasterReq::SET_VRING_ADDR => {
+ let msg = self.extract_request_body::<VhostUserVringAddr>(&hdr, size, &buf)?;
+ let flags = match VhostUserVringAddrFlags::from_bits(msg.flags) {
+ Some(val) => val,
+ None => return Err(Error::InvalidMessage),
+ };
+ let res = self.backend.set_vring_addr(
+ msg.index,
+ flags,
+ msg.descriptor,
+ msg.used,
+ msg.available,
+ msg.log,
+ );
+ self.send_ack_message(&hdr, res)?;
+ }
+ MasterReq::SET_VRING_BASE => {
+ let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
+ let res = self.backend.set_vring_base(msg.index, msg.num);
+ self.send_ack_message(&hdr, res)?;
+ }
+ MasterReq::GET_VRING_BASE => {
+ let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
+ let reply = self.backend.get_vring_base(msg.index)?;
+ self.send_reply_message(&hdr, &reply)?;
+ }
+ MasterReq::SET_VRING_CALL => {
+ self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?;
+ let (index, rfds) = self.handle_vring_fd_request(&buf, rfds)?;
+ let res = self.backend.set_vring_call(index, rfds);
+ self.send_ack_message(&hdr, res)?;
+ }
+ MasterReq::SET_VRING_KICK => {
+ self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?;
+ let (index, rfds) = self.handle_vring_fd_request(&buf, rfds)?;
+ let res = self.backend.set_vring_kick(index, rfds);
+ self.send_ack_message(&hdr, res)?;
+ }
+ MasterReq::SET_VRING_ERR => {
+ self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?;
+ let (index, rfds) = self.handle_vring_fd_request(&buf, rfds)?;
+ let res = self.backend.set_vring_err(index, rfds);
+ self.send_ack_message(&hdr, res)?;
+ }
+ MasterReq::GET_PROTOCOL_FEATURES => {
+ self.check_request_size(&hdr, size, 0)?;
+ let features = self.backend.get_protocol_features()?;
+ let msg = VhostUserU64::new(features.bits());
+ self.send_reply_message(&hdr, &msg)?;
+ self.protocol_features = features;
+ self.update_reply_ack_flag();
+ }
+ MasterReq::SET_PROTOCOL_FEATURES => {
+ let msg = self.extract_request_body::<VhostUserU64>(&hdr, size, &buf)?;
+ self.backend.set_protocol_features(msg.value)?;
+ self.acked_protocol_features = msg.value;
+ self.update_reply_ack_flag();
+ }
+ MasterReq::GET_QUEUE_NUM => {
+ if self.acked_protocol_features & VhostUserProtocolFeatures::MQ.bits() == 0 {
+ return Err(Error::InvalidOperation);
+ }
+ self.check_request_size(&hdr, size, 0)?;
+ let num = self.backend.get_queue_num()?;
+ let msg = VhostUserU64::new(num);
+ self.send_reply_message(&hdr, &msg)?;
+ }
+ MasterReq::SET_VRING_ENABLE => {
+ let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
+ if self.acked_protocol_features & VhostUserProtocolFeatures::MQ.bits() == 0
+ && msg.index > 0
+ {
+ return Err(Error::InvalidOperation);
+ }
+ let enable = match msg.num {
+ 1 => true,
+ 0 => false,
+ _ => return Err(Error::InvalidParam),
+ };
+
+ let res = self.backend.set_vring_enable(msg.index, enable);
+ self.send_ack_message(&hdr, res)?;
+ }
+ MasterReq::GET_CONFIG => {
+ if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 {
+ return Err(Error::InvalidOperation);
+ }
+ self.check_request_size(&hdr, size, hdr.get_size() as usize)?;
+ self.get_config(&hdr, &buf)?;
+ }
+ MasterReq::SET_CONFIG => {
+ if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 {
+ return Err(Error::InvalidOperation);
+ }
+ self.check_request_size(&hdr, size, hdr.get_size() as usize)?;
+ self.set_config(&hdr, size, &buf)?;
+ }
+ MasterReq::SET_SLAVE_REQ_FD => {
+ if self.acked_protocol_features & VhostUserProtocolFeatures::SLAVE_REQ.bits() == 0 {
+ return Err(Error::InvalidOperation);
+ }
+ self.check_request_size(&hdr, size, hdr.get_size() as usize)?;
+ self.set_slave_req_fd(&hdr, rfds)?;
+ }
+ MasterReq::GET_MAX_MEM_SLOTS => {
+ if self.acked_protocol_features
+ & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits()
+ == 0
+ {
+ return Err(Error::InvalidOperation);
+ }
+ self.check_request_size(&hdr, size, 0)?;
+ let num = self.backend.get_max_mem_slots()?;
+ let msg = VhostUserU64::new(num);
+ self.send_reply_message(&hdr, &msg)?;
+ }
+ MasterReq::ADD_MEM_REG => {
+ if self.acked_protocol_features
+ & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits()
+ == 0
+ {
+ return Err(Error::InvalidOperation);
+ }
+ let fd = if let Some(fds) = &rfds {
+ if fds.len() != 1 {
+ return Err(Error::InvalidParam);
+ }
+ fds[0]
+ } else {
+ return Err(Error::InvalidParam);
+ };
+
+ let msg =
+ self.extract_request_body::<VhostUserSingleMemoryRegion>(&hdr, size, &buf)?;
+ let res = self.backend.add_mem_region(&msg, fd);
+ self.send_ack_message(&hdr, res)?;
+ }
+ MasterReq::REM_MEM_REG => {
+ if self.acked_protocol_features
+ & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits()
+ == 0
+ {
+ return Err(Error::InvalidOperation);
+ }
+
+ let msg =
+ self.extract_request_body::<VhostUserSingleMemoryRegion>(&hdr, size, &buf)?;
+ let res = self.backend.remove_mem_region(&msg);
+ self.send_ack_message(&hdr, res)?;
+ }
+ _ => {
+ return Err(Error::InvalidMessage);
+ }
+ }
+ Ok(())
+ }
+
+ fn set_mem_table(
+ &mut self,
+ hdr: &VhostUserMsgHeader<MasterReq>,
+ size: usize,
+ buf: &[u8],
+ rfds: Option<Vec<RawFd>>,
+ ) -> Result<()> {
+ self.check_request_size(&hdr, size, hdr.get_size() as usize)?;
+
+ // check message size is consistent
+ let hdrsize = mem::size_of::<VhostUserMemory>();
+ if size < hdrsize {
+ Endpoint::<MasterReq>::close_rfds(rfds);
+ return Err(Error::InvalidMessage);
+ }
+ let msg = unsafe { &*(buf.as_ptr() as *const VhostUserMemory) };
+ if !msg.is_valid() {
+ Endpoint::<MasterReq>::close_rfds(rfds);
+ return Err(Error::InvalidMessage);
+ }
+ if size != hdrsize + msg.num_regions as usize * mem::size_of::<VhostUserMemoryRegion>() {
+ Endpoint::<MasterReq>::close_rfds(rfds);
+ return Err(Error::InvalidMessage);
+ }
+
+ // validate number of fds matching number of memory regions
+ let fds = match rfds {
+ None => return Err(Error::InvalidMessage),
+ Some(fds) => {
+ if fds.len() != msg.num_regions as usize {
+ Endpoint::<MasterReq>::close_rfds(Some(fds));
+ return Err(Error::InvalidMessage);
+ }
+ fds
+ }
+ };
+
+ // Validate memory regions
+ let regions = unsafe {
+ slice::from_raw_parts(
+ buf.as_ptr().add(hdrsize) as *const VhostUserMemoryRegion,
+ msg.num_regions as usize,
+ )
+ };
+ for region in regions.iter() {
+ if !region.is_valid() {
+ Endpoint::<MasterReq>::close_rfds(Some(fds));
+ return Err(Error::InvalidMessage);
+ }
+ }
+
+ self.backend.set_mem_table(&regions, &fds)
+ }
+
+ fn get_config(&mut self, hdr: &VhostUserMsgHeader<MasterReq>, buf: &[u8]) -> Result<()> {
+ let payload_offset = mem::size_of::<VhostUserConfig>();
+ if buf.len() > MAX_MSG_SIZE || buf.len() < payload_offset {
+ return Err(Error::InvalidMessage);
+ }
+ let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const VhostUserConfig) };
+ if !msg.is_valid() {
+ return Err(Error::InvalidMessage);
+ }
+ if buf.len() - payload_offset != msg.size as usize {
+ return Err(Error::InvalidMessage);
+ }
+ let flags = match VhostUserConfigFlags::from_bits(msg.flags) {
+ Some(val) => val,
+ None => return Err(Error::InvalidMessage),
+ };
+ let res = self.backend.get_config(msg.offset, msg.size, flags);
+
+ // vhost-user slave's payload size MUST match master's request
+ // on success, uses zero length of payload to indicate an error
+ // to vhost-user master.
+ match res {
+ Ok(ref buf) if buf.len() == msg.size as usize => {
+ let reply = VhostUserConfig::new(msg.offset, buf.len() as u32, flags);
+ self.send_reply_with_payload(&hdr, &reply, buf.as_slice())?;
+ }
+ Ok(_) => {
+ let reply = VhostUserConfig::new(msg.offset, 0, flags);
+ self.send_reply_message(&hdr, &reply)?;
+ }
+ Err(_) => {
+ let reply = VhostUserConfig::new(msg.offset, 0, flags);
+ self.send_reply_message(&hdr, &reply)?;
+ }
+ }
+ Ok(())
+ }
+
+ fn set_config(
+ &mut self,
+ hdr: &VhostUserMsgHeader<MasterReq>,
+ size: usize,
+ buf: &[u8],
+ ) -> Result<()> {
+ if size > MAX_MSG_SIZE || size < mem::size_of::<VhostUserConfig>() {
+ return Err(Error::InvalidMessage);
+ }
+ let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const VhostUserConfig) };
+ if !msg.is_valid() {
+ return Err(Error::InvalidMessage);
+ }
+ if size - mem::size_of::<VhostUserConfig>() != msg.size as usize {
+ return Err(Error::InvalidMessage);
+ }
+ let flags: VhostUserConfigFlags;
+ match VhostUserConfigFlags::from_bits(msg.flags) {
+ Some(val) => flags = val,
+ None => return Err(Error::InvalidMessage),
+ }
+
+ let res = self.backend.set_config(msg.offset, buf, flags);
+ self.send_ack_message(&hdr, res)?;
+ Ok(())
+ }
+
+ fn set_slave_req_fd(
+ &mut self,
+ hdr: &VhostUserMsgHeader<MasterReq>,
+ rfds: Option<Vec<RawFd>>,
+ ) -> Result<()> {
+ if let Some(fds) = rfds {
+ if fds.len() == 1 {
+ let sock = unsafe { UnixStream::from_raw_fd(fds[0]) };
+ let vu_req = SlaveFsCacheReq::from_stream(sock);
+ self.backend.set_slave_req_fd(vu_req);
+ self.send_ack_message(&hdr, Ok(()))
+ } else {
+ Err(Error::InvalidMessage)
+ }
+ } else {
+ Err(Error::InvalidMessage)
+ }
+ }
+
+ fn handle_vring_fd_request(
+ &mut self,
+ buf: &[u8],
+ rfds: Option<Vec<RawFd>>,
+ ) -> Result<(u8, Option<RawFd>)> {
+ if buf.len() > MAX_MSG_SIZE || buf.len() < mem::size_of::<VhostUserU64>() {
+ return Err(Error::InvalidMessage);
+ }
+ let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const VhostUserU64) };
+ if !msg.is_valid() {
+ return Err(Error::InvalidMessage);
+ }
+
+ // Bits (0-7) of the payload contain the vring index. Bit 8 is the
+ // invalid FD flag. This flag is set when there is no file descriptor
+ // in the ancillary data. This signals that polling will be used
+ // instead of waiting for the call.
+ let nofd = (msg.value & 0x100u64) == 0x100u64;
+
+ let mut rfd = None;
+ match rfds {
+ Some(fds) => {
+ if !nofd && fds.len() == 1 {
+ rfd = Some(fds[0]);
+ } else if (nofd && !fds.is_empty()) || (!nofd && fds.len() != 1) {
+ Endpoint::<MasterReq>::close_rfds(Some(fds));
+ return Err(Error::InvalidMessage);
+ }
+ }
+ None => {
+ if !nofd {
+ return Err(Error::InvalidMessage);
+ }
+ }
+ }
+ Ok((msg.value as u8, rfd))
+ }
+
+ fn check_state(&self) -> Result<()> {
+ match self.error {
+ Some(e) => Err(Error::SocketBroken(std::io::Error::from_raw_os_error(e))),
+ None => Ok(()),
+ }
+ }
+
+ fn check_request_size(
+ &self,
+ hdr: &VhostUserMsgHeader<MasterReq>,
+ size: usize,
+ expected: usize,
+ ) -> Result<()> {
+ if hdr.get_size() as usize != expected
+ || hdr.is_reply()
+ || hdr.get_version() != 0x1
+ || size != expected
+ {
+ return Err(Error::InvalidMessage);
+ }
+ Ok(())
+ }
+
+ fn check_attached_rfds(
+ &self,
+ hdr: &VhostUserMsgHeader<MasterReq>,
+ rfds: Option<Vec<RawFd>>,
+ ) -> Result<Option<Vec<RawFd>>> {
+ match hdr.get_code() {
+ MasterReq::SET_MEM_TABLE => Ok(rfds),
+ MasterReq::SET_VRING_CALL => Ok(rfds),
+ MasterReq::SET_VRING_KICK => Ok(rfds),
+ MasterReq::SET_VRING_ERR => Ok(rfds),
+ MasterReq::SET_LOG_BASE => Ok(rfds),
+ MasterReq::SET_LOG_FD => Ok(rfds),
+ MasterReq::SET_SLAVE_REQ_FD => Ok(rfds),
+ MasterReq::SET_INFLIGHT_FD => Ok(rfds),
+ MasterReq::ADD_MEM_REG => Ok(rfds),
+ _ => {
+ if rfds.is_some() {
+ Endpoint::<MasterReq>::close_rfds(rfds);
+ Err(Error::InvalidMessage)
+ } else {
+ Ok(rfds)
+ }
+ }
+ }
+ }
+
+ fn extract_request_body<T: Sized + VhostUserMsgValidator>(
+ &self,
+ hdr: &VhostUserMsgHeader<MasterReq>,
+ size: usize,
+ buf: &[u8],
+ ) -> Result<T> {
+ self.check_request_size(hdr, size, mem::size_of::<T>())?;
+ let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const T) };
+ if !msg.is_valid() {
+ return Err(Error::InvalidMessage);
+ }
+ Ok(msg)
+ }
+
+ fn update_reply_ack_flag(&mut self) {
+ let vflag = VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits();
+ let pflag = VhostUserProtocolFeatures::REPLY_ACK;
+ if (self.virtio_features & vflag) != 0
+ && (self.acked_virtio_features & vflag) != 0
+ && self.protocol_features.contains(pflag)
+ && (self.acked_protocol_features & pflag.bits()) != 0
+ {
+ self.reply_ack_enabled = true;
+ } else {
+ self.reply_ack_enabled = false;
+ }
+ }
+
+ fn new_reply_header<T: Sized>(
+ &self,
+ req: &VhostUserMsgHeader<MasterReq>,
+ payload_size: usize,
+ ) -> Result<VhostUserMsgHeader<MasterReq>> {
+ if mem::size_of::<T>() > MAX_MSG_SIZE
+ || payload_size > MAX_MSG_SIZE
+ || mem::size_of::<T>() + payload_size > MAX_MSG_SIZE
+ {
+ return Err(Error::InvalidParam);
+ }
+ self.check_state()?;
+ Ok(VhostUserMsgHeader::new(
+ req.get_code(),
+ VhostUserHeaderFlag::REPLY.bits(),
+ (mem::size_of::<T>() + payload_size) as u32,
+ ))
+ }
+
+ fn send_ack_message(
+ &mut self,
+ req: &VhostUserMsgHeader<MasterReq>,
+ res: Result<()>,
+ ) -> Result<()> {
+ if self.reply_ack_enabled && req.is_need_reply() {
+ let hdr = self.new_reply_header::<VhostUserU64>(req, 0)?;
+ let val = match res {
+ Ok(_) => 0,
+ Err(_) => 1,
+ };
+ let msg = VhostUserU64::new(val);
+ self.main_sock.send_message(&hdr, &msg, None)?;
+ }
+ Ok(())
+ }
+
+ fn send_reply_message<T>(
+ &mut self,
+ req: &VhostUserMsgHeader<MasterReq>,
+ msg: &T,
+ ) -> Result<()> {
+ let hdr = self.new_reply_header::<T>(req, 0)?;
+ self.main_sock.send_message(&hdr, msg, None)?;
+ Ok(())
+ }
+
+ fn send_reply_with_payload<T: Sized>(
+ &mut self,
+ req: &VhostUserMsgHeader<MasterReq>,
+ msg: &T,
+ payload: &[u8],
+ ) -> Result<()> {
+ let hdr = self.new_reply_header::<T>(req, payload.len())?;
+ self.main_sock
+ .send_message_with_payload(&hdr, msg, payload, None)?;
+ Ok(())
+ }
+}
+
+impl<S: VhostUserSlaveReqHandler> AsRawFd for SlaveReqHandler<S> {
+ fn as_raw_fd(&self) -> RawFd {
+ self.main_sock.as_raw_fd()
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use std::os::unix::io::AsRawFd;
+
+ use super::*;
+ use crate::vhost_user::dummy_slave::DummySlaveReqHandler;
+
+ #[test]
+ fn test_slave_req_handler_new() {
+ let (p1, _p2) = UnixStream::pair().unwrap();
+ let endpoint = Endpoint::<MasterReq>::from_stream(p1);
+ let backend = Arc::new(Mutex::new(DummySlaveReqHandler::new()));
+ let mut handler = SlaveReqHandler::new(endpoint, backend);
+
+ handler.check_state().unwrap();
+ handler.set_failed(libc::EAGAIN);
+ handler.check_state().unwrap_err();
+ assert!(handler.as_raw_fd() >= 0);
+ }
+}
diff --git a/src/vsock.rs b/src/vsock.rs
new file mode 100644
index 0000000..1e1b0b9
--- /dev/null
+++ b/src/vsock.rs
@@ -0,0 +1,30 @@
+// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved.
+// SPDX-License-Identifier: Apache-2.0 or BSD-3-Clause
+//
+// Portions Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+//
+// Portions Copyright 2017 The Chromium OS Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE-BSD-Google file.
+
+//! Trait to control vhost-vsock backend drivers.
+
+use crate::backend::VhostBackend;
+use crate::Result;
+
+/// Trait to control vhost-vsock backend drivers.
+pub trait VhostVsock: VhostBackend {
+ /// Set the CID for the guest.
+ /// This number is used for routing all data destined for running in the guest.
+ /// Each guest on a hypervisor must have an unique CID.
+ ///
+ /// # Arguments
+ /// * `cid` - CID to assign to the guest
+ fn set_guest_cid(&self, cid: u64) -> Result<()>;
+
+ /// Tell the VHOST driver to start performing data transfer.
+ fn start(&self) -> Result<()>;
+
+ /// Tell the VHOST driver to stop performing data transfer.
+ fn stop(&self) -> Result<()>;
+}