diff options
author | Jorge E. Moreira <jemoreira@google.com> | 2021-04-15 18:16:05 +0000 |
---|---|---|
committer | Automerger Merge Worker <android-build-automerger-merge-worker@system.gserviceaccount.com> | 2021-04-15 18:16:05 +0000 |
commit | 63e5d173880667d287719756463561e5493c6a5b (patch) | |
tree | 56be29b9f4f9f6308a7ce7adb8a77cdf115f1bc6 | |
parent | 9dc6172cfed4ab25b52a45f3e74dc992b7fb485d (diff) | |
parent | 86868ea15c099f99ed1ddf83d870e1b025d9d53b (diff) | |
download | vmm_vhost-63e5d173880667d287719756463561e5493c6a5b.tar.gz |
Merge remote-tracking branch 'aosp/upstream-main' am: 86868ea15c
Original change: https://android-review.googlesource.com/c/platform/external/rust/crates/vmm_vhost/+/1676994
Change-Id: I1896033874d9f7edc287a4dbf3d3294251eed104
31 files changed, 7390 insertions, 0 deletions
diff --git a/.cargo/config b/.cargo/config new file mode 100644 index 0000000..bf8523e --- /dev/null +++ b/.cargo/config @@ -0,0 +1,5 @@ +# This workaround is needed because the linker is unable to find __addtf3, +# __multf3 and __subtf3. +# Related issue: https://github.com/rust-lang/compiler-builtins/issues/201 +[target.aarch64-unknown-linux-musl] +rustflags = [ "-C", "target-feature=+crt-static", "-C", "link-arg=-lgcc"] diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..f738aa8 --- /dev/null +++ b/.gitignore @@ -0,0 +1,6 @@ +/build +/kcov_build +/target +.idea +**/*.rs.bk +Cargo.lock diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..bda97eb --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "rust-vmm-ci"] + path = rust-vmm-ci + url = https://github.com/rust-vmm/rust-vmm-ci.git diff --git a/CODEOWNERS b/CODEOWNERS new file mode 100644 index 0000000..7174a1b --- /dev/null +++ b/CODEOWNERS @@ -0,0 +1,2 @@ +# Add the list of code owners here (using their GitHub username) +* gatekeeper-PullAssigner @jiangliu @eryugey @sboeuf @slp diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..917ea25 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,30 @@ +[package] +name = "vmm_vhost" +version = "0.1.0" +keywords = ["vhost", "vhost-user", "virtio", "vdpa"] +description = "a pure rust library for vdpa, vhost and vhost-user" +authors = ["Liu Jiang <gerry@linux.alibaba.com>"] +repository = "https://github.com/rust-vmm/vhost" +documentation = "https://docs.rs/vhost" +readme = "README.md" +license = "Apache-2.0 or BSD-3-Clause" +edition = "2018" + +[features] +default = [] +vhost-vsock = [] +vhost-kern = ["vm-memory"] +vhost-user = [] +vhost-user-master = ["vhost-user"] +vhost-user-slave = ["vhost-user"] + +[dependencies] +bitflags = ">=1.0.1" +libc = ">=0.2.39" + +sys_util = { path = "../../../platform/crosvm/sys_util" } # provided by ebuild +tempfile = { path = "../../../platform/crosvm/tempfile" } # provided by ebuild +vm-memory = { version = "0.2.0", optional = true } + +[dev-dependencies] +vm-memory = { version = "0.2.0", features=["backend-mmap"] } @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/LICENSE-BSD-3-Clause b/LICENSE-BSD-3-Clause new file mode 100644 index 0000000..1ff0cd7 --- /dev/null +++ b/LICENSE-BSD-3-Clause @@ -0,0 +1,27 @@ +// Copyright (C) 2019 Alibaba Cloud. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Alibaba Inc. nor the names of its contributors +// may be used to endorse or promote products derived from this software +// without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/LICENSE-BSD-Chromium b/LICENSE-BSD-Chromium new file mode 100644 index 0000000..8bafca3 --- /dev/null +++ b/LICENSE-BSD-Chromium @@ -0,0 +1,27 @@ +// Copyright 2017 The Chromium OS Authors. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. @@ -0,0 +1,3 @@ +jemoreira@google.com +chirantan@google.com +dgreid@google.com diff --git a/PRESUBMIT.cfg b/PRESUBMIT.cfg new file mode 100644 index 0000000..e52636e --- /dev/null +++ b/PRESUBMIT.cfg @@ -0,0 +1,6 @@ +[Hook Overrides] +cargo_clippy_check: true + +[Hook Overrides Options] +cargo_clippy_check: + --project=. diff --git a/README.md b/README.md new file mode 100644 index 0000000..b0f4dfa --- /dev/null +++ b/README.md @@ -0,0 +1,32 @@ +# vHost +A pure rust library for vDPA, vhost and vhost-user. + +The `vhost` crate aims to help implementing dataplane for virtio backend drivers. It supports three different types of dataplane drivers: +- vhost: the dataplane is implemented by linux kernel +- vhost-user: the dataplane is implemented by dedicated vhost-user servers +- vDPA(vhost DataPath Accelerator): the dataplane is implemented by hardwares + +The main relationship among Traits and Structs exported by the `vhost` crate is as below: + +![vhost Architecture](/docs/vhost_architecture.png) +## Kernel-based vHost Backend Drivers +The vhost drivers in Linux provide in-kernel virtio device emulation. Normally +the hypervisor userspace process emulates I/O accesses from the guest. +Vhost puts virtio emulation code into the kernel, taking hypervisor userspace +out of the picture. This allows device emulation code to directly call into +kernel subsystems instead of performing system calls from userspace. +The hypervisor relies on ioctl based interfaces to control those in-kernel +vhost drivers, such as vhost-net, vhost-scsi and vhost-vsock etc. + +## vHost-user Backend Drivers +The [vhost-user protocol](https://qemu.readthedocs.io/en/latest/interop/vhost-user.html#communication) aims to implement vhost backend drivers in +userspace, which complements the ioctl interface used to control the vhost +implementation in the Linux kernel. It implements the control plane needed +to establish virtqueue sharing with a user space process on the same host. +It uses communication over a Unix domain socket to share file descriptors in +the ancillary data of the message. + +The protocol defines two sides of the communication, master and slave. +Master is the application that shares its virtqueues, slave is the consumer +of the virtqueues. Master and slave can be either a client (i.e. connecting) +or server (listening) in the socket communication. diff --git a/coverage_config_aarch64.json b/coverage_config_aarch64.json new file mode 100644 index 0000000..67543fc --- /dev/null +++ b/coverage_config_aarch64.json @@ -0,0 +1 @@ +{"coverage_score": 39.8, "exclude_path": "", "crate_features": "vhost-vsock,vhost-kern,vhost-user-master,vhost-user-slave"}
\ No newline at end of file diff --git a/coverage_config_x86_64.json b/coverage_config_x86_64.json new file mode 100644 index 0000000..2b2c164 --- /dev/null +++ b/coverage_config_x86_64.json @@ -0,0 +1 @@ +{"coverage_score": 81.2, "exclude_path": "src/vhost_kern/", "crate_features": "vhost-user-master,vhost-user-slave"} diff --git a/docs/vhost_architecture.drawio b/docs/vhost_architecture.drawio new file mode 100644 index 0000000..5008d28 --- /dev/null +++ b/docs/vhost_architecture.drawio @@ -0,0 +1,171 @@ +<mxfile host="65bd71144e" modified="2021-02-22T05:37:26.833Z" agent="5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Code/1.53.0 Chrome/87.0.4280.141 Electron/11.2.1 Safari/537.36" etag="HWRXqybJYJqQhnlJWfmB" version="14.2.4" type="embed"> + <diagram id="xCgrIAQPDQM0eynUYBOE" name="Page-1"> + <mxGraphModel dx="3446" dy="1284" grid="1" gridSize="10" guides="1" tooltips="1" connect="1" arrows="1" fold="1" page="1" pageScale="1" pageWidth="850" pageHeight="1100" math="0" shadow="0"> + <root> + <mxCell id="0"/> + <mxCell id="1" parent="0"/> + <mxCell id="46" value="<br><br><br><br><br><br><br><br><br><br><br><br><br><br><br><br>" style="rounded=0;whiteSpace=wrap;html=1;labelBackgroundColor=none;sketch=0;fontSize=25;fontColor=#FF00FF;fillColor=none;strokeColor=#4D4D4D;strokeWidth=5;" vertex="1" parent="1"> + <mxGeometry x="1620" y="27" width="450" height="990" as="geometry"/> + </mxCell> + <mxCell id="47" value="" style="shape=hexagon;perimeter=hexagonPerimeter2;whiteSpace=wrap;html=1;fixedSize=1;rounded=0;labelBackgroundColor=none;sketch=0;fillColor=none;fontSize=25;dashed=1;strokeWidth=6;fontColor=#FF00FF;strokeColor=#FF00FF;" vertex="1" parent="1"> + <mxGeometry x="790" y="237" width="1260" height="750" as="geometry"/> + </mxCell> + <mxCell id="44" value="" style="rounded=0;whiteSpace=wrap;html=1;labelBackgroundColor=none;sketch=0;fontSize=25;fontColor=#FF00FF;fillColor=none;strokeColor=#4D4D4D;strokeWidth=5;" vertex="1" parent="1"> + <mxGeometry x="-10" y="37" width="1250" height="670" as="geometry"/> + </mxCell> + <mxCell id="2" value="<pre style="font-family: &quot;jetbrains mono&quot;, monospace; font-size: 16.5pt;">MasterReqHandler</pre>" style="rounded=0;whiteSpace=wrap;html=1;fontStyle=1;labelBackgroundColor=none;fontColor=#FF00FF;strokeColor=#FF00FF;" parent="1" vertex="1"> + <mxGeometry x="830" y="477" width="220" height="50" as="geometry"/> + </mxCell> + <mxCell id="4" value="<pre style="font-size: 16.5pt; font-weight: 700; font-family: &quot;jetbrains mono&quot;, monospace;">VhostUserMasterReqHandler</pre>" style="rounded=1;whiteSpace=wrap;html=1;labelBackgroundColor=none;fontColor=#FF00FF;strokeColor=#FF00FF;" parent="1" vertex="1"> + <mxGeometry x="840" y="597" width="360" height="60" as="geometry"/> + </mxCell> + <mxCell id="6" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0;exitY=0.5;exitDx=0;exitDy=0;entryX=1;entryY=0.5;entryDx=0;entryDy=0;labelBackgroundColor=none;fontColor=#FF00FF;strokeColor=#FF00FF;" edge="1" parent="1" source="5" target="2"> + <mxGeometry relative="1" as="geometry"> + <Array as="points"> + <mxPoint x="1280" y="792"/> + <mxPoint x="1280" y="502"/> + </Array> + </mxGeometry> + </mxCell> + <mxCell id="5" value="<pre style="font-family: &quot;jetbrains mono&quot;, monospace; font-size: 16.5pt;"><pre style="font-family: &quot;jetbrains mono&quot; , monospace ; font-size: 16.5pt">SlaveFsCacheReq</pre></pre>" style="rounded=0;whiteSpace=wrap;html=1;fontStyle=1;labelBackgroundColor=none;fontColor=#FF00FF;strokeColor=#FF00FF;" parent="1" vertex="1"> + <mxGeometry x="1715" y="767" width="220" height="50" as="geometry"/> + </mxCell> + <mxCell id="7" value="<pre style="font-size: 16.5pt; font-weight: 700; font-family: &quot;jetbrains mono&quot;, monospace;">VhostUserMasterReqHandlerMut</pre>" style="rounded=1;whiteSpace=wrap;html=1;labelBackgroundColor=none;fontColor=#FF00FF;strokeColor=#FF00FF;" vertex="1" parent="1"> + <mxGeometry x="1630" y="657" width="390" height="60" as="geometry"/> + </mxCell> + <mxCell id="8" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.5;exitY=1;exitDx=0;exitDy=0;labelBackgroundColor=none;fontColor=#FF00FF;strokeColor=#FF00FF;" edge="1" parent="1" source="2" target="4"> + <mxGeometry relative="1" as="geometry"> + <mxPoint x="950" y="657" as="sourcePoint"/> + <mxPoint x="680" y="717" as="targetPoint"/> + </mxGeometry> + </mxCell> + <mxCell id="10" value="<pre style="font-family: &quot;jetbrains mono&quot;, monospace; font-size: 16.5pt;"><pre style="font-family: &quot;jetbrains mono&quot; , monospace ; font-size: 16.5pt">SlaveListener</pre></pre>" style="rounded=0;whiteSpace=wrap;html=1;fontStyle=1;labelBackgroundColor=none;fontColor=#FF00FF;strokeColor=#FF00FF;" vertex="1" parent="1"> + <mxGeometry x="1360" y="472" width="190" height="50" as="geometry"/> + </mxCell> + <mxCell id="11" value="<pre style="font-family: &quot;jetbrains mono&quot;, monospace; font-size: 16.5pt;"><pre style="font-family: &quot;jetbrains mono&quot; , monospace ; font-size: 16.5pt"><pre style="font-family: &quot;jetbrains mono&quot; , monospace ; font-size: 16.5pt">SlaveReqHandler</pre></pre></pre>" style="rounded=0;whiteSpace=wrap;html=1;fontStyle=1;labelBackgroundColor=none;fontColor=#FF00FF;strokeColor=#FF00FF;" vertex="1" parent="1"> + <mxGeometry x="1712" y="387" width="210" height="50" as="geometry"/> + </mxCell> + <mxCell id="14" value="<pre style="font-size: 16.5pt; font-weight: 700; font-family: &quot;jetbrains mono&quot;, monospace;"><pre style="font-family: &quot;jetbrains mono&quot;, monospace; font-size: 16.5pt;">VhostUserSlaveReqHandler</pre></pre>" style="rounded=1;whiteSpace=wrap;html=1;labelBackgroundColor=none;fontColor=#FF00FF;strokeColor=#FF00FF;" vertex="1" parent="1"> + <mxGeometry x="1652" y="537" width="330" height="60" as="geometry"/> + </mxCell> + <mxCell id="15" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.5;exitY=1;exitDx=0;exitDy=0;entryX=0.5;entryY=0;entryDx=0;entryDy=0;labelBackgroundColor=none;fontColor=#FF00FF;strokeColor=#FF00FF;" edge="1" parent="1" source="11" target="14"> + <mxGeometry relative="1" as="geometry"> + <mxPoint x="1202" y="567" as="sourcePoint"/> + <mxPoint x="1202" y="667" as="targetPoint"/> + </mxGeometry> + </mxCell> + <mxCell id="16" value="<pre style="font-size: 16.5pt; font-weight: 700; font-family: &quot;jetbrains mono&quot;, monospace;">VhostBackend</pre>" style="rounded=1;whiteSpace=wrap;html=1;labelBackgroundColor=none;fontColor=#00994D;strokeColor=#009900;" vertex="1" parent="1"> + <mxGeometry x="390" y="197" width="250" height="60" as="geometry"/> + </mxCell> + <mxCell id="17" value="<pre style="font-family: &quot;jetbrains mono&quot;, monospace; font-size: 16.5pt;">VhostKernBackend</pre>" style="rounded=0;whiteSpace=wrap;html=1;fontStyle=1;labelBackgroundColor=none;strokeColor=#0000CC;fontColor=#0000CC;" vertex="1" parent="1"> + <mxGeometry x="530" y="387" width="220" height="50" as="geometry"/> + </mxCell> + <mxCell id="18" value="<pre style="font-family: &quot;jetbrains mono&quot;, monospace; font-size: 16.5pt;">VhostVdpaBackend</pre>" style="rounded=0;whiteSpace=wrap;html=1;fontStyle=1;labelBackgroundColor=none;fontColor=#808080;strokeColor=#808080;" vertex="1" parent="1"> + <mxGeometry x="270" y="387" width="220" height="50" as="geometry"/> + </mxCell> + <mxCell id="19" value="<pre style="font-family: &quot;jetbrains mono&quot; , monospace ; font-size: 16.5pt">Master</pre>" style="rounded=0;whiteSpace=wrap;html=1;fontStyle=1;labelBackgroundColor=none;fontColor=#FF00FF;strokeColor=#FF00FF;" vertex="1" parent="1"> + <mxGeometry x="820" y="387" width="220" height="50" as="geometry"/> + </mxCell> + <mxCell id="20" value="<pre style="font-family: &quot;jetbrains mono&quot; , monospace ; font-size: 16.5pt">VhostSoftBackend</pre>" style="rounded=0;whiteSpace=wrap;html=1;fontStyle=1;labelBackgroundColor=none;fontColor=#808080;strokeColor=#808080;" vertex="1" parent="1"> + <mxGeometry x="10" y="387" width="220" height="50" as="geometry"/> + </mxCell> + <mxCell id="21" value="Handle virtque in VMM" style="shape=process;whiteSpace=wrap;html=1;backgroundOutline=1;rounded=0;labelBackgroundColor=none;sketch=0;fontSize=25;fontColor=#808080;strokeColor=#808080;" vertex="1" parent="1"> + <mxGeometry x="10" y="557" width="220" height="120" as="geometry"/> + </mxCell> + <mxCell id="23" value="Handle virtque in hardware" style="shape=process;whiteSpace=wrap;html=1;backgroundOutline=1;rounded=0;labelBackgroundColor=none;sketch=0;fontSize=25;fontColor=#808080;strokeColor=#808080;" vertex="1" parent="1"> + <mxGeometry x="270" y="807" width="220" height="120" as="geometry"/> + </mxCell> + <mxCell id="24" value="Handle virtque in kernel" style="shape=process;whiteSpace=wrap;html=1;backgroundOutline=1;rounded=0;labelBackgroundColor=none;sketch=0;fontSize=25;strokeColor=#0000CC;fontColor=#0000CC;" vertex="1" parent="1"> + <mxGeometry x="530" y="807" width="220" height="120" as="geometry"/> + </mxCell> + <mxCell id="25" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.5;exitY=0;exitDx=0;exitDy=0;labelBackgroundColor=none;entryX=0.5;entryY=1;entryDx=0;entryDy=0;strokeColor=#0000CC;" edge="1" parent="1" source="24" target="17"> + <mxGeometry relative="1" as="geometry"> + <mxPoint x="930" y="647" as="sourcePoint"/> + <mxPoint x="930" y="747" as="targetPoint"/> + </mxGeometry> + </mxCell> + <mxCell id="26" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;labelBackgroundColor=none;entryX=0;entryY=0.5;entryDx=0;entryDy=0;fontColor=#FF00FF;strokeColor=#FF00FF;" edge="1" parent="1" source="19" target="11"> + <mxGeometry relative="1" as="geometry"> + <mxPoint x="840" y="917" as="sourcePoint"/> + <mxPoint x="840" y="1017" as="targetPoint"/> + </mxGeometry> + </mxCell> + <mxCell id="27" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.5;exitY=0;exitDx=0;exitDy=0;labelBackgroundColor=none;entryX=0.5;entryY=1;entryDx=0;entryDy=0;fontColor=#808080;strokeColor=#808080;" edge="1" parent="1" source="23" target="18"> + <mxGeometry relative="1" as="geometry"> + <mxPoint x="420" y="807" as="sourcePoint"/> + <mxPoint x="420" y="907" as="targetPoint"/> + </mxGeometry> + </mxCell> + <mxCell id="28" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.5;exitY=0;exitDx=0;exitDy=0;labelBackgroundColor=none;entryX=0.5;entryY=1;entryDx=0;entryDy=0;fontColor=#808080;strokeColor=#808080;" edge="1" parent="1" source="21" target="20"> + <mxGeometry relative="1" as="geometry"> + <mxPoint x="240" y="857" as="sourcePoint"/> + <mxPoint x="240" y="957" as="targetPoint"/> + </mxGeometry> + </mxCell> + <mxCell id="30" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.5;exitY=0;exitDx=0;exitDy=0;labelBackgroundColor=none;entryX=0.5;entryY=1;entryDx=0;entryDy=0;strokeColor=#00994D;" edge="1" parent="1" source="20" target="16"> + <mxGeometry relative="1" as="geometry"> + <mxPoint x="910" y="647" as="sourcePoint"/> + <mxPoint x="910" y="747" as="targetPoint"/> + </mxGeometry> + </mxCell> + <mxCell id="31" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.5;exitY=0;exitDx=0;exitDy=0;labelBackgroundColor=none;strokeColor=#00994D;entryX=0.5;entryY=1;entryDx=0;entryDy=0;" edge="1" parent="1" source="18" target="16"> + <mxGeometry relative="1" as="geometry"> + <mxPoint x="1000" y="177" as="sourcePoint"/> + <mxPoint x="530" y="227" as="targetPoint"/> + </mxGeometry> + </mxCell> + <mxCell id="32" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.5;exitY=0;exitDx=0;exitDy=0;labelBackgroundColor=none;entryX=0.5;entryY=1;entryDx=0;entryDy=0;strokeColor=#00994D;" edge="1" parent="1" source="17" target="16"> + <mxGeometry relative="1" as="geometry"> + <mxPoint x="1010" y="127" as="sourcePoint"/> + <mxPoint x="1505" y="-73" as="targetPoint"/> + </mxGeometry> + </mxCell> + <mxCell id="35" value="<pre style="font-family: &quot;jetbrains mono&quot; , monospace ; font-size: 16.5pt"><pre style="font-family: &quot;jetbrains mono&quot; , monospace ; font-size: 16.5pt">Endpoint</pre></pre>" style="rounded=0;whiteSpace=wrap;html=1;fontStyle=1;labelBackgroundColor=none;fontColor=#FF00FF;strokeColor=#FF00FF;" vertex="1" parent="1"> + <mxGeometry x="1360" y="552" width="190" height="50" as="geometry"/> + </mxCell> + <mxCell id="36" value="<pre style="font-family: &quot;jetbrains mono&quot; , monospace ; font-size: 16.5pt"><pre style="font-family: &quot;jetbrains mono&quot; , monospace ; font-size: 16.5pt">Message</pre></pre>" style="rounded=0;whiteSpace=wrap;html=1;fontStyle=1;labelBackgroundColor=none;fontColor=#FF00FF;strokeColor=#FF00FF;" vertex="1" parent="1"> + <mxGeometry x="1360" y="632" width="190" height="50" as="geometry"/> + </mxCell> + <mxCell id="37" value="<pre style="font-size: 16.5pt ; font-weight: 700 ; font-family: &quot;jetbrains mono&quot; , monospace"><pre style="font-family: &quot;jetbrains mono&quot; , monospace ; font-size: 16.5pt">VhostUserMaster</pre></pre>" style="rounded=1;whiteSpace=wrap;html=1;labelBackgroundColor=none;strokeColor=#FF33FF;fontColor=#FF00FF;" vertex="1" parent="1"> + <mxGeometry x="980" y="257" width="230" height="60" as="geometry"/> + </mxCell> + <mxCell id="38" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.5;exitY=0;exitDx=0;exitDy=0;labelBackgroundColor=none;entryX=0.5;entryY=1;entryDx=0;entryDy=0;strokeColor=#00994D;" edge="1" parent="1" source="19" target="16"> + <mxGeometry relative="1" as="geometry"> + <mxPoint x="1030" y="527" as="sourcePoint"/> + <mxPoint x="515" y="257" as="targetPoint"/> + </mxGeometry> + </mxCell> + <mxCell id="39" value="Handle virtque in remote process" style="shape=process;whiteSpace=wrap;html=1;backgroundOutline=1;rounded=0;labelBackgroundColor=none;sketch=0;fontSize=25;fontColor=#FF00FF;strokeColor=#FF00FF;" vertex="1" parent="1"> + <mxGeometry x="850" y="807" width="220" height="120" as="geometry"/> + </mxCell> + <mxCell id="41" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.5;exitY=0;exitDx=0;exitDy=0;entryX=0.5;entryY=1;entryDx=0;entryDy=0;labelBackgroundColor=none;fontColor=#FF00FF;strokeColor=#FF00FF;" edge="1" parent="1" source="5" target="7"> + <mxGeometry relative="1" as="geometry"> + <mxPoint x="1860" y="187" as="sourcePoint"/> + <mxPoint x="1860" y="267" as="targetPoint"/> + </mxGeometry> + </mxCell> + <mxCell id="43" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.5;exitY=0;exitDx=0;exitDy=0;labelBackgroundColor=none;entryX=0.5;entryY=1;entryDx=0;entryDy=0;strokeColor=#FF00FF;" edge="1" parent="1" source="19" target="37"> + <mxGeometry relative="1" as="geometry"> + <mxPoint x="1430" y="187" as="sourcePoint"/> + <mxPoint x="2102" y="187" as="targetPoint"/> + </mxGeometry> + </mxCell> + <mxCell id="49" value="<pre style="font-size: 16.5pt ; font-weight: 700 ; font-family: &#34;jetbrains mono&#34; , monospace"><pre style="font-family: &#34;jetbrains mono&#34; , monospace ; font-size: 16.5pt">Trait</pre></pre>" style="rounded=1;whiteSpace=wrap;html=1;labelBackgroundColor=none;strokeColor=#FF33FF;fontColor=#FF00FF;" vertex="1" parent="1"> + <mxGeometry x="60" y="1017" width="130" height="60" as="geometry"/> + </mxCell> + <mxCell id="51" value="Vhost-user protocol" style="rounded=1;whiteSpace=wrap;html=1;dashed=1;labelBackgroundColor=none;sketch=0;strokeWidth=5;fontSize=67;fontColor=#FF00FF;fillColor=none;strokeColor=none;" vertex="1" parent="1"> + <mxGeometry x="1220" y="817" width="330" height="150" as="geometry"/> + </mxCell> + <mxCell id="52" value="Vhost-user server" style="rounded=1;whiteSpace=wrap;html=1;dashed=1;labelBackgroundColor=none;sketch=0;strokeWidth=5;fontSize=67;fillColor=none;strokeColor=none;fontColor=#4D4D4D;" vertex="1" parent="1"> + <mxGeometry x="1680" y="57" width="330" height="150" as="geometry"/> + </mxCell> + <mxCell id="53" value="VMM" style="rounded=1;whiteSpace=wrap;html=1;dashed=1;labelBackgroundColor=none;sketch=0;strokeWidth=5;fontSize=67;fillColor=none;strokeColor=none;fontColor=#4D4D4D;" vertex="1" parent="1"> + <mxGeometry x="20" y="47" width="240" height="150" as="geometry"/> + </mxCell> + <mxCell id="54" value="<pre style="font-family: &#34;jetbrains mono&#34; , monospace ; font-size: 16.5pt">Struct</pre>" style="rounded=0;whiteSpace=wrap;html=1;fontStyle=1;labelBackgroundColor=none;strokeColor=#0000CC;fontColor=#0000CC;" vertex="1" parent="1"> + <mxGeometry x="240" y="1022" width="140" height="55" as="geometry"/> + </mxCell> + </root> + </mxGraphModel> + </diagram> +</mxfile>
\ No newline at end of file diff --git a/docs/vhost_architecture.png b/docs/vhost_architecture.png Binary files differnew file mode 100644 index 0000000..4d1e2bc --- /dev/null +++ b/docs/vhost_architecture.png diff --git a/rust-vmm-ci b/rust-vmm-ci new file mode 160000 +Subproject ebc701641fa57f78d03f3f5ecac617b7bf7470b diff --git a/src/backend.rs b/src/backend.rs new file mode 100644 index 0000000..1ae306f --- /dev/null +++ b/src/backend.rs @@ -0,0 +1,506 @@ +// Copyright (C) 2019-2021 Alibaba Cloud. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 or BSD-3-Clause +// +// Portions Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// Portions Copyright 2017 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE-BSD-Google file. + +//! Common traits and structs for vhost-kern and vhost-user backend drivers. + +use std::cell::RefCell; +use std::os::unix::io::RawFd; +use std::sync::RwLock; + +use sys_util::EventFd; + +use super::Result; + +/// Maximum number of memory regions supported. +pub const VHOST_MAX_MEMORY_REGIONS: usize = 255; + +/// Vring configuration data. +pub struct VringConfigData { + /// Maximum queue size supported by the driver. + pub queue_max_size: u16, + /// Actual queue size negotiated by the driver. + pub queue_size: u16, + /// Bitmask of vring flags. + pub flags: u32, + /// Descriptor table address. + pub desc_table_addr: u64, + /// Used ring buffer address. + pub used_ring_addr: u64, + /// Available ring buffer address. + pub avail_ring_addr: u64, + /// Optional address for logging. + pub log_addr: Option<u64>, +} + +impl VringConfigData { + /// Check whether the log (flag, address) pair is valid. + pub fn is_log_addr_valid(&self) -> bool { + if self.flags & 0x1 != 0 && self.log_addr.is_none() { + return false; + } + + true + } + + /// Get the log address, default to zero if not available. + pub fn get_log_addr(&self) -> u64 { + if self.flags & 0x1 != 0 && self.log_addr.is_some() { + self.log_addr.unwrap() + } else { + 0 + } + } +} + +/// Memory region configuration data. +#[derive(Default, Clone, Copy)] +pub struct VhostUserMemoryRegionInfo { + /// Guest physical address of the memory region. + pub guest_phys_addr: u64, + /// Size of the memory region. + pub memory_size: u64, + /// Virtual address in the current process. + pub userspace_addr: u64, + /// Optional offset where region starts in the mapped memory. + pub mmap_offset: u64, + /// Optional file descriptor for mmap. + pub mmap_handle: RawFd, +} + +/// An interface for setting up vhost-based backend drivers with interior mutability. +/// +/// Vhost devices are subset of virtio devices, which improve virtio device's performance by +/// delegating data plane operations to dedicated IO service processes. Vhost devices use the +/// same virtqueue layout as virtio devices to allow vhost devices to be mapped directly to +/// virtio devices. +/// +/// The purpose of vhost is to implement a subset of a virtio device's functionality outside the +/// VMM process. Typically fast paths for IO operations are delegated to the dedicated IO service +/// processes, and slow path for device configuration are still handled by the VMM process. It may +/// also be used to control access permissions of virtio backend devices. +pub trait VhostBackend: std::marker::Sized { + /// Get a bitmask of supported virtio/vhost features. + fn get_features(&self) -> Result<u64>; + + /// Inform the vhost subsystem which features to enable. + /// This should be a subset of supported features from get_features(). + /// + /// # Arguments + /// * `features` - Bitmask of features to set. + fn set_features(&self, features: u64) -> Result<()>; + + /// Set the current process as the owner of the vhost backend. + /// This must be run before any other vhost commands. + fn set_owner(&self) -> Result<()>; + + /// Used to be sent to request disabling all rings + /// This is no longer used. + fn reset_owner(&self) -> Result<()>; + + /// Set the guest memory mappings for vhost to use. + fn set_mem_table(&self, regions: &[VhostUserMemoryRegionInfo]) -> Result<()>; + + /// Set base address for page modification logging. + fn set_log_base(&self, base: u64, fd: Option<RawFd>) -> Result<()>; + + /// Specify an eventfd file descriptor to signal on log write. + fn set_log_fd(&self, fd: RawFd) -> Result<()>; + + /// Set the number of descriptors in the vring. + /// + /// # Arguments + /// * `queue_index` - Index of the queue to set descriptor count for. + /// * `num` - Number of descriptors in the queue. + fn set_vring_num(&self, queue_index: usize, num: u16) -> Result<()>; + + /// Set the addresses for a given vring. + /// + /// # Arguments + /// * `queue_index` - Index of the queue to set addresses for. + /// * `config_data` - Configuration data for a vring. + fn set_vring_addr(&self, queue_index: usize, config_data: &VringConfigData) -> Result<()>; + + /// Set the first index to look for available descriptors. + /// + /// # Arguments + /// * `queue_index` - Index of the queue to modify. + /// * `num` - Index where available descriptors start. + fn set_vring_base(&self, queue_index: usize, base: u16) -> Result<()>; + + /// Get the available vring base offset. + fn get_vring_base(&self, queue_index: usize) -> Result<u32>; + + /// Set the eventfd to trigger when buffers have been used by the host. + /// + /// # Arguments + /// * `queue_index` - Index of the queue to modify. + /// * `fd` - EventFd to trigger. + fn set_vring_call(&self, queue_index: usize, fd: &EventFd) -> Result<()>; + + /// Set the eventfd that will be signaled by the guest when buffers are + /// available for the host to process. + /// + /// # Arguments + /// * `queue_index` - Index of the queue to modify. + /// * `fd` - EventFd that will be signaled from guest. + fn set_vring_kick(&self, queue_index: usize, fd: &EventFd) -> Result<()>; + + /// Set the eventfd that will be signaled by the guest when error happens. + /// + /// # Arguments + /// * `queue_index` - Index of the queue to modify. + /// * `fd` - EventFd that will be signaled from guest. + fn set_vring_err(&self, queue_index: usize, fd: &EventFd) -> Result<()>; +} + +/// An interface for setting up vhost-based backend drivers. +/// +/// Vhost devices are subset of virtio devices, which improve virtio device's performance by +/// delegating data plane operations to dedicated IO service processes. Vhost devices use the +/// same virtqueue layout as virtio devices to allow vhost devices to be mapped directly to +/// virtio devices. +/// +/// The purpose of vhost is to implement a subset of a virtio device's functionality outside the +/// VMM process. Typically fast paths for IO operations are delegated to the dedicated IO service +/// processes, and slow path for device configuration are still handled by the VMM process. It may +/// also be used to control access permissions of virtio backend devices. +pub trait VhostBackendMut: std::marker::Sized { + /// Get a bitmask of supported virtio/vhost features. + fn get_features(&mut self) -> Result<u64>; + + /// Inform the vhost subsystem which features to enable. + /// This should be a subset of supported features from get_features(). + /// + /// # Arguments + /// * `features` - Bitmask of features to set. + fn set_features(&mut self, features: u64) -> Result<()>; + + /// Set the current process as the owner of the vhost backend. + /// This must be run before any other vhost commands. + fn set_owner(&mut self) -> Result<()>; + + /// Used to be sent to request disabling all rings + /// This is no longer used. + fn reset_owner(&mut self) -> Result<()>; + + /// Set the guest memory mappings for vhost to use. + fn set_mem_table(&mut self, regions: &[VhostUserMemoryRegionInfo]) -> Result<()>; + + /// Set base address for page modification logging. + fn set_log_base(&mut self, base: u64, fd: Option<RawFd>) -> Result<()>; + + /// Specify an eventfd file descriptor to signal on log write. + fn set_log_fd(&mut self, fd: RawFd) -> Result<()>; + + /// Set the number of descriptors in the vring. + /// + /// # Arguments + /// * `queue_index` - Index of the queue to set descriptor count for. + /// * `num` - Number of descriptors in the queue. + fn set_vring_num(&mut self, queue_index: usize, num: u16) -> Result<()>; + + /// Set the addresses for a given vring. + /// + /// # Arguments + /// * `queue_index` - Index of the queue to set addresses for. + /// * `config_data` - Configuration data for a vring. + fn set_vring_addr(&mut self, queue_index: usize, config_data: &VringConfigData) -> Result<()>; + + /// Set the first index to look for available descriptors. + /// + /// # Arguments + /// * `queue_index` - Index of the queue to modify. + /// * `num` - Index where available descriptors start. + fn set_vring_base(&mut self, queue_index: usize, base: u16) -> Result<()>; + + /// Get the available vring base offset. + fn get_vring_base(&mut self, queue_index: usize) -> Result<u32>; + + /// Set the eventfd to trigger when buffers have been used by the host. + /// + /// # Arguments + /// * `queue_index` - Index of the queue to modify. + /// * `fd` - EventFd to trigger. + fn set_vring_call(&mut self, queue_index: usize, fd: &EventFd) -> Result<()>; + + /// Set the eventfd that will be signaled by the guest when buffers are + /// available for the host to process. + /// + /// # Arguments + /// * `queue_index` - Index of the queue to modify. + /// * `fd` - EventFd that will be signaled from guest. + fn set_vring_kick(&mut self, queue_index: usize, fd: &EventFd) -> Result<()>; + + /// Set the eventfd that will be signaled by the guest when error happens. + /// + /// # Arguments + /// * `queue_index` - Index of the queue to modify. + /// * `fd` - EventFd that will be signaled from guest. + fn set_vring_err(&mut self, queue_index: usize, fd: &EventFd) -> Result<()>; +} + +impl<T: VhostBackendMut> VhostBackend for RwLock<T> { + fn get_features(&self) -> Result<u64> { + self.write().unwrap().get_features() + } + + fn set_features(&self, features: u64) -> Result<()> { + self.write().unwrap().set_features(features) + } + + fn set_owner(&self) -> Result<()> { + self.write().unwrap().set_owner() + } + + fn reset_owner(&self) -> Result<()> { + self.write().unwrap().reset_owner() + } + + fn set_mem_table(&self, regions: &[VhostUserMemoryRegionInfo]) -> Result<()> { + self.write().unwrap().set_mem_table(regions) + } + + fn set_log_base(&self, base: u64, fd: Option<RawFd>) -> Result<()> { + self.write().unwrap().set_log_base(base, fd) + } + + fn set_log_fd(&self, fd: RawFd) -> Result<()> { + self.write().unwrap().set_log_fd(fd) + } + + fn set_vring_num(&self, queue_index: usize, num: u16) -> Result<()> { + self.write().unwrap().set_vring_num(queue_index, num) + } + + fn set_vring_addr(&self, queue_index: usize, config_data: &VringConfigData) -> Result<()> { + self.write() + .unwrap() + .set_vring_addr(queue_index, config_data) + } + + fn set_vring_base(&self, queue_index: usize, base: u16) -> Result<()> { + self.write().unwrap().set_vring_base(queue_index, base) + } + + fn get_vring_base(&self, queue_index: usize) -> Result<u32> { + self.write().unwrap().get_vring_base(queue_index) + } + + fn set_vring_call(&self, queue_index: usize, fd: &EventFd) -> Result<()> { + self.write().unwrap().set_vring_call(queue_index, fd) + } + + fn set_vring_kick(&self, queue_index: usize, fd: &EventFd) -> Result<()> { + self.write().unwrap().set_vring_kick(queue_index, fd) + } + + fn set_vring_err(&self, queue_index: usize, fd: &EventFd) -> Result<()> { + self.write().unwrap().set_vring_err(queue_index, fd) + } +} + +impl<T: VhostBackendMut> VhostBackend for RefCell<T> { + fn get_features(&self) -> Result<u64> { + self.borrow_mut().get_features() + } + + fn set_features(&self, features: u64) -> Result<()> { + self.borrow_mut().set_features(features) + } + + fn set_owner(&self) -> Result<()> { + self.borrow_mut().set_owner() + } + + fn reset_owner(&self) -> Result<()> { + self.borrow_mut().reset_owner() + } + + fn set_mem_table(&self, regions: &[VhostUserMemoryRegionInfo]) -> Result<()> { + self.borrow_mut().set_mem_table(regions) + } + + fn set_log_base(&self, base: u64, fd: Option<RawFd>) -> Result<()> { + self.borrow_mut().set_log_base(base, fd) + } + + fn set_log_fd(&self, fd: RawFd) -> Result<()> { + self.borrow_mut().set_log_fd(fd) + } + + fn set_vring_num(&self, queue_index: usize, num: u16) -> Result<()> { + self.borrow_mut().set_vring_num(queue_index, num) + } + + fn set_vring_addr(&self, queue_index: usize, config_data: &VringConfigData) -> Result<()> { + self.borrow_mut().set_vring_addr(queue_index, config_data) + } + + fn set_vring_base(&self, queue_index: usize, base: u16) -> Result<()> { + self.borrow_mut().set_vring_base(queue_index, base) + } + + fn get_vring_base(&self, queue_index: usize) -> Result<u32> { + self.borrow_mut().get_vring_base(queue_index) + } + + fn set_vring_call(&self, queue_index: usize, fd: &EventFd) -> Result<()> { + self.borrow_mut().set_vring_call(queue_index, fd) + } + + fn set_vring_kick(&self, queue_index: usize, fd: &EventFd) -> Result<()> { + self.borrow_mut().set_vring_kick(queue_index, fd) + } + + fn set_vring_err(&self, queue_index: usize, fd: &EventFd) -> Result<()> { + self.borrow_mut().set_vring_err(queue_index, fd) + } +} +#[cfg(test)] +mod tests { + use super::*; + + struct MockBackend {} + + impl VhostBackendMut for MockBackend { + fn get_features(&mut self) -> Result<u64> { + Ok(0x1) + } + + fn set_features(&mut self, features: u64) -> Result<()> { + assert_eq!(features, 0x1); + Ok(()) + } + + fn set_owner(&mut self) -> Result<()> { + Ok(()) + } + + fn reset_owner(&mut self) -> Result<()> { + Ok(()) + } + + fn set_mem_table(&mut self, _regions: &[VhostUserMemoryRegionInfo]) -> Result<()> { + Ok(()) + } + + fn set_log_base(&mut self, base: u64, fd: Option<RawFd>) -> Result<()> { + assert_eq!(base, 0x100); + assert_eq!(fd, Some(100)); + Ok(()) + } + + fn set_log_fd(&mut self, fd: RawFd) -> Result<()> { + assert_eq!(fd, 100); + Ok(()) + } + + fn set_vring_num(&mut self, queue_index: usize, num: u16) -> Result<()> { + assert_eq!(queue_index, 1); + assert_eq!(num, 256); + Ok(()) + } + + fn set_vring_addr( + &mut self, + queue_index: usize, + _config_data: &VringConfigData, + ) -> Result<()> { + assert_eq!(queue_index, 1); + Ok(()) + } + + fn set_vring_base(&mut self, queue_index: usize, base: u16) -> Result<()> { + assert_eq!(queue_index, 1); + assert_eq!(base, 2); + Ok(()) + } + + fn get_vring_base(&mut self, queue_index: usize) -> Result<u32> { + assert_eq!(queue_index, 1); + Ok(2) + } + + fn set_vring_call(&mut self, queue_index: usize, _fd: &EventFd) -> Result<()> { + assert_eq!(queue_index, 1); + Ok(()) + } + + fn set_vring_kick(&mut self, queue_index: usize, _fd: &EventFd) -> Result<()> { + assert_eq!(queue_index, 1); + Ok(()) + } + + fn set_vring_err(&mut self, queue_index: usize, _fd: &EventFd) -> Result<()> { + assert_eq!(queue_index, 1); + Ok(()) + } + } + + #[test] + fn test_vring_backend_mut() { + let b = RwLock::new(MockBackend {}); + + assert_eq!(b.get_features().unwrap(), 0x1); + b.set_features(0x1).unwrap(); + b.set_owner().unwrap(); + b.reset_owner().unwrap(); + b.set_mem_table(&[]).unwrap(); + b.set_log_base(0x100, Some(100)).unwrap(); + b.set_log_fd(100).unwrap(); + b.set_vring_num(1, 256).unwrap(); + + let config = VringConfigData { + queue_max_size: 0x1000, + queue_size: 0x2000, + flags: 0x0, + desc_table_addr: 0x4000, + used_ring_addr: 0x5000, + avail_ring_addr: 0x6000, + log_addr: None, + }; + b.set_vring_addr(1, &config).unwrap(); + + b.set_vring_base(1, 2).unwrap(); + assert_eq!(b.get_vring_base(1).unwrap(), 2); + + let eventfd = EventFd::new().unwrap(); + b.set_vring_call(1, &eventfd).unwrap(); + b.set_vring_kick(1, &eventfd).unwrap(); + b.set_vring_err(1, &eventfd).unwrap(); + } + + #[test] + fn test_vring_config_data() { + let mut config = VringConfigData { + queue_max_size: 0x1000, + queue_size: 0x2000, + flags: 0x0, + desc_table_addr: 0x4000, + used_ring_addr: 0x5000, + avail_ring_addr: 0x6000, + log_addr: None, + }; + + assert_eq!(config.is_log_addr_valid(), true); + assert_eq!(config.get_log_addr(), 0); + + config.flags = 0x1; + assert_eq!(config.is_log_addr_valid(), false); + assert_eq!(config.get_log_addr(), 0); + + config.log_addr = Some(0x7000); + assert_eq!(config.is_log_addr_valid(), true); + assert_eq!(config.get_log_addr(), 0x7000); + + config.flags = 0x0; + assert_eq!(config.is_log_addr_valid(), true); + assert_eq!(config.get_log_addr(), 0); + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..a755f4f --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,162 @@ +// Copyright (C) 2019 Alibaba Cloud. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 or BSD-3-Clause + +//! Virtio Vhost Backend Drivers +//! +//! Virtio devices use virtqueues to transport data efficiently. The first generation of virtqueue +//! is a set of three different single-producer, single-consumer ring structures designed to store +//! generic scatter-gather I/O. The virtio specification 1.1 introduces an alternative compact +//! virtqueue layout named "Packed Virtqueue", which is more friendly to memory cache system and +//! hardware implemented virtio devices. The packed virtqueue uses read-write memory, that means +//! the memory will be both read and written by both host and guest. The new Packed Virtqueue is +//! preferred for performance. +//! +//! Vhost is a mechanism to improve performance of Virtio devices by delegate data plane operations +//! to dedicated IO service processes. Only the configuration, I/O submission notification, and I/O +//! completion interruption are piped through the hypervisor. +//! It uses the same virtqueue layout as Virtio to allow Vhost devices to be mapped directly to +//! Virtio devices. This allows a Vhost device to be accessed directly by a guest OS inside a +//! hypervisor process with an existing Virtio (PCI) driver. +//! +//! The initial vhost implementation is a part of the Linux kernel and uses ioctl interface to +//! communicate with userspace applications. Dedicated kernel worker threads are created to handle +//! IO requests from the guest. +//! +//! Later Vhost-user protocol is introduced to complement the ioctl interface used to control the +//! vhost implementation in the Linux kernel. It implements the control plane needed to establish +//! virtqueues sharing with a user space process on the same host. It uses communication over a +//! Unix domain socket to share file descriptors in the ancillary data of the message. +//! The protocol defines 2 sides of the communication, master and slave. Master is the application +//! that shares its virtqueues. Slave is the consumer of the virtqueues. Master and slave can be +//! either a client (i.e. connecting) or server (listening) in the socket communication. + +#![deny(missing_docs)] + +#[cfg_attr(feature = "vhost-user", macro_use)] +extern crate bitflags; +#[cfg_attr(feature = "vhost-kern", macro_use)] +extern crate sys_util; + +mod backend; +pub use backend::*; + +#[cfg(feature = "vhost-kern")] +pub mod vhost_kern; +#[cfg(feature = "vhost-user")] +pub mod vhost_user; +#[cfg(feature = "vhost-vsock")] +pub mod vsock; + +/// Error codes for vhost operations +#[derive(Debug)] +pub enum Error { + /// Invalid operations. + InvalidOperation, + /// Invalid guest memory. + InvalidGuestMemory, + /// Invalid guest memory region. + InvalidGuestMemoryRegion, + /// Invalid queue. + InvalidQueue, + /// Invalid descriptor table address. + DescriptorTableAddress, + /// Invalid used address. + UsedAddress, + /// Invalid available address. + AvailAddress, + /// Invalid log address. + LogAddress, + #[cfg(feature = "vhost-kern")] + /// Error opening the vhost backend driver. + VhostOpen(std::io::Error), + #[cfg(feature = "vhost-kern")] + /// Error while running ioctl. + IoctlError(std::io::Error), + /// Error from IO subsystem. + IOError(std::io::Error), + #[cfg(feature = "vhost-user")] + /// Error from the vhost-user subsystem. + VhostUserProtocol(vhost_user::Error), +} + +impl std::fmt::Display for Error { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + Error::InvalidOperation => write!(f, "invalid vhost operations"), + Error::InvalidGuestMemory => write!(f, "invalid guest memory object"), + Error::InvalidGuestMemoryRegion => write!(f, "invalid guest memory region"), + Error::InvalidQueue => write!(f, "invalid virtque"), + Error::DescriptorTableAddress => write!(f, "invalid virtque descriptor talbe address"), + Error::UsedAddress => write!(f, "invalid virtque used talbe address"), + Error::AvailAddress => write!(f, "invalid virtque available table address"), + Error::LogAddress => write!(f, "invalid virtque log address"), + Error::IOError(e) => write!(f, "IO error: {}", e), + #[cfg(feature = "vhost-kern")] + Error::VhostOpen(e) => write!(f, "failure in opening vhost file: {}", e), + #[cfg(feature = "vhost-kern")] + Error::IoctlError(e) => write!(f, "failure in vhost ioctl: {}", e), + #[cfg(feature = "vhost-user")] + Error::VhostUserProtocol(e) => write!(f, "vhost-user: {}", e), + } + } +} + +impl std::error::Error for Error {} + +#[cfg(feature = "vhost-user")] +impl std::convert::From<vhost_user::Error> for Error { + fn from(err: vhost_user::Error) -> Self { + Error::VhostUserProtocol(err) + } +} + +/// Result of vhost operations +pub type Result<T> = std::result::Result<T, Error>; + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_error() { + assert_eq!( + format!("{}", Error::AvailAddress), + "invalid virtque available table address" + ); + assert_eq!( + format!("{}", Error::InvalidOperation), + "invalid vhost operations" + ); + assert_eq!( + format!("{}", Error::InvalidGuestMemory), + "invalid guest memory object" + ); + assert_eq!( + format!("{}", Error::InvalidGuestMemoryRegion), + "invalid guest memory region" + ); + assert_eq!(format!("{}", Error::InvalidQueue), "invalid virtque"); + assert_eq!( + format!("{}", Error::DescriptorTableAddress), + "invalid virtque descriptor talbe address" + ); + assert_eq!( + format!("{}", Error::UsedAddress), + "invalid virtque used talbe address" + ); + assert_eq!( + format!("{}", Error::LogAddress), + "invalid virtque log address" + ); + + assert_eq!(format!("{:?}", Error::AvailAddress), "AvailAddress"); + } + + #[cfg(feature = "vhost-user")] + #[test] + fn test_convert_from_vhost_user_error() { + let e: Error = vhost_user::Error::OversizedMsg.into(); + + assert_eq!(format!("{}", e), "vhost-user: oversized message"); + } +} diff --git a/src/vhost_kern/mod.rs b/src/vhost_kern/mod.rs new file mode 100644 index 0000000..5daca51 --- /dev/null +++ b/src/vhost_kern/mod.rs @@ -0,0 +1,283 @@ +// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 or BSD-3-Clause +// +// Portions Copyright 2017 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE-BSD-Google file. + +//! Traits and structs to control Linux in-kernel vhost drivers. +//! +//! The initial vhost implementation is a part of the Linux kernel and uses ioctl interface to +//! communicate with userspace applications. This sub module provides ioctl based interfaces to +//! control the in-kernel net, scsi, vsock vhost drivers. + +use std::os::unix::io::{AsRawFd, RawFd}; + +use sys_util::ioctl::{ioctl, ioctl_with_mut_ref, ioctl_with_ptr, ioctl_with_ref}; +use sys_util::EventFd; +use vm_memory::{Address, GuestAddress, GuestAddressSpace, GuestMemory, GuestUsize}; + +use super::{ + Error, Result, VhostBackend, VhostUserMemoryRegionInfo, VringConfigData, + VHOST_MAX_MEMORY_REGIONS, +}; + +pub mod vhost_binding; +use self::vhost_binding::*; + +#[cfg(feature = "vhost-vsock")] +pub mod vsock; + +#[inline] +fn ioctl_result<T>(rc: i32, res: T) -> Result<T> { + if rc < 0 { + Err(Error::IoctlError(std::io::Error::last_os_error())) + } else { + Ok(res) + } +} + +/// Represent an in-kernel vhost device backend. +pub trait VhostKernBackend: AsRawFd { + /// Associated type to access guest memory. + type AS: GuestAddressSpace; + + /// Get the object to access the guest's memory. + fn mem(&self) -> &Self::AS; + + /// Check whether the ring configuration is valid. + fn is_valid(&self, config_data: &VringConfigData) -> bool { + let queue_size = config_data.queue_size; + if queue_size > config_data.queue_max_size + || queue_size == 0 + || (queue_size & (queue_size - 1)) != 0 + { + return false; + } + + let m = self.mem().memory(); + let desc_table_size = 16 * u64::from(queue_size) as GuestUsize; + let avail_ring_size = 6 + 2 * u64::from(queue_size) as GuestUsize; + let used_ring_size = 6 + 8 * u64::from(queue_size) as GuestUsize; + if GuestAddress(config_data.desc_table_addr) + .checked_add(desc_table_size) + .map_or(true, |v| !m.address_in_range(v)) + { + return false; + } + if GuestAddress(config_data.avail_ring_addr) + .checked_add(avail_ring_size) + .map_or(true, |v| !m.address_in_range(v)) + { + return false; + } + if GuestAddress(config_data.used_ring_addr) + .checked_add(used_ring_size) + .map_or(true, |v| !m.address_in_range(v)) + { + return false; + } + + config_data.is_log_addr_valid() + } +} + +impl<T: VhostKernBackend> VhostBackend for T { + /// Get a bitmask of supported virtio/vhost features. + fn get_features(&self) -> Result<u64> { + let mut avail_features: u64 = 0; + // This ioctl is called on a valid vhost fd and has its return value checked. + let ret = unsafe { ioctl_with_mut_ref(self, VHOST_GET_FEATURES(), &mut avail_features) }; + ioctl_result(ret, avail_features) + } + + /// Inform the vhost subsystem which features to enable. This should be a subset of + /// supported features from VHOST_GET_FEATURES. + /// + /// # Arguments + /// * `features` - Bitmask of features to set. + fn set_features(&self, features: u64) -> Result<()> { + // This ioctl is called on a valid vhost fd and has its return value checked. + let ret = unsafe { ioctl_with_ref(self, VHOST_SET_FEATURES(), &features) }; + ioctl_result(ret, ()) + } + + /// Set the current process as the owner of this file descriptor. + /// This must be run before any other vhost ioctls. + fn set_owner(&self) -> Result<()> { + // This ioctl is called on a valid vhost fd and has its return value checked. + let ret = unsafe { ioctl(self, VHOST_SET_OWNER()) }; + ioctl_result(ret, ()) + } + + fn reset_owner(&self) -> Result<()> { + // This ioctl is called on a valid vhost fd and has its return value checked. + let ret = unsafe { ioctl(self, VHOST_RESET_OWNER()) }; + ioctl_result(ret, ()) + } + + /// Set the guest memory mappings for vhost to use. + fn set_mem_table(&self, regions: &[VhostUserMemoryRegionInfo]) -> Result<()> { + if regions.is_empty() || regions.len() > VHOST_MAX_MEMORY_REGIONS { + return Err(Error::InvalidGuestMemory); + } + + let mut vhost_memory = VhostMemory::new(regions.len() as u16); + for (index, region) in regions.iter().enumerate() { + vhost_memory.set_region( + index as u32, + &vhost_memory_region { + guest_phys_addr: region.guest_phys_addr, + memory_size: region.memory_size, + userspace_addr: region.userspace_addr, + flags_padding: 0u64, + }, + )?; + } + + // This ioctl is called with a pointer that is valid for the lifetime + // of this function. The kernel will make its own copy of the memory + // tables. As always, check the return value. + let ret = unsafe { ioctl_with_ptr(self, VHOST_SET_MEM_TABLE(), vhost_memory.as_ptr()) }; + ioctl_result(ret, ()) + } + + /// Set base address for page modification logging. + /// + /// # Arguments + /// * `base` - Base address for page modification logging. + fn set_log_base(&self, base: u64, fd: Option<RawFd>) -> Result<()> { + if fd.is_some() { + return Err(Error::LogAddress); + } + + // This ioctl is called on a valid vhost fd and has its return value checked. + let ret = unsafe { ioctl_with_ref(self, VHOST_SET_LOG_BASE(), &base) }; + ioctl_result(ret, ()) + } + + /// Specify an eventfd file descriptor to signal on log write. + fn set_log_fd(&self, fd: RawFd) -> Result<()> { + // This ioctl is called on a valid vhost fd and has its return value checked. + let val: i32 = fd; + let ret = unsafe { ioctl_with_ref(self, VHOST_SET_LOG_FD(), &val) }; + ioctl_result(ret, ()) + } + + /// Set the number of descriptors in the vring. + /// + /// # Arguments + /// * `queue_index` - Index of the queue to set descriptor count for. + /// * `num` - Number of descriptors in the queue. + fn set_vring_num(&self, queue_index: usize, num: u16) -> Result<()> { + let vring_state = vhost_vring_state { + index: queue_index as u32, + num: u32::from(num), + }; + + // This ioctl is called on a valid vhost fd and has its return value checked. + let ret = unsafe { ioctl_with_ref(self, VHOST_SET_VRING_NUM(), &vring_state) }; + ioctl_result(ret, ()) + } + + /// Set the addresses for a given vring. + /// + /// # Arguments + /// * `queue_index` - Index of the queue to set addresses for. + /// * `config_data` - Vring config data. + fn set_vring_addr(&self, queue_index: usize, config_data: &VringConfigData) -> Result<()> { + if !self.is_valid(config_data) { + return Err(Error::InvalidQueue); + } + + let vring_addr = vhost_vring_addr { + index: queue_index as u32, + flags: config_data.flags, + desc_user_addr: config_data.desc_table_addr, + used_user_addr: config_data.used_ring_addr, + avail_user_addr: config_data.avail_ring_addr, + log_guest_addr: config_data.get_log_addr(), + }; + + // This ioctl is called on a valid vhost fd and has its + // return value checked. + let ret = unsafe { ioctl_with_ref(self, VHOST_SET_VRING_ADDR(), &vring_addr) }; + ioctl_result(ret, ()) + } + + /// Set the first index to look for available descriptors. + /// + /// # Arguments + /// * `queue_index` - Index of the queue to modify. + /// * `num` - Index where available descriptors start. + fn set_vring_base(&self, queue_index: usize, base: u16) -> Result<()> { + let vring_state = vhost_vring_state { + index: queue_index as u32, + num: u32::from(base), + }; + + // This ioctl is called on a valid vhost fd and has its return value checked. + let ret = unsafe { ioctl_with_ref(self, VHOST_SET_VRING_BASE(), &vring_state) }; + ioctl_result(ret, ()) + } + + /// Get a bitmask of supported virtio/vhost features. + fn get_vring_base(&self, queue_index: usize) -> Result<u32> { + let vring_state = vhost_vring_state { + index: queue_index as u32, + num: 0, + }; + // This ioctl is called on a valid vhost fd and has its return value checked. + let ret = unsafe { ioctl_with_ref(self, VHOST_GET_VRING_BASE(), &vring_state) }; + ioctl_result(ret, vring_state.num) + } + + /// Set the eventfd to trigger when buffers have been used by the host. + /// + /// # Arguments + /// * `queue_index` - Index of the queue to modify. + /// * `fd` - EventFd to trigger. + fn set_vring_call(&self, queue_index: usize, fd: &EventFd) -> Result<()> { + let vring_file = vhost_vring_file { + index: queue_index as u32, + fd: fd.as_raw_fd(), + }; + + // This ioctl is called on a valid vhost fd and has its return value checked. + let ret = unsafe { ioctl_with_ref(self, VHOST_SET_VRING_CALL(), &vring_file) }; + ioctl_result(ret, ()) + } + + /// Set the eventfd that will be signaled by the guest when buffers are + /// available for the host to process. + /// + /// # Arguments + /// * `queue_index` - Index of the queue to modify. + /// * `fd` - EventFd that will be signaled from guest. + fn set_vring_kick(&self, queue_index: usize, fd: &EventFd) -> Result<()> { + let vring_file = vhost_vring_file { + index: queue_index as u32, + fd: fd.as_raw_fd(), + }; + + // This ioctl is called on a valid vhost fd and has its return value checked. + let ret = unsafe { ioctl_with_ref(self, VHOST_SET_VRING_KICK(), &vring_file) }; + ioctl_result(ret, ()) + } + + /// Set the eventfd to signal an error from the vhost backend. + /// + /// # Arguments + /// * `queue_index` - Index of the queue to modify. + /// * `fd` - EventFd that will be signaled from the backend. + fn set_vring_err(&self, queue_index: usize, fd: &EventFd) -> Result<()> { + let vring_file = vhost_vring_file { + index: queue_index as u32, + fd: fd.as_raw_fd(), + }; + + // This ioctl is called on a valid vhost fd and has its return value checked. + let ret = unsafe { ioctl_with_ref(self, VHOST_SET_VRING_ERR(), &vring_file) }; + ioctl_result(ret, ()) + } +} diff --git a/src/vhost_kern/vhost_binding.rs b/src/vhost_kern/vhost_binding.rs new file mode 100644 index 0000000..57ae698 --- /dev/null +++ b/src/vhost_kern/vhost_binding.rs @@ -0,0 +1,406 @@ +// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 or BSD-3-Clause +// +// Portions Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// Portions Copyright 2017 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE-BSD-Google file. + +/* Auto-generated by bindgen then manually edited for simplicity */ + +#![allow(non_upper_case_globals)] +#![allow(non_camel_case_types)] +#![allow(non_snake_case)] +#![allow(missing_docs)] +#![allow(clippy::missing_safety_doc)] + +use crate::{Error, Result}; +use std::os::raw; + +pub const VHOST: raw::c_uint = 0xaf; +pub const VHOST_VRING_F_LOG: raw::c_uint = 0; +pub const VHOST_ACCESS_RO: raw::c_uint = 1; +pub const VHOST_ACCESS_WO: raw::c_uint = 2; +pub const VHOST_ACCESS_RW: raw::c_uint = 3; +pub const VHOST_IOTLB_MISS: raw::c_uint = 1; +pub const VHOST_IOTLB_UPDATE: raw::c_uint = 2; +pub const VHOST_IOTLB_INVALIDATE: raw::c_uint = 3; +pub const VHOST_IOTLB_ACCESS_FAIL: raw::c_uint = 4; +pub const VHOST_IOTLB_MSG: raw::c_uint = 1; +pub const VHOST_PAGE_SIZE: raw::c_uint = 4096; +pub const VHOST_VIRTIO: raw::c_uint = 175; +pub const VHOST_VRING_LITTLE_ENDIAN: raw::c_uint = 0; +pub const VHOST_VRING_BIG_ENDIAN: raw::c_uint = 1; +pub const VHOST_F_LOG_ALL: raw::c_uint = 26; +pub const VHOST_NET_F_VIRTIO_NET_HDR: raw::c_uint = 27; +pub const VHOST_SCSI_ABI_VERSION: raw::c_uint = 1; + +ioctl_ior_nr!(VHOST_GET_FEATURES, VHOST, 0x00, raw::c_ulonglong); +ioctl_iow_nr!(VHOST_SET_FEATURES, VHOST, 0x00, raw::c_ulonglong); +ioctl_io_nr!(VHOST_SET_OWNER, VHOST, 0x01); +ioctl_io_nr!(VHOST_RESET_OWNER, VHOST, 0x02); +ioctl_iow_nr!(VHOST_SET_MEM_TABLE, VHOST, 0x03, vhost_memory); +ioctl_iow_nr!(VHOST_SET_LOG_BASE, VHOST, 0x04, raw::c_ulonglong); +ioctl_iow_nr!(VHOST_SET_LOG_FD, VHOST, 0x07, raw::c_int); +ioctl_iow_nr!(VHOST_SET_VRING_NUM, VHOST, 0x10, vhost_vring_state); +ioctl_iow_nr!(VHOST_SET_VRING_ADDR, VHOST, 0x11, vhost_vring_addr); +ioctl_iow_nr!(VHOST_SET_VRING_BASE, VHOST, 0x12, vhost_vring_state); +ioctl_iowr_nr!(VHOST_GET_VRING_BASE, VHOST, 0x12, vhost_vring_state); +ioctl_iow_nr!(VHOST_SET_VRING_KICK, VHOST, 0x20, vhost_vring_file); +ioctl_iow_nr!(VHOST_SET_VRING_CALL, VHOST, 0x21, vhost_vring_file); +ioctl_iow_nr!(VHOST_SET_VRING_ERR, VHOST, 0x22, vhost_vring_file); +ioctl_iow_nr!(VHOST_NET_SET_BACKEND, VHOST, 0x30, vhost_vring_file); +ioctl_iow_nr!(VHOST_SCSI_SET_ENDPOINT, VHOST, 0x40, vhost_scsi_target); +ioctl_iow_nr!(VHOST_SCSI_CLEAR_ENDPOINT, VHOST, 0x41, vhost_scsi_target); +ioctl_iow_nr!(VHOST_SCSI_GET_ABI_VERSION, VHOST, 0x42, raw::c_int); +ioctl_iow_nr!(VHOST_SCSI_SET_EVENTS_MISSED, VHOST, 0x43, raw::c_uint); +ioctl_iow_nr!(VHOST_SCSI_GET_EVENTS_MISSED, VHOST, 0x44, raw::c_uint); +ioctl_iow_nr!(VHOST_VSOCK_SET_GUEST_CID, VHOST, 0x60, raw::c_ulonglong); +ioctl_iow_nr!(VHOST_VSOCK_SET_RUNNING, VHOST, 0x61, raw::c_int); + +#[repr(C)] +#[derive(Default)] +pub struct __IncompleteArrayField<T>(::std::marker::PhantomData<T>); + +impl<T> __IncompleteArrayField<T> { + #[inline] + pub fn new() -> Self { + __IncompleteArrayField(::std::marker::PhantomData) + } + + #[inline] + #[allow(clippy::trivially_copy_pass_by_ref)] + #[allow(clippy::useless_transmute)] + pub unsafe fn as_ptr(&self) -> *const T { + ::std::mem::transmute(self) + } + + #[inline] + #[allow(clippy::useless_transmute)] + pub unsafe fn as_mut_ptr(&mut self) -> *mut T { + ::std::mem::transmute(self) + } + + #[inline] + pub unsafe fn as_slice(&self, len: usize) -> &[T] { + ::std::slice::from_raw_parts(self.as_ptr(), len) + } + + #[inline] + pub unsafe fn as_mut_slice(&mut self, len: usize) -> &mut [T] { + ::std::slice::from_raw_parts_mut(self.as_mut_ptr(), len) + } +} + +impl<T> ::std::fmt::Debug for __IncompleteArrayField<T> { + fn fmt(&self, fmt: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + fmt.write_str("__IncompleteArrayField") + } +} + +impl<T> ::std::clone::Clone for __IncompleteArrayField<T> { + #[inline] + fn clone(&self) -> Self { + Self::new() + } +} + +impl<T> ::std::marker::Copy for __IncompleteArrayField<T> {} + +#[repr(C)] +#[derive(Debug, Default, Copy, Clone)] +pub struct vhost_vring_state { + pub index: raw::c_uint, + pub num: raw::c_uint, +} + +#[repr(C)] +#[derive(Debug, Default, Copy, Clone)] +pub struct vhost_vring_file { + pub index: raw::c_uint, + pub fd: raw::c_int, +} + +#[repr(C)] +#[derive(Debug, Default, Copy, Clone)] +pub struct vhost_vring_addr { + pub index: raw::c_uint, + pub flags: raw::c_uint, + pub desc_user_addr: raw::c_ulonglong, + pub used_user_addr: raw::c_ulonglong, + pub avail_user_addr: raw::c_ulonglong, + pub log_guest_addr: raw::c_ulonglong, +} + +#[repr(C)] +#[derive(Debug, Default, Copy, Clone)] +pub struct vhost_iotlb_msg { + pub iova: raw::c_ulonglong, + pub size: raw::c_ulonglong, + pub uaddr: raw::c_ulonglong, + pub perm: raw::c_uchar, + pub type_: raw::c_uchar, +} + +#[repr(C)] +#[derive(Copy, Clone)] +pub struct vhost_msg { + pub type_: raw::c_int, + pub __bindgen_anon_1: vhost_msg__bindgen_ty_1, +} + +impl Default for vhost_msg { + fn default() -> Self { + unsafe { ::std::mem::zeroed() } + } +} + +#[repr(C)] +#[derive(Copy, Clone)] +pub union vhost_msg__bindgen_ty_1 { + pub iotlb: vhost_iotlb_msg, + pub padding: [raw::c_uchar; 64usize], + _bindgen_union_align: [u64; 8usize], +} + +impl Default for vhost_msg__bindgen_ty_1 { + fn default() -> Self { + unsafe { ::std::mem::zeroed() } + } +} + +#[repr(C)] +#[derive(Debug, Default, Copy, Clone)] +pub struct vhost_memory_region { + pub guest_phys_addr: raw::c_ulonglong, + pub memory_size: raw::c_ulonglong, + pub userspace_addr: raw::c_ulonglong, + pub flags_padding: raw::c_ulonglong, +} + +#[repr(C)] +#[derive(Debug, Default, Clone)] +pub struct vhost_memory { + pub nregions: raw::c_uint, + pub padding: raw::c_uint, + pub regions: __IncompleteArrayField<vhost_memory_region>, + __force_alignment: [u64; 0], +} + +#[repr(C)] +#[derive(Copy, Clone)] +pub struct vhost_scsi_target { + pub abi_version: raw::c_int, + pub vhost_wwpn: [raw::c_char; 224usize], + pub vhost_tpgt: raw::c_ushort, + pub reserved: raw::c_ushort, +} + +impl Default for vhost_scsi_target { + fn default() -> Self { + unsafe { ::std::mem::zeroed() } + } +} + +/// Helper to support vhost::set_mem_table() +pub struct VhostMemory { + buf: Vec<vhost_memory>, +} + +impl VhostMemory { + // Limit number of regions to u16 to simplify error handling + pub fn new(entries: u16) -> Self { + let size = std::mem::size_of::<vhost_memory_region>() * entries as usize; + let count = (size + 2 * std::mem::size_of::<vhost_memory>() - 1) + / std::mem::size_of::<vhost_memory>(); + let mut buf: Vec<vhost_memory> = vec![Default::default(); count]; + buf[0].nregions = u32::from(entries); + VhostMemory { buf } + } + + pub fn as_ptr(&self) -> *const char { + &self.buf[0] as *const vhost_memory as *const char + } + + pub fn get_header(&self) -> &vhost_memory { + &self.buf[0] + } + + pub fn get_region(&self, index: u32) -> Option<&vhost_memory_region> { + if index >= self.buf[0].nregions { + return None; + } + // Safe because we have allocated enough space nregions + let regions = unsafe { self.buf[0].regions.as_slice(self.buf[0].nregions as usize) }; + Some(®ions[index as usize]) + } + + pub fn set_region(&mut self, index: u32, region: &vhost_memory_region) -> Result<()> { + if index >= self.buf[0].nregions { + return Err(Error::InvalidGuestMemory); + } + // Safe because we have allocated enough space nregions and checked the index. + let regions = unsafe { self.buf[0].regions.as_mut_slice(index as usize + 1) }; + regions[index as usize] = *region; + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn bindgen_test_layout_vhost_vring_state() { + assert_eq!( + ::std::mem::size_of::<vhost_vring_state>(), + 8usize, + concat!("Size of: ", stringify!(vhost_vring_state)) + ); + assert_eq!( + ::std::mem::align_of::<vhost_vring_state>(), + 4usize, + concat!("Alignment of ", stringify!(vhost_vring_state)) + ); + } + + #[test] + fn bindgen_test_layout_vhost_vring_file() { + assert_eq!( + ::std::mem::size_of::<vhost_vring_file>(), + 8usize, + concat!("Size of: ", stringify!(vhost_vring_file)) + ); + assert_eq!( + ::std::mem::align_of::<vhost_vring_file>(), + 4usize, + concat!("Alignment of ", stringify!(vhost_vring_file)) + ); + } + + #[test] + fn bindgen_test_layout_vhost_vring_addr() { + assert_eq!( + ::std::mem::size_of::<vhost_vring_addr>(), + 40usize, + concat!("Size of: ", stringify!(vhost_vring_addr)) + ); + assert_eq!( + ::std::mem::align_of::<vhost_vring_addr>(), + 8usize, + concat!("Alignment of ", stringify!(vhost_vring_addr)) + ); + } + + #[test] + fn bindgen_test_layout_vhost_msg__bindgen_ty_1() { + assert_eq!( + ::std::mem::size_of::<vhost_msg__bindgen_ty_1>(), + 64usize, + concat!("Size of: ", stringify!(vhost_msg__bindgen_ty_1)) + ); + assert_eq!( + ::std::mem::align_of::<vhost_msg__bindgen_ty_1>(), + 8usize, + concat!("Alignment of ", stringify!(vhost_msg__bindgen_ty_1)) + ); + } + + #[test] + fn bindgen_test_layout_vhost_msg() { + assert_eq!( + ::std::mem::size_of::<vhost_msg>(), + 72usize, + concat!("Size of: ", stringify!(vhost_msg)) + ); + assert_eq!( + ::std::mem::align_of::<vhost_msg>(), + 8usize, + concat!("Alignment of ", stringify!(vhost_msg)) + ); + } + + #[test] + fn bindgen_test_layout_vhost_memory_region() { + assert_eq!( + ::std::mem::size_of::<vhost_memory_region>(), + 32usize, + concat!("Size of: ", stringify!(vhost_memory_region)) + ); + assert_eq!( + ::std::mem::align_of::<vhost_memory_region>(), + 8usize, + concat!("Alignment of ", stringify!(vhost_memory_region)) + ); + } + + #[test] + fn bindgen_test_layout_vhost_memory() { + assert_eq!( + ::std::mem::size_of::<vhost_memory>(), + 8usize, + concat!("Size of: ", stringify!(vhost_memory)) + ); + assert_eq!( + ::std::mem::align_of::<vhost_memory>(), + 8usize, + concat!("Alignment of ", stringify!(vhost_memory)) + ); + } + + #[test] + fn bindgen_test_layout_vhost_iotlb_msg() { + assert_eq!( + ::std::mem::size_of::<vhost_iotlb_msg>(), + 32usize, + concat!("Size of: ", stringify!(vhost_iotlb_msg)) + ); + assert_eq!( + ::std::mem::align_of::<vhost_iotlb_msg>(), + 8usize, + concat!("Alignment of ", stringify!(vhost_iotlb_msg)) + ); + } + + #[test] + fn bindgen_test_layout_vhost_scsi_target() { + assert_eq!( + ::std::mem::size_of::<vhost_scsi_target>(), + 232usize, + concat!("Size of: ", stringify!(vhost_scsi_target)) + ); + assert_eq!( + ::std::mem::align_of::<vhost_scsi_target>(), + 4usize, + concat!("Alignment of ", stringify!(vhost_scsi_target)) + ); + } + + #[test] + fn test_vhostmemory() { + let mut obj = VhostMemory::new(2); + let region = vhost_memory_region { + guest_phys_addr: 0x1000u64, + memory_size: 0x2000u64, + userspace_addr: 0x300000u64, + flags_padding: 0u64, + }; + assert!(obj.get_region(2).is_none()); + + { + let header = obj.get_header(); + assert_eq!(header.nregions, 2u32); + } + { + assert!(obj.set_region(0, ®ion).is_ok()); + assert!(obj.set_region(1, ®ion).is_ok()); + assert!(obj.set_region(2, ®ion).is_err()); + } + + let region1 = obj.get_region(1).unwrap(); + assert_eq!(region1.guest_phys_addr, 0x1000u64); + assert_eq!(region1.memory_size, 0x2000u64); + assert_eq!(region1.userspace_addr, 0x300000u64); + } +} diff --git a/src/vhost_kern/vsock.rs b/src/vhost_kern/vsock.rs new file mode 100644 index 0000000..388d500 --- /dev/null +++ b/src/vhost_kern/vsock.rs @@ -0,0 +1,184 @@ +// Copyright (C) 2019 Alibaba Cloud. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 or BSD-3-Clause +// +// Copyright 2017 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE-BSD-Google file. + +//! Kernel-based vhost-vsock backend. + +use std::fs::{File, OpenOptions}; +use std::os::unix::fs::OpenOptionsExt; +use std::os::unix::io::{AsRawFd, RawFd}; + +use sys_util::ioctl_with_ref; +use vm_memory::GuestAddressSpace; + +use super::vhost_binding::{VHOST_VSOCK_SET_GUEST_CID, VHOST_VSOCK_SET_RUNNING}; +use super::{ioctl_result, Error, Result, VhostKernBackend}; +use crate::vsock::VhostVsock; + +const VHOST_PATH: &str = "/dev/vhost-vsock"; + +/// Handle for running VHOST_VSOCK ioctls. +pub struct Vsock<AS: GuestAddressSpace> { + fd: File, + mem: AS, +} + +impl<AS: GuestAddressSpace> Vsock<AS> { + /// Open a handle to a new VHOST-VSOCK instance. + pub fn new(mem: AS) -> Result<Self> { + Ok(Vsock { + fd: OpenOptions::new() + .read(true) + .write(true) + .custom_flags(libc::O_CLOEXEC | libc::O_NONBLOCK) + .open(VHOST_PATH) + .map_err(Error::VhostOpen)?, + mem, + }) + } + + fn set_running(&self, running: bool) -> Result<()> { + let on: ::std::os::raw::c_int = if running { 1 } else { 0 }; + let ret = unsafe { ioctl_with_ref(&self.fd, VHOST_VSOCK_SET_RUNNING(), &on) }; + ioctl_result(ret, ()) + } +} + +impl<AS: GuestAddressSpace> VhostVsock for Vsock<AS> { + fn set_guest_cid(&self, cid: u64) -> Result<()> { + let ret = unsafe { ioctl_with_ref(&self.fd, VHOST_VSOCK_SET_GUEST_CID(), &cid) }; + ioctl_result(ret, ()) + } + + fn start(&self) -> Result<()> { + self.set_running(true) + } + + fn stop(&self) -> Result<()> { + self.set_running(false) + } +} + +impl<AS: GuestAddressSpace> VhostKernBackend for Vsock<AS> { + type AS = AS; + + fn mem(&self) -> &Self::AS { + &self.mem + } +} + +impl<AS: GuestAddressSpace> AsRawFd for Vsock<AS> { + fn as_raw_fd(&self) -> RawFd { + self.fd.as_raw_fd() + } +} + +#[cfg(test)] +mod tests { + use sys_util::EventFd; + use vm_memory::{GuestAddress, GuestMemory, GuestMemoryMmap}; + + use super::*; + use crate::{VhostBackend, VhostUserMemoryRegionInfo, VringConfigData}; + + // Ignore all tests because /dev/vhost-vsock is unavailable in Chrome OS chroot. + #[test] + #[ignore] + fn test_vsock_new_device() { + let m = GuestMemoryMmap::from_ranges(&[(GuestAddress(0), 0x10_0000)]).unwrap(); + let vsock = Vsock::new(&m).unwrap(); + + assert!(vsock.as_raw_fd() >= 0); + assert!(vsock.mem().find_region(GuestAddress(0x100)).is_some()); + assert!(vsock.mem().find_region(GuestAddress(0x10_0000)).is_none()); + } + + #[test] + #[ignore] + fn test_vsock_is_valid() { + let m = GuestMemoryMmap::from_ranges(&[(GuestAddress(0), 0x10_0000)]).unwrap(); + let vsock = Vsock::new(&m).unwrap(); + + let mut config = VringConfigData { + queue_max_size: 32, + queue_size: 32, + flags: 0, + desc_table_addr: 0x1000, + used_ring_addr: 0x2000, + avail_ring_addr: 0x3000, + log_addr: None, + }; + assert_eq!(vsock.is_valid(&config), true); + + config.queue_size = 0; + assert_eq!(vsock.is_valid(&config), false); + config.queue_size = 31; + assert_eq!(vsock.is_valid(&config), false); + config.queue_size = 33; + assert_eq!(vsock.is_valid(&config), false); + } + + #[test] + #[ignore] + fn test_vsock_ioctls() { + let m = GuestMemoryMmap::from_ranges(&[(GuestAddress(0), 0x10_0000)]).unwrap(); + let vsock = Vsock::new(&m).unwrap(); + + let features = vsock.get_features().unwrap(); + vsock.set_features(features).unwrap(); + + vsock.set_owner().unwrap(); + + vsock.set_mem_table(&[]).unwrap_err(); + + /* + let region = VhostUserMemoryRegionInfo { + guest_phys_addr: 0x0, + memory_size: 0x10_0000, + userspace_addr: 0, + mmap_offset: 0, + mmap_handle: -1, + }; + vsock.set_mem_table(&[region]).unwrap_err(); + */ + + let region = VhostUserMemoryRegionInfo { + guest_phys_addr: 0x0, + memory_size: 0x10_0000, + userspace_addr: m.get_host_address(GuestAddress(0x0)).unwrap() as u64, + mmap_offset: 0, + mmap_handle: -1, + }; + vsock.set_mem_table(&[region]).unwrap(); + + vsock.set_log_base(0x4000, Some(1)).unwrap_err(); + vsock.set_log_base(0x4000, None).unwrap(); + + let eventfd = EventFd::new().unwrap(); + vsock.set_log_fd(eventfd.as_raw_fd()).unwrap(); + + vsock.set_vring_num(0, 32).unwrap(); + + let config = VringConfigData { + queue_max_size: 32, + queue_size: 32, + flags: 0, + desc_table_addr: 0x1000, + used_ring_addr: 0x2000, + avail_ring_addr: 0x3000, + log_addr: None, + }; + vsock.set_vring_addr(0, &config).unwrap(); + vsock.set_vring_base(0, 1).unwrap(); + vsock.set_vring_call(0, &eventfd).unwrap(); + vsock.set_vring_kick(0, &eventfd).unwrap(); + vsock.set_vring_err(0, &eventfd).unwrap(); + assert_eq!(vsock.get_vring_base(0).unwrap(), 1); + vsock.set_guest_cid(0xdead).unwrap(); + //vsock.start().unwrap(); + //vsock.stop().unwrap(); + } +} diff --git a/src/vhost_user/connection.rs b/src/vhost_user/connection.rs new file mode 100644 index 0000000..f92db45 --- /dev/null +++ b/src/vhost_user/connection.rs @@ -0,0 +1,858 @@ +// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Structs for Unix Domain Socket listener and endpoint. + +#![allow(dead_code)] + +use std::io::ErrorKind; +use std::marker::PhantomData; +use std::os::unix::io::{AsRawFd, RawFd}; +use std::os::unix::net::{UnixListener, UnixStream}; +use std::path::{Path, PathBuf}; +use std::{mem, slice}; + +use libc::{c_void, iovec}; +use sys_util::ScmSocket; + +use super::message::*; +use super::{Error, Result}; + +/// Unix domain socket listener for accepting incoming connections. +pub struct Listener { + fd: UnixListener, + path: PathBuf, +} + +impl Listener { + /// Create a unix domain socket listener. + /// + /// # Return: + /// * - the new Listener object on success. + /// * - SocketError: failed to create listener socket. + pub fn new<P: AsRef<Path>>(path: P, unlink: bool) -> Result<Self> { + if unlink { + let _ = std::fs::remove_file(&path); + } + let fd = UnixListener::bind(&path).map_err(Error::SocketError)?; + Ok(Listener { + fd, + path: path.as_ref().to_owned(), + }) + } + + /// Accept an incoming connection. + /// + /// # Return: + /// * - Some(UnixStream): new UnixStream object if new incoming connection is available. + /// * - None: no incoming connection available. + /// * - SocketError: errors from accept(). + pub fn accept(&self) -> Result<Option<UnixStream>> { + loop { + match self.fd.accept() { + Ok((socket, _addr)) => return Ok(Some(socket)), + Err(e) => { + match e.kind() { + // No incoming connection available. + ErrorKind::WouldBlock => return Ok(None), + // New connection closed by peer. + ErrorKind::ConnectionAborted => return Ok(None), + // Interrupted by signals, retry + ErrorKind::Interrupted => continue, + _ => return Err(Error::SocketError(e)), + } + } + } + } + } + + /// Change blocking status on the listener. + /// + /// # Return: + /// * - () on success. + /// * - SocketError: failure from set_nonblocking(). + pub fn set_nonblocking(&self, block: bool) -> Result<()> { + self.fd.set_nonblocking(block).map_err(Error::SocketError) + } +} + +impl AsRawFd for Listener { + fn as_raw_fd(&self) -> RawFd { + self.fd.as_raw_fd() + } +} + +impl Drop for Listener { + fn drop(&mut self) { + let _ = std::fs::remove_file(&self.path); + } +} + +/// Unix domain socket endpoint for vhost-user connection. +pub(super) struct Endpoint<R: Req> { + sock: UnixStream, + _r: PhantomData<R>, +} + +impl<R: Req> Endpoint<R> { + /// Create a new stream by connecting to server at `str`. + /// + /// # Return: + /// * - the new Endpoint object on success. + /// * - SocketConnect: failed to connect to peer. + pub fn connect<P: AsRef<Path>>(path: P) -> Result<Self> { + let sock = UnixStream::connect(path).map_err(Error::SocketConnect)?; + Ok(Self::from_stream(sock)) + } + + /// Create an endpoint from a stream object. + pub fn from_stream(sock: UnixStream) -> Self { + Endpoint { + sock, + _r: PhantomData, + } + } + + /// Sends bytes from scatter-gather vectors over the socket with optional attached file + /// descriptors. + /// + /// # Return: + /// * - number of bytes sent on success + /// * - SocketRetry: temporary error caused by signals or short of resources. + /// * - SocketBroken: the underline socket is broken. + /// * - SocketError: other socket related errors. + pub fn send_iovec(&mut self, iovs: &[&[u8]], fds: Option<&[RawFd]>) -> Result<usize> { + let rfds = match fds { + Some(rfds) => rfds, + _ => &[], + }; + self.sock.send_bufs_with_fds(iovs, rfds).map_err(Into::into) + } + + /// Sends all bytes from scatter-gather vectors over the socket with optional attached file + /// descriptors. Will loop until all data has been transfered. + /// + /// # Return: + /// * - number of bytes sent on success + /// * - SocketBroken: the underline socket is broken. + /// * - SocketError: other socket related errors. + pub fn send_iovec_all(&mut self, iovs: &[&[u8]], fds: Option<&[RawFd]>) -> Result<usize> { + let mut data_sent = 0; + let mut data_total = 0; + let iov_lens: Vec<usize> = iovs.iter().map(|iov| iov.len()).collect(); + for len in &iov_lens { + data_total += len; + } + + while (data_total - data_sent) > 0 { + let (nr_skip, offset) = get_sub_iovs_offset(&iov_lens, data_sent); + let iov = &iovs[nr_skip][offset..]; + + let data = &[&[iov], &iovs[(nr_skip + 1)..]].concat(); + let sfds = if data_sent == 0 { fds } else { None }; + + let sent = self.send_iovec(data, sfds); + match sent { + Ok(0) => return Ok(data_sent), + Ok(n) => data_sent += n, + Err(e) => match e { + Error::SocketRetry(_) => {} + _ => return Err(e), + }, + } + } + Ok(data_sent) + } + + /// Sends bytes from a slice over the socket with optional attached file descriptors. + /// + /// # Return: + /// * - number of bytes sent on success + /// * - SocketRetry: temporary error caused by signals or short of resources. + /// * - SocketBroken: the underline socket is broken. + /// * - SocketError: other socket related errors. + pub fn send_slice(&mut self, data: &[u8], fds: Option<&[RawFd]>) -> Result<usize> { + self.send_iovec(&[data], fds) + } + + /// Sends a header-only message with optional attached file descriptors. + /// + /// # Return: + /// * - number of bytes sent on success + /// * - SocketRetry: temporary error caused by signals or short of resources. + /// * - SocketBroken: the underline socket is broken. + /// * - SocketError: other socket related errors. + /// * - PartialMessage: received a partial message. + pub fn send_header( + &mut self, + hdr: &VhostUserMsgHeader<R>, + fds: Option<&[RawFd]>, + ) -> Result<()> { + // Safe because there can't be other mutable referance to hdr. + let iovs = unsafe { + [slice::from_raw_parts( + hdr as *const VhostUserMsgHeader<R> as *const u8, + mem::size_of::<VhostUserMsgHeader<R>>(), + )] + }; + let bytes = self.send_iovec_all(&iovs[..], fds)?; + if bytes != mem::size_of::<VhostUserMsgHeader<R>>() { + return Err(Error::PartialMessage); + } + Ok(()) + } + + /// Send a message with header and body. Optional file descriptors may be attached to + /// the message. + /// + /// # Return: + /// * - number of bytes sent on success + /// * - SocketRetry: temporary error caused by signals or short of resources. + /// * - SocketBroken: the underline socket is broken. + /// * - SocketError: other socket related errors. + /// * - PartialMessage: received a partial message. + pub fn send_message<T: Sized>( + &mut self, + hdr: &VhostUserMsgHeader<R>, + body: &T, + fds: Option<&[RawFd]>, + ) -> Result<()> { + if mem::size_of::<T>() > MAX_MSG_SIZE { + return Err(Error::OversizedMsg); + } + // Safe because there can't be other mutable referance to hdr and body. + let iovs = unsafe { + [ + slice::from_raw_parts( + hdr as *const VhostUserMsgHeader<R> as *const u8, + mem::size_of::<VhostUserMsgHeader<R>>(), + ), + slice::from_raw_parts(body as *const T as *const u8, mem::size_of::<T>()), + ] + }; + let bytes = self.send_iovec_all(&iovs[..], fds)?; + if bytes != mem::size_of::<VhostUserMsgHeader<R>>() + mem::size_of::<T>() { + return Err(Error::PartialMessage); + } + Ok(()) + } + + /// Send a message with header, body and payload. Optional file descriptors + /// may also be attached to the message. + /// + /// # Return: + /// * - number of bytes sent on success + /// * - SocketRetry: temporary error caused by signals or short of resources. + /// * - SocketBroken: the underline socket is broken. + /// * - SocketError: other socket related errors. + /// * - OversizedMsg: message size is too big. + /// * - PartialMessage: received a partial message. + /// * - IncorrectFds: wrong number of attached fds. + pub fn send_message_with_payload<T: Sized>( + &mut self, + hdr: &VhostUserMsgHeader<R>, + body: &T, + payload: &[u8], + fds: Option<&[RawFd]>, + ) -> Result<()> { + let len = payload.len(); + if mem::size_of::<T>() > MAX_MSG_SIZE { + return Err(Error::OversizedMsg); + } + if len > MAX_MSG_SIZE - mem::size_of::<T>() { + return Err(Error::OversizedMsg); + } + if let Some(fd_arr) = fds { + if fd_arr.len() > MAX_ATTACHED_FD_ENTRIES { + return Err(Error::IncorrectFds); + } + } + + // Safe because there can't be other mutable reference to hdr, body and payload. + let iovs = unsafe { + [ + slice::from_raw_parts( + hdr as *const VhostUserMsgHeader<R> as *const u8, + mem::size_of::<VhostUserMsgHeader<R>>(), + ), + slice::from_raw_parts(body as *const T as *const u8, mem::size_of::<T>()), + slice::from_raw_parts(payload.as_ptr() as *const u8, len), + ] + }; + let total = mem::size_of::<VhostUserMsgHeader<R>>() + mem::size_of::<T>() + len; + let len = self.send_iovec_all(&iovs, fds)?; + if len != total { + return Err(Error::PartialMessage); + } + Ok(()) + } + + /// Reads bytes from the socket into the given scatter/gather vectors. + /// + /// # Return: + /// * - (number of bytes received, buf) on success + /// * - SocketRetry: temporary error caused by signals or short of resources. + /// * - SocketBroken: the underline socket is broken. + /// * - SocketError: other socket related errors. + pub fn recv_data(&mut self, len: usize) -> Result<(usize, Vec<u8>)> { + let mut rbuf = vec![0u8; len]; + let (bytes, _) = self.sock.recv_with_fds(&mut rbuf[..], &mut [])?; + Ok((bytes, rbuf)) + } + + /// Reads bytes from the socket into the given scatter/gather vectors with optional attached + /// file descriptors. + /// + /// The underlying communication channel is a Unix domain socket in STREAM mode. It's a little + /// tricky to pass file descriptors through such a communication channel. Let's assume that a + /// sender sending a message with some file descriptors attached. To successfully receive those + /// attached file descriptors, the receiver must obey following rules: + /// 1) file descriptors are attached to a message. + /// 2) message(packet) boundaries must be respected on the receive side. + /// In other words, recvmsg() operations must not cross the packet boundary, otherwise the + /// attached file descriptors will get lost. + /// + /// # Return: + /// * - (number of bytes received, [received fds]) on success + /// * - SocketRetry: temporary error caused by signals or short of resources. + /// * - SocketBroken: the underline socket is broken. + /// * - SocketError: other socket related errors. + pub fn recv_into_iovec(&mut self, iovs: &mut [iovec]) -> Result<(usize, Option<Vec<RawFd>>)> { + let mut fd_array = vec![0; MAX_ATTACHED_FD_ENTRIES]; + let (bytes, fds) = self.sock.recv_iovecs_with_fds(iovs, &mut fd_array)?; + let rfds = match fds { + 0 => None, + n => { + let mut fds = Vec::with_capacity(n); + fds.extend_from_slice(&fd_array[0..n]); + Some(fds) + } + }; + + Ok((bytes, rfds)) + } + + /// Reads all bytes from the socket into the given scatter/gather vectors with optional + /// attached file descriptors. Will loop until all data has been transfered. + /// + /// The underlying communication channel is a Unix domain socket in STREAM mode. It's a little + /// tricky to pass file descriptors through such a communication channel. Let's assume that a + /// sender sending a message with some file descriptors attached. To successfully receive those + /// attached file descriptors, the receiver must obey following rules: + /// 1) file descriptors are attached to a message. + /// 2) message(packet) boundaries must be respected on the receive side. + /// In other words, recvmsg() operations must not cross the packet boundary, otherwise the + /// attached file descriptors will get lost. + /// + /// # Return: + /// * - (number of bytes received, [received fds]) on success + /// * - SocketBroken: the underline socket is broken. + /// * - SocketError: other socket related errors. + pub fn recv_into_iovec_all( + &mut self, + iovs: &mut [iovec], + ) -> Result<(usize, Option<Vec<RawFd>>)> { + let mut data_read = 0; + let mut data_total = 0; + let mut rfds = None; + let iov_lens: Vec<usize> = iovs.iter().map(|iov| iov.iov_len).collect(); + for len in &iov_lens { + data_total += len; + } + + while (data_total - data_read) > 0 { + let (nr_skip, offset) = get_sub_iovs_offset(&iov_lens, data_read); + let iov = &mut iovs[nr_skip]; + + let mut data = [ + &[iovec { + iov_base: (iov.iov_base as usize + offset) as *mut c_void, + iov_len: iov.iov_len - offset, + }], + &iovs[(nr_skip + 1)..], + ] + .concat(); + + let res = self.recv_into_iovec(&mut data); + match res { + Ok((0, _)) => return Ok((data_read, rfds)), + Ok((n, fds)) => { + if data_read == 0 { + rfds = fds; + } + data_read += n; + } + Err(e) => match e { + Error::SocketRetry(_) => {} + _ => return Err(e), + }, + } + } + Ok((data_read, rfds)) + } + + /// Reads bytes from the socket into a new buffer with optional attached + /// file descriptors. Received file descriptors are set close-on-exec. + /// + /// # Return: + /// * - (number of bytes received, buf, [received fds]) on success. + /// * - SocketRetry: temporary error caused by signals or short of resources. + /// * - SocketBroken: the underline socket is broken. + /// * - SocketError: other socket related errors. + pub fn recv_into_buf( + &mut self, + buf_size: usize, + ) -> Result<(usize, Vec<u8>, Option<Vec<RawFd>>)> { + let mut buf = vec![0u8; buf_size]; + let (bytes, rfds) = { + let mut iovs = [iovec { + iov_base: buf.as_mut_ptr() as *mut c_void, + iov_len: buf_size, + }]; + self.recv_into_iovec(&mut iovs)? + }; + Ok((bytes, buf, rfds)) + } + + /// Receive a header-only message with optional attached file descriptors. + /// Note, only the first MAX_ATTACHED_FD_ENTRIES file descriptors will be + /// accepted and all other file descriptor will be discard silently. + /// + /// # Return: + /// * - (message header, [received fds]) on success. + /// * - SocketRetry: temporary error caused by signals or short of resources. + /// * - SocketBroken: the underline socket is broken. + /// * - SocketError: other socket related errors. + /// * - PartialMessage: received a partial message. + /// * - InvalidMessage: received a invalid message. + pub fn recv_header(&mut self) -> Result<(VhostUserMsgHeader<R>, Option<Vec<RawFd>>)> { + let mut hdr = VhostUserMsgHeader::default(); + let mut iovs = [iovec { + iov_base: (&mut hdr as *mut VhostUserMsgHeader<R>) as *mut c_void, + iov_len: mem::size_of::<VhostUserMsgHeader<R>>(), + }]; + let (bytes, rfds) = self.recv_into_iovec_all(&mut iovs[..])?; + + if bytes != mem::size_of::<VhostUserMsgHeader<R>>() { + return Err(Error::PartialMessage); + } else if !hdr.is_valid() { + return Err(Error::InvalidMessage); + } + + Ok((hdr, rfds)) + } + + /// Receive a message with optional attached file descriptors. + /// Note, only the first MAX_ATTACHED_FD_ENTRIES file descriptors will be + /// accepted and all other file descriptor will be discard silently. + /// + /// # Return: + /// * - (message header, message body, [received fds]) on success. + /// * - SocketRetry: temporary error caused by signals or short of resources. + /// * - SocketBroken: the underline socket is broken. + /// * - SocketError: other socket related errors. + /// * - PartialMessage: received a partial message. + /// * - InvalidMessage: received a invalid message. + pub fn recv_body<T: Sized + Default + VhostUserMsgValidator>( + &mut self, + ) -> Result<(VhostUserMsgHeader<R>, T, Option<Vec<RawFd>>)> { + let mut hdr = VhostUserMsgHeader::default(); + let mut body: T = Default::default(); + let mut iovs = [ + iovec { + iov_base: (&mut hdr as *mut VhostUserMsgHeader<R>) as *mut c_void, + iov_len: mem::size_of::<VhostUserMsgHeader<R>>(), + }, + iovec { + iov_base: (&mut body as *mut T) as *mut c_void, + iov_len: mem::size_of::<T>(), + }, + ]; + let (bytes, rfds) = self.recv_into_iovec_all(&mut iovs[..])?; + + let total = mem::size_of::<VhostUserMsgHeader<R>>() + mem::size_of::<T>(); + if bytes != total { + return Err(Error::PartialMessage); + } else if !hdr.is_valid() || !body.is_valid() { + return Err(Error::InvalidMessage); + } + + Ok((hdr, body, rfds)) + } + + /// Receive a message with header and optional content. Callers need to + /// pre-allocate a big enough buffer to receive the message body and + /// optional payload. If there are attached file descriptor associated + /// with the message, the first MAX_ATTACHED_FD_ENTRIES file descriptors + /// will be accepted and all other file descriptor will be discard + /// silently. + /// + /// # Return: + /// * - (message header, message size, [received fds]) on success. + /// * - SocketRetry: temporary error caused by signals or short of resources. + /// * - SocketBroken: the underline socket is broken. + /// * - SocketError: other socket related errors. + /// * - PartialMessage: received a partial message. + /// * - InvalidMessage: received a invalid message. + pub fn recv_body_into_buf( + &mut self, + buf: &mut [u8], + ) -> Result<(VhostUserMsgHeader<R>, usize, Option<Vec<RawFd>>)> { + let mut hdr = VhostUserMsgHeader::default(); + let mut iovs = [ + iovec { + iov_base: (&mut hdr as *mut VhostUserMsgHeader<R>) as *mut c_void, + iov_len: mem::size_of::<VhostUserMsgHeader<R>>(), + }, + iovec { + iov_base: buf.as_mut_ptr() as *mut c_void, + iov_len: buf.len(), + }, + ]; + let (bytes, rfds) = self.recv_into_iovec_all(&mut iovs[..])?; + + if bytes < mem::size_of::<VhostUserMsgHeader<R>>() { + return Err(Error::PartialMessage); + } else if !hdr.is_valid() { + return Err(Error::InvalidMessage); + } + + Ok((hdr, bytes - mem::size_of::<VhostUserMsgHeader<R>>(), rfds)) + } + + /// Receive a message with optional payload and attached file descriptors. + /// Note, only the first MAX_ATTACHED_FD_ENTRIES file descriptors will be + /// accepted and all other file descriptor will be discard silently. + /// + /// # Return: + /// * - (message header, message body, size of payload, [received fds]) on success. + /// * - SocketRetry: temporary error caused by signals or short of resources. + /// * - SocketBroken: the underline socket is broken. + /// * - SocketError: other socket related errors. + /// * - PartialMessage: received a partial message. + /// * - InvalidMessage: received a invalid message. + #[cfg_attr(feature = "cargo-clippy", allow(clippy::type_complexity))] + pub fn recv_payload_into_buf<T: Sized + Default + VhostUserMsgValidator>( + &mut self, + buf: &mut [u8], + ) -> Result<(VhostUserMsgHeader<R>, T, usize, Option<Vec<RawFd>>)> { + let mut hdr = VhostUserMsgHeader::default(); + let mut body: T = Default::default(); + let mut iovs = [ + iovec { + iov_base: (&mut hdr as *mut VhostUserMsgHeader<R>) as *mut c_void, + iov_len: mem::size_of::<VhostUserMsgHeader<R>>(), + }, + iovec { + iov_base: (&mut body as *mut T) as *mut c_void, + iov_len: mem::size_of::<T>(), + }, + iovec { + iov_base: buf.as_mut_ptr() as *mut c_void, + iov_len: buf.len(), + }, + ]; + let (bytes, rfds) = self.recv_into_iovec_all(&mut iovs[..])?; + + let total = mem::size_of::<VhostUserMsgHeader<R>>() + mem::size_of::<T>(); + if bytes < total { + return Err(Error::PartialMessage); + } else if !hdr.is_valid() || !body.is_valid() { + return Err(Error::InvalidMessage); + } + + Ok((hdr, body, bytes - total, rfds)) + } + + /// Close all raw file descriptors. + pub fn close_rfds(rfds: Option<Vec<RawFd>>) { + if let Some(fds) = rfds { + for fd in fds { + // safe because the rawfds are valid and we don't care about the result. + let _ = unsafe { libc::close(fd) }; + } + } + } +} + +impl<T: Req> AsRawFd for Endpoint<T> { + fn as_raw_fd(&self) -> RawFd { + self.sock.as_raw_fd() + } +} + +// Given a slice of sizes and the `skip_size`, return the offset of `skip_size` in the slice. +// For example: +// let iov_lens = vec![4, 4, 5]; +// let size = 6; +// assert_eq!(get_sub_iovs_offset(&iov_len, size), (1, 2)); +fn get_sub_iovs_offset(iov_lens: &[usize], skip_size: usize) -> (usize, usize) { + let mut size = skip_size; + let mut nr_skip = 0; + + for len in iov_lens { + if size >= *len { + size -= *len; + nr_skip += 1; + } else { + break; + } + } + (nr_skip, size) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::fs::File; + use std::io::{Read, Seek, SeekFrom, Write}; + use std::os::unix::io::FromRawFd; + use tempfile::{tempfile, Builder, TempDir}; + + fn temp_dir() -> TempDir { + Builder::new().prefix("/tmp/vhost_test").tempdir().unwrap() + } + + #[test] + fn create_listener() { + let dir = temp_dir(); + let mut path = dir.path().to_owned(); + path.push("sock"); + let listener = Listener::new(&path, true).unwrap(); + + assert!(listener.as_raw_fd() > 0); + } + + #[test] + fn accept_connection() { + let dir = temp_dir(); + let mut path = dir.path().to_owned(); + path.push("sock"); + let listener = Listener::new(&path, true).unwrap(); + listener.set_nonblocking(true).unwrap(); + + // accept on a fd without incoming connection + let conn = listener.accept().unwrap(); + assert!(conn.is_none()); + } + + #[test] + fn send_data() { + let dir = temp_dir(); + let mut path = dir.path().to_owned(); + path.push("sock"); + let listener = Listener::new(&path, true).unwrap(); + listener.set_nonblocking(true).unwrap(); + let mut master = Endpoint::<MasterReq>::connect(&path).unwrap(); + let sock = listener.accept().unwrap().unwrap(); + let mut slave = Endpoint::<MasterReq>::from_stream(sock); + + let buf1 = vec![0x1, 0x2, 0x3, 0x4]; + let mut len = master.send_slice(&buf1[..], None).unwrap(); + assert_eq!(len, 4); + let (bytes, buf2, _) = slave.recv_into_buf(0x1000).unwrap(); + assert_eq!(bytes, 4); + assert_eq!(&buf1[..], &buf2[..bytes]); + + len = master.send_slice(&buf1[..], None).unwrap(); + assert_eq!(len, 4); + let (bytes, buf2, _) = slave.recv_into_buf(0x2).unwrap(); + assert_eq!(bytes, 2); + assert_eq!(&buf1[..2], &buf2[..]); + let (bytes, buf2, _) = slave.recv_into_buf(0x2).unwrap(); + assert_eq!(bytes, 2); + assert_eq!(&buf1[2..], &buf2[..]); + } + + #[test] + fn send_fd() { + let dir = temp_dir(); + let mut path = dir.path().to_owned(); + path.push("sock"); + let listener = Listener::new(&path, true).unwrap(); + listener.set_nonblocking(true).unwrap(); + let mut master = Endpoint::<MasterReq>::connect(&path).unwrap(); + let sock = listener.accept().unwrap().unwrap(); + let mut slave = Endpoint::<MasterReq>::from_stream(sock); + + let mut fd = tempfile().unwrap(); + write!(fd, "test").unwrap(); + + // Normal case for sending/receiving file descriptors + let buf1 = vec![0x1, 0x2, 0x3, 0x4]; + let len = master + .send_slice(&buf1[..], Some(&[fd.as_raw_fd()])) + .unwrap(); + assert_eq!(len, 4); + + let (bytes, buf2, rfds) = slave.recv_into_buf(4).unwrap(); + assert_eq!(bytes, 4); + assert_eq!(&buf1[..], &buf2[..]); + assert!(rfds.is_some()); + let fds = rfds.unwrap(); + { + assert_eq!(fds.len(), 1); + let mut file = unsafe { File::from_raw_fd(fds[0]) }; + let mut content = String::new(); + file.seek(SeekFrom::Start(0)).unwrap(); + file.read_to_string(&mut content).unwrap(); + assert_eq!(content, "test"); + } + + // Following communication pattern should work: + // Sending side: data(header, body) with fds + // Receiving side: data(header) with fds, data(body) + let len = master + .send_slice( + &buf1[..], + Some(&[fd.as_raw_fd(), fd.as_raw_fd(), fd.as_raw_fd()]), + ) + .unwrap(); + assert_eq!(len, 4); + + let (bytes, buf2, rfds) = slave.recv_into_buf(0x2).unwrap(); + assert_eq!(bytes, 2); + assert_eq!(&buf1[..2], &buf2[..]); + assert!(rfds.is_some()); + let fds = rfds.unwrap(); + { + assert_eq!(fds.len(), 3); + let mut file = unsafe { File::from_raw_fd(fds[1]) }; + let mut content = String::new(); + file.seek(SeekFrom::Start(0)).unwrap(); + file.read_to_string(&mut content).unwrap(); + assert_eq!(content, "test"); + } + let (bytes, buf2, rfds) = slave.recv_into_buf(0x2).unwrap(); + assert_eq!(bytes, 2); + assert_eq!(&buf1[2..], &buf2[..]); + assert!(rfds.is_none()); + + // Following communication pattern should not work: + // Sending side: data(header, body) with fds + // Receiving side: data(header), data(body) with fds + let len = master + .send_slice( + &buf1[..], + Some(&[fd.as_raw_fd(), fd.as_raw_fd(), fd.as_raw_fd()]), + ) + .unwrap(); + assert_eq!(len, 4); + + let (bytes, buf4) = slave.recv_data(2).unwrap(); + assert_eq!(bytes, 2); + assert_eq!(&buf1[..2], &buf4[..]); + let (bytes, buf2, rfds) = slave.recv_into_buf(0x2).unwrap(); + assert_eq!(bytes, 2); + assert_eq!(&buf1[2..], &buf2[..]); + assert!(rfds.is_none()); + + // Following communication pattern should work: + // Sending side: data, data with fds + // Receiving side: data, data with fds + let len = master.send_slice(&buf1[..], None).unwrap(); + assert_eq!(len, 4); + let len = master + .send_slice( + &buf1[..], + Some(&[fd.as_raw_fd(), fd.as_raw_fd(), fd.as_raw_fd()]), + ) + .unwrap(); + assert_eq!(len, 4); + + let (bytes, buf2, rfds) = slave.recv_into_buf(0x4).unwrap(); + assert_eq!(bytes, 4); + assert_eq!(&buf1[..], &buf2[..]); + assert!(rfds.is_none()); + + let (bytes, buf2, rfds) = slave.recv_into_buf(0x2).unwrap(); + assert_eq!(bytes, 2); + assert_eq!(&buf1[..2], &buf2[..]); + assert!(rfds.is_some()); + let fds = rfds.unwrap(); + { + assert_eq!(fds.len(), 3); + let mut file = unsafe { File::from_raw_fd(fds[1]) }; + let mut content = String::new(); + file.seek(SeekFrom::Start(0)).unwrap(); + file.read_to_string(&mut content).unwrap(); + assert_eq!(content, "test"); + } + let (bytes, buf2, rfds) = slave.recv_into_buf(0x2).unwrap(); + assert_eq!(bytes, 2); + assert_eq!(&buf1[2..], &buf2[..]); + assert!(rfds.is_none()); + + // Following communication pattern should not work: + // Sending side: data1, data2 with fds + // Receiving side: data + partial of data2, left of data2 with fds + let len = master.send_slice(&buf1[..], None).unwrap(); + assert_eq!(len, 4); + let len = master + .send_slice( + &buf1[..], + Some(&[fd.as_raw_fd(), fd.as_raw_fd(), fd.as_raw_fd()]), + ) + .unwrap(); + assert_eq!(len, 4); + + let (bytes, _) = slave.recv_data(5).unwrap(); + assert_eq!(bytes, 5); + + let (bytes, _, rfds) = slave.recv_into_buf(0x4).unwrap(); + assert_eq!(bytes, 3); + assert!(rfds.is_none()); + + // If the target fd array is too small, extra file descriptors will get lost. + let len = master + .send_slice( + &buf1[..], + Some(&[fd.as_raw_fd(), fd.as_raw_fd(), fd.as_raw_fd()]), + ) + .unwrap(); + assert_eq!(len, 4); + + let (bytes, _, rfds) = slave.recv_into_buf(0x4).unwrap(); + assert_eq!(bytes, 4); + assert!(rfds.is_some()); + + Endpoint::<MasterReq>::close_rfds(rfds); + Endpoint::<MasterReq>::close_rfds(None); + } + + #[test] + fn send_recv() { + let dir = temp_dir(); + let mut path = dir.path().to_owned(); + path.push("sock"); + let listener = Listener::new(&path, true).unwrap(); + listener.set_nonblocking(true).unwrap(); + let mut master = Endpoint::<MasterReq>::connect(&path).unwrap(); + let sock = listener.accept().unwrap().unwrap(); + let mut slave = Endpoint::<MasterReq>::from_stream(sock); + + let mut hdr1 = + VhostUserMsgHeader::new(MasterReq::GET_FEATURES, 0, mem::size_of::<u64>() as u32); + hdr1.set_need_reply(true); + let features1 = 0x1u64; + master.send_message(&hdr1, &features1, None).unwrap(); + + let mut features2 = 0u64; + let slice = unsafe { + slice::from_raw_parts_mut( + (&mut features2 as *mut u64) as *mut u8, + mem::size_of::<u64>(), + ) + }; + let (hdr2, bytes, rfds) = slave.recv_body_into_buf(slice).unwrap(); + assert_eq!(hdr1, hdr2); + assert_eq!(bytes, 8); + assert_eq!(features1, features2); + assert!(rfds.is_none()); + + master.send_header(&hdr1, None).unwrap(); + let (hdr2, rfds) = slave.recv_header().unwrap(); + assert_eq!(hdr1, hdr2); + assert!(rfds.is_none()); + } +} diff --git a/src/vhost_user/dummy_slave.rs b/src/vhost_user/dummy_slave.rs new file mode 100644 index 0000000..b2b83d2 --- /dev/null +++ b/src/vhost_user/dummy_slave.rs @@ -0,0 +1,259 @@ +// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use std::os::unix::io::RawFd; + +use super::message::*; +use super::*; + +pub const MAX_QUEUE_NUM: usize = 2; +pub const MAX_VRING_NUM: usize = 256; +pub const MAX_MEM_SLOTS: usize = 32; +pub const VIRTIO_FEATURES: u64 = 0x40000003; + +#[derive(Default)] +pub struct DummySlaveReqHandler { + pub owned: bool, + pub features_acked: bool, + pub acked_features: u64, + pub acked_protocol_features: u64, + pub queue_num: usize, + pub vring_num: [u32; MAX_QUEUE_NUM], + pub vring_base: [u32; MAX_QUEUE_NUM], + pub call_fd: [Option<RawFd>; MAX_QUEUE_NUM], + pub kick_fd: [Option<RawFd>; MAX_QUEUE_NUM], + pub err_fd: [Option<RawFd>; MAX_QUEUE_NUM], + pub vring_started: [bool; MAX_QUEUE_NUM], + pub vring_enabled: [bool; MAX_QUEUE_NUM], +} + +impl DummySlaveReqHandler { + pub fn new() -> Self { + DummySlaveReqHandler { + queue_num: MAX_QUEUE_NUM, + ..Default::default() + } + } +} + +impl VhostUserSlaveReqHandlerMut for DummySlaveReqHandler { + fn set_owner(&mut self) -> Result<()> { + if self.owned { + return Err(Error::InvalidOperation); + } + self.owned = true; + Ok(()) + } + + fn reset_owner(&mut self) -> Result<()> { + self.owned = false; + self.features_acked = false; + self.acked_features = 0; + self.acked_protocol_features = 0; + Ok(()) + } + + fn get_features(&mut self) -> Result<u64> { + Ok(VIRTIO_FEATURES) + } + + fn set_features(&mut self, features: u64) -> Result<()> { + if !self.owned || self.features_acked { + return Err(Error::InvalidOperation); + } else if (features & !VIRTIO_FEATURES) != 0 { + return Err(Error::InvalidParam); + } + + self.acked_features = features; + self.features_acked = true; + + // If VHOST_USER_F_PROTOCOL_FEATURES has not been negotiated, + // the ring is initialized in an enabled state. + // If VHOST_USER_F_PROTOCOL_FEATURES has been negotiated, + // the ring is initialized in a disabled state. Client must not + // pass data to/from the backend until ring is enabled by + // VHOST_USER_SET_VRING_ENABLE with parameter 1, or after it has + // been disabled by VHOST_USER_SET_VRING_ENABLE with parameter 0. + let vring_enabled = + self.acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() == 0; + for enabled in &mut self.vring_enabled { + *enabled = vring_enabled; + } + + Ok(()) + } + + fn set_mem_table(&mut self, _ctx: &[VhostUserMemoryRegion], _fds: &[RawFd]) -> Result<()> { + Ok(()) + } + + fn set_vring_num(&mut self, index: u32, num: u32) -> Result<()> { + if index as usize >= self.queue_num || num == 0 || num as usize > MAX_VRING_NUM { + return Err(Error::InvalidParam); + } + self.vring_num[index as usize] = num; + Ok(()) + } + + fn set_vring_addr( + &mut self, + index: u32, + _flags: VhostUserVringAddrFlags, + _descriptor: u64, + _used: u64, + _available: u64, + _log: u64, + ) -> Result<()> { + if index as usize >= self.queue_num { + return Err(Error::InvalidParam); + } + Ok(()) + } + + fn set_vring_base(&mut self, index: u32, base: u32) -> Result<()> { + if index as usize >= self.queue_num || base as usize >= MAX_VRING_NUM { + return Err(Error::InvalidParam); + } + self.vring_base[index as usize] = base; + Ok(()) + } + + fn get_vring_base(&mut self, index: u32) -> Result<VhostUserVringState> { + if index as usize >= self.queue_num { + return Err(Error::InvalidParam); + } + // Quotation from vhost-user spec: + // Client must start ring upon receiving a kick (that is, detecting + // that file descriptor is readable) on the descriptor specified by + // VHOST_USER_SET_VRING_KICK, and stop ring upon receiving + // VHOST_USER_GET_VRING_BASE. + self.vring_started[index as usize] = false; + Ok(VhostUserVringState::new( + index, + self.vring_base[index as usize], + )) + } + + fn set_vring_kick(&mut self, index: u8, fd: Option<RawFd>) -> Result<()> { + if index as usize >= self.queue_num || index as usize > self.queue_num { + return Err(Error::InvalidParam); + } + if self.kick_fd[index as usize].is_some() { + // Close file descriptor set by previous operations. + let _ = unsafe { libc::close(self.kick_fd[index as usize].unwrap()) }; + } + self.kick_fd[index as usize] = fd; + + // Quotation from vhost-user spec: + // Client must start ring upon receiving a kick (that is, detecting + // that file descriptor is readable) on the descriptor specified by + // VHOST_USER_SET_VRING_KICK, and stop ring upon receiving + // VHOST_USER_GET_VRING_BASE. + // + // So we should add fd to event monitor(select, poll, epoll) here. + self.vring_started[index as usize] = true; + Ok(()) + } + + fn set_vring_call(&mut self, index: u8, fd: Option<RawFd>) -> Result<()> { + if index as usize >= self.queue_num || index as usize > self.queue_num { + return Err(Error::InvalidParam); + } + if self.call_fd[index as usize].is_some() { + // Close file descriptor set by previous operations. + let _ = unsafe { libc::close(self.call_fd[index as usize].unwrap()) }; + } + self.call_fd[index as usize] = fd; + Ok(()) + } + + fn set_vring_err(&mut self, index: u8, fd: Option<RawFd>) -> Result<()> { + if index as usize >= self.queue_num || index as usize > self.queue_num { + return Err(Error::InvalidParam); + } + if self.err_fd[index as usize].is_some() { + // Close file descriptor set by previous operations. + let _ = unsafe { libc::close(self.err_fd[index as usize].unwrap()) }; + } + self.err_fd[index as usize] = fd; + Ok(()) + } + + fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures> { + Ok(VhostUserProtocolFeatures::all()) + } + + fn set_protocol_features(&mut self, features: u64) -> Result<()> { + // Note: slave that reported VHOST_USER_F_PROTOCOL_FEATURES must + // support this message even before VHOST_USER_SET_FEATURES was + // called. + // What happens if the master calls set_features() with + // VHOST_USER_F_PROTOCOL_FEATURES cleared after calling this + // interface? + self.acked_protocol_features = features; + Ok(()) + } + + fn get_queue_num(&mut self) -> Result<u64> { + Ok(MAX_QUEUE_NUM as u64) + } + + fn set_vring_enable(&mut self, index: u32, enable: bool) -> Result<()> { + // This request should be handled only when VHOST_USER_F_PROTOCOL_FEATURES + // has been negotiated. + if self.acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() == 0 { + return Err(Error::InvalidOperation); + } else if index as usize >= self.queue_num || index as usize > self.queue_num { + return Err(Error::InvalidParam); + } + + // Slave must not pass data to/from the backend until ring is + // enabled by VHOST_USER_SET_VRING_ENABLE with parameter 1, + // or after it has been disabled by VHOST_USER_SET_VRING_ENABLE + // with parameter 0. + self.vring_enabled[index as usize] = enable; + Ok(()) + } + + fn get_config( + &mut self, + offset: u32, + size: u32, + _flags: VhostUserConfigFlags, + ) -> Result<Vec<u8>> { + if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 { + return Err(Error::InvalidOperation); + } else if !(VHOST_USER_CONFIG_OFFSET..VHOST_USER_CONFIG_SIZE).contains(&offset) + || size > VHOST_USER_CONFIG_SIZE - VHOST_USER_CONFIG_OFFSET + || size + offset > VHOST_USER_CONFIG_SIZE + { + return Err(Error::InvalidParam); + } + Ok(vec![0xa5; size as usize]) + } + + fn set_config(&mut self, offset: u32, buf: &[u8], _flags: VhostUserConfigFlags) -> Result<()> { + let size = buf.len() as u32; + if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 { + return Err(Error::InvalidOperation); + } else if !(VHOST_USER_CONFIG_OFFSET..VHOST_USER_CONFIG_SIZE).contains(&offset) + || size > VHOST_USER_CONFIG_SIZE - VHOST_USER_CONFIG_OFFSET + || size + offset > VHOST_USER_CONFIG_SIZE + { + return Err(Error::InvalidParam); + } + Ok(()) + } + + fn get_max_mem_slots(&mut self) -> Result<u64> { + Ok(MAX_MEM_SLOTS as u64) + } + + fn add_mem_region(&mut self, _region: &VhostUserSingleMemoryRegion, _fd: RawFd) -> Result<()> { + Ok(()) + } + + fn remove_mem_region(&mut self, _region: &VhostUserSingleMemoryRegion) -> Result<()> { + Ok(()) + } +} diff --git a/src/vhost_user/master.rs b/src/vhost_user/master.rs new file mode 100644 index 0000000..cc79871 --- /dev/null +++ b/src/vhost_user/master.rs @@ -0,0 +1,1071 @@ +// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Traits and Struct for vhost-user master. + +use std::mem; +use std::os::unix::io::{AsRawFd, RawFd}; +use std::os::unix::net::UnixStream; +use std::path::Path; +use std::sync::{Arc, Mutex, MutexGuard}; + +use sys_util::EventFd; + +use super::connection::Endpoint; +use super::message::*; +use super::{Error as VhostUserError, Result as VhostUserResult}; +use crate::backend::{VhostBackend, VhostUserMemoryRegionInfo, VringConfigData}; +use crate::{Error, Result}; + +/// Trait for vhost-user master to provide extra methods not covered by the VhostBackend yet. +pub trait VhostUserMaster: VhostBackend { + /// Get the protocol feature bitmask from the underlying vhost implementation. + fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures>; + + /// Enable protocol features in the underlying vhost implementation. + fn set_protocol_features(&mut self, features: VhostUserProtocolFeatures) -> Result<()>; + + /// Query how many queues the backend supports. + fn get_queue_num(&mut self) -> Result<u64>; + + /// Signal slave to enable or disable corresponding vring. + /// + /// Slave must not pass data to/from the backend until ring is enabled by + /// VHOST_USER_SET_VRING_ENABLE with parameter 1, or after it has been + /// disabled by VHOST_USER_SET_VRING_ENABLE with parameter 0. + fn set_vring_enable(&mut self, queue_index: usize, enable: bool) -> Result<()>; + + /// Fetch the contents of the virtio device configuration space. + fn get_config( + &mut self, + offset: u32, + size: u32, + flags: VhostUserConfigFlags, + buf: &[u8], + ) -> Result<(VhostUserConfig, VhostUserConfigPayload)>; + + /// Change the virtio device configuration space. It also can be used for live migration on the + /// destination host to set readonly configuration space fields. + fn set_config(&mut self, offset: u32, flags: VhostUserConfigFlags, buf: &[u8]) -> Result<()>; + + /// Setup slave communication channel. + fn set_slave_request_fd(&mut self, fd: RawFd) -> Result<()>; + + /// Query the maximum amount of memory slots supported by the backend. + fn get_max_mem_slots(&mut self) -> Result<u64>; + + /// Add a new guest memory mapping for vhost to use. + fn add_mem_region(&mut self, region: &VhostUserMemoryRegionInfo) -> Result<()>; + + /// Remove a guest memory mapping from vhost. + fn remove_mem_region(&mut self, region: &VhostUserMemoryRegionInfo) -> Result<()>; +} + +fn error_code<T>(err: VhostUserError) -> Result<T> { + Err(Error::VhostUserProtocol(err)) +} + +/// Struct for the vhost-user master endpoint. +#[derive(Clone)] +pub struct Master { + node: Arc<Mutex<MasterInternal>>, +} + +impl Master { + /// Create a new instance. + fn new(ep: Endpoint<MasterReq>, max_queue_num: u64) -> Self { + Master { + node: Arc::new(Mutex::new(MasterInternal { + main_sock: ep, + virtio_features: 0, + acked_virtio_features: 0, + protocol_features: 0, + acked_protocol_features: 0, + protocol_features_ready: false, + max_queue_num, + error: None, + })), + } + } + + fn node(&self) -> MutexGuard<MasterInternal> { + self.node.lock().unwrap() + } + + /// Create a new instance from a Unix stream socket. + pub fn from_stream(sock: UnixStream, max_queue_num: u64) -> Self { + Self::new(Endpoint::<MasterReq>::from_stream(sock), max_queue_num) + } + + /// Create a new vhost-user master endpoint. + /// + /// Will retry as the backend may not be ready to accept the connection. + /// + /// # Arguments + /// * `path` - path of Unix domain socket listener to connect to + pub fn connect<P: AsRef<Path>>(path: P, max_queue_num: u64) -> Result<Self> { + let mut retry_count = 5; + let endpoint = loop { + match Endpoint::<MasterReq>::connect(&path) { + Ok(endpoint) => break Ok(endpoint), + Err(e) => match &e { + VhostUserError::SocketConnect(why) => { + if why.kind() == std::io::ErrorKind::ConnectionRefused && retry_count > 0 { + std::thread::sleep(std::time::Duration::from_millis(100)); + retry_count -= 1; + continue; + } else { + break Err(e); + } + } + _ => break Err(e), + }, + } + }?; + + Ok(Self::new(endpoint, max_queue_num)) + } +} + +impl VhostBackend for Master { + /// Get from the underlying vhost implementation the feature bitmask. + fn get_features(&self) -> Result<u64> { + let mut node = self.node(); + let hdr = node.send_request_header(MasterReq::GET_FEATURES, None)?; + let val = node.recv_reply::<VhostUserU64>(&hdr)?; + node.virtio_features = val.value; + Ok(node.virtio_features) + } + + /// Enable features in the underlying vhost implementation using a bitmask. + fn set_features(&self, features: u64) -> Result<()> { + let mut node = self.node(); + let val = VhostUserU64::new(features); + let _ = node.send_request_with_body(MasterReq::SET_FEATURES, &val, None)?; + // Don't wait for ACK here because the protocol feature negotiation process hasn't been + // completed yet. + node.acked_virtio_features = features & node.virtio_features; + Ok(()) + } + + /// Set the current Master as an owner of the session. + fn set_owner(&self) -> Result<()> { + // We unwrap() the return value to assert that we are not expecting threads to ever fail + // while holding the lock. + let mut node = self.node(); + let _ = node.send_request_header(MasterReq::SET_OWNER, None)?; + // Don't wait for ACK here because the protocol feature negotiation process hasn't been + // completed yet. + Ok(()) + } + + fn reset_owner(&self) -> Result<()> { + let mut node = self.node(); + let _ = node.send_request_header(MasterReq::RESET_OWNER, None)?; + // Don't wait for ACK here because the protocol feature negotiation process hasn't been + // completed yet. + Ok(()) + } + + /// Set the memory map regions on the slave so it can translate the vring + /// addresses. In the ancillary data there is an array of file descriptors + fn set_mem_table(&self, regions: &[VhostUserMemoryRegionInfo]) -> Result<()> { + if regions.is_empty() || regions.len() > MAX_ATTACHED_FD_ENTRIES { + return error_code(VhostUserError::InvalidParam); + } + + let mut ctx = VhostUserMemoryContext::new(); + for region in regions.iter() { + if region.memory_size == 0 || region.mmap_handle < 0 { + return error_code(VhostUserError::InvalidParam); + } + let reg = VhostUserMemoryRegion { + guest_phys_addr: region.guest_phys_addr, + memory_size: region.memory_size, + user_addr: region.userspace_addr, + mmap_offset: region.mmap_offset, + }; + ctx.append(®, region.mmap_handle); + } + + let mut node = self.node(); + let body = VhostUserMemory::new(ctx.regions.len() as u32); + let (_, payload, _) = unsafe { ctx.regions.align_to::<u8>() }; + let hdr = node.send_request_with_payload( + MasterReq::SET_MEM_TABLE, + &body, + payload, + Some(ctx.fds.as_slice()), + )?; + node.wait_for_ack(&hdr).map_err(|e| e.into()) + } + + // Clippy doesn't seem to know that if let with && is still experimental + #[allow(clippy::unnecessary_unwrap)] + fn set_log_base(&self, base: u64, fd: Option<RawFd>) -> Result<()> { + let mut node = self.node(); + let val = VhostUserU64::new(base); + + if node.acked_protocol_features & VhostUserProtocolFeatures::LOG_SHMFD.bits() != 0 + && fd.is_some() + { + let fds = [fd.unwrap()]; + let _ = node.send_request_with_body(MasterReq::SET_LOG_BASE, &val, Some(&fds))?; + } else { + let _ = node.send_request_with_body(MasterReq::SET_LOG_BASE, &val, None)?; + } + Ok(()) + } + + fn set_log_fd(&self, fd: RawFd) -> Result<()> { + let mut node = self.node(); + let fds = [fd]; + node.send_request_header(MasterReq::SET_LOG_FD, Some(&fds))?; + Ok(()) + } + + /// Set the size of the queue. + fn set_vring_num(&self, queue_index: usize, num: u16) -> Result<()> { + let mut node = self.node(); + if queue_index as u64 >= node.max_queue_num { + return error_code(VhostUserError::InvalidParam); + } + + let val = VhostUserVringState::new(queue_index as u32, num.into()); + let hdr = node.send_request_with_body(MasterReq::SET_VRING_NUM, &val, None)?; + node.wait_for_ack(&hdr).map_err(|e| e.into()) + } + + /// Sets the addresses of the different aspects of the vring. + fn set_vring_addr(&self, queue_index: usize, config_data: &VringConfigData) -> Result<()> { + let mut node = self.node(); + if queue_index as u64 >= node.max_queue_num + || config_data.flags & !(VhostUserVringAddrFlags::all().bits()) != 0 + { + return error_code(VhostUserError::InvalidParam); + } + + let val = VhostUserVringAddr::from_config_data(queue_index as u32, config_data); + let hdr = node.send_request_with_body(MasterReq::SET_VRING_ADDR, &val, None)?; + node.wait_for_ack(&hdr).map_err(|e| e.into()) + } + + /// Sets the base offset in the available vring. + fn set_vring_base(&self, queue_index: usize, base: u16) -> Result<()> { + let mut node = self.node(); + if queue_index as u64 >= node.max_queue_num { + return error_code(VhostUserError::InvalidParam); + } + + let val = VhostUserVringState::new(queue_index as u32, base.into()); + let hdr = node.send_request_with_body(MasterReq::SET_VRING_BASE, &val, None)?; + node.wait_for_ack(&hdr).map_err(|e| e.into()) + } + + fn get_vring_base(&self, queue_index: usize) -> Result<u32> { + let mut node = self.node(); + if queue_index as u64 >= node.max_queue_num { + return error_code(VhostUserError::InvalidParam); + } + + let req = VhostUserVringState::new(queue_index as u32, 0); + let hdr = node.send_request_with_body(MasterReq::GET_VRING_BASE, &req, None)?; + let reply = node.recv_reply::<VhostUserVringState>(&hdr)?; + Ok(reply.num) + } + + /// Set the event file descriptor to signal when buffers are used. + /// Bits (0-7) of the payload contain the vring index. Bit 8 is the invalid FD flag. This flag + /// is set when there is no file descriptor in the ancillary data. This signals that polling + /// will be used instead of waiting for the call. + fn set_vring_call(&self, queue_index: usize, fd: &EventFd) -> Result<()> { + let mut node = self.node(); + if queue_index as u64 >= node.max_queue_num { + return error_code(VhostUserError::InvalidParam); + } + node.send_fd_for_vring(MasterReq::SET_VRING_CALL, queue_index, fd.as_raw_fd())?; + Ok(()) + } + + /// Set the event file descriptor for adding buffers to the vring. + /// Bits (0-7) of the payload contain the vring index. Bit 8 is the invalid FD flag. This flag + /// is set when there is no file descriptor in the ancillary data. This signals that polling + /// should be used instead of waiting for a kick. + fn set_vring_kick(&self, queue_index: usize, fd: &EventFd) -> Result<()> { + let mut node = self.node(); + if queue_index as u64 >= node.max_queue_num { + return error_code(VhostUserError::InvalidParam); + } + node.send_fd_for_vring(MasterReq::SET_VRING_KICK, queue_index, fd.as_raw_fd())?; + Ok(()) + } + + /// Set the event file descriptor to signal when error occurs. + /// Bits (0-7) of the payload contain the vring index. Bit 8 is the invalid FD flag. This flag + /// is set when there is no file descriptor in the ancillary data. + fn set_vring_err(&self, queue_index: usize, fd: &EventFd) -> Result<()> { + let mut node = self.node(); + if queue_index as u64 >= node.max_queue_num { + return error_code(VhostUserError::InvalidParam); + } + node.send_fd_for_vring(MasterReq::SET_VRING_ERR, queue_index, fd.as_raw_fd())?; + Ok(()) + } +} + +impl VhostUserMaster for Master { + fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures> { + let mut node = self.node(); + let flag = VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits(); + if node.virtio_features & flag == 0 || node.acked_virtio_features & flag == 0 { + return error_code(VhostUserError::InvalidOperation); + } + let hdr = node.send_request_header(MasterReq::GET_PROTOCOL_FEATURES, None)?; + let val = node.recv_reply::<VhostUserU64>(&hdr)?; + node.protocol_features = val.value; + // Should we support forward compatibility? + // If so just mask out unrecognized flags instead of return errors. + match VhostUserProtocolFeatures::from_bits(node.protocol_features) { + Some(val) => Ok(val), + None => error_code(VhostUserError::InvalidMessage), + } + } + + fn set_protocol_features(&mut self, features: VhostUserProtocolFeatures) -> Result<()> { + let mut node = self.node(); + let flag = VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits(); + if node.virtio_features & flag == 0 || node.acked_virtio_features & flag == 0 { + return error_code(VhostUserError::InvalidOperation); + } + let val = VhostUserU64::new(features.bits()); + let _ = node.send_request_with_body(MasterReq::SET_PROTOCOL_FEATURES, &val, None)?; + // Don't wait for ACK here because the protocol feature negotiation process hasn't been + // completed yet. + node.acked_protocol_features = features.bits(); + node.protocol_features_ready = true; + Ok(()) + } + + fn get_queue_num(&mut self) -> Result<u64> { + let mut node = self.node(); + if !node.is_feature_mq_available() { + return error_code(VhostUserError::InvalidOperation); + } + + let hdr = node.send_request_header(MasterReq::GET_QUEUE_NUM, None)?; + let val = node.recv_reply::<VhostUserU64>(&hdr)?; + if val.value > VHOST_USER_MAX_VRINGS { + return error_code(VhostUserError::InvalidMessage); + } + node.max_queue_num = val.value; + Ok(node.max_queue_num) + } + + fn set_vring_enable(&mut self, queue_index: usize, enable: bool) -> Result<()> { + let mut node = self.node(); + // set_vring_enable() is supported only when PROTOCOL_FEATURES has been enabled. + if node.acked_virtio_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() == 0 { + return error_code(VhostUserError::InvalidOperation); + } else if queue_index as u64 >= node.max_queue_num { + return error_code(VhostUserError::InvalidParam); + } + + let flag = if enable { 1 } else { 0 }; + let val = VhostUserVringState::new(queue_index as u32, flag); + let hdr = node.send_request_with_body(MasterReq::SET_VRING_ENABLE, &val, None)?; + node.wait_for_ack(&hdr).map_err(|e| e.into()) + } + + fn get_config( + &mut self, + offset: u32, + size: u32, + flags: VhostUserConfigFlags, + buf: &[u8], + ) -> Result<(VhostUserConfig, VhostUserConfigPayload)> { + let body = VhostUserConfig::new(offset, size, flags); + if !body.is_valid() { + return error_code(VhostUserError::InvalidParam); + } + + let mut node = self.node(); + // depends on VhostUserProtocolFeatures::CONFIG + if node.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 { + return error_code(VhostUserError::InvalidOperation); + } + + // vhost-user spec states that: + // "Master payload: virtio device config space" + // "Slave payload: virtio device config space" + let hdr = node.send_request_with_payload(MasterReq::GET_CONFIG, &body, buf, None)?; + let (body_reply, buf_reply, rfds) = + node.recv_reply_with_payload::<VhostUserConfig>(&hdr)?; + if rfds.is_some() { + Endpoint::<MasterReq>::close_rfds(rfds); + return error_code(VhostUserError::InvalidMessage); + } else if body_reply.size == 0 { + return error_code(VhostUserError::SlaveInternalError); + } else if body_reply.size != body.size + || body_reply.size as usize != buf.len() + || body_reply.offset != body.offset + { + return error_code(VhostUserError::InvalidMessage); + } + + Ok((body_reply, buf_reply)) + } + + fn set_config(&mut self, offset: u32, flags: VhostUserConfigFlags, buf: &[u8]) -> Result<()> { + if buf.len() > MAX_MSG_SIZE { + return error_code(VhostUserError::InvalidParam); + } + let body = VhostUserConfig::new(offset, buf.len() as u32, flags); + if !body.is_valid() { + return error_code(VhostUserError::InvalidParam); + } + + let mut node = self.node(); + // depends on VhostUserProtocolFeatures::CONFIG + if node.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 { + return error_code(VhostUserError::InvalidOperation); + } + + let hdr = node.send_request_with_payload(MasterReq::SET_CONFIG, &body, buf, None)?; + node.wait_for_ack(&hdr).map_err(|e| e.into()) + } + + fn set_slave_request_fd(&mut self, fd: RawFd) -> Result<()> { + let mut node = self.node(); + if node.acked_protocol_features & VhostUserProtocolFeatures::SLAVE_REQ.bits() == 0 { + return error_code(VhostUserError::InvalidOperation); + } + + let fds = [fd]; + node.send_request_header(MasterReq::SET_SLAVE_REQ_FD, Some(&fds))?; + Ok(()) + } + + fn get_max_mem_slots(&mut self) -> Result<u64> { + let mut node = self.node(); + if node.acked_protocol_features & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits() == 0 + { + return error_code(VhostUserError::InvalidOperation); + } + + let hdr = node.send_request_header(MasterReq::GET_MAX_MEM_SLOTS, None)?; + let val = node.recv_reply::<VhostUserU64>(&hdr)?; + + Ok(val.value) + } + + fn add_mem_region(&mut self, region: &VhostUserMemoryRegionInfo) -> Result<()> { + let mut node = self.node(); + if node.acked_protocol_features & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits() == 0 + { + return error_code(VhostUserError::InvalidOperation); + } + if region.memory_size == 0 || region.mmap_handle < 0 { + return error_code(VhostUserError::InvalidParam); + } + + let body = VhostUserSingleMemoryRegion::new( + region.guest_phys_addr, + region.memory_size, + region.userspace_addr, + region.mmap_offset, + ); + let fds = [region.mmap_handle]; + let hdr = node.send_request_with_body(MasterReq::ADD_MEM_REG, &body, Some(&fds))?; + node.wait_for_ack(&hdr).map_err(|e| e.into()) + } + + fn remove_mem_region(&mut self, region: &VhostUserMemoryRegionInfo) -> Result<()> { + let mut node = self.node(); + if node.acked_protocol_features & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits() == 0 + { + return error_code(VhostUserError::InvalidOperation); + } + if region.memory_size == 0 { + return error_code(VhostUserError::InvalidParam); + } + + let body = VhostUserSingleMemoryRegion::new( + region.guest_phys_addr, + region.memory_size, + region.userspace_addr, + region.mmap_offset, + ); + let hdr = node.send_request_with_body(MasterReq::REM_MEM_REG, &body, None)?; + node.wait_for_ack(&hdr).map_err(|e| e.into()) + } +} + +impl AsRawFd for Master { + fn as_raw_fd(&self) -> RawFd { + let node = self.node(); + node.main_sock.as_raw_fd() + } +} + +/// Context object to pass guest memory configuration to VhostUserMaster::set_mem_table(). +struct VhostUserMemoryContext { + regions: VhostUserMemoryPayload, + fds: Vec<RawFd>, +} + +impl VhostUserMemoryContext { + /// Create a context object. + pub fn new() -> Self { + VhostUserMemoryContext { + regions: VhostUserMemoryPayload::new(), + fds: Vec::new(), + } + } + + /// Append a user memory region and corresponding RawFd into the context object. + pub fn append(&mut self, region: &VhostUserMemoryRegion, fd: RawFd) { + self.regions.push(*region); + self.fds.push(fd); + } +} + +struct MasterInternal { + // Used to send requests to the slave. + main_sock: Endpoint<MasterReq>, + // Cached virtio features from the slave. + virtio_features: u64, + // Cached acked virtio features from the driver. + acked_virtio_features: u64, + // Cached vhost-user protocol features from the slave. + protocol_features: u64, + // Cached vhost-user protocol features. + acked_protocol_features: u64, + // Cached vhost-user protocol features are ready to use. + protocol_features_ready: bool, + // Cached maxinum number of queues supported from the slave. + max_queue_num: u64, + // Internal flag to mark failure state. + error: Option<i32>, +} + +impl MasterInternal { + fn send_request_header( + &mut self, + code: MasterReq, + fds: Option<&[RawFd]>, + ) -> VhostUserResult<VhostUserMsgHeader<MasterReq>> { + self.check_state()?; + let hdr = Self::new_request_header(code, 0); + self.main_sock.send_header(&hdr, fds)?; + Ok(hdr) + } + + fn send_request_with_body<T: Sized>( + &mut self, + code: MasterReq, + msg: &T, + fds: Option<&[RawFd]>, + ) -> VhostUserResult<VhostUserMsgHeader<MasterReq>> { + if mem::size_of::<T>() > MAX_MSG_SIZE { + return Err(VhostUserError::InvalidParam); + } + self.check_state()?; + + let hdr = Self::new_request_header(code, mem::size_of::<T>() as u32); + self.main_sock.send_message(&hdr, msg, fds)?; + Ok(hdr) + } + + fn send_request_with_payload<T: Sized>( + &mut self, + code: MasterReq, + msg: &T, + payload: &[u8], + fds: Option<&[RawFd]>, + ) -> VhostUserResult<VhostUserMsgHeader<MasterReq>> { + let len = mem::size_of::<T>() + payload.len(); + if len > MAX_MSG_SIZE { + return Err(VhostUserError::InvalidParam); + } + if let Some(ref fd_arr) = fds { + if fd_arr.len() > MAX_ATTACHED_FD_ENTRIES { + return Err(VhostUserError::InvalidParam); + } + } + self.check_state()?; + + let hdr = Self::new_request_header(code, len as u32); + self.main_sock + .send_message_with_payload(&hdr, msg, payload, fds)?; + Ok(hdr) + } + + fn send_fd_for_vring( + &mut self, + code: MasterReq, + queue_index: usize, + fd: RawFd, + ) -> VhostUserResult<VhostUserMsgHeader<MasterReq>> { + if queue_index as u64 >= self.max_queue_num { + return Err(VhostUserError::InvalidParam); + } + self.check_state()?; + + // Bits (0-7) of the payload contain the vring index. Bit 8 is the invalid FD flag. + // This flag is set when there is no file descriptor in the ancillary data. This signals + // that polling will be used instead of waiting for the call. + let msg = VhostUserU64::new(queue_index as u64); + let hdr = Self::new_request_header(code, mem::size_of::<VhostUserU64>() as u32); + self.main_sock.send_message(&hdr, &msg, Some(&[fd]))?; + Ok(hdr) + } + + fn recv_reply<T: Sized + Default + VhostUserMsgValidator>( + &mut self, + hdr: &VhostUserMsgHeader<MasterReq>, + ) -> VhostUserResult<T> { + if mem::size_of::<T>() > MAX_MSG_SIZE || hdr.is_reply() { + return Err(VhostUserError::InvalidParam); + } + self.check_state()?; + + let (reply, body, rfds) = self.main_sock.recv_body::<T>()?; + if !reply.is_reply_for(&hdr) || rfds.is_some() || !body.is_valid() { + Endpoint::<MasterReq>::close_rfds(rfds); + return Err(VhostUserError::InvalidMessage); + } + Ok(body) + } + + fn recv_reply_with_payload<T: Sized + Default + VhostUserMsgValidator>( + &mut self, + hdr: &VhostUserMsgHeader<MasterReq>, + ) -> VhostUserResult<(T, Vec<u8>, Option<Vec<RawFd>>)> { + if mem::size_of::<T>() > MAX_MSG_SIZE + || hdr.get_size() as usize <= mem::size_of::<T>() + || hdr.get_size() as usize > MAX_MSG_SIZE + || hdr.is_reply() + { + return Err(VhostUserError::InvalidParam); + } + self.check_state()?; + + let mut buf: Vec<u8> = vec![0; hdr.get_size() as usize - mem::size_of::<T>()]; + let (reply, body, bytes, rfds) = self.main_sock.recv_payload_into_buf::<T>(&mut buf)?; + if !reply.is_reply_for(hdr) + || reply.get_size() as usize != mem::size_of::<T>() + bytes + || rfds.is_some() + || !body.is_valid() + { + Endpoint::<MasterReq>::close_rfds(rfds); + return Err(VhostUserError::InvalidMessage); + } else if bytes != buf.len() { + return Err(VhostUserError::InvalidMessage); + } + Ok((body, buf, rfds)) + } + + fn wait_for_ack(&mut self, hdr: &VhostUserMsgHeader<MasterReq>) -> VhostUserResult<()> { + if self.acked_protocol_features & VhostUserProtocolFeatures::REPLY_ACK.bits() == 0 + || !hdr.is_need_reply() + { + return Ok(()); + } + self.check_state()?; + + let (reply, body, rfds) = self.main_sock.recv_body::<VhostUserU64>()?; + if !reply.is_reply_for(&hdr) || rfds.is_some() || !body.is_valid() { + Endpoint::<MasterReq>::close_rfds(rfds); + return Err(VhostUserError::InvalidMessage); + } + if body.value != 0 { + return Err(VhostUserError::SlaveInternalError); + } + Ok(()) + } + + fn is_feature_mq_available(&self) -> bool { + self.acked_protocol_features & VhostUserProtocolFeatures::MQ.bits() != 0 + } + + fn check_state(&self) -> VhostUserResult<()> { + match self.error { + Some(e) => Err(VhostUserError::SocketBroken( + std::io::Error::from_raw_os_error(e), + )), + None => Ok(()), + } + } + + #[inline] + fn new_request_header(request: MasterReq, size: u32) -> VhostUserMsgHeader<MasterReq> { + // TODO: handle NEED_REPLY flag + VhostUserMsgHeader::new(request, 0x1, size) + } +} + +#[cfg(test)] +mod tests { + use super::super::connection::Listener; + use super::*; + use tempfile::{Builder, TempDir}; + + fn temp_dir() -> TempDir { + Builder::new().prefix("/tmp/vhost_test").tempdir().unwrap() + } + + fn create_pair<P: AsRef<Path>>(path: P) -> (Master, Endpoint<MasterReq>) { + let listener = Listener::new(&path, true).unwrap(); + listener.set_nonblocking(true).unwrap(); + let master = Master::connect(path, 2).unwrap(); + let slave = listener.accept().unwrap().unwrap(); + (master, Endpoint::from_stream(slave)) + } + + #[test] + fn create_master() { + let dir = temp_dir(); + let mut path = dir.path().to_owned(); + path.push("sock"); + let listener = Listener::new(&path, true).unwrap(); + listener.set_nonblocking(true).unwrap(); + + let master = Master::connect(&path, 1).unwrap(); + let mut slave = Endpoint::<MasterReq>::from_stream(listener.accept().unwrap().unwrap()); + + assert!(master.as_raw_fd() > 0); + // Send two messages continuously + master.set_owner().unwrap(); + master.reset_owner().unwrap(); + + let (hdr, rfds) = slave.recv_header().unwrap(); + assert_eq!(hdr.get_code(), MasterReq::SET_OWNER); + assert_eq!(hdr.get_size(), 0); + assert_eq!(hdr.get_version(), 0x1); + assert!(rfds.is_none()); + + let (hdr, rfds) = slave.recv_header().unwrap(); + assert_eq!(hdr.get_code(), MasterReq::RESET_OWNER); + assert_eq!(hdr.get_size(), 0); + assert_eq!(hdr.get_version(), 0x1); + assert!(rfds.is_none()); + } + + #[test] + fn test_create_failure() { + let dir = temp_dir(); + let mut path = dir.path().to_owned(); + path.push("sock"); + let _ = Listener::new(&path, true).unwrap(); + let _ = Listener::new(&path, false).is_err(); + assert!(Master::connect(&path, 1).is_err()); + + let listener = Listener::new(&path, true).unwrap(); + assert!(Listener::new(&path, false).is_err()); + listener.set_nonblocking(true).unwrap(); + + let _master = Master::connect(&path, 1).unwrap(); + let _slave = listener.accept().unwrap().unwrap(); + } + + #[test] + fn test_features() { + let dir = temp_dir(); + let mut path = dir.path().to_owned(); + path.push("sock"); + let (master, mut peer) = create_pair(&path); + + master.set_owner().unwrap(); + let (hdr, rfds) = peer.recv_header().unwrap(); + assert_eq!(hdr.get_code(), MasterReq::SET_OWNER); + assert_eq!(hdr.get_size(), 0); + assert_eq!(hdr.get_version(), 0x1); + assert!(rfds.is_none()); + + let hdr = VhostUserMsgHeader::new(MasterReq::GET_FEATURES, 0x4, 8); + let msg = VhostUserU64::new(0x15); + peer.send_message(&hdr, &msg, None).unwrap(); + let features = master.get_features().unwrap(); + assert_eq!(features, 0x15u64); + let (_hdr, rfds) = peer.recv_header().unwrap(); + assert!(rfds.is_none()); + + let hdr = VhostUserMsgHeader::new(MasterReq::SET_FEATURES, 0x4, 8); + let msg = VhostUserU64::new(0x15); + peer.send_message(&hdr, &msg, None).unwrap(); + master.set_features(0x15).unwrap(); + let (_hdr, msg, rfds) = peer.recv_body::<VhostUserU64>().unwrap(); + assert!(rfds.is_none()); + let val = msg.value; + assert_eq!(val, 0x15); + + let hdr = VhostUserMsgHeader::new(MasterReq::GET_FEATURES, 0x4, 8); + let msg = 0x15u32; + peer.send_message(&hdr, &msg, None).unwrap(); + assert!(master.get_features().is_err()); + } + + #[test] + fn test_protocol_features() { + let dir = temp_dir(); + let mut path = dir.path().to_owned(); + path.push("sock"); + let (mut master, mut peer) = create_pair(&path); + + master.set_owner().unwrap(); + let (hdr, rfds) = peer.recv_header().unwrap(); + assert_eq!(hdr.get_code(), MasterReq::SET_OWNER); + assert!(rfds.is_none()); + + assert!(master.get_protocol_features().is_err()); + assert!(master + .set_protocol_features(VhostUserProtocolFeatures::all()) + .is_err()); + + let vfeatures = 0x15 | VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits(); + let hdr = VhostUserMsgHeader::new(MasterReq::GET_FEATURES, 0x4, 8); + let msg = VhostUserU64::new(vfeatures); + peer.send_message(&hdr, &msg, None).unwrap(); + let features = master.get_features().unwrap(); + assert_eq!(features, vfeatures); + let (_hdr, rfds) = peer.recv_header().unwrap(); + assert!(rfds.is_none()); + + master.set_features(vfeatures).unwrap(); + let (_hdr, msg, rfds) = peer.recv_body::<VhostUserU64>().unwrap(); + assert!(rfds.is_none()); + let val = msg.value; + assert_eq!(val, vfeatures); + + let pfeatures = VhostUserProtocolFeatures::all(); + let hdr = VhostUserMsgHeader::new(MasterReq::GET_PROTOCOL_FEATURES, 0x4, 8); + let msg = VhostUserU64::new(pfeatures.bits()); + peer.send_message(&hdr, &msg, None).unwrap(); + let features = master.get_protocol_features().unwrap(); + assert_eq!(features, pfeatures); + let (_hdr, rfds) = peer.recv_header().unwrap(); + assert!(rfds.is_none()); + + master.set_protocol_features(pfeatures).unwrap(); + let (_hdr, msg, rfds) = peer.recv_body::<VhostUserU64>().unwrap(); + assert!(rfds.is_none()); + let val = msg.value; + assert_eq!(val, pfeatures.bits()); + + let hdr = VhostUserMsgHeader::new(MasterReq::SET_PROTOCOL_FEATURES, 0x4, 8); + let msg = VhostUserU64::new(pfeatures.bits()); + peer.send_message(&hdr, &msg, None).unwrap(); + assert!(master.get_protocol_features().is_err()); + } + + #[test] + fn test_master_set_config_negative() { + let dir = temp_dir(); + let mut path = dir.path().to_owned(); + path.push("sock"); + let (mut master, _peer) = create_pair(&path); + let buf = vec![0x0; MAX_MSG_SIZE + 1]; + + master + .set_config(0x100, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .unwrap_err(); + + { + let mut node = master.node(); + node.virtio_features = 0xffff_ffff; + node.acked_virtio_features = 0xffff_ffff; + node.protocol_features = 0xffff_ffff; + node.acked_protocol_features = 0xffff_ffff; + } + + master + .set_config(0x100, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .unwrap(); + master + .set_config(0x0, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .unwrap_err(); + master + .set_config(0x1000, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .unwrap_err(); + master + .set_config( + 0x100, + unsafe { VhostUserConfigFlags::from_bits_unchecked(0xffff_ffff) }, + &buf[0..4], + ) + .unwrap_err(); + master + .set_config(0x100, VhostUserConfigFlags::WRITABLE, &buf) + .unwrap_err(); + master + .set_config(0x100, VhostUserConfigFlags::WRITABLE, &[]) + .unwrap_err(); + } + + fn create_pair2() -> (Master, Endpoint<MasterReq>) { + let dir = temp_dir(); + let mut path = dir.path().to_owned(); + path.push("sock"); + let (master, peer) = create_pair(&path); + + { + let mut node = master.node(); + node.virtio_features = 0xffff_ffff; + node.acked_virtio_features = 0xffff_ffff; + node.protocol_features = 0xffff_ffff; + node.acked_protocol_features = 0xffff_ffff; + } + + (master, peer) + } + + #[test] + fn test_master_get_config_negative0() { + let (mut master, mut peer) = create_pair2(); + let buf = vec![0x0; MAX_MSG_SIZE + 1]; + + let mut hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16); + let msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty()); + peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None) + .unwrap(); + assert!(master + .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .is_ok()); + + hdr.set_code(MasterReq::GET_FEATURES); + peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None) + .unwrap(); + assert!(master + .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .is_err()); + hdr.set_code(MasterReq::GET_CONFIG); + } + + #[test] + fn test_master_get_config_negative1() { + let (mut master, mut peer) = create_pair2(); + let buf = vec![0x0; MAX_MSG_SIZE + 1]; + + let mut hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16); + let msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty()); + peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None) + .unwrap(); + assert!(master + .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .is_ok()); + + hdr.set_reply(false); + peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None) + .unwrap(); + assert!(master + .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .is_err()); + } + + #[test] + fn test_master_get_config_negative2() { + let (mut master, mut peer) = create_pair2(); + let buf = vec![0x0; MAX_MSG_SIZE + 1]; + + let hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16); + let msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty()); + peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None) + .unwrap(); + assert!(master + .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .is_ok()); + } + + #[test] + fn test_master_get_config_negative3() { + let (mut master, mut peer) = create_pair2(); + let buf = vec![0x0; MAX_MSG_SIZE + 1]; + + let hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16); + let mut msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty()); + peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None) + .unwrap(); + assert!(master + .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .is_ok()); + + msg.offset = 0; + peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None) + .unwrap(); + assert!(master + .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .is_err()); + } + + #[test] + fn test_master_get_config_negative4() { + let (mut master, mut peer) = create_pair2(); + let buf = vec![0x0; MAX_MSG_SIZE + 1]; + + let hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16); + let mut msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty()); + peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None) + .unwrap(); + assert!(master + .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .is_ok()); + + msg.offset = 0x101; + peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None) + .unwrap(); + assert!(master + .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .is_err()); + } + + #[test] + fn test_master_get_config_negative5() { + let (mut master, mut peer) = create_pair2(); + let buf = vec![0x0; MAX_MSG_SIZE + 1]; + + let hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16); + let mut msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty()); + peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None) + .unwrap(); + assert!(master + .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .is_ok()); + + msg.offset = (MAX_MSG_SIZE + 1) as u32; + peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None) + .unwrap(); + assert!(master + .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .is_err()); + } + + #[test] + fn test_master_get_config_negative6() { + let (mut master, mut peer) = create_pair2(); + let buf = vec![0x0; MAX_MSG_SIZE + 1]; + + let hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16); + let mut msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty()); + peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None) + .unwrap(); + assert!(master + .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .is_ok()); + + msg.size = 6; + peer.send_message_with_payload(&hdr, &msg, &buf[0..6], None) + .unwrap(); + assert!(master + .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4]) + .is_err()); + } + + #[test] + fn test_maset_set_mem_table_failure() { + let (master, _peer) = create_pair2(); + + master.set_mem_table(&[]).unwrap_err(); + let tables = vec![VhostUserMemoryRegionInfo::default(); MAX_ATTACHED_FD_ENTRIES + 1]; + master.set_mem_table(&tables).unwrap_err(); + } +} diff --git a/src/vhost_user/master_req_handler.rs b/src/vhost_user/master_req_handler.rs new file mode 100644 index 0000000..8cba188 --- /dev/null +++ b/src/vhost_user/master_req_handler.rs @@ -0,0 +1,477 @@ +// Copyright (C) 2019-2021 Alibaba Cloud. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use std::mem; +use std::os::unix::io::{AsRawFd, RawFd}; +use std::os::unix::net::UnixStream; +use std::sync::{Arc, Mutex}; + +use super::connection::Endpoint; +use super::message::*; +use super::{Error, HandlerResult, Result}; + +/// Define services provided by masters for the slave communication channel. +/// +/// The vhost-user specification defines a slave communication channel, by which slaves could +/// request services from masters. The [VhostUserMasterReqHandler] trait defines services provided +/// by masters, and it's used both on the master side and slave side. +/// - on the slave side, a stub forwarder implementing [VhostUserMasterReqHandler] will proxy +/// service requests to masters. The [SlaveFsCacheReq] is an example stub forwarder. +/// - on the master side, the [MasterReqHandler] will forward service requests to a handler +/// implementing [VhostUserMasterReqHandler]. +/// +/// The [VhostUserMasterReqHandler] trait is design with interior mutability to improve performance +/// for multi-threading. +/// +/// [VhostUserMasterReqHandler]: trait.VhostUserMasterReqHandler.html +/// [MasterReqHandler]: struct.MasterReqHandler.html +/// [SlaveFsCacheReq]: struct.SlaveFsCacheReq.html +pub trait VhostUserMasterReqHandler { + /// Handle device configuration change notifications. + fn handle_config_change(&self) -> HandlerResult<u64> { + Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) + } + + /// Handle virtio-fs map file requests. + fn fs_slave_map(&self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> { + // Safe because we have just received the rawfd from kernel. + unsafe { libc::close(fd) }; + Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) + } + + /// Handle virtio-fs unmap file requests. + fn fs_slave_unmap(&self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> { + Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) + } + + /// Handle virtio-fs sync file requests. + fn fs_slave_sync(&self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> { + Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) + } + + /// Handle virtio-fs file IO requests. + fn fs_slave_io(&self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> { + // Safe because we have just received the rawfd from kernel. + unsafe { libc::close(fd) }; + Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) + } + + // fn handle_iotlb_msg(&mut self, iotlb: VhostUserIotlb); + // fn handle_vring_host_notifier(&mut self, area: VhostUserVringArea, fd: RawFd); +} + +/// A helper trait mirroring [VhostUserMasterReqHandler] but without interior mutability. +/// +/// [VhostUserMasterReqHandler]: trait.VhostUserMasterReqHandler.html +pub trait VhostUserMasterReqHandlerMut { + /// Handle device configuration change notifications. + fn handle_config_change(&mut self) -> HandlerResult<u64> { + Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) + } + + /// Handle virtio-fs map file requests. + fn fs_slave_map(&mut self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> { + // Safe because we have just received the rawfd from kernel. + unsafe { libc::close(fd) }; + Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) + } + + /// Handle virtio-fs unmap file requests. + fn fs_slave_unmap(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> { + Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) + } + + /// Handle virtio-fs sync file requests. + fn fs_slave_sync(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> { + Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) + } + + /// Handle virtio-fs file IO requests. + fn fs_slave_io(&mut self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> { + // Safe because we have just received the rawfd from kernel. + unsafe { libc::close(fd) }; + Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) + } + + // fn handle_iotlb_msg(&mut self, iotlb: VhostUserIotlb); + // fn handle_vring_host_notifier(&mut self, area: VhostUserVringArea, fd: RawFd); +} + +impl<S: VhostUserMasterReqHandlerMut> VhostUserMasterReqHandler for Mutex<S> { + fn handle_config_change(&self) -> HandlerResult<u64> { + self.lock().unwrap().handle_config_change() + } + + fn fs_slave_map(&self, fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> { + self.lock().unwrap().fs_slave_map(fs, fd) + } + + fn fs_slave_unmap(&self, fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> { + self.lock().unwrap().fs_slave_unmap(fs) + } + + fn fs_slave_sync(&self, fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> { + self.lock().unwrap().fs_slave_sync(fs) + } + + fn fs_slave_io(&self, fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> { + self.lock().unwrap().fs_slave_io(fs, fd) + } +} + +/// Server to handle service requests from slaves from the slave communication channel. +/// +/// The [MasterReqHandler] acts as a server on the master side, to handle service requests from +/// slaves on the slave communication channel. It's actually a proxy invoking the registered +/// handler implementing [VhostUserMasterReqHandler] to do the real work. +/// +/// [MasterReqHandler]: struct.MasterReqHandler.html +/// [VhostUserMasterReqHandler]: trait.VhostUserMasterReqHandler.html +pub struct MasterReqHandler<S: VhostUserMasterReqHandler> { + // underlying Unix domain socket for communication + sub_sock: Endpoint<SlaveReq>, + tx_sock: UnixStream, + // Protocol feature VHOST_USER_PROTOCOL_F_REPLY_ACK has been negotiated. + reply_ack_negotiated: bool, + // the VirtIO backend device object + backend: Arc<S>, + // whether the endpoint has encountered any failure + error: Option<i32>, +} + +impl<S: VhostUserMasterReqHandler> MasterReqHandler<S> { + /// Create a server to handle service requests from slaves on the slave communication channel. + /// + /// This opens a pair of connected anonymous sockets to form the slave communication channel. + /// The socket fd returned by [Self::get_tx_raw_fd()] should be sent to the slave by + /// [VhostUserMaster::set_slave_request_fd()]. + /// + /// [Self::get_tx_raw_fd()]: struct.MasterReqHandler.html#method.get_tx_raw_fd + /// [VhostUserMaster::set_slave_request_fd()]: trait.VhostUserMaster.html#tymethod.set_slave_request_fd + pub fn new(backend: Arc<S>) -> Result<Self> { + let (tx, rx) = UnixStream::pair().map_err(Error::SocketError)?; + + Ok(MasterReqHandler { + sub_sock: Endpoint::<SlaveReq>::from_stream(rx), + tx_sock: tx, + reply_ack_negotiated: false, + backend, + error: None, + }) + } + + /// Get the socket fd for the slave to communication with the master. + /// + /// The returned fd should be sent to the slave by [VhostUserMaster::set_slave_request_fd()]. + /// + /// [VhostUserMaster::set_slave_request_fd()]: trait.VhostUserMaster.html#tymethod.set_slave_request_fd + pub fn get_tx_raw_fd(&self) -> RawFd { + self.tx_sock.as_raw_fd() + } + + /// Set the negotiation state of the `VHOST_USER_PROTOCOL_F_REPLY_ACK` protocol feature. + /// + /// When the `VHOST_USER_PROTOCOL_F_REPLY_ACK` protocol feature has been negotiated, + /// the "REPLY_ACK" flag will be set in the message header for every slave to master request + /// message. + pub fn set_reply_ack_flag(&mut self, enable: bool) { + self.reply_ack_negotiated = enable; + } + + /// Mark endpoint as failed or in normal state. + pub fn set_failed(&mut self, error: i32) { + if error == 0 { + self.error = None; + } else { + self.error = Some(error); + } + } + + /// Main entrance to server slave request from the slave communication channel. + /// + /// The caller needs to: + /// - serialize calls to this function + /// - decide what to do when errer happens + /// - optional recover from failure + pub fn handle_request(&mut self) -> Result<u64> { + // Return error if the endpoint is already in failed state. + self.check_state()?; + + // The underlying communication channel is a Unix domain socket in + // stream mode, and recvmsg() is a little tricky here. To successfully + // receive attached file descriptors, we need to receive messages and + // corresponding attached file descriptors in this way: + // . recv messsage header and optional attached file + // . validate message header + // . recv optional message body and payload according size field in + // message header + // . validate message body and optional payload + let (hdr, rfds) = self.sub_sock.recv_header()?; + let rfds = self.check_attached_rfds(&hdr, rfds)?; + let (size, buf) = match hdr.get_size() { + 0 => (0, vec![0u8; 0]), + len => { + if len as usize > MAX_MSG_SIZE { + return Err(Error::InvalidMessage); + } + let (size2, rbuf) = self.sub_sock.recv_data(len as usize)?; + if size2 != len as usize { + return Err(Error::InvalidMessage); + } + (size2, rbuf) + } + }; + + let res = match hdr.get_code() { + SlaveReq::CONFIG_CHANGE_MSG => { + self.check_msg_size(&hdr, size, 0)?; + self.backend + .handle_config_change() + .map_err(Error::ReqHandlerError) + } + SlaveReq::FS_MAP => { + let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?; + // check_attached_rfds() has validated rfds + self.backend + .fs_slave_map(&msg, rfds.unwrap()[0]) + .map_err(Error::ReqHandlerError) + } + SlaveReq::FS_UNMAP => { + let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?; + self.backend + .fs_slave_unmap(&msg) + .map_err(Error::ReqHandlerError) + } + SlaveReq::FS_SYNC => { + let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?; + self.backend + .fs_slave_sync(&msg) + .map_err(Error::ReqHandlerError) + } + SlaveReq::FS_IO => { + let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?; + // check_attached_rfds() has validated rfds + self.backend + .fs_slave_io(&msg, rfds.unwrap()[0]) + .map_err(Error::ReqHandlerError) + } + _ => Err(Error::InvalidMessage), + }; + + self.send_ack_message(&hdr, &res)?; + + res + } + + fn check_state(&self) -> Result<()> { + match self.error { + Some(e) => Err(Error::SocketBroken(std::io::Error::from_raw_os_error(e))), + None => Ok(()), + } + } + + fn check_msg_size( + &self, + hdr: &VhostUserMsgHeader<SlaveReq>, + size: usize, + expected: usize, + ) -> Result<()> { + if hdr.get_size() as usize != expected + || hdr.is_reply() + || hdr.get_version() != 0x1 + || size != expected + { + return Err(Error::InvalidMessage); + } + Ok(()) + } + + fn check_attached_rfds( + &self, + hdr: &VhostUserMsgHeader<SlaveReq>, + rfds: Option<Vec<RawFd>>, + ) -> Result<Option<Vec<RawFd>>> { + match hdr.get_code() { + SlaveReq::FS_MAP | SlaveReq::FS_IO => { + // Expect an fd set with a single fd. + match rfds { + None => Err(Error::InvalidMessage), + Some(fds) => { + if fds.len() != 1 { + Endpoint::<SlaveReq>::close_rfds(Some(fds)); + Err(Error::InvalidMessage) + } else { + Ok(Some(fds)) + } + } + } + } + _ => { + if rfds.is_some() { + Endpoint::<SlaveReq>::close_rfds(rfds); + Err(Error::InvalidMessage) + } else { + Ok(rfds) + } + } + } + } + + fn extract_msg_body<T: Sized + VhostUserMsgValidator>( + &self, + hdr: &VhostUserMsgHeader<SlaveReq>, + size: usize, + buf: &[u8], + ) -> Result<T> { + self.check_msg_size(hdr, size, mem::size_of::<T>())?; + let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const T) }; + if !msg.is_valid() { + return Err(Error::InvalidMessage); + } + Ok(msg) + } + + fn new_reply_header<T: Sized>( + &self, + req: &VhostUserMsgHeader<SlaveReq>, + ) -> Result<VhostUserMsgHeader<SlaveReq>> { + if mem::size_of::<T>() > MAX_MSG_SIZE { + return Err(Error::InvalidParam); + } + self.check_state()?; + Ok(VhostUserMsgHeader::new( + req.get_code(), + VhostUserHeaderFlag::REPLY.bits(), + mem::size_of::<T>() as u32, + )) + } + + fn send_ack_message( + &mut self, + req: &VhostUserMsgHeader<SlaveReq>, + res: &Result<u64>, + ) -> Result<()> { + if self.reply_ack_negotiated && req.is_need_reply() { + let hdr = self.new_reply_header::<VhostUserU64>(req)?; + let def_err = libc::EINVAL; + let val = match res { + Ok(n) => *n, + Err(e) => match &*e { + Error::ReqHandlerError(ioerr) => match ioerr.raw_os_error() { + Some(rawerr) => -rawerr as u64, + None => -def_err as u64, + }, + _ => -def_err as u64, + }, + }; + let msg = VhostUserU64::new(val); + self.sub_sock.send_message(&hdr, &msg, None)?; + } + Ok(()) + } +} + +impl<S: VhostUserMasterReqHandler> AsRawFd for MasterReqHandler<S> { + fn as_raw_fd(&self) -> RawFd { + self.sub_sock.as_raw_fd() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[cfg(feature = "vhost-user-slave")] + use crate::vhost_user::SlaveFsCacheReq; + #[cfg(feature = "vhost-user-slave")] + use std::os::unix::io::FromRawFd; + + struct MockMasterReqHandler {} + + impl VhostUserMasterReqHandlerMut for MockMasterReqHandler { + /// Handle virtio-fs map file requests from the slave. + fn fs_slave_map(&mut self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> { + // Safe because we have just received the rawfd from kernel. + unsafe { libc::close(fd) }; + Ok(0) + } + + /// Handle virtio-fs unmap file requests from the slave. + fn fs_slave_unmap(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> { + Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) + } + } + + #[test] + fn test_new_master_req_handler() { + let backend = Arc::new(Mutex::new(MockMasterReqHandler {})); + let mut handler = MasterReqHandler::new(backend).unwrap(); + + assert!(handler.get_tx_raw_fd() >= 0); + assert!(handler.as_raw_fd() >= 0); + handler.check_state().unwrap(); + + assert_eq!(handler.error, None); + handler.set_failed(libc::EAGAIN); + assert_eq!(handler.error, Some(libc::EAGAIN)); + handler.check_state().unwrap_err(); + } + + #[cfg(feature = "vhost-user-slave")] + #[test] + fn test_master_slave_req_handler() { + let backend = Arc::new(Mutex::new(MockMasterReqHandler {})); + let mut handler = MasterReqHandler::new(backend).unwrap(); + + let fd = unsafe { libc::dup(handler.get_tx_raw_fd()) }; + if fd < 0 { + panic!("failed to duplicated tx fd!"); + } + let stream = unsafe { UnixStream::from_raw_fd(fd) }; + let fs_cache = SlaveFsCacheReq::from_stream(stream); + + std::thread::spawn(move || { + let res = handler.handle_request().unwrap(); + assert_eq!(res, 0); + handler.handle_request().unwrap_err(); + }); + + fs_cache + .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd) + .unwrap(); + // When REPLY_ACK has not been negotiated, the master has no way to detect failure from + // slave side. + fs_cache + .fs_slave_unmap(&VhostUserFSSlaveMsg::default()) + .unwrap(); + } + + #[cfg(feature = "vhost-user-slave")] + #[test] + fn test_master_slave_req_handler_with_ack() { + let backend = Arc::new(Mutex::new(MockMasterReqHandler {})); + let mut handler = MasterReqHandler::new(backend).unwrap(); + handler.set_reply_ack_flag(true); + + let fd = unsafe { libc::dup(handler.get_tx_raw_fd()) }; + if fd < 0 { + panic!("failed to duplicated tx fd!"); + } + let stream = unsafe { UnixStream::from_raw_fd(fd) }; + let fs_cache = SlaveFsCacheReq::from_stream(stream); + + std::thread::spawn(move || { + let res = handler.handle_request().unwrap(); + assert_eq!(res, 0); + handler.handle_request().unwrap_err(); + }); + + fs_cache.set_reply_ack_flag(true); + fs_cache + .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd) + .unwrap(); + fs_cache + .fs_slave_unmap(&VhostUserFSSlaveMsg::default()) + .unwrap_err(); + } +} diff --git a/src/vhost_user/message.rs b/src/vhost_user/message.rs new file mode 100644 index 0000000..ea2df4e --- /dev/null +++ b/src/vhost_user/message.rs @@ -0,0 +1,1042 @@ +// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Define communication messages for the vhost-user protocol. +//! +//! For message definition, please refer to the [vhost-user spec](https://github.com/qemu/qemu/blob/f7526eece29cd2e36a63b6703508b24453095eb8/docs/interop/vhost-user.txt). + +#![allow(dead_code)] +#![allow(non_camel_case_types)] + +use std::fmt::Debug; +use std::marker::PhantomData; + +use crate::VringConfigData; + +/// The vhost-user specification uses a field of u32 to store message length. +/// On the other hand, preallocated buffers are needed to receive messages from the Unix domain +/// socket. To preallocating a 4GB buffer for each vhost-user message is really just an overhead. +/// Among all defined vhost-user messages, only the VhostUserConfig and VhostUserMemory has variable +/// message size. For the VhostUserConfig, a maximum size of 4K is enough because the user +/// configuration space for virtio devices is (4K - 0x100) bytes at most. For the VhostUserMemory, +/// 4K should be enough too because it can support 255 memory regions at most. +pub const MAX_MSG_SIZE: usize = 0x1000; + +/// The VhostUserMemory message has variable message size and variable number of attached file +/// descriptors. Each user memory region entry in the message payload occupies 32 bytes, +/// so setting maximum number of attached file descriptors based on the maximum message size. +/// But rust only implements Default and AsMut traits for arrays with 0 - 32 entries, so further +/// reduce the maximum number... +// pub const MAX_ATTACHED_FD_ENTRIES: usize = (MAX_MSG_SIZE - 8) / 32; +pub const MAX_ATTACHED_FD_ENTRIES: usize = 32; + +/// Starting position (inclusion) of the device configuration space in virtio devices. +pub const VHOST_USER_CONFIG_OFFSET: u32 = 0x100; + +/// Ending position (exclusion) of the device configuration space in virtio devices. +pub const VHOST_USER_CONFIG_SIZE: u32 = 0x1000; + +/// Maximum number of vrings supported. +pub const VHOST_USER_MAX_VRINGS: u64 = 0x8000u64; + +pub(super) trait Req: + Clone + Copy + Debug + PartialEq + Eq + PartialOrd + Ord + Into<u32> +{ + fn is_valid(&self) -> bool; +} + +/// Type of requests sending from masters to slaves. +#[repr(u32)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] +pub enum MasterReq { + /// Null operation. + NOOP = 0, + /// Get from the underlying vhost implementation the features bit mask. + GET_FEATURES = 1, + /// Enable features in the underlying vhost implementation using a bit mask. + SET_FEATURES = 2, + /// Set the current Master as an owner of the session. + SET_OWNER = 3, + /// No longer used. + RESET_OWNER = 4, + /// Set the memory map regions on the slave so it can translate the vring addresses. + SET_MEM_TABLE = 5, + /// Set logging shared memory space. + SET_LOG_BASE = 6, + /// Set the logging file descriptor, which is passed as ancillary data. + SET_LOG_FD = 7, + /// Set the size of the queue. + SET_VRING_NUM = 8, + /// Set the addresses of the different aspects of the vring. + SET_VRING_ADDR = 9, + /// Set the base offset in the available vring. + SET_VRING_BASE = 10, + /// Get the available vring base offset. + GET_VRING_BASE = 11, + /// Set the event file descriptor for adding buffers to the vring. + SET_VRING_KICK = 12, + /// Set the event file descriptor to signal when buffers are used. + SET_VRING_CALL = 13, + /// Set the event file descriptor to signal when error occurs. + SET_VRING_ERR = 14, + /// Get the protocol feature bit mask from the underlying vhost implementation. + GET_PROTOCOL_FEATURES = 15, + /// Enable protocol features in the underlying vhost implementation. + SET_PROTOCOL_FEATURES = 16, + /// Query how many queues the backend supports. + GET_QUEUE_NUM = 17, + /// Signal slave to enable or disable corresponding vring. + SET_VRING_ENABLE = 18, + /// Ask vhost user backend to broadcast a fake RARP to notify the migration is terminated + /// for guest that does not support GUEST_ANNOUNCE. + SEND_RARP = 19, + /// Set host MTU value exposed to the guest. + NET_SET_MTU = 20, + /// Set the socket file descriptor for slave initiated requests. + SET_SLAVE_REQ_FD = 21, + /// Send IOTLB messages with struct vhost_iotlb_msg as payload. + IOTLB_MSG = 22, + /// Set the endianness of a VQ for legacy devices. + SET_VRING_ENDIAN = 23, + /// Fetch the contents of the virtio device configuration space. + GET_CONFIG = 24, + /// Change the contents of the virtio device configuration space. + SET_CONFIG = 25, + /// Create a session for crypto operation. + CREATE_CRYPTO_SESSION = 26, + /// Close a session for crypto operation. + CLOSE_CRYPTO_SESSION = 27, + /// Advise slave that a migration with postcopy enabled is underway. + POSTCOPY_ADVISE = 28, + /// Advise slave that a transition to postcopy mode has happened. + POSTCOPY_LISTEN = 29, + /// Advise that postcopy migration has now completed. + POSTCOPY_END = 30, + /// Get a shared buffer from slave. + GET_INFLIGHT_FD = 31, + /// Send the shared inflight buffer back to slave. + SET_INFLIGHT_FD = 32, + /// Sets the GPU protocol socket file descriptor. + GPU_SET_SOCKET = 33, + /// Ask the vhost user backend to disable all rings and reset all internal + /// device state to the initial state. + RESET_DEVICE = 34, + /// Indicate that a buffer was added to the vring instead of signalling it + /// using the vring’s kick file descriptor. + VRING_KICK = 35, + /// Return a u64 payload containing the maximum number of memory slots. + GET_MAX_MEM_SLOTS = 36, + /// Update the memory tables by adding the region described. + ADD_MEM_REG = 37, + /// Update the memory tables by removing the region described. + REM_MEM_REG = 38, + /// Notify the backend with updated device status as defined in the VIRTIO + /// specification. + SET_STATUS = 39, + /// Query the backend for its device status as defined in the VIRTIO + /// specification. + GET_STATUS = 40, + /// Upper bound of valid commands. + MAX_CMD = 41, +} + +impl Into<u32> for MasterReq { + fn into(self) -> u32 { + self as u32 + } +} + +impl Req for MasterReq { + fn is_valid(&self) -> bool { + (*self > MasterReq::NOOP) && (*self < MasterReq::MAX_CMD) + } +} + +/// Type of requests sending from slaves to masters. +#[repr(u32)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] +pub enum SlaveReq { + /// Null operation. + NOOP = 0, + /// Send IOTLB messages with struct vhost_iotlb_msg as payload. + IOTLB_MSG = 1, + /// Notify that the virtio device's configuration space has changed. + CONFIG_CHANGE_MSG = 2, + /// Set host notifier for a specified queue. + VRING_HOST_NOTIFIER_MSG = 3, + /// Indicate that a buffer was used from the vring. + VRING_CALL = 4, + /// Indicate that an error occurred on the specific vring. + VRING_ERR = 5, + /// Virtio-fs draft: map file content into the window. + FS_MAP = 6, + /// Virtio-fs draft: unmap file content from the window. + FS_UNMAP = 7, + /// Virtio-fs draft: sync file content. + FS_SYNC = 8, + /// Virtio-fs draft: perform a read/write from an fd directly to GPA. + FS_IO = 9, + /// Upper bound of valid commands. + MAX_CMD = 10, +} + +impl Into<u32> for SlaveReq { + fn into(self) -> u32 { + self as u32 + } +} + +impl Req for SlaveReq { + fn is_valid(&self) -> bool { + (*self > SlaveReq::NOOP) && (*self < SlaveReq::MAX_CMD) + } +} + +/// Vhost message Validator. +pub trait VhostUserMsgValidator { + /// Validate message syntax only. + /// It doesn't validate message semantics such as protocol version number and dependency + /// on feature flags etc. + fn is_valid(&self) -> bool { + true + } +} + +// Bit mask for common message flags. +bitflags! { + /// Common message flags for vhost-user requests and replies. + pub struct VhostUserHeaderFlag: u32 { + /// Bits[0..2] is message version number. + const VERSION = 0x3; + /// Mark message as reply. + const REPLY = 0x4; + /// Sender anticipates a reply message from the peer. + const NEED_REPLY = 0x8; + /// All valid bits. + const ALL_FLAGS = 0xc; + /// All reserved bits. + const RESERVED_BITS = !0xf; + } +} + +/// Common message header for vhost-user requests and replies. +/// A vhost-user message consists of 3 header fields and an optional payload. All numbers are in the +/// machine native byte order. +#[allow(safe_packed_borrows)] +#[repr(packed)] +#[derive(Debug, Clone, Copy, PartialEq)] +pub(super) struct VhostUserMsgHeader<R: Req> { + request: u32, + flags: u32, + size: u32, + _r: PhantomData<R>, +} + +impl<R: Req> VhostUserMsgHeader<R> { + /// Create a new instance of `VhostUserMsgHeader`. + pub fn new(request: R, flags: u32, size: u32) -> Self { + // Default to protocol version 1 + let fl = (flags & VhostUserHeaderFlag::ALL_FLAGS.bits()) | 0x1; + VhostUserMsgHeader { + request: request.into(), + flags: fl, + size, + _r: PhantomData, + } + } + + /// Get message type. + pub fn get_code(&self) -> R { + // It's safe because R is marked as repr(u32). + unsafe { std::mem::transmute_copy::<u32, R>(&self.request) } + } + + /// Set message type. + pub fn set_code(&mut self, request: R) { + self.request = request.into(); + } + + /// Get message version number. + pub fn get_version(&self) -> u32 { + self.flags & 0x3 + } + + /// Set message version number. + pub fn set_version(&mut self, ver: u32) { + self.flags &= !0x3; + self.flags |= ver & 0x3; + } + + /// Check whether it's a reply message. + pub fn is_reply(&self) -> bool { + (self.flags & VhostUserHeaderFlag::REPLY.bits()) != 0 + } + + /// Mark message as reply. + pub fn set_reply(&mut self, is_reply: bool) { + if is_reply { + self.flags |= VhostUserHeaderFlag::REPLY.bits(); + } else { + self.flags &= !VhostUserHeaderFlag::REPLY.bits(); + } + } + + /// Check whether reply for this message is requested. + pub fn is_need_reply(&self) -> bool { + (self.flags & VhostUserHeaderFlag::NEED_REPLY.bits()) != 0 + } + + /// Mark that reply for this message is needed. + pub fn set_need_reply(&mut self, need_reply: bool) { + if need_reply { + self.flags |= VhostUserHeaderFlag::NEED_REPLY.bits(); + } else { + self.flags &= !VhostUserHeaderFlag::NEED_REPLY.bits(); + } + } + + /// Check whether it's the reply message for the request `req`. + pub fn is_reply_for(&self, req: &VhostUserMsgHeader<R>) -> bool { + self.is_reply() && !req.is_reply() && self.get_code() == req.get_code() + } + + /// Get message size. + pub fn get_size(&self) -> u32 { + self.size + } + + /// Set message size. + pub fn set_size(&mut self, size: u32) { + self.size = size; + } +} + +impl<R: Req> Default for VhostUserMsgHeader<R> { + fn default() -> Self { + VhostUserMsgHeader { + request: 0, + flags: 0x1, + size: 0, + _r: PhantomData, + } + } +} + +impl<T: Req> VhostUserMsgValidator for VhostUserMsgHeader<T> { + #[allow(clippy::if_same_then_else)] + fn is_valid(&self) -> bool { + if !self.get_code().is_valid() { + return false; + } else if self.size as usize > MAX_MSG_SIZE { + return false; + } else if self.get_version() != 0x1 { + return false; + } else if (self.flags & VhostUserHeaderFlag::RESERVED_BITS.bits()) != 0 { + return false; + } + true + } +} + +// Bit mask for transport specific flags in VirtIO feature set defined by vhost-user. +bitflags! { + /// Transport specific flags in VirtIO feature set defined by vhost-user. + pub struct VhostUserVirtioFeatures: u64 { + /// Feature flag for the protocol feature. + const PROTOCOL_FEATURES = 0x4000_0000; + } +} + +// Bit mask for vhost-user protocol feature flags. +bitflags! { + /// Vhost-user protocol feature flags. + pub struct VhostUserProtocolFeatures: u64 { + /// Support multiple queues. + const MQ = 0x0000_0001; + /// Support logging through shared memory fd. + const LOG_SHMFD = 0x0000_0002; + /// Support broadcasting fake RARP packet. + const RARP = 0x0000_0004; + /// Support sending reply messages for requests with NEED_REPLY flag set. + const REPLY_ACK = 0x0000_0008; + /// Support setting MTU for virtio-net devices. + const MTU = 0x0000_0010; + /// Allow the slave to send requests to the master by an optional communication channel. + const SLAVE_REQ = 0x0000_0020; + /// Support setting slave endian by SET_VRING_ENDIAN. + const CROSS_ENDIAN = 0x0000_0040; + /// Support crypto operations. + const CRYPTO_SESSION = 0x0000_0080; + /// Support sending userfault_fd from slaves to masters. + const PAGEFAULT = 0x0000_0100; + /// Support Virtio device configuration. + const CONFIG = 0x0000_0200; + /// Allow the slave to send fds (at most 8 descriptors in each message) to the master. + const SLAVE_SEND_FD = 0x0000_0400; + /// Allow the slave to register a host notifier. + const HOST_NOTIFIER = 0x0000_0800; + /// Support inflight shmfd. + const INFLIGHT_SHMFD = 0x0000_1000; + /// Support resetting the device. + const RESET_DEVICE = 0x0000_2000; + /// Support inband notifications. + const INBAND_NOTIFICATIONS = 0x0000_4000; + /// Support configuring memory slots. + const CONFIGURE_MEM_SLOTS = 0x0000_8000; + /// Support reporting status. + const STATUS = 0x0001_0000; + } +} + +/// A generic message to encapsulate a 64-bit value. +#[repr(packed)] +#[derive(Default)] +pub struct VhostUserU64 { + /// The encapsulated 64-bit common value. + pub value: u64, +} + +impl VhostUserU64 { + /// Create a new instance. + pub fn new(value: u64) -> Self { + VhostUserU64 { value } + } +} + +impl VhostUserMsgValidator for VhostUserU64 {} + +/// Memory region descriptor for the SET_MEM_TABLE request. +#[repr(packed)] +#[derive(Default)] +pub struct VhostUserMemory { + /// Number of memory regions in the payload. + pub num_regions: u32, + /// Padding for alignment. + pub padding1: u32, +} + +impl VhostUserMemory { + /// Create a new instance. + pub fn new(cnt: u32) -> Self { + VhostUserMemory { + num_regions: cnt, + padding1: 0, + } + } +} + +impl VhostUserMsgValidator for VhostUserMemory { + #[allow(clippy::if_same_then_else)] + fn is_valid(&self) -> bool { + if self.padding1 != 0 { + return false; + } else if self.num_regions == 0 || self.num_regions > MAX_ATTACHED_FD_ENTRIES as u32 { + return false; + } + true + } +} + +/// Memory region descriptors as payload for the SET_MEM_TABLE request. +#[repr(packed)] +#[derive(Default, Clone, Copy)] +pub struct VhostUserMemoryRegion { + /// Guest physical address of the memory region. + pub guest_phys_addr: u64, + /// Size of the memory region. + pub memory_size: u64, + /// Virtual address in the current process. + pub user_addr: u64, + /// Offset where region starts in the mapped memory. + pub mmap_offset: u64, +} + +impl VhostUserMemoryRegion { + /// Create a new instance. + pub fn new(guest_phys_addr: u64, memory_size: u64, user_addr: u64, mmap_offset: u64) -> Self { + VhostUserMemoryRegion { + guest_phys_addr, + memory_size, + user_addr, + mmap_offset, + } + } +} + +impl VhostUserMsgValidator for VhostUserMemoryRegion { + fn is_valid(&self) -> bool { + if self.memory_size == 0 + || self.guest_phys_addr.checked_add(self.memory_size).is_none() + || self.user_addr.checked_add(self.memory_size).is_none() + || self.mmap_offset.checked_add(self.memory_size).is_none() + { + return false; + } + true + } +} + +/// Payload of the VhostUserMemory message. +pub type VhostUserMemoryPayload = Vec<VhostUserMemoryRegion>; + +/// Single memory region descriptor as payload for ADD_MEM_REG and REM_MEM_REG +/// requests. +#[repr(C)] +#[derive(Default, Clone, Copy)] +pub struct VhostUserSingleMemoryRegion { + /// Padding for correct alignment + padding: u64, + /// Guest physical address of the memory region. + pub guest_phys_addr: u64, + /// Size of the memory region. + pub memory_size: u64, + /// Virtual address in the current process. + pub user_addr: u64, + /// Offset where region starts in the mapped memory. + pub mmap_offset: u64, +} + +impl VhostUserSingleMemoryRegion { + /// Create a new instance. + pub fn new(guest_phys_addr: u64, memory_size: u64, user_addr: u64, mmap_offset: u64) -> Self { + VhostUserSingleMemoryRegion { + padding: 0, + guest_phys_addr, + memory_size, + user_addr, + mmap_offset, + } + } +} + +impl VhostUserMsgValidator for VhostUserSingleMemoryRegion { + fn is_valid(&self) -> bool { + if self.memory_size == 0 + || self.guest_phys_addr.checked_add(self.memory_size).is_none() + || self.user_addr.checked_add(self.memory_size).is_none() + || self.mmap_offset.checked_add(self.memory_size).is_none() + { + return false; + } + true + } +} + +/// Vring state descriptor. +#[repr(packed)] +#[derive(Default)] +pub struct VhostUserVringState { + /// Vring index. + pub index: u32, + /// A common 32bit value to encapsulate vring state etc. + pub num: u32, +} + +impl VhostUserVringState { + /// Create a new instance. + pub fn new(index: u32, num: u32) -> Self { + VhostUserVringState { index, num } + } +} + +impl VhostUserMsgValidator for VhostUserVringState {} + +// Bit mask for vring address flags. +bitflags! { + /// Flags for vring address. + pub struct VhostUserVringAddrFlags: u32 { + /// Support log of vring operations. + /// Modifications to "used" vring should be logged. + const VHOST_VRING_F_LOG = 0x1; + } +} + +/// Vring address descriptor. +#[repr(packed)] +#[derive(Default)] +pub struct VhostUserVringAddr { + /// Vring index. + pub index: u32, + /// Vring flags defined by VhostUserVringAddrFlags. + pub flags: u32, + /// Ring address of the vring descriptor table. + pub descriptor: u64, + /// Ring address of the vring used ring. + pub used: u64, + /// Ring address of the vring available ring. + pub available: u64, + /// Guest address for logging. + pub log: u64, +} + +impl VhostUserVringAddr { + /// Create a new instance. + pub fn new( + index: u32, + flags: VhostUserVringAddrFlags, + descriptor: u64, + used: u64, + available: u64, + log: u64, + ) -> Self { + VhostUserVringAddr { + index, + flags: flags.bits(), + descriptor, + used, + available, + log, + } + } + + /// Create a new instance from `VringConfigData`. + #[cfg_attr(feature = "cargo-clippy", allow(clippy::identity_conversion))] + pub fn from_config_data(index: u32, config_data: &VringConfigData) -> Self { + let log_addr = config_data.log_addr.unwrap_or(0); + VhostUserVringAddr { + index, + flags: config_data.flags, + descriptor: config_data.desc_table_addr, + used: config_data.used_ring_addr, + available: config_data.avail_ring_addr, + log: log_addr, + } + } +} + +impl VhostUserMsgValidator for VhostUserVringAddr { + #[allow(clippy::if_same_then_else)] + fn is_valid(&self) -> bool { + if (self.flags & !VhostUserVringAddrFlags::all().bits()) != 0 { + return false; + } else if self.descriptor & 0xf != 0 { + return false; + } else if self.available & 0x1 != 0 { + return false; + } else if self.used & 0x3 != 0 { + return false; + } + true + } +} + +// Bit mask for the vhost-user device configuration message. +bitflags! { + /// Flags for the device configuration message. + pub struct VhostUserConfigFlags: u32 { + /// Vhost master messages used for writeable fields. + const WRITABLE = 0x1; + /// Vhost master messages used for live migration. + const LIVE_MIGRATION = 0x2; + } +} + +/// Message to read/write device configuration space. +#[repr(packed)] +#[derive(Default)] +pub struct VhostUserConfig { + /// Offset of virtio device's configuration space. + pub offset: u32, + /// Configuration space access size in bytes. + pub size: u32, + /// Flags for the device configuration operation. + pub flags: u32, +} + +impl VhostUserConfig { + /// Create a new instance. + pub fn new(offset: u32, size: u32, flags: VhostUserConfigFlags) -> Self { + VhostUserConfig { + offset, + size, + flags: flags.bits(), + } + } +} + +impl VhostUserMsgValidator for VhostUserConfig { + #[allow(clippy::if_same_then_else)] + fn is_valid(&self) -> bool { + if (self.flags & !VhostUserConfigFlags::all().bits()) != 0 { + return false; + } else if self.offset < 0x100 { + return false; + } else if self.size == 0 + || self.size > VHOST_USER_CONFIG_SIZE + || self.size + self.offset > VHOST_USER_CONFIG_SIZE + { + return false; + } + true + } +} + +/// Payload for the VhostUserConfig message. +pub type VhostUserConfigPayload = Vec<u8>; + +/* + * TODO: support dirty log, live migration and IOTLB operations. +#[repr(packed)] +pub struct VhostUserVringArea { + pub index: u32, + pub flags: u32, + pub size: u64, + pub offset: u64, +} + +#[repr(packed)] +pub struct VhostUserLog { + pub size: u64, + pub offset: u64, +} + +#[repr(packed)] +pub struct VhostUserIotlb { + pub iova: u64, + pub size: u64, + pub user_addr: u64, + pub permission: u8, + pub optype: u8, +} +*/ + +// Bit mask for flags in virtio-fs slave messages +bitflags! { + #[derive(Default)] + /// Flags for virtio-fs slave messages. + pub struct VhostUserFSSlaveMsgFlags: u64 { + /// Empty permission. + const EMPTY = 0x0; + /// Read permission. + const MAP_R = 0x1; + /// Write permission. + const MAP_W = 0x2; + } +} + +/// Max entries in one virtio-fs slave request. +pub const VHOST_USER_FS_SLAVE_ENTRIES: usize = 8; + +/// Slave request message to update the MMIO window. +#[repr(packed)] +#[derive(Default)] +pub struct VhostUserFSSlaveMsg { + /// File offset. + pub fd_offset: [u64; VHOST_USER_FS_SLAVE_ENTRIES], + /// Offset into the DAX window. + pub cache_offset: [u64; VHOST_USER_FS_SLAVE_ENTRIES], + /// Size of region to map. + pub len: [u64; VHOST_USER_FS_SLAVE_ENTRIES], + /// Flags for the mmap operation + pub flags: [VhostUserFSSlaveMsgFlags; VHOST_USER_FS_SLAVE_ENTRIES], +} + +impl VhostUserMsgValidator for VhostUserFSSlaveMsg { + fn is_valid(&self) -> bool { + for i in 0..VHOST_USER_FS_SLAVE_ENTRIES { + if ({ self.flags[i] }.bits() & !VhostUserFSSlaveMsgFlags::all().bits()) != 0 + || self.fd_offset[i].checked_add(self.len[i]).is_none() + || self.cache_offset[i].checked_add(self.len[i]).is_none() + { + return false; + } + } + true + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::mem; + + #[test] + fn check_master_request_code() { + let code = MasterReq::NOOP; + assert!(!code.is_valid()); + let code = MasterReq::MAX_CMD; + assert!(!code.is_valid()); + assert!(code > MasterReq::NOOP); + let code = MasterReq::GET_FEATURES; + assert!(code.is_valid()); + assert_eq!(code, code.clone()); + let code: MasterReq = unsafe { std::mem::transmute::<u32, MasterReq>(10000u32) }; + assert!(!code.is_valid()); + } + + #[test] + fn check_slave_request_code() { + let code = SlaveReq::NOOP; + assert!(!code.is_valid()); + let code = SlaveReq::MAX_CMD; + assert!(!code.is_valid()); + assert!(code > SlaveReq::NOOP); + let code = SlaveReq::CONFIG_CHANGE_MSG; + assert!(code.is_valid()); + assert_eq!(code, code.clone()); + let code: SlaveReq = unsafe { std::mem::transmute::<u32, SlaveReq>(10000u32) }; + assert!(!code.is_valid()); + } + + #[test] + fn msg_header_ops() { + let mut hdr = VhostUserMsgHeader::new(MasterReq::GET_FEATURES, 0, 0x100); + assert_eq!(hdr.get_code(), MasterReq::GET_FEATURES); + hdr.set_code(MasterReq::SET_FEATURES); + assert_eq!(hdr.get_code(), MasterReq::SET_FEATURES); + + assert_eq!(hdr.get_version(), 0x1); + + assert_eq!(hdr.is_reply(), false); + hdr.set_reply(true); + assert_eq!(hdr.is_reply(), true); + hdr.set_reply(false); + + assert_eq!(hdr.is_need_reply(), false); + hdr.set_need_reply(true); + assert_eq!(hdr.is_need_reply(), true); + hdr.set_need_reply(false); + + assert_eq!(hdr.get_size(), 0x100); + hdr.set_size(0x200); + assert_eq!(hdr.get_size(), 0x200); + + assert_eq!(hdr.is_need_reply(), false); + assert_eq!(hdr.is_reply(), false); + assert_eq!(hdr.get_version(), 0x1); + + // Check message length + assert!(hdr.is_valid()); + hdr.set_size(0x2000); + assert!(!hdr.is_valid()); + hdr.set_size(0x100); + assert_eq!(hdr.get_size(), 0x100); + assert!(hdr.is_valid()); + hdr.set_size((MAX_MSG_SIZE - mem::size_of::<VhostUserMsgHeader<MasterReq>>()) as u32); + assert!(hdr.is_valid()); + hdr.set_size(0x0); + assert!(hdr.is_valid()); + + // Check version + hdr.set_version(0x0); + assert!(!hdr.is_valid()); + hdr.set_version(0x2); + assert!(!hdr.is_valid()); + hdr.set_version(0x1); + assert!(hdr.is_valid()); + + assert_eq!(hdr, hdr.clone()); + } + + #[test] + fn test_vhost_user_message_u64() { + let val = VhostUserU64::default(); + let val1 = VhostUserU64::new(0); + + let a = val.value; + let b = val1.value; + assert_eq!(a, b); + let a = VhostUserU64::new(1).value; + assert_eq!(a, 1); + } + + #[test] + fn check_user_memory() { + let mut msg = VhostUserMemory::new(1); + assert!(msg.is_valid()); + msg.num_regions = MAX_ATTACHED_FD_ENTRIES as u32; + assert!(msg.is_valid()); + + msg.num_regions += 1; + assert!(!msg.is_valid()); + msg.num_regions = 0xFFFFFFFF; + assert!(!msg.is_valid()); + msg.num_regions = MAX_ATTACHED_FD_ENTRIES as u32; + msg.padding1 = 1; + assert!(!msg.is_valid()); + } + + #[test] + fn check_user_memory_region() { + let mut msg = VhostUserMemoryRegion { + guest_phys_addr: 0, + memory_size: 0x1000, + user_addr: 0, + mmap_offset: 0, + }; + assert!(msg.is_valid()); + msg.guest_phys_addr = 0xFFFFFFFFFFFFEFFF; + assert!(msg.is_valid()); + msg.guest_phys_addr = 0xFFFFFFFFFFFFF000; + assert!(!msg.is_valid()); + msg.guest_phys_addr = 0xFFFFFFFFFFFF0000; + msg.memory_size = 0; + assert!(!msg.is_valid()); + let a = msg.guest_phys_addr; + let b = msg.guest_phys_addr; + assert_eq!(a, b); + + let msg = VhostUserMemoryRegion::default(); + let a = msg.guest_phys_addr; + assert_eq!(a, 0); + let a = msg.memory_size; + assert_eq!(a, 0); + let a = msg.user_addr; + assert_eq!(a, 0); + let a = msg.mmap_offset; + assert_eq!(a, 0); + } + + #[test] + fn test_vhost_user_state() { + let state = VhostUserVringState::new(5, 8); + + let a = state.index; + assert_eq!(a, 5); + let a = state.num; + assert_eq!(a, 8); + assert_eq!(state.is_valid(), true); + + let state = VhostUserVringState::default(); + let a = state.index; + assert_eq!(a, 0); + let a = state.num; + assert_eq!(a, 0); + assert_eq!(state.is_valid(), true); + } + + #[test] + fn test_vhost_user_addr() { + let mut addr = VhostUserVringAddr::new( + 2, + VhostUserVringAddrFlags::VHOST_VRING_F_LOG, + 0x1000, + 0x2000, + 0x3000, + 0x4000, + ); + + let a = addr.index; + assert_eq!(a, 2); + let a = addr.flags; + assert_eq!(a, VhostUserVringAddrFlags::VHOST_VRING_F_LOG.bits()); + let a = addr.descriptor; + assert_eq!(a, 0x1000); + let a = addr.used; + assert_eq!(a, 0x2000); + let a = addr.available; + assert_eq!(a, 0x3000); + let a = addr.log; + assert_eq!(a, 0x4000); + assert_eq!(addr.is_valid(), true); + + addr.descriptor = 0x1001; + assert_eq!(addr.is_valid(), false); + addr.descriptor = 0x1000; + + addr.available = 0x3001; + assert_eq!(addr.is_valid(), false); + addr.available = 0x3000; + + addr.used = 0x2001; + assert_eq!(addr.is_valid(), false); + addr.used = 0x2000; + assert_eq!(addr.is_valid(), true); + } + + #[test] + fn test_vhost_user_state_from_config() { + let config = VringConfigData { + queue_max_size: 256, + queue_size: 128, + flags: VhostUserVringAddrFlags::VHOST_VRING_F_LOG.bits, + desc_table_addr: 0x1000, + used_ring_addr: 0x2000, + avail_ring_addr: 0x3000, + log_addr: Some(0x4000), + }; + let addr = VhostUserVringAddr::from_config_data(2, &config); + + let a = addr.index; + assert_eq!(a, 2); + let a = addr.flags; + assert_eq!(a, VhostUserVringAddrFlags::VHOST_VRING_F_LOG.bits()); + let a = addr.descriptor; + assert_eq!(a, 0x1000); + let a = addr.used; + assert_eq!(a, 0x2000); + let a = addr.available; + assert_eq!(a, 0x3000); + let a = addr.log; + assert_eq!(a, 0x4000); + assert_eq!(addr.is_valid(), true); + } + + #[test] + fn check_user_vring_addr() { + let mut msg = + VhostUserVringAddr::new(0, VhostUserVringAddrFlags::all(), 0x0, 0x0, 0x0, 0x0); + assert!(msg.is_valid()); + + msg.descriptor = 1; + assert!(!msg.is_valid()); + msg.descriptor = 0; + + msg.available = 1; + assert!(!msg.is_valid()); + msg.available = 0; + + msg.used = 1; + assert!(!msg.is_valid()); + msg.used = 0; + + msg.flags |= 0x80000000; + assert!(!msg.is_valid()); + msg.flags &= !0x80000000; + } + + #[test] + fn check_user_config_msg() { + let mut msg = VhostUserConfig::new( + VHOST_USER_CONFIG_OFFSET, + VHOST_USER_CONFIG_SIZE - VHOST_USER_CONFIG_OFFSET, + VhostUserConfigFlags::WRITABLE, + ); + + assert!(msg.is_valid()); + msg.size = 0; + assert!(!msg.is_valid()); + msg.size = 1; + assert!(msg.is_valid()); + msg.offset = 0; + assert!(!msg.is_valid()); + msg.offset = VHOST_USER_CONFIG_SIZE; + assert!(!msg.is_valid()); + msg.offset = VHOST_USER_CONFIG_SIZE - 1; + assert!(msg.is_valid()); + msg.size = 2; + assert!(!msg.is_valid()); + msg.size = 1; + msg.flags |= VhostUserConfigFlags::LIVE_MIGRATION.bits(); + assert!(msg.is_valid()); + msg.flags |= 0x4; + assert!(!msg.is_valid()); + } + + #[test] + fn test_vhost_user_fs_slave() { + let mut fs_slave = VhostUserFSSlaveMsg::default(); + + assert_eq!(fs_slave.is_valid(), true); + + fs_slave.fd_offset[0] = 0xffff_ffff_ffff_ffff; + fs_slave.len[0] = 0x1; + assert_eq!(fs_slave.is_valid(), false); + + assert_ne!( + VhostUserFSSlaveMsgFlags::MAP_R, + VhostUserFSSlaveMsgFlags::MAP_W + ); + assert_eq!(VhostUserFSSlaveMsgFlags::EMPTY.bits(), 0); + } +} diff --git a/src/vhost_user/mod.rs b/src/vhost_user/mod.rs new file mode 100644 index 0000000..9ef6453 --- /dev/null +++ b/src/vhost_user/mod.rs @@ -0,0 +1,456 @@ +// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! The protocol for vhost-user is based on the existing implementation of vhost for the Linux +//! Kernel. The protocol defines two sides of the communication, master and slave. Master is +//! the application that shares its virtqueues. Slave is the consumer of the virtqueues. +//! +//! The communication channel between the master and the slave includes two sub channels. One is +//! used to send requests from the master to the slave and optional replies from the slave to the +//! master. This sub channel is created on master startup by connecting to the slave service +//! endpoint. The other is used to send requests from the slave to the master and optional replies +//! from the master to the slave. This sub channel is created by the master issuing a +//! VHOST_USER_SET_SLAVE_REQ_FD request to the slave with an auxiliary file descriptor. +//! +//! Unix domain socket is used as the underlying communication channel because the master needs to +//! send file descriptors to the slave. +//! +//! Most messages that can be sent via the Unix domain socket implementing vhost-user have an +//! equivalent ioctl to the kernel implementation. + +use std::io::Error as IOError; + +pub mod message; + +mod connection; +pub use self::connection::Listener; + +#[cfg(feature = "vhost-user-master")] +mod master; +#[cfg(feature = "vhost-user-master")] +pub use self::master::{Master, VhostUserMaster}; +#[cfg(feature = "vhost-user")] +mod master_req_handler; +#[cfg(feature = "vhost-user")] +pub use self::master_req_handler::{ + MasterReqHandler, VhostUserMasterReqHandler, VhostUserMasterReqHandlerMut, +}; + +#[cfg(feature = "vhost-user-slave")] +mod slave; +#[cfg(feature = "vhost-user-slave")] +pub use self::slave::SlaveListener; +#[cfg(feature = "vhost-user-slave")] +mod slave_req_handler; +#[cfg(feature = "vhost-user-slave")] +pub use self::slave_req_handler::{ + SlaveReqHandler, VhostUserSlaveReqHandler, VhostUserSlaveReqHandlerMut, +}; +#[cfg(feature = "vhost-user-slave")] +mod slave_fs_cache; +#[cfg(feature = "vhost-user-slave")] +pub use self::slave_fs_cache::SlaveFsCacheReq; + +/// Errors for vhost-user operations +#[derive(Debug)] +pub enum Error { + /// Invalid parameters. + InvalidParam, + /// Unsupported operations due to that the protocol feature hasn't been negotiated. + InvalidOperation, + /// Invalid message format, flag or content. + InvalidMessage, + /// Only part of a message have been sent or received successfully + PartialMessage, + /// Message is too large + OversizedMsg, + /// Fd array in question is too big or too small + IncorrectFds, + /// Can't connect to peer. + SocketConnect(std::io::Error), + /// Generic socket errors. + SocketError(std::io::Error), + /// The socket is broken or has been closed. + SocketBroken(std::io::Error), + /// Should retry the socket operation again. + SocketRetry(std::io::Error), + /// Failure from the slave side. + SlaveInternalError, + /// Failure from the master side. + MasterInternalError, + /// Virtio/protocol features mismatch. + FeatureMismatch, + /// Error from request handler + ReqHandlerError(IOError), +} + +impl std::fmt::Display for Error { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + Error::InvalidParam => write!(f, "invalid parameters"), + Error::InvalidOperation => write!(f, "invalid operation"), + Error::InvalidMessage => write!(f, "invalid message"), + Error::PartialMessage => write!(f, "partial message"), + Error::OversizedMsg => write!(f, "oversized message"), + Error::IncorrectFds => write!(f, "wrong number of attached fds"), + Error::SocketError(e) => write!(f, "socket error: {}", e), + Error::SocketConnect(e) => write!(f, "can't connect to peer: {}", e), + Error::SocketBroken(e) => write!(f, "socket is broken: {}", e), + Error::SocketRetry(e) => write!(f, "temporary socket error: {}", e), + Error::SlaveInternalError => write!(f, "slave internal error"), + Error::MasterInternalError => write!(f, "Master internal error"), + Error::FeatureMismatch => write!(f, "virtio/protocol features mismatch"), + Error::ReqHandlerError(e) => write!(f, "handler failed to handle request: {}", e), + } + } +} + +impl std::error::Error for Error {} + +impl Error { + /// Determine whether to rebuild the underline communication channel. + pub fn should_reconnect(&self) -> bool { + match *self { + // Should reconnect because it may be caused by temporary network errors. + Error::PartialMessage => true, + // Should reconnect because the underline socket is broken. + Error::SocketBroken(_) => true, + // Slave internal error, hope it recovers on reconnect. + Error::SlaveInternalError => true, + // Master internal error, hope it recovers on reconnect. + Error::MasterInternalError => true, + // Should just retry the IO operation instead of rebuilding the underline connection. + Error::SocketRetry(_) => false, + Error::InvalidParam | Error::InvalidOperation => false, + Error::InvalidMessage | Error::IncorrectFds | Error::OversizedMsg => false, + Error::SocketError(_) | Error::SocketConnect(_) => false, + Error::FeatureMismatch => false, + Error::ReqHandlerError(_) => false, + } + } +} + +impl std::convert::From<sys_util::Error> for Error { + /// Convert raw socket errors into meaningful vhost-user errors. + /// + /// The sys_util::Error is a simple wrapper over the raw errno, which doesn't means + /// much to the vhost-user connection manager. So convert it into meaningful errors to simplify + /// the connection manager logic. + /// + /// # Return: + /// * - Error::SocketRetry: temporary error caused by signals or short of resources. + /// * - Error::SocketBroken: the underline socket is broken. + /// * - Error::SocketError: other socket related errors. + #[allow(unreachable_patterns)] // EWOULDBLOCK equals to EGAIN on linux + fn from(err: sys_util::Error) -> Self { + match err.errno() { + // The socket is marked nonblocking and the requested operation would block. + libc::EAGAIN => Error::SocketRetry(IOError::from_raw_os_error(libc::EAGAIN)), + // The socket is marked nonblocking and the requested operation would block. + libc::EWOULDBLOCK => Error::SocketRetry(IOError::from_raw_os_error(libc::EWOULDBLOCK)), + // A signal occurred before any data was transmitted + libc::EINTR => Error::SocketRetry(IOError::from_raw_os_error(libc::EINTR)), + // The output queue for a network interface was full. This generally indicates + // that the interface has stopped sending, but may be caused by transient congestion. + libc::ENOBUFS => Error::SocketRetry(IOError::from_raw_os_error(libc::ENOBUFS)), + // No memory available. + libc::ENOMEM => Error::SocketRetry(IOError::from_raw_os_error(libc::ENOMEM)), + // Connection reset by peer. + libc::ECONNRESET => Error::SocketBroken(IOError::from_raw_os_error(libc::ECONNRESET)), + // The local end has been shut down on a connection oriented socket. In this case the + // process will also receive a SIGPIPE unless MSG_NOSIGNAL is set. + libc::EPIPE => Error::SocketBroken(IOError::from_raw_os_error(libc::EPIPE)), + // Write permission is denied on the destination socket file, or search permission is + // denied for one of the directories the path prefix. + libc::EACCES => Error::SocketConnect(IOError::from_raw_os_error(libc::EACCES)), + // Catch all other errors + e => Error::SocketError(IOError::from_raw_os_error(e)), + } + } +} + +/// Result of vhost-user operations +pub type Result<T> = std::result::Result<T, Error>; + +/// Result of request handler. +pub type HandlerResult<T> = std::result::Result<T, IOError>; + +#[cfg(all(test, feature = "vhost-user-slave"))] +mod dummy_slave; + +#[cfg(all(test, feature = "vhost-user-master", feature = "vhost-user-slave"))] +mod tests { + use std::os::unix::io::AsRawFd; + use std::path::Path; + use std::sync::{Arc, Barrier, Mutex}; + use std::thread; + + use super::dummy_slave::{DummySlaveReqHandler, VIRTIO_FEATURES}; + use super::message::*; + use super::*; + use crate::backend::VhostBackend; + use crate::{VhostUserMemoryRegionInfo, VringConfigData}; + use tempfile::{tempfile, Builder, TempDir}; + + fn temp_dir() -> TempDir { + Builder::new().prefix("/tmp/vhost_test").tempdir().unwrap() + } + + fn create_slave<P, S>(path: P, backend: Arc<S>) -> (Master, SlaveReqHandler<S>) + where + P: AsRef<Path>, + S: VhostUserSlaveReqHandler, + { + let listener = Listener::new(&path, true).unwrap(); + let mut slave_listener = SlaveListener::new(listener, backend).unwrap(); + let master = Master::connect(&path, 1).unwrap(); + (master, slave_listener.accept().unwrap().unwrap()) + } + + #[test] + fn create_dummy_slave() { + let slave = Arc::new(Mutex::new(DummySlaveReqHandler::new())); + + slave.set_owner().unwrap(); + assert!(slave.set_owner().is_err()); + } + + #[test] + fn test_set_owner() { + let slave_be = Arc::new(Mutex::new(DummySlaveReqHandler::new())); + let dir = temp_dir(); + let mut path = dir.path().to_owned(); + path.push("sock"); + let (master, mut slave) = create_slave(&path, slave_be.clone()); + + assert_eq!(slave_be.lock().unwrap().owned, false); + master.set_owner().unwrap(); + slave.handle_request().unwrap(); + assert_eq!(slave_be.lock().unwrap().owned, true); + master.set_owner().unwrap(); + assert!(slave.handle_request().is_err()); + assert_eq!(slave_be.lock().unwrap().owned, true); + } + + #[test] + fn test_set_features() { + let mbar = Arc::new(Barrier::new(2)); + let sbar = mbar.clone(); + let dir = temp_dir(); + let mut path = dir.path().to_owned(); + path.push("sock"); + let slave_be = Arc::new(Mutex::new(DummySlaveReqHandler::new())); + let (mut master, mut slave) = create_slave(&path, slave_be.clone()); + + thread::spawn(move || { + slave.handle_request().unwrap(); + assert_eq!(slave_be.lock().unwrap().owned, true); + + slave.handle_request().unwrap(); + slave.handle_request().unwrap(); + assert_eq!( + slave_be.lock().unwrap().acked_features, + VIRTIO_FEATURES & !0x1 + ); + + slave.handle_request().unwrap(); + slave.handle_request().unwrap(); + assert_eq!( + slave_be.lock().unwrap().acked_protocol_features, + VhostUserProtocolFeatures::all().bits() + ); + + sbar.wait(); + }); + + master.set_owner().unwrap(); + + // set virtio features + let features = master.get_features().unwrap(); + assert_eq!(features, VIRTIO_FEATURES); + master.set_features(VIRTIO_FEATURES & !0x1).unwrap(); + + // set vhost protocol features + let features = master.get_protocol_features().unwrap(); + assert_eq!(features.bits(), VhostUserProtocolFeatures::all().bits()); + master.set_protocol_features(features).unwrap(); + + mbar.wait(); + } + + #[test] + fn test_master_slave_process() { + let mbar = Arc::new(Barrier::new(2)); + let sbar = mbar.clone(); + let dir = temp_dir(); + let mut path = dir.path().to_owned(); + path.push("sock"); + let slave_be = Arc::new(Mutex::new(DummySlaveReqHandler::new())); + let (mut master, mut slave) = create_slave(&path, slave_be.clone()); + + thread::spawn(move || { + // set_own() + slave.handle_request().unwrap(); + assert_eq!(slave_be.lock().unwrap().owned, true); + + // get/set_features() + slave.handle_request().unwrap(); + slave.handle_request().unwrap(); + assert_eq!( + slave_be.lock().unwrap().acked_features, + VIRTIO_FEATURES & !0x1 + ); + + slave.handle_request().unwrap(); + slave.handle_request().unwrap(); + assert_eq!( + slave_be.lock().unwrap().acked_protocol_features, + VhostUserProtocolFeatures::all().bits() + ); + + // get_queue_num() + slave.handle_request().unwrap(); + + // set_mem_table() + slave.handle_request().unwrap(); + + // get/set_config() + slave.handle_request().unwrap(); + slave.handle_request().unwrap(); + + // set_slave_request_fd + slave.handle_request().unwrap(); + + // set_vring_enable + slave.handle_request().unwrap(); + + // set_log_base,set_log_fd() + slave.handle_request().unwrap_err(); + slave.handle_request().unwrap_err(); + + // set_vring_xxx + slave.handle_request().unwrap(); + slave.handle_request().unwrap(); + slave.handle_request().unwrap(); + slave.handle_request().unwrap(); + slave.handle_request().unwrap(); + slave.handle_request().unwrap(); + + // get_max_mem_slots() + slave.handle_request().unwrap(); + + // add_mem_region() + slave.handle_request().unwrap(); + + // remove_mem_region() + slave.handle_request().unwrap(); + + sbar.wait(); + }); + + master.set_owner().unwrap(); + + // set virtio features + let features = master.get_features().unwrap(); + assert_eq!(features, VIRTIO_FEATURES); + master.set_features(VIRTIO_FEATURES & !0x1).unwrap(); + + // set vhost protocol features + let features = master.get_protocol_features().unwrap(); + assert_eq!(features.bits(), VhostUserProtocolFeatures::all().bits()); + master.set_protocol_features(features).unwrap(); + + let num = master.get_queue_num().unwrap(); + assert_eq!(num, 2); + + let eventfd = sys_util::EventFd::new().unwrap(); + let mem = [VhostUserMemoryRegionInfo { + guest_phys_addr: 0, + memory_size: 0x10_0000, + userspace_addr: 0, + mmap_offset: 0, + mmap_handle: eventfd.as_raw_fd(), + }]; + master.set_mem_table(&mem).unwrap(); + + master + .set_config(0x100, VhostUserConfigFlags::WRITABLE, &[0xa5u8]) + .unwrap(); + let buf = [0x0u8; 4]; + let (reply_body, reply_payload) = master + .get_config(0x100, 4, VhostUserConfigFlags::empty(), &buf) + .unwrap(); + let offset = reply_body.offset; + assert_eq!(offset, 0x100); + assert_eq!(reply_payload[0], 0xa5); + + master.set_slave_request_fd(eventfd.as_raw_fd()).unwrap(); + master.set_vring_enable(0, true).unwrap(); + + // unimplemented yet + master.set_log_base(0, Some(eventfd.as_raw_fd())).unwrap(); + master.set_log_fd(eventfd.as_raw_fd()).unwrap(); + + master.set_vring_num(0, 256).unwrap(); + master.set_vring_base(0, 0).unwrap(); + let config = VringConfigData { + queue_max_size: 256, + queue_size: 128, + flags: VhostUserVringAddrFlags::VHOST_VRING_F_LOG.bits(), + desc_table_addr: 0x1000, + used_ring_addr: 0x2000, + avail_ring_addr: 0x3000, + log_addr: Some(0x4000), + }; + master.set_vring_addr(0, &config).unwrap(); + master.set_vring_call(0, &eventfd).unwrap(); + master.set_vring_kick(0, &eventfd).unwrap(); + master.set_vring_err(0, &eventfd).unwrap(); + + let max_mem_slots = master.get_max_mem_slots().unwrap(); + assert_eq!(max_mem_slots, 32); + + let region_file = tempfile().unwrap(); + let region = VhostUserMemoryRegionInfo { + guest_phys_addr: 0x10_0000, + memory_size: 0x10_0000, + userspace_addr: 0, + mmap_offset: 0, + mmap_handle: region_file.as_raw_fd(), + }; + master.add_mem_region(®ion).unwrap(); + + master.remove_mem_region(®ion).unwrap(); + + mbar.wait(); + } + + #[test] + fn test_error_display() { + assert_eq!(format!("{}", Error::InvalidParam), "invalid parameters"); + assert_eq!(format!("{}", Error::InvalidOperation), "invalid operation"); + } + + #[test] + fn test_should_reconnect() { + assert_eq!(Error::PartialMessage.should_reconnect(), true); + assert_eq!(Error::SlaveInternalError.should_reconnect(), true); + assert_eq!(Error::MasterInternalError.should_reconnect(), true); + assert_eq!(Error::InvalidParam.should_reconnect(), false); + assert_eq!(Error::InvalidOperation.should_reconnect(), false); + assert_eq!(Error::InvalidMessage.should_reconnect(), false); + assert_eq!(Error::IncorrectFds.should_reconnect(), false); + assert_eq!(Error::OversizedMsg.should_reconnect(), false); + assert_eq!(Error::FeatureMismatch.should_reconnect(), false); + } + + #[test] + fn test_error_from_sys_util_error() { + let e: Error = sys_util::Error::new(libc::EAGAIN).into(); + if let Error::SocketRetry(e1) = e { + assert_eq!(e1.raw_os_error().unwrap(), libc::EAGAIN); + } else { + panic!("invalid error code conversion!"); + } + } +} diff --git a/src/vhost_user/slave.rs b/src/vhost_user/slave.rs new file mode 100644 index 0000000..fb65c41 --- /dev/null +++ b/src/vhost_user/slave.rs @@ -0,0 +1,86 @@ +// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Traits and Structs for vhost-user slave. + +use std::sync::Arc; + +use super::connection::{Endpoint, Listener}; +use super::message::*; +use super::{Result, SlaveReqHandler, VhostUserSlaveReqHandler}; + +/// Vhost-user slave side connection listener. +pub struct SlaveListener<S: VhostUserSlaveReqHandler> { + listener: Listener, + backend: Option<Arc<S>>, +} + +/// Sets up a listener for incoming master connections, and handles construction +/// of a Slave on success. +impl<S: VhostUserSlaveReqHandler> SlaveListener<S> { + /// Create a unix domain socket for incoming master connections. + pub fn new(listener: Listener, backend: Arc<S>) -> Result<Self> { + Ok(SlaveListener { + listener, + backend: Some(backend), + }) + } + + /// Accept an incoming connection from the master, returning Some(Slave) on + /// success, or None if the socket is nonblocking and no incoming connection + /// was detected + pub fn accept(&mut self) -> Result<Option<SlaveReqHandler<S>>> { + if let Some(fd) = self.listener.accept()? { + return Ok(Some(SlaveReqHandler::new( + Endpoint::<MasterReq>::from_stream(fd), + self.backend.take().unwrap(), + ))); + } + Ok(None) + } + + /// Change blocking status on the listener. + pub fn set_nonblocking(&self, block: bool) -> Result<()> { + self.listener.set_nonblocking(block) + } +} + +#[cfg(test)] +mod tests { + use std::sync::Mutex; + + use super::*; + use crate::vhost_user::dummy_slave::DummySlaveReqHandler; + + #[test] + fn test_slave_listener_set_nonblocking() { + let backend = Arc::new(Mutex::new(DummySlaveReqHandler::new())); + let listener = + Listener::new("/tmp/vhost_user_lib_unit_test_slave_nonblocking", true).unwrap(); + let slave_listener = SlaveListener::new(listener, backend).unwrap(); + + slave_listener.set_nonblocking(true).unwrap(); + slave_listener.set_nonblocking(false).unwrap(); + slave_listener.set_nonblocking(false).unwrap(); + slave_listener.set_nonblocking(true).unwrap(); + slave_listener.set_nonblocking(true).unwrap(); + } + + #[cfg(feature = "vhost-user-master")] + #[test] + fn test_slave_listener_accept() { + use super::super::Master; + + let path = "/tmp/vhost_user_lib_unit_test_slave_accept"; + let backend = Arc::new(Mutex::new(DummySlaveReqHandler::new())); + let listener = Listener::new(path, true).unwrap(); + let mut slave_listener = SlaveListener::new(listener, backend).unwrap(); + + slave_listener.set_nonblocking(true).unwrap(); + assert!(slave_listener.accept().unwrap().is_none()); + assert!(slave_listener.accept().unwrap().is_none()); + + let _master = Master::connect(path, 1).unwrap(); + let _slave = slave_listener.accept().unwrap().unwrap(); + } +} diff --git a/src/vhost_user/slave_fs_cache.rs b/src/vhost_user/slave_fs_cache.rs new file mode 100644 index 0000000..a9c4ed2 --- /dev/null +++ b/src/vhost_user/slave_fs_cache.rs @@ -0,0 +1,226 @@ +// Copyright (C) 2020 Alibaba Cloud. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use std::io; +use std::mem; +use std::os::unix::io::RawFd; +use std::os::unix::net::UnixStream; +use std::sync::{Arc, Mutex, MutexGuard}; + +use super::connection::Endpoint; +use super::message::*; +use super::{Error, HandlerResult, Result, VhostUserMasterReqHandler}; + +struct SlaveFsCacheReqInternal { + sock: Endpoint<SlaveReq>, + + // Protocol feature VHOST_USER_PROTOCOL_F_REPLY_ACK has been negotiated. + reply_ack_negotiated: bool, + + // whether the endpoint has encountered any failure + error: Option<i32>, +} + +impl SlaveFsCacheReqInternal { + fn check_state(&self) -> Result<u64> { + match self.error { + Some(e) => Err(Error::SocketBroken(std::io::Error::from_raw_os_error(e))), + None => Ok(0), + } + } + + fn send_message( + &mut self, + request: SlaveReq, + fs: &VhostUserFSSlaveMsg, + fds: Option<&[RawFd]>, + ) -> Result<u64> { + self.check_state()?; + + let len = mem::size_of::<VhostUserFSSlaveMsg>(); + let mut hdr = VhostUserMsgHeader::new(request, 0, len as u32); + if self.reply_ack_negotiated { + hdr.set_need_reply(true); + } + self.sock.send_message(&hdr, fs, fds)?; + + self.wait_for_ack(&hdr) + } + + fn wait_for_ack(&mut self, hdr: &VhostUserMsgHeader<SlaveReq>) -> Result<u64> { + self.check_state()?; + if !self.reply_ack_negotiated { + return Ok(0); + } + + let (reply, body, rfds) = self.sock.recv_body::<VhostUserU64>()?; + if !reply.is_reply_for(&hdr) || rfds.is_some() || !body.is_valid() { + Endpoint::<SlaveReq>::close_rfds(rfds); + return Err(Error::InvalidMessage); + } + if body.value != 0 { + return Err(Error::MasterInternalError); + } + + Ok(body.value) + } +} + +/// Request proxy to send vhost-user-fs slave requests to the master through the slave +/// communication channel. +/// +/// The [SlaveFsCacheReq] acts as a message proxy to forward vhost-user-fs slave requests to the +/// master through the vhost-user slave communication channel. The forwarded messages will be +/// handled by the [MasterReqHandler] server. +/// +/// [SlaveFsCacheReq]: struct.SlaveFsCacheReq.html +/// [MasterReqHandler]: struct.MasterReqHandler.html +#[derive(Clone)] +pub struct SlaveFsCacheReq { + // underlying Unix domain socket for communication + node: Arc<Mutex<SlaveFsCacheReqInternal>>, +} + +impl SlaveFsCacheReq { + fn new(ep: Endpoint<SlaveReq>) -> Self { + SlaveFsCacheReq { + node: Arc::new(Mutex::new(SlaveFsCacheReqInternal { + sock: ep, + reply_ack_negotiated: false, + error: None, + })), + } + } + + fn node(&self) -> MutexGuard<SlaveFsCacheReqInternal> { + self.node.lock().unwrap() + } + + fn send_message( + &self, + request: SlaveReq, + fs: &VhostUserFSSlaveMsg, + fds: Option<&[RawFd]>, + ) -> io::Result<u64> { + self.node() + .send_message(request, fs, fds) + .map_err(|e| io::Error::new(io::ErrorKind::Other, format!("{}", e))) + } + + /// Create a new instance from a `UnixStream` object. + pub fn from_stream(sock: UnixStream) -> Self { + Self::new(Endpoint::<SlaveReq>::from_stream(sock)) + } + + /// Set the negotiation state of the `VHOST_USER_PROTOCOL_F_REPLY_ACK` protocol feature. + /// + /// When the `VHOST_USER_PROTOCOL_F_REPLY_ACK` protocol feature has been negotiated, + /// the "REPLY_ACK" flag will be set in the message header for every slave to master request + /// message. + pub fn set_reply_ack_flag(&self, enable: bool) { + self.node().reply_ack_negotiated = enable; + } + + /// Mark endpoint as failed with specified error code. + pub fn set_failed(&self, error: i32) { + self.node().error = Some(error); + } +} + +impl VhostUserMasterReqHandler for SlaveFsCacheReq { + /// Forward vhost-user-fs map file requests to the slave. + fn fs_slave_map(&self, fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> { + self.send_message(SlaveReq::FS_MAP, fs, Some(&[fd])) + } + + /// Forward vhost-user-fs unmap file requests to the master. + fn fs_slave_unmap(&self, fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> { + self.send_message(SlaveReq::FS_UNMAP, fs, None) + } +} + +#[cfg(test)] +mod tests { + use std::os::unix::io::AsRawFd; + + use super::*; + + #[test] + fn test_slave_fs_cache_req_set_failed() { + let (p1, _p2) = UnixStream::pair().unwrap(); + let fs_cache = SlaveFsCacheReq::from_stream(p1); + + assert!(fs_cache.node().error.is_none()); + fs_cache.set_failed(libc::EAGAIN); + assert_eq!(fs_cache.node().error, Some(libc::EAGAIN)); + } + + #[test] + fn test_slave_fs_cache_send_failure() { + let (p1, p2) = UnixStream::pair().unwrap(); + let fd = p2.as_raw_fd(); + let fs_cache = SlaveFsCacheReq::from_stream(p1); + + fs_cache.set_failed(libc::ECONNRESET); + fs_cache + .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd) + .unwrap_err(); + fs_cache + .fs_slave_unmap(&VhostUserFSSlaveMsg::default()) + .unwrap_err(); + fs_cache.node().error = None; + + drop(p2); + fs_cache + .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd) + .unwrap_err(); + fs_cache + .fs_slave_unmap(&VhostUserFSSlaveMsg::default()) + .unwrap_err(); + } + + #[test] + fn test_slave_fs_cache_recv_negative() { + let (p1, p2) = UnixStream::pair().unwrap(); + let fd = p2.as_raw_fd(); + let fs_cache = SlaveFsCacheReq::from_stream(p1); + let mut master = Endpoint::<SlaveReq>::from_stream(p2); + + let len = mem::size_of::<VhostUserFSSlaveMsg>(); + let mut hdr = VhostUserMsgHeader::new( + SlaveReq::FS_MAP, + VhostUserHeaderFlag::REPLY.bits(), + len as u32, + ); + let body = VhostUserU64::new(0); + + master.send_message(&hdr, &body, Some(&[fd])).unwrap(); + fs_cache + .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd) + .unwrap(); + + fs_cache.set_reply_ack_flag(true); + fs_cache + .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd) + .unwrap_err(); + + hdr.set_code(SlaveReq::FS_UNMAP); + master.send_message(&hdr, &body, None).unwrap(); + fs_cache + .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd) + .unwrap_err(); + hdr.set_code(SlaveReq::FS_MAP); + + let body = VhostUserU64::new(1); + master.send_message(&hdr, &body, None).unwrap(); + fs_cache + .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd) + .unwrap_err(); + + let body = VhostUserU64::new(0); + master.send_message(&hdr, &body, None).unwrap(); + fs_cache + .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd) + .unwrap(); + } +} diff --git a/src/vhost_user/slave_req_handler.rs b/src/vhost_user/slave_req_handler.rs new file mode 100644 index 0000000..18459a2 --- /dev/null +++ b/src/vhost_user/slave_req_handler.rs @@ -0,0 +1,828 @@ +// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use std::mem; +use std::os::unix::io::{AsRawFd, FromRawFd, RawFd}; +use std::os::unix::net::UnixStream; +use std::slice; +use std::sync::{Arc, Mutex}; + +use super::connection::Endpoint; +use super::message::*; +use super::slave_fs_cache::SlaveFsCacheReq; +use super::{Error, Result}; + +/// Services provided to the master by the slave with interior mutability. +/// +/// The [VhostUserSlaveReqHandler] trait defines the services provided to the master by the slave. +/// And the [VhostUserSlaveReqHandlerMut] trait is a helper mirroring [VhostUserSlaveReqHandler], +/// but without interior mutability. +/// The vhost-user specification defines a master communication channel, by which masters could +/// request services from slaves. The [VhostUserSlaveReqHandler] trait defines services provided by +/// slaves, and it's used both on the master side and slave side. +/// +/// - on the master side, a stub forwarder implementing [VhostUserSlaveReqHandler] will proxy +/// service requests to slaves. +/// - on the slave side, the [SlaveReqHandler] will forward service requests to a handler +/// implementing [VhostUserSlaveReqHandler]. +/// +/// The [VhostUserSlaveReqHandler] trait is design with interior mutability to improve performance +/// for multi-threading. +/// +/// [VhostUserSlaveReqHandler]: trait.VhostUserSlaveReqHandler.html +/// [VhostUserSlaveReqHandlerMut]: trait.VhostUserSlaveReqHandlerMut.html +/// [SlaveReqHandler]: struct.SlaveReqHandler.html +#[allow(missing_docs)] +pub trait VhostUserSlaveReqHandler { + fn set_owner(&self) -> Result<()>; + fn reset_owner(&self) -> Result<()>; + fn get_features(&self) -> Result<u64>; + fn set_features(&self, features: u64) -> Result<()>; + fn set_mem_table(&self, ctx: &[VhostUserMemoryRegion], fds: &[RawFd]) -> Result<()>; + fn set_vring_num(&self, index: u32, num: u32) -> Result<()>; + fn set_vring_addr( + &self, + index: u32, + flags: VhostUserVringAddrFlags, + descriptor: u64, + used: u64, + available: u64, + log: u64, + ) -> Result<()>; + fn set_vring_base(&self, index: u32, base: u32) -> Result<()>; + fn get_vring_base(&self, index: u32) -> Result<VhostUserVringState>; + fn set_vring_kick(&self, index: u8, fd: Option<RawFd>) -> Result<()>; + fn set_vring_call(&self, index: u8, fd: Option<RawFd>) -> Result<()>; + fn set_vring_err(&self, index: u8, fd: Option<RawFd>) -> Result<()>; + + fn get_protocol_features(&self) -> Result<VhostUserProtocolFeatures>; + fn set_protocol_features(&self, features: u64) -> Result<()>; + fn get_queue_num(&self) -> Result<u64>; + fn set_vring_enable(&self, index: u32, enable: bool) -> Result<()>; + fn get_config(&self, offset: u32, size: u32, flags: VhostUserConfigFlags) -> Result<Vec<u8>>; + fn set_config(&self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>; + fn set_slave_req_fd(&self, _vu_req: SlaveFsCacheReq) {} + fn get_max_mem_slots(&self) -> Result<u64>; + fn add_mem_region(&self, region: &VhostUserSingleMemoryRegion, fd: RawFd) -> Result<()>; + fn remove_mem_region(&self, region: &VhostUserSingleMemoryRegion) -> Result<()>; +} + +/// Services provided to the master by the slave without interior mutability. +/// +/// This is a helper trait mirroring the [VhostUserSlaveReqHandler] trait. +#[allow(missing_docs)] +pub trait VhostUserSlaveReqHandlerMut { + fn set_owner(&mut self) -> Result<()>; + fn reset_owner(&mut self) -> Result<()>; + fn get_features(&mut self) -> Result<u64>; + fn set_features(&mut self, features: u64) -> Result<()>; + fn set_mem_table(&mut self, ctx: &[VhostUserMemoryRegion], fds: &[RawFd]) -> Result<()>; + fn set_vring_num(&mut self, index: u32, num: u32) -> Result<()>; + fn set_vring_addr( + &mut self, + index: u32, + flags: VhostUserVringAddrFlags, + descriptor: u64, + used: u64, + available: u64, + log: u64, + ) -> Result<()>; + fn set_vring_base(&mut self, index: u32, base: u32) -> Result<()>; + fn get_vring_base(&mut self, index: u32) -> Result<VhostUserVringState>; + fn set_vring_kick(&mut self, index: u8, fd: Option<RawFd>) -> Result<()>; + fn set_vring_call(&mut self, index: u8, fd: Option<RawFd>) -> Result<()>; + fn set_vring_err(&mut self, index: u8, fd: Option<RawFd>) -> Result<()>; + + fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures>; + fn set_protocol_features(&mut self, features: u64) -> Result<()>; + fn get_queue_num(&mut self) -> Result<u64>; + fn set_vring_enable(&mut self, index: u32, enable: bool) -> Result<()>; + fn get_config( + &mut self, + offset: u32, + size: u32, + flags: VhostUserConfigFlags, + ) -> Result<Vec<u8>>; + fn set_config(&mut self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>; + fn set_slave_req_fd(&mut self, _vu_req: SlaveFsCacheReq) {} + fn get_max_mem_slots(&mut self) -> Result<u64>; + fn add_mem_region(&mut self, region: &VhostUserSingleMemoryRegion, fd: RawFd) -> Result<()>; + fn remove_mem_region(&mut self, region: &VhostUserSingleMemoryRegion) -> Result<()>; +} + +impl<T: VhostUserSlaveReqHandlerMut> VhostUserSlaveReqHandler for Mutex<T> { + fn set_owner(&self) -> Result<()> { + self.lock().unwrap().set_owner() + } + + fn reset_owner(&self) -> Result<()> { + self.lock().unwrap().reset_owner() + } + + fn get_features(&self) -> Result<u64> { + self.lock().unwrap().get_features() + } + + fn set_features(&self, features: u64) -> Result<()> { + self.lock().unwrap().set_features(features) + } + + fn set_mem_table(&self, ctx: &[VhostUserMemoryRegion], fds: &[RawFd]) -> Result<()> { + self.lock().unwrap().set_mem_table(ctx, fds) + } + + fn set_vring_num(&self, index: u32, num: u32) -> Result<()> { + self.lock().unwrap().set_vring_num(index, num) + } + + fn set_vring_addr( + &self, + index: u32, + flags: VhostUserVringAddrFlags, + descriptor: u64, + used: u64, + available: u64, + log: u64, + ) -> Result<()> { + self.lock() + .unwrap() + .set_vring_addr(index, flags, descriptor, used, available, log) + } + + fn set_vring_base(&self, index: u32, base: u32) -> Result<()> { + self.lock().unwrap().set_vring_base(index, base) + } + + fn get_vring_base(&self, index: u32) -> Result<VhostUserVringState> { + self.lock().unwrap().get_vring_base(index) + } + + fn set_vring_kick(&self, index: u8, fd: Option<RawFd>) -> Result<()> { + self.lock().unwrap().set_vring_kick(index, fd) + } + + fn set_vring_call(&self, index: u8, fd: Option<RawFd>) -> Result<()> { + self.lock().unwrap().set_vring_call(index, fd) + } + + fn set_vring_err(&self, index: u8, fd: Option<RawFd>) -> Result<()> { + self.lock().unwrap().set_vring_err(index, fd) + } + + fn get_protocol_features(&self) -> Result<VhostUserProtocolFeatures> { + self.lock().unwrap().get_protocol_features() + } + + fn set_protocol_features(&self, features: u64) -> Result<()> { + self.lock().unwrap().set_protocol_features(features) + } + + fn get_queue_num(&self) -> Result<u64> { + self.lock().unwrap().get_queue_num() + } + + fn set_vring_enable(&self, index: u32, enable: bool) -> Result<()> { + self.lock().unwrap().set_vring_enable(index, enable) + } + + fn get_config(&self, offset: u32, size: u32, flags: VhostUserConfigFlags) -> Result<Vec<u8>> { + self.lock().unwrap().get_config(offset, size, flags) + } + + fn set_config(&self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()> { + self.lock().unwrap().set_config(offset, buf, flags) + } + + fn set_slave_req_fd(&self, vu_req: SlaveFsCacheReq) { + self.lock().unwrap().set_slave_req_fd(vu_req) + } + + fn get_max_mem_slots(&self) -> Result<u64> { + self.lock().unwrap().get_max_mem_slots() + } + + fn add_mem_region(&self, region: &VhostUserSingleMemoryRegion, fd: RawFd) -> Result<()> { + self.lock().unwrap().add_mem_region(region, fd) + } + + fn remove_mem_region(&self, region: &VhostUserSingleMemoryRegion) -> Result<()> { + self.lock().unwrap().remove_mem_region(region) + } +} + +/// Server to handle service requests from masters from the master communication channel. +/// +/// The [SlaveReqHandler] acts as a server on the slave side, to handle service requests from +/// masters on the master communication channel. It's actually a proxy invoking the registered +/// handler implementing [VhostUserSlaveReqHandler] to do the real work. +/// +/// The lifetime of the SlaveReqHandler object should be the same as the underline Unix Domain +/// Socket, so it gets simpler to recover from disconnect. +/// +/// [VhostUserSlaveReqHandler]: trait.VhostUserSlaveReqHandler.html +/// [SlaveReqHandler]: struct.SlaveReqHandler.html +pub struct SlaveReqHandler<S: VhostUserSlaveReqHandler> { + // underlying Unix domain socket for communication + main_sock: Endpoint<MasterReq>, + // the vhost-user backend device object + backend: Arc<S>, + + virtio_features: u64, + acked_virtio_features: u64, + protocol_features: VhostUserProtocolFeatures, + acked_protocol_features: u64, + + // sending ack for messages without payload + reply_ack_enabled: bool, + // whether the endpoint has encountered any failure + error: Option<i32>, +} + +impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { + /// Create a vhost-user slave endpoint. + pub(super) fn new(main_sock: Endpoint<MasterReq>, backend: Arc<S>) -> Self { + SlaveReqHandler { + main_sock, + backend, + virtio_features: 0, + acked_virtio_features: 0, + protocol_features: VhostUserProtocolFeatures::empty(), + acked_protocol_features: 0, + reply_ack_enabled: false, + error: None, + } + } + + /// Create a new vhost-user slave endpoint. + /// + /// # Arguments + /// * - `path` - path of Unix domain socket listener to connect to + /// * - `backend` - handler for requests from the master to the slave + pub fn connect(path: &str, backend: Arc<S>) -> Result<Self> { + Ok(Self::new(Endpoint::<MasterReq>::connect(path)?, backend)) + } + + /// Mark endpoint as failed with specified error code. + pub fn set_failed(&mut self, error: i32) { + self.error = Some(error); + } + + /// Main entrance to server slave request from the slave communication channel. + /// + /// Receive and handle one incoming request message from the master. The caller needs to: + /// - serialize calls to this function + /// - decide what to do when error happens + /// - optional recover from failure + pub fn handle_request(&mut self) -> Result<()> { + // Return error if the endpoint is already in failed state. + self.check_state()?; + + // The underlying communication channel is a Unix domain socket in + // stream mode, and recvmsg() is a little tricky here. To successfully + // receive attached file descriptors, we need to receive messages and + // corresponding attached file descriptors in this way: + // . recv messsage header and optional attached file + // . validate message header + // . recv optional message body and payload according size field in + // message header + // . validate message body and optional payload + let (hdr, rfds) = self.main_sock.recv_header()?; + let rfds = self.check_attached_rfds(&hdr, rfds)?; + let (size, buf) = match hdr.get_size() { + 0 => (0, vec![0u8; 0]), + len => { + let (size2, rbuf) = self.main_sock.recv_data(len as usize)?; + if size2 != len as usize { + return Err(Error::InvalidMessage); + } + (size2, rbuf) + } + }; + + match hdr.get_code() { + MasterReq::SET_OWNER => { + self.check_request_size(&hdr, size, 0)?; + self.backend.set_owner()?; + } + MasterReq::RESET_OWNER => { + self.check_request_size(&hdr, size, 0)?; + self.backend.reset_owner()?; + } + MasterReq::GET_FEATURES => { + self.check_request_size(&hdr, size, 0)?; + let features = self.backend.get_features()?; + let msg = VhostUserU64::new(features); + self.send_reply_message(&hdr, &msg)?; + self.virtio_features = features; + self.update_reply_ack_flag(); + } + MasterReq::SET_FEATURES => { + let msg = self.extract_request_body::<VhostUserU64>(&hdr, size, &buf)?; + self.backend.set_features(msg.value)?; + self.acked_virtio_features = msg.value; + self.update_reply_ack_flag(); + } + MasterReq::SET_MEM_TABLE => { + let res = self.set_mem_table(&hdr, size, &buf, rfds); + self.send_ack_message(&hdr, res)?; + } + MasterReq::SET_VRING_NUM => { + let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?; + let res = self.backend.set_vring_num(msg.index, msg.num); + self.send_ack_message(&hdr, res)?; + } + MasterReq::SET_VRING_ADDR => { + let msg = self.extract_request_body::<VhostUserVringAddr>(&hdr, size, &buf)?; + let flags = match VhostUserVringAddrFlags::from_bits(msg.flags) { + Some(val) => val, + None => return Err(Error::InvalidMessage), + }; + let res = self.backend.set_vring_addr( + msg.index, + flags, + msg.descriptor, + msg.used, + msg.available, + msg.log, + ); + self.send_ack_message(&hdr, res)?; + } + MasterReq::SET_VRING_BASE => { + let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?; + let res = self.backend.set_vring_base(msg.index, msg.num); + self.send_ack_message(&hdr, res)?; + } + MasterReq::GET_VRING_BASE => { + let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?; + let reply = self.backend.get_vring_base(msg.index)?; + self.send_reply_message(&hdr, &reply)?; + } + MasterReq::SET_VRING_CALL => { + self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?; + let (index, rfds) = self.handle_vring_fd_request(&buf, rfds)?; + let res = self.backend.set_vring_call(index, rfds); + self.send_ack_message(&hdr, res)?; + } + MasterReq::SET_VRING_KICK => { + self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?; + let (index, rfds) = self.handle_vring_fd_request(&buf, rfds)?; + let res = self.backend.set_vring_kick(index, rfds); + self.send_ack_message(&hdr, res)?; + } + MasterReq::SET_VRING_ERR => { + self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?; + let (index, rfds) = self.handle_vring_fd_request(&buf, rfds)?; + let res = self.backend.set_vring_err(index, rfds); + self.send_ack_message(&hdr, res)?; + } + MasterReq::GET_PROTOCOL_FEATURES => { + self.check_request_size(&hdr, size, 0)?; + let features = self.backend.get_protocol_features()?; + let msg = VhostUserU64::new(features.bits()); + self.send_reply_message(&hdr, &msg)?; + self.protocol_features = features; + self.update_reply_ack_flag(); + } + MasterReq::SET_PROTOCOL_FEATURES => { + let msg = self.extract_request_body::<VhostUserU64>(&hdr, size, &buf)?; + self.backend.set_protocol_features(msg.value)?; + self.acked_protocol_features = msg.value; + self.update_reply_ack_flag(); + } + MasterReq::GET_QUEUE_NUM => { + if self.acked_protocol_features & VhostUserProtocolFeatures::MQ.bits() == 0 { + return Err(Error::InvalidOperation); + } + self.check_request_size(&hdr, size, 0)?; + let num = self.backend.get_queue_num()?; + let msg = VhostUserU64::new(num); + self.send_reply_message(&hdr, &msg)?; + } + MasterReq::SET_VRING_ENABLE => { + let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?; + if self.acked_protocol_features & VhostUserProtocolFeatures::MQ.bits() == 0 + && msg.index > 0 + { + return Err(Error::InvalidOperation); + } + let enable = match msg.num { + 1 => true, + 0 => false, + _ => return Err(Error::InvalidParam), + }; + + let res = self.backend.set_vring_enable(msg.index, enable); + self.send_ack_message(&hdr, res)?; + } + MasterReq::GET_CONFIG => { + if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 { + return Err(Error::InvalidOperation); + } + self.check_request_size(&hdr, size, hdr.get_size() as usize)?; + self.get_config(&hdr, &buf)?; + } + MasterReq::SET_CONFIG => { + if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 { + return Err(Error::InvalidOperation); + } + self.check_request_size(&hdr, size, hdr.get_size() as usize)?; + self.set_config(&hdr, size, &buf)?; + } + MasterReq::SET_SLAVE_REQ_FD => { + if self.acked_protocol_features & VhostUserProtocolFeatures::SLAVE_REQ.bits() == 0 { + return Err(Error::InvalidOperation); + } + self.check_request_size(&hdr, size, hdr.get_size() as usize)?; + self.set_slave_req_fd(&hdr, rfds)?; + } + MasterReq::GET_MAX_MEM_SLOTS => { + if self.acked_protocol_features + & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits() + == 0 + { + return Err(Error::InvalidOperation); + } + self.check_request_size(&hdr, size, 0)?; + let num = self.backend.get_max_mem_slots()?; + let msg = VhostUserU64::new(num); + self.send_reply_message(&hdr, &msg)?; + } + MasterReq::ADD_MEM_REG => { + if self.acked_protocol_features + & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits() + == 0 + { + return Err(Error::InvalidOperation); + } + let fd = if let Some(fds) = &rfds { + if fds.len() != 1 { + return Err(Error::InvalidParam); + } + fds[0] + } else { + return Err(Error::InvalidParam); + }; + + let msg = + self.extract_request_body::<VhostUserSingleMemoryRegion>(&hdr, size, &buf)?; + let res = self.backend.add_mem_region(&msg, fd); + self.send_ack_message(&hdr, res)?; + } + MasterReq::REM_MEM_REG => { + if self.acked_protocol_features + & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits() + == 0 + { + return Err(Error::InvalidOperation); + } + + let msg = + self.extract_request_body::<VhostUserSingleMemoryRegion>(&hdr, size, &buf)?; + let res = self.backend.remove_mem_region(&msg); + self.send_ack_message(&hdr, res)?; + } + _ => { + return Err(Error::InvalidMessage); + } + } + Ok(()) + } + + fn set_mem_table( + &mut self, + hdr: &VhostUserMsgHeader<MasterReq>, + size: usize, + buf: &[u8], + rfds: Option<Vec<RawFd>>, + ) -> Result<()> { + self.check_request_size(&hdr, size, hdr.get_size() as usize)?; + + // check message size is consistent + let hdrsize = mem::size_of::<VhostUserMemory>(); + if size < hdrsize { + Endpoint::<MasterReq>::close_rfds(rfds); + return Err(Error::InvalidMessage); + } + let msg = unsafe { &*(buf.as_ptr() as *const VhostUserMemory) }; + if !msg.is_valid() { + Endpoint::<MasterReq>::close_rfds(rfds); + return Err(Error::InvalidMessage); + } + if size != hdrsize + msg.num_regions as usize * mem::size_of::<VhostUserMemoryRegion>() { + Endpoint::<MasterReq>::close_rfds(rfds); + return Err(Error::InvalidMessage); + } + + // validate number of fds matching number of memory regions + let fds = match rfds { + None => return Err(Error::InvalidMessage), + Some(fds) => { + if fds.len() != msg.num_regions as usize { + Endpoint::<MasterReq>::close_rfds(Some(fds)); + return Err(Error::InvalidMessage); + } + fds + } + }; + + // Validate memory regions + let regions = unsafe { + slice::from_raw_parts( + buf.as_ptr().add(hdrsize) as *const VhostUserMemoryRegion, + msg.num_regions as usize, + ) + }; + for region in regions.iter() { + if !region.is_valid() { + Endpoint::<MasterReq>::close_rfds(Some(fds)); + return Err(Error::InvalidMessage); + } + } + + self.backend.set_mem_table(®ions, &fds) + } + + fn get_config(&mut self, hdr: &VhostUserMsgHeader<MasterReq>, buf: &[u8]) -> Result<()> { + let payload_offset = mem::size_of::<VhostUserConfig>(); + if buf.len() > MAX_MSG_SIZE || buf.len() < payload_offset { + return Err(Error::InvalidMessage); + } + let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const VhostUserConfig) }; + if !msg.is_valid() { + return Err(Error::InvalidMessage); + } + if buf.len() - payload_offset != msg.size as usize { + return Err(Error::InvalidMessage); + } + let flags = match VhostUserConfigFlags::from_bits(msg.flags) { + Some(val) => val, + None => return Err(Error::InvalidMessage), + }; + let res = self.backend.get_config(msg.offset, msg.size, flags); + + // vhost-user slave's payload size MUST match master's request + // on success, uses zero length of payload to indicate an error + // to vhost-user master. + match res { + Ok(ref buf) if buf.len() == msg.size as usize => { + let reply = VhostUserConfig::new(msg.offset, buf.len() as u32, flags); + self.send_reply_with_payload(&hdr, &reply, buf.as_slice())?; + } + Ok(_) => { + let reply = VhostUserConfig::new(msg.offset, 0, flags); + self.send_reply_message(&hdr, &reply)?; + } + Err(_) => { + let reply = VhostUserConfig::new(msg.offset, 0, flags); + self.send_reply_message(&hdr, &reply)?; + } + } + Ok(()) + } + + fn set_config( + &mut self, + hdr: &VhostUserMsgHeader<MasterReq>, + size: usize, + buf: &[u8], + ) -> Result<()> { + if size > MAX_MSG_SIZE || size < mem::size_of::<VhostUserConfig>() { + return Err(Error::InvalidMessage); + } + let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const VhostUserConfig) }; + if !msg.is_valid() { + return Err(Error::InvalidMessage); + } + if size - mem::size_of::<VhostUserConfig>() != msg.size as usize { + return Err(Error::InvalidMessage); + } + let flags: VhostUserConfigFlags; + match VhostUserConfigFlags::from_bits(msg.flags) { + Some(val) => flags = val, + None => return Err(Error::InvalidMessage), + } + + let res = self.backend.set_config(msg.offset, buf, flags); + self.send_ack_message(&hdr, res)?; + Ok(()) + } + + fn set_slave_req_fd( + &mut self, + hdr: &VhostUserMsgHeader<MasterReq>, + rfds: Option<Vec<RawFd>>, + ) -> Result<()> { + if let Some(fds) = rfds { + if fds.len() == 1 { + let sock = unsafe { UnixStream::from_raw_fd(fds[0]) }; + let vu_req = SlaveFsCacheReq::from_stream(sock); + self.backend.set_slave_req_fd(vu_req); + self.send_ack_message(&hdr, Ok(())) + } else { + Err(Error::InvalidMessage) + } + } else { + Err(Error::InvalidMessage) + } + } + + fn handle_vring_fd_request( + &mut self, + buf: &[u8], + rfds: Option<Vec<RawFd>>, + ) -> Result<(u8, Option<RawFd>)> { + if buf.len() > MAX_MSG_SIZE || buf.len() < mem::size_of::<VhostUserU64>() { + return Err(Error::InvalidMessage); + } + let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const VhostUserU64) }; + if !msg.is_valid() { + return Err(Error::InvalidMessage); + } + + // Bits (0-7) of the payload contain the vring index. Bit 8 is the + // invalid FD flag. This flag is set when there is no file descriptor + // in the ancillary data. This signals that polling will be used + // instead of waiting for the call. + let nofd = (msg.value & 0x100u64) == 0x100u64; + + let mut rfd = None; + match rfds { + Some(fds) => { + if !nofd && fds.len() == 1 { + rfd = Some(fds[0]); + } else if (nofd && !fds.is_empty()) || (!nofd && fds.len() != 1) { + Endpoint::<MasterReq>::close_rfds(Some(fds)); + return Err(Error::InvalidMessage); + } + } + None => { + if !nofd { + return Err(Error::InvalidMessage); + } + } + } + Ok((msg.value as u8, rfd)) + } + + fn check_state(&self) -> Result<()> { + match self.error { + Some(e) => Err(Error::SocketBroken(std::io::Error::from_raw_os_error(e))), + None => Ok(()), + } + } + + fn check_request_size( + &self, + hdr: &VhostUserMsgHeader<MasterReq>, + size: usize, + expected: usize, + ) -> Result<()> { + if hdr.get_size() as usize != expected + || hdr.is_reply() + || hdr.get_version() != 0x1 + || size != expected + { + return Err(Error::InvalidMessage); + } + Ok(()) + } + + fn check_attached_rfds( + &self, + hdr: &VhostUserMsgHeader<MasterReq>, + rfds: Option<Vec<RawFd>>, + ) -> Result<Option<Vec<RawFd>>> { + match hdr.get_code() { + MasterReq::SET_MEM_TABLE => Ok(rfds), + MasterReq::SET_VRING_CALL => Ok(rfds), + MasterReq::SET_VRING_KICK => Ok(rfds), + MasterReq::SET_VRING_ERR => Ok(rfds), + MasterReq::SET_LOG_BASE => Ok(rfds), + MasterReq::SET_LOG_FD => Ok(rfds), + MasterReq::SET_SLAVE_REQ_FD => Ok(rfds), + MasterReq::SET_INFLIGHT_FD => Ok(rfds), + MasterReq::ADD_MEM_REG => Ok(rfds), + _ => { + if rfds.is_some() { + Endpoint::<MasterReq>::close_rfds(rfds); + Err(Error::InvalidMessage) + } else { + Ok(rfds) + } + } + } + } + + fn extract_request_body<T: Sized + VhostUserMsgValidator>( + &self, + hdr: &VhostUserMsgHeader<MasterReq>, + size: usize, + buf: &[u8], + ) -> Result<T> { + self.check_request_size(hdr, size, mem::size_of::<T>())?; + let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const T) }; + if !msg.is_valid() { + return Err(Error::InvalidMessage); + } + Ok(msg) + } + + fn update_reply_ack_flag(&mut self) { + let vflag = VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits(); + let pflag = VhostUserProtocolFeatures::REPLY_ACK; + if (self.virtio_features & vflag) != 0 + && (self.acked_virtio_features & vflag) != 0 + && self.protocol_features.contains(pflag) + && (self.acked_protocol_features & pflag.bits()) != 0 + { + self.reply_ack_enabled = true; + } else { + self.reply_ack_enabled = false; + } + } + + fn new_reply_header<T: Sized>( + &self, + req: &VhostUserMsgHeader<MasterReq>, + payload_size: usize, + ) -> Result<VhostUserMsgHeader<MasterReq>> { + if mem::size_of::<T>() > MAX_MSG_SIZE + || payload_size > MAX_MSG_SIZE + || mem::size_of::<T>() + payload_size > MAX_MSG_SIZE + { + return Err(Error::InvalidParam); + } + self.check_state()?; + Ok(VhostUserMsgHeader::new( + req.get_code(), + VhostUserHeaderFlag::REPLY.bits(), + (mem::size_of::<T>() + payload_size) as u32, + )) + } + + fn send_ack_message( + &mut self, + req: &VhostUserMsgHeader<MasterReq>, + res: Result<()>, + ) -> Result<()> { + if self.reply_ack_enabled && req.is_need_reply() { + let hdr = self.new_reply_header::<VhostUserU64>(req, 0)?; + let val = match res { + Ok(_) => 0, + Err(_) => 1, + }; + let msg = VhostUserU64::new(val); + self.main_sock.send_message(&hdr, &msg, None)?; + } + Ok(()) + } + + fn send_reply_message<T>( + &mut self, + req: &VhostUserMsgHeader<MasterReq>, + msg: &T, + ) -> Result<()> { + let hdr = self.new_reply_header::<T>(req, 0)?; + self.main_sock.send_message(&hdr, msg, None)?; + Ok(()) + } + + fn send_reply_with_payload<T: Sized>( + &mut self, + req: &VhostUserMsgHeader<MasterReq>, + msg: &T, + payload: &[u8], + ) -> Result<()> { + let hdr = self.new_reply_header::<T>(req, payload.len())?; + self.main_sock + .send_message_with_payload(&hdr, msg, payload, None)?; + Ok(()) + } +} + +impl<S: VhostUserSlaveReqHandler> AsRawFd for SlaveReqHandler<S> { + fn as_raw_fd(&self) -> RawFd { + self.main_sock.as_raw_fd() + } +} + +#[cfg(test)] +mod tests { + use std::os::unix::io::AsRawFd; + + use super::*; + use crate::vhost_user::dummy_slave::DummySlaveReqHandler; + + #[test] + fn test_slave_req_handler_new() { + let (p1, _p2) = UnixStream::pair().unwrap(); + let endpoint = Endpoint::<MasterReq>::from_stream(p1); + let backend = Arc::new(Mutex::new(DummySlaveReqHandler::new())); + let mut handler = SlaveReqHandler::new(endpoint, backend); + + handler.check_state().unwrap(); + handler.set_failed(libc::EAGAIN); + handler.check_state().unwrap_err(); + assert!(handler.as_raw_fd() >= 0); + } +} diff --git a/src/vsock.rs b/src/vsock.rs new file mode 100644 index 0000000..1e1b0b9 --- /dev/null +++ b/src/vsock.rs @@ -0,0 +1,30 @@ +// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 or BSD-3-Clause +// +// Portions Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// +// Portions Copyright 2017 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE-BSD-Google file. + +//! Trait to control vhost-vsock backend drivers. + +use crate::backend::VhostBackend; +use crate::Result; + +/// Trait to control vhost-vsock backend drivers. +pub trait VhostVsock: VhostBackend { + /// Set the CID for the guest. + /// This number is used for routing all data destined for running in the guest. + /// Each guest on a hypervisor must have an unique CID. + /// + /// # Arguments + /// * `cid` - CID to assign to the guest + fn set_guest_cid(&self, cid: u64) -> Result<()>; + + /// Tell the VHOST driver to start performing data transfer. + fn start(&self) -> Result<()>; + + /// Tell the VHOST driver to stop performing data transfer. + fn stop(&self) -> Result<()>; +} |