summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCommit Bot <commit-bot@chromium.org>2021-03-05 02:30:36 +0000
committerCommit Bot <commit-bot@chromium.org>2021-03-05 02:30:36 +0000
commit0cf1160edd4f0617a17609f64384e5f8cce9c0f1 (patch)
tree1aaafea87f312ab78a9e7eaa18ca2c8af04d41be
parentc3e06b6aa3c76850b292189bf880b80d4959833e (diff)
parent73a86a346836d0329d2862387bc8a81be8a338ca (diff)
downloadvmm_vhost-0cf1160edd4f0617a17609f64384e5f8cce9c0f1.tar.gz
Merge "Merge remote-tracking branch 'cros/upstream/master' into HEAD" into main
-rw-r--r--.gitignore3
-rw-r--r--CODEOWNERS2
-rw-r--r--Cargo.toml13
-rw-r--r--README.md12
-rw-r--r--coverage_config_x86_64.json2
-rw-r--r--docs/vhost_architecture.drawio171
-rw-r--r--docs/vhost_architecture.pngbin0 -> 146074 bytes
m---------rust-vmm-ci0
-rw-r--r--src/backend.rs330
-rw-r--r--src/lib.rs70
-rw-r--r--src/vhost_kern/mod.rs80
-rw-r--r--src/vhost_kern/vhost_binding.rs1
-rw-r--r--src/vhost_kern/vsock.rs141
-rw-r--r--src/vhost_user/connection.rs56
-rw-r--r--src/vhost_user/dummy_slave.rs58
-rw-r--r--src/vhost_user/master.rs352
-rw-r--r--src/vhost_user/master_req_handler.rs285
-rw-r--r--src/vhost_user/message.rs162
-rw-r--r--src/vhost_user/mod.rs200
-rw-r--r--src/vhost_user/slave.rs46
-rw-r--r--src/vhost_user/slave_fs_cache.rs210
-rw-r--r--src/vhost_user/slave_req_handler.rs303
-rw-r--r--src/vhost_user/sock_ctrl_msg.rs499
-rw-r--r--src/vsock.rs6
24 files changed, 2104 insertions, 898 deletions
diff --git a/.gitignore b/.gitignore
index 6936990..f738aa8 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,3 +1,6 @@
+/build
+/kcov_build
/target
+.idea
**/*.rs.bk
Cargo.lock
diff --git a/CODEOWNERS b/CODEOWNERS
index 4d96c3f..7174a1b 100644
--- a/CODEOWNERS
+++ b/CODEOWNERS
@@ -1,2 +1,2 @@
# Add the list of code owners here (using their GitHub username)
-* gatekeeper-PullAssigner
+* gatekeeper-PullAssigner @jiangliu @eryugey @sboeuf @slp
diff --git a/Cargo.toml b/Cargo.toml
index 8c676b3..0cd15f7 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -1,16 +1,22 @@
[package]
name = "vhost"
version = "0.1.0"
+keywords = ["vhost", "vhost-user", "virtio", "vdpa"]
+description = "a pure rust library for vdpa, vhost and vhost-user"
authors = ["Liu Jiang <gerry@linux.alibaba.com>"]
repository = "https://github.com/rust-vmm/vhost"
+documentation = "https://docs.rs/vhost"
+readme = "README.md"
license = "Apache-2.0 or BSD-3-Clause"
+edition = "2018"
[features]
default = []
vhost-vsock = []
vhost-kern = ["vm-memory"]
-vhost-user-master = []
-vhost-user-slave = []
+vhost-user = []
+vhost-user-master = ["vhost-user"]
+vhost-user-slave = ["vhost-user"]
[dependencies]
bitflags = ">=1.0.1"
@@ -18,3 +24,6 @@ libc = ">=0.2.39"
vmm-sys-util = ">=0.3.1"
vm-memory = { version = "0.2.0", optional = true }
+
+[dev-dependencies]
+vm-memory = { version = "0.2.0", features=["backend-mmap"] }
diff --git a/README.md b/README.md
index c1c2ab6..b0f4dfa 100644
--- a/README.md
+++ b/README.md
@@ -1,6 +1,14 @@
# vHost
-A crate to support vhost backend drivers for virtio devices.
+A pure rust library for vDPA, vhost and vhost-user.
+The `vhost` crate aims to help implementing dataplane for virtio backend drivers. It supports three different types of dataplane drivers:
+- vhost: the dataplane is implemented by linux kernel
+- vhost-user: the dataplane is implemented by dedicated vhost-user servers
+- vDPA(vhost DataPath Accelerator): the dataplane is implemented by hardwares
+
+The main relationship among Traits and Structs exported by the `vhost` crate is as below:
+
+![vhost Architecture](/docs/vhost_architecture.png)
## Kernel-based vHost Backend Drivers
The vhost drivers in Linux provide in-kernel virtio device emulation. Normally
the hypervisor userspace process emulates I/O accesses from the guest.
@@ -11,7 +19,7 @@ The hypervisor relies on ioctl based interfaces to control those in-kernel
vhost drivers, such as vhost-net, vhost-scsi and vhost-vsock etc.
## vHost-user Backend Drivers
-The vhost-user protocol is aiming to implement vhost backend drivers in
+The [vhost-user protocol](https://qemu.readthedocs.io/en/latest/interop/vhost-user.html#communication) aims to implement vhost backend drivers in
userspace, which complements the ioctl interface used to control the vhost
implementation in the Linux kernel. It implements the control plane needed
to establish virtqueue sharing with a user space process on the same host.
diff --git a/coverage_config_x86_64.json b/coverage_config_x86_64.json
index ec91006..a4ed64f 100644
--- a/coverage_config_x86_64.json
+++ b/coverage_config_x86_64.json
@@ -1 +1 @@
-{"coverage_score": 40.2, "exclude_path": "", "crate_features": "vhost-vsock,vhost-kern,vhost-user-master,vhost-user-slave"}
+{"coverage_score": 81.3, "exclude_path": "src/vhost_kern/", "crate_features": "vhost-user-master,vhost-user-slave"} \ No newline at end of file
diff --git a/docs/vhost_architecture.drawio b/docs/vhost_architecture.drawio
new file mode 100644
index 0000000..5008d28
--- /dev/null
+++ b/docs/vhost_architecture.drawio
@@ -0,0 +1,171 @@
+<mxfile host="65bd71144e" modified="2021-02-22T05:37:26.833Z" agent="5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Code/1.53.0 Chrome/87.0.4280.141 Electron/11.2.1 Safari/537.36" etag="HWRXqybJYJqQhnlJWfmB" version="14.2.4" type="embed">
+ <diagram id="xCgrIAQPDQM0eynUYBOE" name="Page-1">
+ <mxGraphModel dx="3446" dy="1284" grid="1" gridSize="10" guides="1" tooltips="1" connect="1" arrows="1" fold="1" page="1" pageScale="1" pageWidth="850" pageHeight="1100" math="0" shadow="0">
+ <root>
+ <mxCell id="0"/>
+ <mxCell id="1" parent="0"/>
+ <mxCell id="46" value="&lt;br&gt;&lt;br&gt;&lt;br&gt;&lt;br&gt;&lt;br&gt;&lt;br&gt;&lt;br&gt;&lt;br&gt;&lt;br&gt;&lt;br&gt;&lt;br&gt;&lt;br&gt;&lt;br&gt;&lt;br&gt;&lt;br&gt;&lt;br&gt;" style="rounded=0;whiteSpace=wrap;html=1;labelBackgroundColor=none;sketch=0;fontSize=25;fontColor=#FF00FF;fillColor=none;strokeColor=#4D4D4D;strokeWidth=5;" vertex="1" parent="1">
+ <mxGeometry x="1620" y="27" width="450" height="990" as="geometry"/>
+ </mxCell>
+ <mxCell id="47" value="" style="shape=hexagon;perimeter=hexagonPerimeter2;whiteSpace=wrap;html=1;fixedSize=1;rounded=0;labelBackgroundColor=none;sketch=0;fillColor=none;fontSize=25;dashed=1;strokeWidth=6;fontColor=#FF00FF;strokeColor=#FF00FF;" vertex="1" parent="1">
+ <mxGeometry x="790" y="237" width="1260" height="750" as="geometry"/>
+ </mxCell>
+ <mxCell id="44" value="" style="rounded=0;whiteSpace=wrap;html=1;labelBackgroundColor=none;sketch=0;fontSize=25;fontColor=#FF00FF;fillColor=none;strokeColor=#4D4D4D;strokeWidth=5;" vertex="1" parent="1">
+ <mxGeometry x="-10" y="37" width="1250" height="670" as="geometry"/>
+ </mxCell>
+ <mxCell id="2" value="&lt;pre style=&quot;font-family: &amp;quot;jetbrains mono&amp;quot;, monospace; font-size: 16.5pt;&quot;&gt;MasterReqHandler&lt;/pre&gt;" style="rounded=0;whiteSpace=wrap;html=1;fontStyle=1;labelBackgroundColor=none;fontColor=#FF00FF;strokeColor=#FF00FF;" parent="1" vertex="1">
+ <mxGeometry x="830" y="477" width="220" height="50" as="geometry"/>
+ </mxCell>
+ <mxCell id="4" value="&lt;pre style=&quot;font-size: 16.5pt; font-weight: 700; font-family: &amp;quot;jetbrains mono&amp;quot;, monospace;&quot;&gt;VhostUserMasterReqHandler&lt;/pre&gt;" style="rounded=1;whiteSpace=wrap;html=1;labelBackgroundColor=none;fontColor=#FF00FF;strokeColor=#FF00FF;" parent="1" vertex="1">
+ <mxGeometry x="840" y="597" width="360" height="60" as="geometry"/>
+ </mxCell>
+ <mxCell id="6" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0;exitY=0.5;exitDx=0;exitDy=0;entryX=1;entryY=0.5;entryDx=0;entryDy=0;labelBackgroundColor=none;fontColor=#FF00FF;strokeColor=#FF00FF;" edge="1" parent="1" source="5" target="2">
+ <mxGeometry relative="1" as="geometry">
+ <Array as="points">
+ <mxPoint x="1280" y="792"/>
+ <mxPoint x="1280" y="502"/>
+ </Array>
+ </mxGeometry>
+ </mxCell>
+ <mxCell id="5" value="&lt;pre style=&quot;font-family: &amp;quot;jetbrains mono&amp;quot;, monospace; font-size: 16.5pt;&quot;&gt;&lt;pre style=&quot;font-family: &amp;quot;jetbrains mono&amp;quot; , monospace ; font-size: 16.5pt&quot;&gt;SlaveFsCacheReq&lt;/pre&gt;&lt;/pre&gt;" style="rounded=0;whiteSpace=wrap;html=1;fontStyle=1;labelBackgroundColor=none;fontColor=#FF00FF;strokeColor=#FF00FF;" parent="1" vertex="1">
+ <mxGeometry x="1715" y="767" width="220" height="50" as="geometry"/>
+ </mxCell>
+ <mxCell id="7" value="&lt;pre style=&quot;font-size: 16.5pt; font-weight: 700; font-family: &amp;quot;jetbrains mono&amp;quot;, monospace;&quot;&gt;VhostUserMasterReqHandlerMut&lt;/pre&gt;" style="rounded=1;whiteSpace=wrap;html=1;labelBackgroundColor=none;fontColor=#FF00FF;strokeColor=#FF00FF;" vertex="1" parent="1">
+ <mxGeometry x="1630" y="657" width="390" height="60" as="geometry"/>
+ </mxCell>
+ <mxCell id="8" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.5;exitY=1;exitDx=0;exitDy=0;labelBackgroundColor=none;fontColor=#FF00FF;strokeColor=#FF00FF;" edge="1" parent="1" source="2" target="4">
+ <mxGeometry relative="1" as="geometry">
+ <mxPoint x="950" y="657" as="sourcePoint"/>
+ <mxPoint x="680" y="717" as="targetPoint"/>
+ </mxGeometry>
+ </mxCell>
+ <mxCell id="10" value="&lt;pre style=&quot;font-family: &amp;quot;jetbrains mono&amp;quot;, monospace; font-size: 16.5pt;&quot;&gt;&lt;pre style=&quot;font-family: &amp;quot;jetbrains mono&amp;quot; , monospace ; font-size: 16.5pt&quot;&gt;SlaveListener&lt;/pre&gt;&lt;/pre&gt;" style="rounded=0;whiteSpace=wrap;html=1;fontStyle=1;labelBackgroundColor=none;fontColor=#FF00FF;strokeColor=#FF00FF;" vertex="1" parent="1">
+ <mxGeometry x="1360" y="472" width="190" height="50" as="geometry"/>
+ </mxCell>
+ <mxCell id="11" value="&lt;pre style=&quot;font-family: &amp;quot;jetbrains mono&amp;quot;, monospace; font-size: 16.5pt;&quot;&gt;&lt;pre style=&quot;font-family: &amp;quot;jetbrains mono&amp;quot; , monospace ; font-size: 16.5pt&quot;&gt;&lt;pre style=&quot;font-family: &amp;quot;jetbrains mono&amp;quot; , monospace ; font-size: 16.5pt&quot;&gt;SlaveReqHandler&lt;/pre&gt;&lt;/pre&gt;&lt;/pre&gt;" style="rounded=0;whiteSpace=wrap;html=1;fontStyle=1;labelBackgroundColor=none;fontColor=#FF00FF;strokeColor=#FF00FF;" vertex="1" parent="1">
+ <mxGeometry x="1712" y="387" width="210" height="50" as="geometry"/>
+ </mxCell>
+ <mxCell id="14" value="&lt;pre style=&quot;font-size: 16.5pt; font-weight: 700; font-family: &amp;quot;jetbrains mono&amp;quot;, monospace;&quot;&gt;&lt;pre style=&quot;font-family: &amp;quot;jetbrains mono&amp;quot;, monospace; font-size: 16.5pt;&quot;&gt;VhostUserSlaveReqHandler&lt;/pre&gt;&lt;/pre&gt;" style="rounded=1;whiteSpace=wrap;html=1;labelBackgroundColor=none;fontColor=#FF00FF;strokeColor=#FF00FF;" vertex="1" parent="1">
+ <mxGeometry x="1652" y="537" width="330" height="60" as="geometry"/>
+ </mxCell>
+ <mxCell id="15" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.5;exitY=1;exitDx=0;exitDy=0;entryX=0.5;entryY=0;entryDx=0;entryDy=0;labelBackgroundColor=none;fontColor=#FF00FF;strokeColor=#FF00FF;" edge="1" parent="1" source="11" target="14">
+ <mxGeometry relative="1" as="geometry">
+ <mxPoint x="1202" y="567" as="sourcePoint"/>
+ <mxPoint x="1202" y="667" as="targetPoint"/>
+ </mxGeometry>
+ </mxCell>
+ <mxCell id="16" value="&lt;pre style=&quot;font-size: 16.5pt; font-weight: 700; font-family: &amp;quot;jetbrains mono&amp;quot;, monospace;&quot;&gt;VhostBackend&lt;/pre&gt;" style="rounded=1;whiteSpace=wrap;html=1;labelBackgroundColor=none;fontColor=#00994D;strokeColor=#009900;" vertex="1" parent="1">
+ <mxGeometry x="390" y="197" width="250" height="60" as="geometry"/>
+ </mxCell>
+ <mxCell id="17" value="&lt;pre style=&quot;font-family: &amp;quot;jetbrains mono&amp;quot;, monospace; font-size: 16.5pt;&quot;&gt;VhostKernBackend&lt;/pre&gt;" style="rounded=0;whiteSpace=wrap;html=1;fontStyle=1;labelBackgroundColor=none;strokeColor=#0000CC;fontColor=#0000CC;" vertex="1" parent="1">
+ <mxGeometry x="530" y="387" width="220" height="50" as="geometry"/>
+ </mxCell>
+ <mxCell id="18" value="&lt;pre style=&quot;font-family: &amp;quot;jetbrains mono&amp;quot;, monospace; font-size: 16.5pt;&quot;&gt;VhostVdpaBackend&lt;/pre&gt;" style="rounded=0;whiteSpace=wrap;html=1;fontStyle=1;labelBackgroundColor=none;fontColor=#808080;strokeColor=#808080;" vertex="1" parent="1">
+ <mxGeometry x="270" y="387" width="220" height="50" as="geometry"/>
+ </mxCell>
+ <mxCell id="19" value="&lt;pre style=&quot;font-family: &amp;quot;jetbrains mono&amp;quot; , monospace ; font-size: 16.5pt&quot;&gt;Master&lt;/pre&gt;" style="rounded=0;whiteSpace=wrap;html=1;fontStyle=1;labelBackgroundColor=none;fontColor=#FF00FF;strokeColor=#FF00FF;" vertex="1" parent="1">
+ <mxGeometry x="820" y="387" width="220" height="50" as="geometry"/>
+ </mxCell>
+ <mxCell id="20" value="&lt;pre style=&quot;font-family: &amp;quot;jetbrains mono&amp;quot; , monospace ; font-size: 16.5pt&quot;&gt;VhostSoftBackend&lt;/pre&gt;" style="rounded=0;whiteSpace=wrap;html=1;fontStyle=1;labelBackgroundColor=none;fontColor=#808080;strokeColor=#808080;" vertex="1" parent="1">
+ <mxGeometry x="10" y="387" width="220" height="50" as="geometry"/>
+ </mxCell>
+ <mxCell id="21" value="Handle virtque in VMM" style="shape=process;whiteSpace=wrap;html=1;backgroundOutline=1;rounded=0;labelBackgroundColor=none;sketch=0;fontSize=25;fontColor=#808080;strokeColor=#808080;" vertex="1" parent="1">
+ <mxGeometry x="10" y="557" width="220" height="120" as="geometry"/>
+ </mxCell>
+ <mxCell id="23" value="Handle virtque in hardware" style="shape=process;whiteSpace=wrap;html=1;backgroundOutline=1;rounded=0;labelBackgroundColor=none;sketch=0;fontSize=25;fontColor=#808080;strokeColor=#808080;" vertex="1" parent="1">
+ <mxGeometry x="270" y="807" width="220" height="120" as="geometry"/>
+ </mxCell>
+ <mxCell id="24" value="Handle virtque in kernel" style="shape=process;whiteSpace=wrap;html=1;backgroundOutline=1;rounded=0;labelBackgroundColor=none;sketch=0;fontSize=25;strokeColor=#0000CC;fontColor=#0000CC;" vertex="1" parent="1">
+ <mxGeometry x="530" y="807" width="220" height="120" as="geometry"/>
+ </mxCell>
+ <mxCell id="25" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.5;exitY=0;exitDx=0;exitDy=0;labelBackgroundColor=none;entryX=0.5;entryY=1;entryDx=0;entryDy=0;strokeColor=#0000CC;" edge="1" parent="1" source="24" target="17">
+ <mxGeometry relative="1" as="geometry">
+ <mxPoint x="930" y="647" as="sourcePoint"/>
+ <mxPoint x="930" y="747" as="targetPoint"/>
+ </mxGeometry>
+ </mxCell>
+ <mxCell id="26" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;labelBackgroundColor=none;entryX=0;entryY=0.5;entryDx=0;entryDy=0;fontColor=#FF00FF;strokeColor=#FF00FF;" edge="1" parent="1" source="19" target="11">
+ <mxGeometry relative="1" as="geometry">
+ <mxPoint x="840" y="917" as="sourcePoint"/>
+ <mxPoint x="840" y="1017" as="targetPoint"/>
+ </mxGeometry>
+ </mxCell>
+ <mxCell id="27" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.5;exitY=0;exitDx=0;exitDy=0;labelBackgroundColor=none;entryX=0.5;entryY=1;entryDx=0;entryDy=0;fontColor=#808080;strokeColor=#808080;" edge="1" parent="1" source="23" target="18">
+ <mxGeometry relative="1" as="geometry">
+ <mxPoint x="420" y="807" as="sourcePoint"/>
+ <mxPoint x="420" y="907" as="targetPoint"/>
+ </mxGeometry>
+ </mxCell>
+ <mxCell id="28" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.5;exitY=0;exitDx=0;exitDy=0;labelBackgroundColor=none;entryX=0.5;entryY=1;entryDx=0;entryDy=0;fontColor=#808080;strokeColor=#808080;" edge="1" parent="1" source="21" target="20">
+ <mxGeometry relative="1" as="geometry">
+ <mxPoint x="240" y="857" as="sourcePoint"/>
+ <mxPoint x="240" y="957" as="targetPoint"/>
+ </mxGeometry>
+ </mxCell>
+ <mxCell id="30" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.5;exitY=0;exitDx=0;exitDy=0;labelBackgroundColor=none;entryX=0.5;entryY=1;entryDx=0;entryDy=0;strokeColor=#00994D;" edge="1" parent="1" source="20" target="16">
+ <mxGeometry relative="1" as="geometry">
+ <mxPoint x="910" y="647" as="sourcePoint"/>
+ <mxPoint x="910" y="747" as="targetPoint"/>
+ </mxGeometry>
+ </mxCell>
+ <mxCell id="31" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.5;exitY=0;exitDx=0;exitDy=0;labelBackgroundColor=none;strokeColor=#00994D;entryX=0.5;entryY=1;entryDx=0;entryDy=0;" edge="1" parent="1" source="18" target="16">
+ <mxGeometry relative="1" as="geometry">
+ <mxPoint x="1000" y="177" as="sourcePoint"/>
+ <mxPoint x="530" y="227" as="targetPoint"/>
+ </mxGeometry>
+ </mxCell>
+ <mxCell id="32" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.5;exitY=0;exitDx=0;exitDy=0;labelBackgroundColor=none;entryX=0.5;entryY=1;entryDx=0;entryDy=0;strokeColor=#00994D;" edge="1" parent="1" source="17" target="16">
+ <mxGeometry relative="1" as="geometry">
+ <mxPoint x="1010" y="127" as="sourcePoint"/>
+ <mxPoint x="1505" y="-73" as="targetPoint"/>
+ </mxGeometry>
+ </mxCell>
+ <mxCell id="35" value="&lt;pre style=&quot;font-family: &amp;quot;jetbrains mono&amp;quot; , monospace ; font-size: 16.5pt&quot;&gt;&lt;pre style=&quot;font-family: &amp;quot;jetbrains mono&amp;quot; , monospace ; font-size: 16.5pt&quot;&gt;Endpoint&lt;/pre&gt;&lt;/pre&gt;" style="rounded=0;whiteSpace=wrap;html=1;fontStyle=1;labelBackgroundColor=none;fontColor=#FF00FF;strokeColor=#FF00FF;" vertex="1" parent="1">
+ <mxGeometry x="1360" y="552" width="190" height="50" as="geometry"/>
+ </mxCell>
+ <mxCell id="36" value="&lt;pre style=&quot;font-family: &amp;quot;jetbrains mono&amp;quot; , monospace ; font-size: 16.5pt&quot;&gt;&lt;pre style=&quot;font-family: &amp;quot;jetbrains mono&amp;quot; , monospace ; font-size: 16.5pt&quot;&gt;Message&lt;/pre&gt;&lt;/pre&gt;" style="rounded=0;whiteSpace=wrap;html=1;fontStyle=1;labelBackgroundColor=none;fontColor=#FF00FF;strokeColor=#FF00FF;" vertex="1" parent="1">
+ <mxGeometry x="1360" y="632" width="190" height="50" as="geometry"/>
+ </mxCell>
+ <mxCell id="37" value="&lt;pre style=&quot;font-size: 16.5pt ; font-weight: 700 ; font-family: &amp;quot;jetbrains mono&amp;quot; , monospace&quot;&gt;&lt;pre style=&quot;font-family: &amp;quot;jetbrains mono&amp;quot; , monospace ; font-size: 16.5pt&quot;&gt;VhostUserMaster&lt;/pre&gt;&lt;/pre&gt;" style="rounded=1;whiteSpace=wrap;html=1;labelBackgroundColor=none;strokeColor=#FF33FF;fontColor=#FF00FF;" vertex="1" parent="1">
+ <mxGeometry x="980" y="257" width="230" height="60" as="geometry"/>
+ </mxCell>
+ <mxCell id="38" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.5;exitY=0;exitDx=0;exitDy=0;labelBackgroundColor=none;entryX=0.5;entryY=1;entryDx=0;entryDy=0;strokeColor=#00994D;" edge="1" parent="1" source="19" target="16">
+ <mxGeometry relative="1" as="geometry">
+ <mxPoint x="1030" y="527" as="sourcePoint"/>
+ <mxPoint x="515" y="257" as="targetPoint"/>
+ </mxGeometry>
+ </mxCell>
+ <mxCell id="39" value="Handle virtque in remote process" style="shape=process;whiteSpace=wrap;html=1;backgroundOutline=1;rounded=0;labelBackgroundColor=none;sketch=0;fontSize=25;fontColor=#FF00FF;strokeColor=#FF00FF;" vertex="1" parent="1">
+ <mxGeometry x="850" y="807" width="220" height="120" as="geometry"/>
+ </mxCell>
+ <mxCell id="41" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.5;exitY=0;exitDx=0;exitDy=0;entryX=0.5;entryY=1;entryDx=0;entryDy=0;labelBackgroundColor=none;fontColor=#FF00FF;strokeColor=#FF00FF;" edge="1" parent="1" source="5" target="7">
+ <mxGeometry relative="1" as="geometry">
+ <mxPoint x="1860" y="187" as="sourcePoint"/>
+ <mxPoint x="1860" y="267" as="targetPoint"/>
+ </mxGeometry>
+ </mxCell>
+ <mxCell id="43" style="edgeStyle=orthogonalEdgeStyle;rounded=0;orthogonalLoop=1;jettySize=auto;html=1;exitX=0.5;exitY=0;exitDx=0;exitDy=0;labelBackgroundColor=none;entryX=0.5;entryY=1;entryDx=0;entryDy=0;strokeColor=#FF00FF;" edge="1" parent="1" source="19" target="37">
+ <mxGeometry relative="1" as="geometry">
+ <mxPoint x="1430" y="187" as="sourcePoint"/>
+ <mxPoint x="2102" y="187" as="targetPoint"/>
+ </mxGeometry>
+ </mxCell>
+ <mxCell id="49" value="&lt;pre style=&quot;font-size: 16.5pt ; font-weight: 700 ; font-family: &amp;#34;jetbrains mono&amp;#34; , monospace&quot;&gt;&lt;pre style=&quot;font-family: &amp;#34;jetbrains mono&amp;#34; , monospace ; font-size: 16.5pt&quot;&gt;Trait&lt;/pre&gt;&lt;/pre&gt;" style="rounded=1;whiteSpace=wrap;html=1;labelBackgroundColor=none;strokeColor=#FF33FF;fontColor=#FF00FF;" vertex="1" parent="1">
+ <mxGeometry x="60" y="1017" width="130" height="60" as="geometry"/>
+ </mxCell>
+ <mxCell id="51" value="Vhost-user protocol" style="rounded=1;whiteSpace=wrap;html=1;dashed=1;labelBackgroundColor=none;sketch=0;strokeWidth=5;fontSize=67;fontColor=#FF00FF;fillColor=none;strokeColor=none;" vertex="1" parent="1">
+ <mxGeometry x="1220" y="817" width="330" height="150" as="geometry"/>
+ </mxCell>
+ <mxCell id="52" value="Vhost-user server" style="rounded=1;whiteSpace=wrap;html=1;dashed=1;labelBackgroundColor=none;sketch=0;strokeWidth=5;fontSize=67;fillColor=none;strokeColor=none;fontColor=#4D4D4D;" vertex="1" parent="1">
+ <mxGeometry x="1680" y="57" width="330" height="150" as="geometry"/>
+ </mxCell>
+ <mxCell id="53" value="VMM" style="rounded=1;whiteSpace=wrap;html=1;dashed=1;labelBackgroundColor=none;sketch=0;strokeWidth=5;fontSize=67;fillColor=none;strokeColor=none;fontColor=#4D4D4D;" vertex="1" parent="1">
+ <mxGeometry x="20" y="47" width="240" height="150" as="geometry"/>
+ </mxCell>
+ <mxCell id="54" value="&lt;pre style=&quot;font-family: &amp;#34;jetbrains mono&amp;#34; , monospace ; font-size: 16.5pt&quot;&gt;Struct&lt;/pre&gt;" style="rounded=0;whiteSpace=wrap;html=1;fontStyle=1;labelBackgroundColor=none;strokeColor=#0000CC;fontColor=#0000CC;" vertex="1" parent="1">
+ <mxGeometry x="240" y="1022" width="140" height="55" as="geometry"/>
+ </mxCell>
+ </root>
+ </mxGraphModel>
+ </diagram>
+</mxfile> \ No newline at end of file
diff --git a/docs/vhost_architecture.png b/docs/vhost_architecture.png
new file mode 100644
index 0000000..4d1e2bc
--- /dev/null
+++ b/docs/vhost_architecture.png
Binary files differ
diff --git a/rust-vmm-ci b/rust-vmm-ci
-Subproject e58ea7445ace0cb984f8002ba2436c34cf592ef
+Subproject ebc701641fa57f78d03f3f5ecac617b7bf7470b
diff --git a/src/backend.rs b/src/backend.rs
index 2d1a4a2..89fde50 100644
--- a/src/backend.rs
+++ b/src/backend.rs
@@ -1,4 +1,4 @@
-// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved.
+// Copyright (C) 2019-2021 Alibaba Cloud. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 or BSD-3-Clause
//
// Portions Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
@@ -9,14 +9,18 @@
//! Common traits and structs for vhost-kern and vhost-user backend drivers.
-use super::Result;
+use std::cell::RefCell;
use std::os::unix::io::RawFd;
+use std::sync::RwLock;
+
use vmm_sys_util::eventfd::EventFd;
+use super::Result;
+
/// Maximum number of memory regions supported.
pub const VHOST_MAX_MEMORY_REGIONS: usize = 255;
-/// Vring/virtque configuration data.
+/// Vring configuration data.
pub struct VringConfigData {
/// Maximum queue size supported by the driver.
pub queue_max_size: u16,
@@ -65,22 +69,109 @@ pub struct VhostUserMemoryRegionInfo {
pub userspace_addr: u64,
/// Optional offset where region starts in the mapped memory.
pub mmap_offset: u64,
- /// Optional file diescriptor for mmap
+ /// Optional file descriptor for mmap.
pub mmap_handle: RawFd,
}
-/// An interface for setting up vhost-based backend drivers.
+/// An interface for setting up vhost-based backend drivers with interior mutability.
///
/// Vhost devices are subset of virtio devices, which improve virtio device's performance by
/// delegating data plane operations to dedicated IO service processes. Vhost devices use the
/// same virtqueue layout as virtio devices to allow vhost devices to be mapped directly to
/// virtio devices.
+///
/// The purpose of vhost is to implement a subset of a virtio device's functionality outside the
/// VMM process. Typically fast paths for IO operations are delegated to the dedicated IO service
/// processes, and slow path for device configuration are still handled by the VMM process. It may
/// also be used to control access permissions of virtio backend devices.
pub trait VhostBackend: std::marker::Sized {
/// Get a bitmask of supported virtio/vhost features.
+ fn get_features(&self) -> Result<u64>;
+
+ /// Inform the vhost subsystem which features to enable.
+ /// This should be a subset of supported features from get_features().
+ ///
+ /// # Arguments
+ /// * `features` - Bitmask of features to set.
+ fn set_features(&self, features: u64) -> Result<()>;
+
+ /// Set the current process as the owner of the vhost backend.
+ /// This must be run before any other vhost commands.
+ fn set_owner(&self) -> Result<()>;
+
+ /// Used to be sent to request disabling all rings
+ /// This is no longer used.
+ fn reset_owner(&self) -> Result<()>;
+
+ /// Set the guest memory mappings for vhost to use.
+ fn set_mem_table(&self, regions: &[VhostUserMemoryRegionInfo]) -> Result<()>;
+
+ /// Set base address for page modification logging.
+ fn set_log_base(&self, base: u64, fd: Option<RawFd>) -> Result<()>;
+
+ /// Specify an eventfd file descriptor to signal on log write.
+ fn set_log_fd(&self, fd: RawFd) -> Result<()>;
+
+ /// Set the number of descriptors in the vring.
+ ///
+ /// # Arguments
+ /// * `queue_index` - Index of the queue to set descriptor count for.
+ /// * `num` - Number of descriptors in the queue.
+ fn set_vring_num(&self, queue_index: usize, num: u16) -> Result<()>;
+
+ /// Set the addresses for a given vring.
+ ///
+ /// # Arguments
+ /// * `queue_index` - Index of the queue to set addresses for.
+ /// * `config_data` - Configuration data for a vring.
+ fn set_vring_addr(&self, queue_index: usize, config_data: &VringConfigData) -> Result<()>;
+
+ /// Set the first index to look for available descriptors.
+ ///
+ /// # Arguments
+ /// * `queue_index` - Index of the queue to modify.
+ /// * `num` - Index where available descriptors start.
+ fn set_vring_base(&self, queue_index: usize, base: u16) -> Result<()>;
+
+ /// Get the available vring base offset.
+ fn get_vring_base(&self, queue_index: usize) -> Result<u32>;
+
+ /// Set the eventfd to trigger when buffers have been used by the host.
+ ///
+ /// # Arguments
+ /// * `queue_index` - Index of the queue to modify.
+ /// * `fd` - EventFd to trigger.
+ fn set_vring_call(&self, queue_index: usize, fd: &EventFd) -> Result<()>;
+
+ /// Set the eventfd that will be signaled by the guest when buffers are
+ /// available for the host to process.
+ ///
+ /// # Arguments
+ /// * `queue_index` - Index of the queue to modify.
+ /// * `fd` - EventFd that will be signaled from guest.
+ fn set_vring_kick(&self, queue_index: usize, fd: &EventFd) -> Result<()>;
+
+ /// Set the eventfd that will be signaled by the guest when error happens.
+ ///
+ /// # Arguments
+ /// * `queue_index` - Index of the queue to modify.
+ /// * `fd` - EventFd that will be signaled from guest.
+ fn set_vring_err(&self, queue_index: usize, fd: &EventFd) -> Result<()>;
+}
+
+/// An interface for setting up vhost-based backend drivers.
+///
+/// Vhost devices are subset of virtio devices, which improve virtio device's performance by
+/// delegating data plane operations to dedicated IO service processes. Vhost devices use the
+/// same virtqueue layout as virtio devices to allow vhost devices to be mapped directly to
+/// virtio devices.
+///
+/// The purpose of vhost is to implement a subset of a virtio device's functionality outside the
+/// VMM process. Typically fast paths for IO operations are delegated to the dedicated IO service
+/// processes, and slow path for device configuration are still handled by the VMM process. It may
+/// also be used to control access permissions of virtio backend devices.
+pub trait VhostBackendMut: std::marker::Sized {
+ /// Get a bitmask of supported virtio/vhost features.
fn get_features(&mut self) -> Result<u64>;
/// Inform the vhost subsystem which features to enable.
@@ -154,9 +245,236 @@ pub trait VhostBackend: std::marker::Sized {
fn set_vring_err(&mut self, queue_index: usize, fd: &EventFd) -> Result<()>;
}
+impl<T: VhostBackendMut> VhostBackend for RwLock<T> {
+ fn get_features(&self) -> Result<u64> {
+ self.write().unwrap().get_features()
+ }
+
+ fn set_features(&self, features: u64) -> Result<()> {
+ self.write().unwrap().set_features(features)
+ }
+
+ fn set_owner(&self) -> Result<()> {
+ self.write().unwrap().set_owner()
+ }
+
+ fn reset_owner(&self) -> Result<()> {
+ self.write().unwrap().reset_owner()
+ }
+
+ fn set_mem_table(&self, regions: &[VhostUserMemoryRegionInfo]) -> Result<()> {
+ self.write().unwrap().set_mem_table(regions)
+ }
+
+ fn set_log_base(&self, base: u64, fd: Option<RawFd>) -> Result<()> {
+ self.write().unwrap().set_log_base(base, fd)
+ }
+
+ fn set_log_fd(&self, fd: RawFd) -> Result<()> {
+ self.write().unwrap().set_log_fd(fd)
+ }
+
+ fn set_vring_num(&self, queue_index: usize, num: u16) -> Result<()> {
+ self.write().unwrap().set_vring_num(queue_index, num)
+ }
+
+ fn set_vring_addr(&self, queue_index: usize, config_data: &VringConfigData) -> Result<()> {
+ self.write()
+ .unwrap()
+ .set_vring_addr(queue_index, config_data)
+ }
+
+ fn set_vring_base(&self, queue_index: usize, base: u16) -> Result<()> {
+ self.write().unwrap().set_vring_base(queue_index, base)
+ }
+
+ fn get_vring_base(&self, queue_index: usize) -> Result<u32> {
+ self.write().unwrap().get_vring_base(queue_index)
+ }
+
+ fn set_vring_call(&self, queue_index: usize, fd: &EventFd) -> Result<()> {
+ self.write().unwrap().set_vring_call(queue_index, fd)
+ }
+
+ fn set_vring_kick(&self, queue_index: usize, fd: &EventFd) -> Result<()> {
+ self.write().unwrap().set_vring_kick(queue_index, fd)
+ }
+
+ fn set_vring_err(&self, queue_index: usize, fd: &EventFd) -> Result<()> {
+ self.write().unwrap().set_vring_err(queue_index, fd)
+ }
+}
+
+impl<T: VhostBackendMut> VhostBackend for RefCell<T> {
+ fn get_features(&self) -> Result<u64> {
+ self.borrow_mut().get_features()
+ }
+
+ fn set_features(&self, features: u64) -> Result<()> {
+ self.borrow_mut().set_features(features)
+ }
+
+ fn set_owner(&self) -> Result<()> {
+ self.borrow_mut().set_owner()
+ }
+
+ fn reset_owner(&self) -> Result<()> {
+ self.borrow_mut().reset_owner()
+ }
+
+ fn set_mem_table(&self, regions: &[VhostUserMemoryRegionInfo]) -> Result<()> {
+ self.borrow_mut().set_mem_table(regions)
+ }
+
+ fn set_log_base(&self, base: u64, fd: Option<RawFd>) -> Result<()> {
+ self.borrow_mut().set_log_base(base, fd)
+ }
+
+ fn set_log_fd(&self, fd: RawFd) -> Result<()> {
+ self.borrow_mut().set_log_fd(fd)
+ }
+
+ fn set_vring_num(&self, queue_index: usize, num: u16) -> Result<()> {
+ self.borrow_mut().set_vring_num(queue_index, num)
+ }
+
+ fn set_vring_addr(&self, queue_index: usize, config_data: &VringConfigData) -> Result<()> {
+ self.borrow_mut().set_vring_addr(queue_index, config_data)
+ }
+
+ fn set_vring_base(&self, queue_index: usize, base: u16) -> Result<()> {
+ self.borrow_mut().set_vring_base(queue_index, base)
+ }
+
+ fn get_vring_base(&self, queue_index: usize) -> Result<u32> {
+ self.borrow_mut().get_vring_base(queue_index)
+ }
+
+ fn set_vring_call(&self, queue_index: usize, fd: &EventFd) -> Result<()> {
+ self.borrow_mut().set_vring_call(queue_index, fd)
+ }
+
+ fn set_vring_kick(&self, queue_index: usize, fd: &EventFd) -> Result<()> {
+ self.borrow_mut().set_vring_kick(queue_index, fd)
+ }
+
+ fn set_vring_err(&self, queue_index: usize, fd: &EventFd) -> Result<()> {
+ self.borrow_mut().set_vring_err(queue_index, fd)
+ }
+}
#[cfg(test)]
mod tests {
- use VringConfigData;
+ use super::*;
+
+ struct MockBackend {}
+
+ impl VhostBackendMut for MockBackend {
+ fn get_features(&mut self) -> Result<u64> {
+ Ok(0x1)
+ }
+
+ fn set_features(&mut self, features: u64) -> Result<()> {
+ assert_eq!(features, 0x1);
+ Ok(())
+ }
+
+ fn set_owner(&mut self) -> Result<()> {
+ Ok(())
+ }
+
+ fn reset_owner(&mut self) -> Result<()> {
+ Ok(())
+ }
+
+ fn set_mem_table(&mut self, _regions: &[VhostUserMemoryRegionInfo]) -> Result<()> {
+ Ok(())
+ }
+
+ fn set_log_base(&mut self, base: u64, fd: Option<RawFd>) -> Result<()> {
+ assert_eq!(base, 0x100);
+ assert_eq!(fd, Some(100));
+ Ok(())
+ }
+
+ fn set_log_fd(&mut self, fd: RawFd) -> Result<()> {
+ assert_eq!(fd, 100);
+ Ok(())
+ }
+
+ fn set_vring_num(&mut self, queue_index: usize, num: u16) -> Result<()> {
+ assert_eq!(queue_index, 1);
+ assert_eq!(num, 256);
+ Ok(())
+ }
+
+ fn set_vring_addr(
+ &mut self,
+ queue_index: usize,
+ _config_data: &VringConfigData,
+ ) -> Result<()> {
+ assert_eq!(queue_index, 1);
+ Ok(())
+ }
+
+ fn set_vring_base(&mut self, queue_index: usize, base: u16) -> Result<()> {
+ assert_eq!(queue_index, 1);
+ assert_eq!(base, 2);
+ Ok(())
+ }
+
+ fn get_vring_base(&mut self, queue_index: usize) -> Result<u32> {
+ assert_eq!(queue_index, 1);
+ Ok(2)
+ }
+
+ fn set_vring_call(&mut self, queue_index: usize, _fd: &EventFd) -> Result<()> {
+ assert_eq!(queue_index, 1);
+ Ok(())
+ }
+
+ fn set_vring_kick(&mut self, queue_index: usize, _fd: &EventFd) -> Result<()> {
+ assert_eq!(queue_index, 1);
+ Ok(())
+ }
+
+ fn set_vring_err(&mut self, queue_index: usize, _fd: &EventFd) -> Result<()> {
+ assert_eq!(queue_index, 1);
+ Ok(())
+ }
+ }
+
+ #[test]
+ fn test_vring_backend_mut() {
+ let b = RwLock::new(MockBackend {});
+
+ assert_eq!(b.get_features().unwrap(), 0x1);
+ b.set_features(0x1).unwrap();
+ b.set_owner().unwrap();
+ b.reset_owner().unwrap();
+ b.set_mem_table(&[]).unwrap();
+ b.set_log_base(0x100, Some(100)).unwrap();
+ b.set_log_fd(100).unwrap();
+ b.set_vring_num(1, 256).unwrap();
+
+ let config = VringConfigData {
+ queue_max_size: 0x1000,
+ queue_size: 0x2000,
+ flags: 0x0,
+ desc_table_addr: 0x4000,
+ used_ring_addr: 0x5000,
+ avail_ring_addr: 0x6000,
+ log_addr: None,
+ };
+ b.set_vring_addr(1, &config).unwrap();
+
+ b.set_vring_base(1, 2).unwrap();
+ assert_eq!(b.get_vring_base(1).unwrap(), 2);
+
+ let eventfd = EventFd::new(0).unwrap();
+ b.set_vring_call(1, &eventfd).unwrap();
+ b.set_vring_kick(1, &eventfd).unwrap();
+ b.set_vring_err(1, &eventfd).unwrap();
+ }
#[test]
fn test_vring_config_data() {
diff --git a/src/lib.rs b/src/lib.rs
index e0cb2b8..b7ed15c 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -1,4 +1,4 @@
-// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved.
+// Copyright (C) 2019 Alibaba Cloud. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 or BSD-3-Clause
//! Virtio Vhost Backend Drivers
@@ -32,14 +32,8 @@
#![deny(missing_docs)]
-#[cfg_attr(
- any(feature = "vhost-user-master", feature = "vhost-user-slave"),
- macro_use
-)]
+#[cfg_attr(feature = "vhost-user", macro_use)]
extern crate bitflags;
-extern crate libc;
-#[cfg(feature = "vhost-kern")]
-extern crate vm_memory;
#[cfg_attr(feature = "vhost-kern", macro_use)]
extern crate vmm_sys_util;
@@ -48,7 +42,7 @@ pub use backend::*;
#[cfg(feature = "vhost-kern")]
pub mod vhost_kern;
-#[cfg(any(feature = "vhost-user-master", feature = "vhost-user-slave"))]
+#[cfg(feature = "vhost-user")]
pub mod vhost_user;
#[cfg(feature = "vhost-vsock")]
pub mod vsock;
@@ -80,7 +74,7 @@ pub enum Error {
IoctlError(std::io::Error),
/// Error from IO subsystem.
IOError(std::io::Error),
- #[cfg(any(feature = "vhost-user-master", feature = "vhost-user-slave"))]
+ #[cfg(feature = "vhost-user")]
/// Error from the vhost-user subsystem.
VhostUserProtocol(vhost_user::Error),
}
@@ -94,20 +88,22 @@ impl std::fmt::Display for Error {
Error::InvalidQueue => write!(f, "invalid virtque"),
Error::DescriptorTableAddress => write!(f, "invalid virtque descriptor talbe address"),
Error::UsedAddress => write!(f, "invalid virtque used talbe address"),
- Error::AvailAddress => write!(f, "invalid virtque available talbe address"),
+ Error::AvailAddress => write!(f, "invalid virtque available table address"),
Error::LogAddress => write!(f, "invalid virtque log address"),
Error::IOError(e) => write!(f, "IO error: {}", e),
#[cfg(feature = "vhost-kern")]
Error::VhostOpen(e) => write!(f, "failure in opening vhost file: {}", e),
#[cfg(feature = "vhost-kern")]
Error::IoctlError(e) => write!(f, "failure in vhost ioctl: {}", e),
- #[cfg(any(feature = "vhost-user-master", feature = "vhost-user-slave"))]
+ #[cfg(feature = "vhost-user")]
Error::VhostUserProtocol(e) => write!(f, "vhost-user: {}", e),
}
}
}
-#[cfg(any(feature = "vhost-user-master", feature = "vhost-user-slave"))]
+impl std::error::Error for Error {}
+
+#[cfg(feature = "vhost-user")]
impl std::convert::From<vhost_user::Error> for Error {
fn from(err: vhost_user::Error) -> Self {
Error::VhostUserProtocol(err)
@@ -116,3 +112,51 @@ impl std::convert::From<vhost_user::Error> for Error {
/// Result of vhost operations
pub type Result<T> = std::result::Result<T, Error>;
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_error() {
+ assert_eq!(
+ format!("{}", Error::AvailAddress),
+ "invalid virtque available table address"
+ );
+ assert_eq!(
+ format!("{}", Error::InvalidOperation),
+ "invalid vhost operations"
+ );
+ assert_eq!(
+ format!("{}", Error::InvalidGuestMemory),
+ "invalid guest memory object"
+ );
+ assert_eq!(
+ format!("{}", Error::InvalidGuestMemoryRegion),
+ "invalid guest memory region"
+ );
+ assert_eq!(format!("{}", Error::InvalidQueue), "invalid virtque");
+ assert_eq!(
+ format!("{}", Error::DescriptorTableAddress),
+ "invalid virtque descriptor talbe address"
+ );
+ assert_eq!(
+ format!("{}", Error::UsedAddress),
+ "invalid virtque used talbe address"
+ );
+ assert_eq!(
+ format!("{}", Error::LogAddress),
+ "invalid virtque log address"
+ );
+
+ assert_eq!(format!("{:?}", Error::AvailAddress), "AvailAddress");
+ }
+
+ #[cfg(feature = "vhost-user")]
+ #[test]
+ fn test_convert_from_vhost_user_error() {
+ let e: Error = vhost_user::Error::OversizedMsg.into();
+
+ assert_eq!(format!("{}", e), "vhost-user: oversized message");
+ }
+}
diff --git a/src/vhost_kern/mod.rs b/src/vhost_kern/mod.rs
index 350e134..f82cbfc 100644
--- a/src/vhost_kern/mod.rs
+++ b/src/vhost_kern/mod.rs
@@ -13,7 +13,7 @@
use std::os::unix::io::{AsRawFd, RawFd};
-use vm_memory::GuestAddressSpace;
+use vm_memory::{Address, GuestAddress, GuestAddressSpace, GuestMemory, GuestUsize};
use vmm_sys_util::eventfd::EventFd;
use vmm_sys_util::ioctl::{ioctl, ioctl_with_mut_ref, ioctl_with_ptr, ioctl_with_ref};
@@ -39,7 +39,7 @@ fn ioctl_result<T>(rc: i32, res: T) -> Result<T> {
/// Represent an in-kernel vhost device backend.
pub trait VhostKernBackend: AsRawFd {
- /// Assoicated type to access guest memory.
+ /// Associated type to access guest memory.
type AS: GuestAddressSpace;
/// Get the object to access the guest's memory.
@@ -55,52 +55,36 @@ pub trait VhostKernBackend: AsRawFd {
return false;
}
- // TODO: the GuestMemory trait lacks of method to look up GPA by HVA,
- // so there's no way to validate HVAs. Please extend vm-memory crate
- // first.
- /*
+ let m = self.mem().memory();
let desc_table_size = 16 * u64::from(queue_size) as GuestUsize;
let avail_ring_size = 6 + 2 * u64::from(queue_size) as GuestUsize;
let used_ring_size = 6 + 8 * u64::from(queue_size) as GuestUsize;
if GuestAddress(config_data.desc_table_addr)
.checked_add(desc_table_size)
- .map_or(true, |v| !self.mem().address_in_range(v))
+ .map_or(true, |v| !m.address_in_range(v))
{
- false
- } else if GuestAddress(config_data.avail_ring_addr)
+ return false;
+ }
+ if GuestAddress(config_data.avail_ring_addr)
.checked_add(avail_ring_size)
- .map_or(true, |v| !self.mem().address_in_range(v))
+ .map_or(true, |v| !m.address_in_range(v))
{
- false
- } else if GuestAddress(config_data.used_ring_addr)
+ return false;
+ }
+ if GuestAddress(config_data.used_ring_addr)
.checked_add(used_ring_size)
- .map_or(true, |v| !self.mem().address_in_range(v))
+ .map_or(true, |v| !m.address_in_range(v))
{
- false
+ return false;
}
- */
config_data.is_log_addr_valid()
}
}
impl<T: VhostKernBackend> VhostBackend for T {
- /// Set the current process as the owner of this file descriptor.
- /// This must be run before any other vhost ioctls.
- fn set_owner(&mut self) -> Result<()> {
- // This ioctl is called on a valid vhost fd and has its return value checked.
- let ret = unsafe { ioctl(self, VHOST_SET_OWNER()) };
- ioctl_result(ret, ())
- }
-
- fn reset_owner(&mut self) -> Result<()> {
- // This ioctl is called on a valid vhost fd and has its return value checked.
- let ret = unsafe { ioctl(self, VHOST_RESET_OWNER()) };
- ioctl_result(ret, ())
- }
-
/// Get a bitmask of supported virtio/vhost features.
- fn get_features(&mut self) -> Result<u64> {
+ fn get_features(&self) -> Result<u64> {
let mut avail_features: u64 = 0;
// This ioctl is called on a valid vhost fd and has its return value checked.
let ret = unsafe { ioctl_with_mut_ref(self, VHOST_GET_FEATURES(), &mut avail_features) };
@@ -112,14 +96,28 @@ impl<T: VhostKernBackend> VhostBackend for T {
///
/// # Arguments
/// * `features` - Bitmask of features to set.
- fn set_features(&mut self, features: u64) -> Result<()> {
+ fn set_features(&self, features: u64) -> Result<()> {
// This ioctl is called on a valid vhost fd and has its return value checked.
let ret = unsafe { ioctl_with_ref(self, VHOST_SET_FEATURES(), &features) };
ioctl_result(ret, ())
}
+ /// Set the current process as the owner of this file descriptor.
+ /// This must be run before any other vhost ioctls.
+ fn set_owner(&self) -> Result<()> {
+ // This ioctl is called on a valid vhost fd and has its return value checked.
+ let ret = unsafe { ioctl(self, VHOST_SET_OWNER()) };
+ ioctl_result(ret, ())
+ }
+
+ fn reset_owner(&self) -> Result<()> {
+ // This ioctl is called on a valid vhost fd and has its return value checked.
+ let ret = unsafe { ioctl(self, VHOST_RESET_OWNER()) };
+ ioctl_result(ret, ())
+ }
+
/// Set the guest memory mappings for vhost to use.
- fn set_mem_table(&mut self, regions: &[VhostUserMemoryRegionInfo]) -> Result<()> {
+ fn set_mem_table(&self, regions: &[VhostUserMemoryRegionInfo]) -> Result<()> {
if regions.is_empty() || regions.len() > VHOST_MAX_MEMORY_REGIONS {
return Err(Error::InvalidGuestMemory);
}
@@ -148,7 +146,7 @@ impl<T: VhostKernBackend> VhostBackend for T {
///
/// # Arguments
/// * `base` - Base address for page modification logging.
- fn set_log_base(&mut self, base: u64, fd: Option<RawFd>) -> Result<()> {
+ fn set_log_base(&self, base: u64, fd: Option<RawFd>) -> Result<()> {
if fd.is_some() {
return Err(Error::LogAddress);
}
@@ -159,7 +157,7 @@ impl<T: VhostKernBackend> VhostBackend for T {
}
/// Specify an eventfd file descriptor to signal on log write.
- fn set_log_fd(&mut self, fd: RawFd) -> Result<()> {
+ fn set_log_fd(&self, fd: RawFd) -> Result<()> {
// This ioctl is called on a valid vhost fd and has its return value checked.
let val: i32 = fd;
let ret = unsafe { ioctl_with_ref(self, VHOST_SET_LOG_FD(), &val) };
@@ -171,7 +169,7 @@ impl<T: VhostKernBackend> VhostBackend for T {
/// # Arguments
/// * `queue_index` - Index of the queue to set descriptor count for.
/// * `num` - Number of descriptors in the queue.
- fn set_vring_num(&mut self, queue_index: usize, num: u16) -> Result<()> {
+ fn set_vring_num(&self, queue_index: usize, num: u16) -> Result<()> {
let vring_state = vhost_vring_state {
index: queue_index as u32,
num: u32::from(num),
@@ -187,7 +185,7 @@ impl<T: VhostKernBackend> VhostBackend for T {
/// # Arguments
/// * `queue_index` - Index of the queue to set addresses for.
/// * `config_data` - Vring config data.
- fn set_vring_addr(&mut self, queue_index: usize, config_data: &VringConfigData) -> Result<()> {
+ fn set_vring_addr(&self, queue_index: usize, config_data: &VringConfigData) -> Result<()> {
if !self.is_valid(config_data) {
return Err(Error::InvalidQueue);
}
@@ -212,7 +210,7 @@ impl<T: VhostKernBackend> VhostBackend for T {
/// # Arguments
/// * `queue_index` - Index of the queue to modify.
/// * `num` - Index where available descriptors start.
- fn set_vring_base(&mut self, queue_index: usize, base: u16) -> Result<()> {
+ fn set_vring_base(&self, queue_index: usize, base: u16) -> Result<()> {
let vring_state = vhost_vring_state {
index: queue_index as u32,
num: u32::from(base),
@@ -224,7 +222,7 @@ impl<T: VhostKernBackend> VhostBackend for T {
}
/// Get a bitmask of supported virtio/vhost features.
- fn get_vring_base(&mut self, queue_index: usize) -> Result<u32> {
+ fn get_vring_base(&self, queue_index: usize) -> Result<u32> {
let vring_state = vhost_vring_state {
index: queue_index as u32,
num: 0,
@@ -239,7 +237,7 @@ impl<T: VhostKernBackend> VhostBackend for T {
/// # Arguments
/// * `queue_index` - Index of the queue to modify.
/// * `fd` - EventFd to trigger.
- fn set_vring_call(&mut self, queue_index: usize, fd: &EventFd) -> Result<()> {
+ fn set_vring_call(&self, queue_index: usize, fd: &EventFd) -> Result<()> {
let vring_file = vhost_vring_file {
index: queue_index as u32,
fd: fd.as_raw_fd(),
@@ -256,7 +254,7 @@ impl<T: VhostKernBackend> VhostBackend for T {
/// # Arguments
/// * `queue_index` - Index of the queue to modify.
/// * `fd` - EventFd that will be signaled from guest.
- fn set_vring_kick(&mut self, queue_index: usize, fd: &EventFd) -> Result<()> {
+ fn set_vring_kick(&self, queue_index: usize, fd: &EventFd) -> Result<()> {
let vring_file = vhost_vring_file {
index: queue_index as u32,
fd: fd.as_raw_fd(),
@@ -272,7 +270,7 @@ impl<T: VhostKernBackend> VhostBackend for T {
/// # Arguments
/// * `queue_index` - Index of the queue to modify.
/// * `fd` - EventFd that will be signaled from the backend.
- fn set_vring_err(&mut self, queue_index: usize, fd: &EventFd) -> Result<()> {
+ fn set_vring_err(&self, queue_index: usize, fd: &EventFd) -> Result<()> {
let vring_file = vhost_vring_file {
index: queue_index as u32,
fd: fd.as_raw_fd(),
diff --git a/src/vhost_kern/vhost_binding.rs b/src/vhost_kern/vhost_binding.rs
index fdc5225..57ae698 100644
--- a/src/vhost_kern/vhost_binding.rs
+++ b/src/vhost_kern/vhost_binding.rs
@@ -13,6 +13,7 @@
#![allow(non_camel_case_types)]
#![allow(non_snake_case)]
#![allow(missing_docs)]
+#![allow(clippy::missing_safety_doc)]
use crate::{Error, Result};
use std::os::raw;
diff --git a/src/vhost_kern/vsock.rs b/src/vhost_kern/vsock.rs
index c4149bd..65f89e4 100644
--- a/src/vhost_kern/vsock.rs
+++ b/src/vhost_kern/vsock.rs
@@ -1,22 +1,23 @@
-// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved.
+// Copyright (C) 2019 Alibaba Cloud. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 or BSD-3-Clause
//
// Copyright 2017 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE-BSD-Google file.
-//! Kernel-based vsock vhost backend.
+//! Kernel-based vhost-vsock backend.
use std::fs::{File, OpenOptions};
use std::os::unix::fs::OpenOptionsExt;
use std::os::unix::io::{AsRawFd, RawFd};
-use super::vhost_binding::{VHOST_VSOCK_SET_GUEST_CID, VHOST_VSOCK_SET_RUNNING};
-use super::{ioctl_result, Error, Result, VhostKernBackend};
-use libc;
use vm_memory::GuestAddressSpace;
use vmm_sys_util::ioctl::ioctl_with_ref;
+use super::vhost_binding::{VHOST_VSOCK_SET_GUEST_CID, VHOST_VSOCK_SET_RUNNING};
+use super::{ioctl_result, Error, Result, VhostKernBackend};
+use crate::vsock::VhostVsock;
+
const VHOST_PATH: &str = "/dev/vhost-vsock";
/// Handle for running VHOST_VSOCK ioctls.
@@ -39,31 +40,26 @@ impl<AS: GuestAddressSpace> Vsock<AS> {
})
}
- /// Set the CID for the guest. This number is used for routing all data destined for
- /// running in the guest. Each guest on a hypervisor must have an unique CID
- ///
- /// # Arguments
- /// * `cid` - CID to assign to the guest
- pub fn set_guest_cid(&self, cid: u64) -> Result<()> {
+ fn set_running(&self, running: bool) -> Result<()> {
+ let on: ::std::os::raw::c_int = if running { 1 } else { 0 };
+ let ret = unsafe { ioctl_with_ref(&self.fd, VHOST_VSOCK_SET_RUNNING(), &on) };
+ ioctl_result(ret, ())
+ }
+}
+
+impl<AS: GuestAddressSpace> VhostVsock for Vsock<AS> {
+ fn set_guest_cid(&self, cid: u64) -> Result<()> {
let ret = unsafe { ioctl_with_ref(&self.fd, VHOST_VSOCK_SET_GUEST_CID(), &cid) };
ioctl_result(ret, ())
}
- /// Tell the VHOST driver to start performing data transfer.
- pub fn start(&self) -> Result<()> {
+ fn start(&self) -> Result<()> {
self.set_running(true)
}
- /// Tell the VHOST driver to stop performing data transfer.
- pub fn stop(&self) -> Result<()> {
+ fn stop(&self) -> Result<()> {
self.set_running(false)
}
-
- fn set_running(&self, running: bool) -> Result<()> {
- let on: ::std::os::raw::c_int = if running { 1 } else { 0 };
- let ret = unsafe { ioctl_with_ref(&self.fd, VHOST_VSOCK_SET_RUNNING(), &on) };
- ioctl_result(ret, ())
- }
}
impl<AS: GuestAddressSpace> VhostKernBackend for Vsock<AS> {
@@ -79,3 +75,106 @@ impl<AS: GuestAddressSpace> AsRawFd for Vsock<AS> {
self.fd.as_raw_fd()
}
}
+
+#[cfg(test)]
+mod tests {
+ use vm_memory::{GuestAddress, GuestMemory, GuestMemoryMmap};
+ use vmm_sys_util::eventfd::EventFd;
+
+ use super::*;
+ use crate::{VhostBackend, VhostUserMemoryRegionInfo, VringConfigData};
+
+ #[test]
+ fn test_vsock_new_device() {
+ let m = GuestMemoryMmap::from_ranges(&[(GuestAddress(0), 0x10_0000)]).unwrap();
+ let vsock = Vsock::new(&m).unwrap();
+
+ assert!(vsock.as_raw_fd() >= 0);
+ assert!(vsock.mem().find_region(GuestAddress(0x100)).is_some());
+ assert!(vsock.mem().find_region(GuestAddress(0x10_0000)).is_none());
+ }
+
+ #[test]
+ fn test_vsock_is_valid() {
+ let m = GuestMemoryMmap::from_ranges(&[(GuestAddress(0), 0x10_0000)]).unwrap();
+ let vsock = Vsock::new(&m).unwrap();
+
+ let mut config = VringConfigData {
+ queue_max_size: 32,
+ queue_size: 32,
+ flags: 0,
+ desc_table_addr: 0x1000,
+ used_ring_addr: 0x2000,
+ avail_ring_addr: 0x3000,
+ log_addr: None,
+ };
+ assert_eq!(vsock.is_valid(&config), true);
+
+ config.queue_size = 0;
+ assert_eq!(vsock.is_valid(&config), false);
+ config.queue_size = 31;
+ assert_eq!(vsock.is_valid(&config), false);
+ config.queue_size = 33;
+ assert_eq!(vsock.is_valid(&config), false);
+ }
+
+ #[test]
+ fn test_vsock_ioctls() {
+ let m = GuestMemoryMmap::from_ranges(&[(GuestAddress(0), 0x10_0000)]).unwrap();
+ let vsock = Vsock::new(&m).unwrap();
+
+ let features = vsock.get_features().unwrap();
+ vsock.set_features(features).unwrap();
+
+ vsock.set_owner().unwrap();
+
+ vsock.set_mem_table(&[]).unwrap_err();
+
+ /*
+ let region = VhostUserMemoryRegionInfo {
+ guest_phys_addr: 0x0,
+ memory_size: 0x10_0000,
+ userspace_addr: 0,
+ mmap_offset: 0,
+ mmap_handle: -1,
+ };
+ vsock.set_mem_table(&[region]).unwrap_err();
+ */
+
+ let region = VhostUserMemoryRegionInfo {
+ guest_phys_addr: 0x0,
+ memory_size: 0x10_0000,
+ userspace_addr: m.get_host_address(GuestAddress(0x0)).unwrap() as u64,
+ mmap_offset: 0,
+ mmap_handle: -1,
+ };
+ vsock.set_mem_table(&[region]).unwrap();
+
+ vsock.set_log_base(0x4000, Some(1)).unwrap_err();
+ vsock.set_log_base(0x4000, None).unwrap();
+
+ let eventfd = EventFd::new(0).unwrap();
+ vsock.set_log_fd(eventfd.as_raw_fd()).unwrap();
+
+ vsock.set_vring_num(0, 32).unwrap();
+
+ let config = VringConfigData {
+ queue_max_size: 32,
+ queue_size: 32,
+ flags: 0,
+ desc_table_addr: 0x1000,
+ used_ring_addr: 0x2000,
+ avail_ring_addr: 0x3000,
+ log_addr: None,
+ };
+ vsock.set_vring_addr(0, &config).unwrap();
+ vsock.set_vring_base(0, 1).unwrap();
+ vsock.set_vring_call(0, &eventfd).unwrap();
+ vsock.set_vring_kick(0, &eventfd).unwrap();
+ vsock.set_vring_err(0, &eventfd).unwrap();
+ assert_eq!(vsock.get_vring_base(0).unwrap(), 1);
+ vsock.set_guest_cid(0xdead).unwrap();
+ //vsock.start().unwrap();
+ //vsock.stop().unwrap();
+ }
+}
diff --git a/src/vhost_user/connection.rs b/src/vhost_user/connection.rs
index deafdeb..01bf124 100644
--- a/src/vhost_user/connection.rs
+++ b/src/vhost_user/connection.rs
@@ -5,15 +5,16 @@
#![allow(dead_code)]
-use libc::{c_void, iovec};
use std::io::ErrorKind;
use std::marker::PhantomData;
use std::os::unix::io::{AsRawFd, RawFd};
use std::os::unix::net::{UnixListener, UnixStream};
use std::{mem, slice};
+use libc::{c_void, iovec};
+use vmm_sys_util::sock_ctrl_msg::ScmSocket;
+
use super::message::*;
-use super::sock_ctrl_msg::ScmSocket;
use super::{Error, Result};
/// Unix domain socket listener for accepting incoming connections.
@@ -215,6 +216,9 @@ impl<R: Req> Endpoint<R> {
body: &T,
fds: Option<&[RawFd]>,
) -> Result<()> {
+ if mem::size_of::<T>() > MAX_MSG_SIZE {
+ return Err(Error::OversizedMsg);
+ }
// Safe because there can't be other mutable referance to hdr and body.
let iovs = unsafe {
[
@@ -243,14 +247,17 @@ impl<R: Req> Endpoint<R> {
/// * - OversizedMsg: message size is too big.
/// * - PartialMessage: received a partial message.
/// * - IncorrectFds: wrong number of attached fds.
- pub fn send_message_with_payload<T: Sized, P: Sized>(
+ pub fn send_message_with_payload<T: Sized>(
&mut self,
hdr: &VhostUserMsgHeader<R>,
body: &T,
- payload: &[P],
+ payload: &[u8],
fds: Option<&[RawFd]>,
) -> Result<()> {
- let len = payload.len() * mem::size_of::<P>();
+ let len = payload.len();
+ if mem::size_of::<T>() > MAX_MSG_SIZE {
+ return Err(Error::OversizedMsg);
+ }
if len > MAX_MSG_SIZE - mem::size_of::<T>() {
return Err(Error::OversizedMsg);
}
@@ -599,27 +606,32 @@ fn get_sub_iovs_offset(iov_lens: &[usize], skip_size: usize) -> (usize, usize) {
#[cfg(test)]
mod tests {
-
use super::*;
use std::fs::File;
use std::io::{Read, Seek, SeekFrom, Write};
use std::os::unix::io::FromRawFd;
+ use vmm_sys_util::rand::rand_alphanumerics;
use vmm_sys_util::tempfile::TempFile;
- const UNIX_SOCKET_LISTENER: &'static str = "/tmp/vhost_user_test_rust_listener";
- const UNIX_SOCKET_CONNECTION: &'static str = "/tmp/vhost_user_test_rust_connection";
- const UNIX_SOCKET_DATA: &'static str = "/tmp/vhost_user_test_rust_data";
- const UNIX_SOCKET_FD: &'static str = "/tmp/vhost_user_test_rust_fd";
- const UNIX_SOCKET_SEND: &'static str = "/tmp/vhost_user_test_rust_send";
+ fn temp_path() -> String {
+ format!(
+ "/tmp/vhost_test_{}",
+ rand_alphanumerics(8).to_str().unwrap()
+ )
+ }
#[test]
fn create_listener() {
- let _ = Listener::new(UNIX_SOCKET_LISTENER, true).unwrap();
+ let path = temp_path();
+ let listener = Listener::new(&path, true).unwrap();
+
+ assert!(listener.as_raw_fd() > 0);
}
#[test]
fn accept_connection() {
- let listener = Listener::new(UNIX_SOCKET_CONNECTION, true).unwrap();
+ let path = temp_path();
+ let listener = Listener::new(&path, true).unwrap();
listener.set_nonblocking(true).unwrap();
// accept on a fd without incoming connection
@@ -628,11 +640,11 @@ mod tests {
}
#[test]
- #[ignore]
fn send_data() {
- let listener = Listener::new(UNIX_SOCKET_DATA, true).unwrap();
+ let path = temp_path();
+ let listener = Listener::new(&path, true).unwrap();
listener.set_nonblocking(true).unwrap();
- let mut master = Endpoint::<MasterReq>::connect(UNIX_SOCKET_DATA).unwrap();
+ let mut master = Endpoint::<MasterReq>::connect(&path).unwrap();
let sock = listener.accept().unwrap().unwrap();
let mut slave = Endpoint::<MasterReq>::from_stream(sock);
@@ -654,11 +666,11 @@ mod tests {
}
#[test]
- #[ignore]
fn send_fd() {
- let listener = Listener::new(UNIX_SOCKET_FD, true).unwrap();
+ let path = temp_path();
+ let listener = Listener::new(&path, true).unwrap();
listener.set_nonblocking(true).unwrap();
- let mut master = Endpoint::<MasterReq>::connect(UNIX_SOCKET_FD).unwrap();
+ let mut master = Endpoint::<MasterReq>::connect(&path).unwrap();
let sock = listener.accept().unwrap().unwrap();
let mut slave = Endpoint::<MasterReq>::from_stream(sock);
@@ -808,11 +820,11 @@ mod tests {
}
#[test]
- #[ignore]
fn send_recv() {
- let listener = Listener::new(UNIX_SOCKET_SEND, true).unwrap();
+ let path = temp_path();
+ let listener = Listener::new(&path, true).unwrap();
listener.set_nonblocking(true).unwrap();
- let mut master = Endpoint::<MasterReq>::connect(UNIX_SOCKET_SEND).unwrap();
+ let mut master = Endpoint::<MasterReq>::connect(&path).unwrap();
let sock = listener.accept().unwrap().unwrap();
let mut slave = Endpoint::<MasterReq>::from_stream(sock);
diff --git a/src/vhost_user/dummy_slave.rs b/src/vhost_user/dummy_slave.rs
index 53887e2..9eedcbb 100644
--- a/src/vhost_user/dummy_slave.rs
+++ b/src/vhost_user/dummy_slave.rs
@@ -1,9 +1,10 @@
// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
+use std::os::unix::io::RawFd;
+
use super::message::*;
use super::*;
-use std::os::unix::io::RawFd;
pub const MAX_QUEUE_NUM: usize = 2;
pub const MAX_VRING_NUM: usize = 256;
@@ -34,7 +35,7 @@ impl DummySlaveReqHandler {
}
}
-impl VhostUserSlaveReqHandler for DummySlaveReqHandler {
+impl VhostUserSlaveReqHandlerMut for DummySlaveReqHandler {
fn set_owner(&mut self) -> Result<()> {
if self.owned {
return Err(Error::InvalidOperation);
@@ -56,9 +57,7 @@ impl VhostUserSlaveReqHandler for DummySlaveReqHandler {
}
fn set_features(&mut self, features: u64) -> Result<()> {
- if !self.owned {
- return Err(Error::InvalidOperation);
- } else if self.features_acked {
+ if !self.owned || self.features_acked {
return Err(Error::InvalidOperation);
} else if (features & !VIRTIO_FEATURES) != 0 {
return Err(Error::InvalidParam);
@@ -83,30 +82,10 @@ impl VhostUserSlaveReqHandler for DummySlaveReqHandler {
Ok(())
}
- fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures> {
- Ok(VhostUserProtocolFeatures::all())
- }
-
- fn set_protocol_features(&mut self, features: u64) -> Result<()> {
- // Note: slave that reported VHOST_USER_F_PROTOCOL_FEATURES must
- // support this message even before VHOST_USER_SET_FEATURES was
- // called.
- // What happens if the master calls set_features() with
- // VHOST_USER_F_PROTOCOL_FEATURES cleared after calling this
- // interface?
- self.acked_protocol_features = features;
- Ok(())
- }
-
fn set_mem_table(&mut self, _ctx: &[VhostUserMemoryRegion], _fds: &[RawFd]) -> Result<()> {
- // TODO
Ok(())
}
- fn get_queue_num(&mut self) -> Result<u64> {
- Ok(MAX_QUEUE_NUM as u64)
- }
-
fn set_vring_num(&mut self, index: u32, num: u32) -> Result<()> {
if index as usize >= self.queue_num || num == 0 || num as usize > MAX_VRING_NUM {
return Err(Error::InvalidParam);
@@ -199,6 +178,25 @@ impl VhostUserSlaveReqHandler for DummySlaveReqHandler {
Ok(())
}
+ fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures> {
+ Ok(VhostUserProtocolFeatures::all())
+ }
+
+ fn set_protocol_features(&mut self, features: u64) -> Result<()> {
+ // Note: slave that reported VHOST_USER_F_PROTOCOL_FEATURES must
+ // support this message even before VHOST_USER_SET_FEATURES was
+ // called.
+ // What happens if the master calls set_features() with
+ // VHOST_USER_F_PROTOCOL_FEATURES cleared after calling this
+ // interface?
+ self.acked_protocol_features = features;
+ Ok(())
+ }
+
+ fn get_queue_num(&mut self) -> Result<u64> {
+ Ok(MAX_QUEUE_NUM as u64)
+ }
+
fn set_vring_enable(&mut self, index: u32, enable: bool) -> Result<()> {
// This request should be handled only when VHOST_USER_F_PROTOCOL_FEATURES
// has been negotiated.
@@ -222,10 +220,9 @@ impl VhostUserSlaveReqHandler for DummySlaveReqHandler {
size: u32,
_flags: VhostUserConfigFlags,
) -> Result<Vec<u8>> {
- if self.acked_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 {
+ if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 {
return Err(Error::InvalidOperation);
- } else if offset < VHOST_USER_CONFIG_OFFSET
- || offset >= VHOST_USER_CONFIG_SIZE
+ } else if !(VHOST_USER_CONFIG_OFFSET..VHOST_USER_CONFIG_SIZE).contains(&offset)
|| size > VHOST_USER_CONFIG_SIZE - VHOST_USER_CONFIG_OFFSET
|| size + offset > VHOST_USER_CONFIG_SIZE
{
@@ -236,10 +233,9 @@ impl VhostUserSlaveReqHandler for DummySlaveReqHandler {
fn set_config(&mut self, offset: u32, buf: &[u8], _flags: VhostUserConfigFlags) -> Result<()> {
let size = buf.len() as u32;
- if self.acked_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 {
+ if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 {
return Err(Error::InvalidOperation);
- } else if offset < VHOST_USER_CONFIG_OFFSET
- || offset >= VHOST_USER_CONFIG_SIZE
+ } else if !(VHOST_USER_CONFIG_OFFSET..VHOST_USER_CONFIG_SIZE).contains(&offset)
|| size > VHOST_USER_CONFIG_SIZE - VHOST_USER_CONFIG_OFFSET
|| size + offset > VHOST_USER_CONFIG_SIZE
{
diff --git a/src/vhost_user/master.rs b/src/vhost_user/master.rs
index ffed909..35ca471 100644
--- a/src/vhost_user/master.rs
+++ b/src/vhost_user/master.rs
@@ -6,7 +6,7 @@
use std::mem;
use std::os::unix::io::{AsRawFd, RawFd};
use std::os::unix::net::UnixStream;
-use std::sync::{Arc, Mutex};
+use std::sync::{Arc, Mutex, MutexGuard};
use vmm_sys_util::eventfd::EventFd;
@@ -78,6 +78,10 @@ impl Master {
}
}
+ fn node(&self) -> MutexGuard<MasterInternal> {
+ self.node.lock().unwrap()
+ }
+
/// Create a new instance from a Unix stream socket.
pub fn from_stream(sock: UnixStream, max_queue_num: u64) -> Self {
Self::new(Endpoint::<MasterReq>::from_stream(sock), max_queue_num)
@@ -115,8 +119,8 @@ impl Master {
impl VhostBackend for Master {
/// Get from the underlying vhost implementation the feature bitmask.
- fn get_features(&mut self) -> Result<u64> {
- let mut node = self.node.lock().unwrap();
+ fn get_features(&self) -> Result<u64> {
+ let mut node = self.node();
let hdr = node.send_request_header(MasterReq::GET_FEATURES, None)?;
let val = node.recv_reply::<VhostUserU64>(&hdr)?;
node.virtio_features = val.value;
@@ -124,8 +128,8 @@ impl VhostBackend for Master {
}
/// Enable features in the underlying vhost implementation using a bitmask.
- fn set_features(&mut self, features: u64) -> Result<()> {
- let mut node = self.node.lock().unwrap();
+ fn set_features(&self, features: u64) -> Result<()> {
+ let mut node = self.node();
let val = VhostUserU64::new(features);
let _ = node.send_request_with_body(MasterReq::SET_FEATURES, &val, None)?;
// Don't wait for ACK here because the protocol feature negotiation process hasn't been
@@ -135,18 +139,18 @@ impl VhostBackend for Master {
}
/// Set the current Master as an owner of the session.
- fn set_owner(&mut self) -> Result<()> {
+ fn set_owner(&self) -> Result<()> {
// We unwrap() the return value to assert that we are not expecting threads to ever fail
// while holding the lock.
- let mut node = self.node.lock().unwrap();
+ let mut node = self.node();
let _ = node.send_request_header(MasterReq::SET_OWNER, None)?;
// Don't wait for ACK here because the protocol feature negotiation process hasn't been
// completed yet.
Ok(())
}
- fn reset_owner(&mut self) -> Result<()> {
- let mut node = self.node.lock().unwrap();
+ fn reset_owner(&self) -> Result<()> {
+ let mut node = self.node();
let _ = node.send_request_header(MasterReq::RESET_OWNER, None)?;
// Don't wait for ACK here because the protocol feature negotiation process hasn't been
// completed yet.
@@ -155,7 +159,7 @@ impl VhostBackend for Master {
/// Set the memory map regions on the slave so it can translate the vring
/// addresses. In the ancillary data there is an array of file descriptors
- fn set_mem_table(&mut self, regions: &[VhostUserMemoryRegionInfo]) -> Result<()> {
+ fn set_mem_table(&self, regions: &[VhostUserMemoryRegionInfo]) -> Result<()> {
if regions.is_empty() || regions.len() > MAX_ATTACHED_FD_ENTRIES {
return error_code(VhostUserError::InvalidParam);
}
@@ -174,12 +178,13 @@ impl VhostBackend for Master {
ctx.append(&reg, region.mmap_handle);
}
- let mut node = self.node.lock().unwrap();
+ let mut node = self.node();
let body = VhostUserMemory::new(ctx.regions.len() as u32);
+ let (_, payload, _) = unsafe { ctx.regions.align_to::<u8>() };
let hdr = node.send_request_with_payload(
MasterReq::SET_MEM_TABLE,
&body,
- ctx.regions.as_slice(),
+ payload,
Some(ctx.fds.as_slice()),
)?;
node.wait_for_ack(&hdr).map_err(|e| e.into())
@@ -187,8 +192,8 @@ impl VhostBackend for Master {
// Clippy doesn't seem to know that if let with && is still experimental
#[allow(clippy::unnecessary_unwrap)]
- fn set_log_base(&mut self, base: u64, fd: Option<RawFd>) -> Result<()> {
- let mut node = self.node.lock().unwrap();
+ fn set_log_base(&self, base: u64, fd: Option<RawFd>) -> Result<()> {
+ let mut node = self.node();
let val = VhostUserU64::new(base);
if node.acked_protocol_features & VhostUserProtocolFeatures::LOG_SHMFD.bits() != 0
@@ -202,16 +207,16 @@ impl VhostBackend for Master {
Ok(())
}
- fn set_log_fd(&mut self, fd: RawFd) -> Result<()> {
- let mut node = self.node.lock().unwrap();
+ fn set_log_fd(&self, fd: RawFd) -> Result<()> {
+ let mut node = self.node();
let fds = [fd];
node.send_request_header(MasterReq::SET_LOG_FD, Some(&fds))?;
Ok(())
}
/// Set the size of the queue.
- fn set_vring_num(&mut self, queue_index: usize, num: u16) -> Result<()> {
- let mut node = self.node.lock().unwrap();
+ fn set_vring_num(&self, queue_index: usize, num: u16) -> Result<()> {
+ let mut node = self.node();
if queue_index as u64 >= node.max_queue_num {
return error_code(VhostUserError::InvalidParam);
}
@@ -222,8 +227,8 @@ impl VhostBackend for Master {
}
/// Sets the addresses of the different aspects of the vring.
- fn set_vring_addr(&mut self, queue_index: usize, config_data: &VringConfigData) -> Result<()> {
- let mut node = self.node.lock().unwrap();
+ fn set_vring_addr(&self, queue_index: usize, config_data: &VringConfigData) -> Result<()> {
+ let mut node = self.node();
if queue_index as u64 >= node.max_queue_num
|| config_data.flags & !(VhostUserVringAddrFlags::all().bits()) != 0
{
@@ -236,8 +241,8 @@ impl VhostBackend for Master {
}
/// Sets the base offset in the available vring.
- fn set_vring_base(&mut self, queue_index: usize, base: u16) -> Result<()> {
- let mut node = self.node.lock().unwrap();
+ fn set_vring_base(&self, queue_index: usize, base: u16) -> Result<()> {
+ let mut node = self.node();
if queue_index as u64 >= node.max_queue_num {
return error_code(VhostUserError::InvalidParam);
}
@@ -247,8 +252,8 @@ impl VhostBackend for Master {
node.wait_for_ack(&hdr).map_err(|e| e.into())
}
- fn get_vring_base(&mut self, queue_index: usize) -> Result<u32> {
- let mut node = self.node.lock().unwrap();
+ fn get_vring_base(&self, queue_index: usize) -> Result<u32> {
+ let mut node = self.node();
if queue_index as u64 >= node.max_queue_num {
return error_code(VhostUserError::InvalidParam);
}
@@ -263,8 +268,8 @@ impl VhostBackend for Master {
/// Bits (0-7) of the payload contain the vring index. Bit 8 is the invalid FD flag. This flag
/// is set when there is no file descriptor in the ancillary data. This signals that polling
/// will be used instead of waiting for the call.
- fn set_vring_call(&mut self, queue_index: usize, fd: &EventFd) -> Result<()> {
- let mut node = self.node.lock().unwrap();
+ fn set_vring_call(&self, queue_index: usize, fd: &EventFd) -> Result<()> {
+ let mut node = self.node();
if queue_index as u64 >= node.max_queue_num {
return error_code(VhostUserError::InvalidParam);
}
@@ -276,8 +281,8 @@ impl VhostBackend for Master {
/// Bits (0-7) of the payload contain the vring index. Bit 8 is the invalid FD flag. This flag
/// is set when there is no file descriptor in the ancillary data. This signals that polling
/// should be used instead of waiting for a kick.
- fn set_vring_kick(&mut self, queue_index: usize, fd: &EventFd) -> Result<()> {
- let mut node = self.node.lock().unwrap();
+ fn set_vring_kick(&self, queue_index: usize, fd: &EventFd) -> Result<()> {
+ let mut node = self.node();
if queue_index as u64 >= node.max_queue_num {
return error_code(VhostUserError::InvalidParam);
}
@@ -288,8 +293,8 @@ impl VhostBackend for Master {
/// Set the event file descriptor to signal when error occurs.
/// Bits (0-7) of the payload contain the vring index. Bit 8 is the invalid FD flag. This flag
/// is set when there is no file descriptor in the ancillary data.
- fn set_vring_err(&mut self, queue_index: usize, fd: &EventFd) -> Result<()> {
- let mut node = self.node.lock().unwrap();
+ fn set_vring_err(&self, queue_index: usize, fd: &EventFd) -> Result<()> {
+ let mut node = self.node();
if queue_index as u64 >= node.max_queue_num {
return error_code(VhostUserError::InvalidParam);
}
@@ -300,7 +305,7 @@ impl VhostBackend for Master {
impl VhostUserMaster for Master {
fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures> {
- let mut node = self.node.lock().unwrap();
+ let mut node = self.node();
let flag = VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits();
if node.virtio_features & flag == 0 || node.acked_virtio_features & flag == 0 {
return error_code(VhostUserError::InvalidOperation);
@@ -317,7 +322,7 @@ impl VhostUserMaster for Master {
}
fn set_protocol_features(&mut self, features: VhostUserProtocolFeatures) -> Result<()> {
- let mut node = self.node.lock().unwrap();
+ let mut node = self.node();
let flag = VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits();
if node.virtio_features & flag == 0 || node.acked_virtio_features & flag == 0 {
return error_code(VhostUserError::InvalidOperation);
@@ -332,7 +337,7 @@ impl VhostUserMaster for Master {
}
fn get_queue_num(&mut self) -> Result<u64> {
- let mut node = self.node.lock().unwrap();
+ let mut node = self.node();
if !node.is_feature_mq_available() {
return error_code(VhostUserError::InvalidOperation);
}
@@ -347,7 +352,7 @@ impl VhostUserMaster for Master {
}
fn set_vring_enable(&mut self, queue_index: usize, enable: bool) -> Result<()> {
- let mut node = self.node.lock().unwrap();
+ let mut node = self.node();
// set_vring_enable() is supported only when PROTOCOL_FEATURES has been enabled.
if node.acked_virtio_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() == 0 {
return error_code(VhostUserError::InvalidOperation);
@@ -373,7 +378,7 @@ impl VhostUserMaster for Master {
return error_code(VhostUserError::InvalidParam);
}
- let mut node = self.node.lock().unwrap();
+ let mut node = self.node();
// depends on VhostUserProtocolFeatures::CONFIG
if node.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 {
return error_code(VhostUserError::InvalidOperation);
@@ -390,9 +395,13 @@ impl VhostUserMaster for Master {
return error_code(VhostUserError::InvalidMessage);
} else if body_reply.size == 0 {
return error_code(VhostUserError::SlaveInternalError);
- } else if body_reply.size != body.size || body_reply.size as usize != buf.len() {
+ } else if body_reply.size != body.size
+ || body_reply.size as usize != buf.len()
+ || body_reply.offset != body.offset
+ {
return error_code(VhostUserError::InvalidMessage);
}
+
Ok((body_reply, buf_reply))
}
@@ -405,7 +414,7 @@ impl VhostUserMaster for Master {
return error_code(VhostUserError::InvalidParam);
}
- let mut node = self.node.lock().unwrap();
+ let mut node = self.node();
// depends on VhostUserProtocolFeatures::CONFIG
if node.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 {
return error_code(VhostUserError::InvalidOperation);
@@ -416,7 +425,7 @@ impl VhostUserMaster for Master {
}
fn set_slave_request_fd(&mut self, fd: RawFd) -> Result<()> {
- let mut node = self.node.lock().unwrap();
+ let mut node = self.node();
if node.acked_protocol_features & VhostUserProtocolFeatures::SLAVE_REQ.bits() == 0 {
return error_code(VhostUserError::InvalidOperation);
}
@@ -429,7 +438,7 @@ impl VhostUserMaster for Master {
impl AsRawFd for Master {
fn as_raw_fd(&self) -> RawFd {
- let node = self.node.lock().unwrap();
+ let node = self.node();
node.main_sock.as_raw_fd()
}
}
@@ -503,14 +512,14 @@ impl MasterInternal {
Ok(hdr)
}
- fn send_request_with_payload<T: Sized, P: Sized>(
+ fn send_request_with_payload<T: Sized>(
&mut self,
code: MasterReq,
msg: &T,
- payload: &[P],
+ payload: &[u8],
fds: Option<&[RawFd]>,
) -> VhostUserResult<VhostUserMsgHeader<MasterReq>> {
- let len = mem::size_of::<T>() + payload.len() * mem::size_of::<P>();
+ let len = mem::size_of::<T>() + payload.len();
if len > MAX_MSG_SIZE {
return Err(VhostUserError::InvalidParam);
}
@@ -568,7 +577,11 @@ impl MasterInternal {
&mut self,
hdr: &VhostUserMsgHeader<MasterReq>,
) -> VhostUserResult<(T, Vec<u8>, Option<Vec<RawFd>>)> {
- if mem::size_of::<T>() > MAX_MSG_SIZE || hdr.is_reply() {
+ if mem::size_of::<T>() > MAX_MSG_SIZE
+ || hdr.get_size() as usize <= mem::size_of::<T>()
+ || hdr.get_size() as usize > MAX_MSG_SIZE
+ || hdr.is_reply()
+ {
return Err(VhostUserError::InvalidParam);
}
self.check_state()?;
@@ -582,11 +595,8 @@ impl MasterInternal {
{
Endpoint::<MasterReq>::close_rfds(rfds);
return Err(VhostUserError::InvalidMessage);
- } else if bytes > MAX_MSG_SIZE - mem::size_of::<T>() {
+ } else if bytes != buf.len() {
return Err(VhostUserError::InvalidMessage);
- } else if bytes < buf.len() {
- // It's safe because we have checked the buffer size
- unsafe { buf.set_len(bytes) };
}
Ok((body, buf, rfds))
}
@@ -634,11 +644,14 @@ impl MasterInternal {
mod tests {
use super::super::connection::Listener;
use super::*;
+ use vmm_sys_util::rand::rand_alphanumerics;
- const UNIX_SOCKET_MASTER: &'static str = "/tmp/vhost_user_test_rust_master";
- const UNIX_SOCKET_MASTER2: &'static str = "/tmp/vhost_user_test_rust_master2";
- const UNIX_SOCKET_MASTER3: &'static str = "/tmp/vhost_user_test_rust_master3";
- const UNIX_SOCKET_MASTER4: &'static str = "/tmp/vhost_user_test_rust_master4";
+ fn temp_path() -> String {
+ format!(
+ "/tmp/vhost_test_{}",
+ rand_alphanumerics(8).to_str().unwrap()
+ )
+ }
fn create_pair(path: &str) -> (Master, Endpoint<MasterReq>) {
let listener = Listener::new(path, true).unwrap();
@@ -649,14 +662,15 @@ mod tests {
}
#[test]
- #[ignore]
fn create_master() {
- let listener = Listener::new(UNIX_SOCKET_MASTER, true).unwrap();
+ let path = temp_path();
+ let listener = Listener::new(&path, true).unwrap();
listener.set_nonblocking(true).unwrap();
- let mut master = Master::connect(UNIX_SOCKET_MASTER, 1).unwrap();
+ let master = Master::connect(&path, 1).unwrap();
let mut slave = Endpoint::<MasterReq>::from_stream(listener.accept().unwrap().unwrap());
+ assert!(master.as_raw_fd() > 0);
// Send two messages continuously
master.set_owner().unwrap();
master.reset_owner().unwrap();
@@ -675,24 +689,24 @@ mod tests {
}
#[test]
- #[ignore]
fn test_create_failure() {
- let _ = Listener::new(UNIX_SOCKET_MASTER2, true).unwrap();
- let _ = Listener::new(UNIX_SOCKET_MASTER2, false).is_err();
- assert!(Master::connect(UNIX_SOCKET_MASTER2, 1).is_err());
+ let path = temp_path();
+ let _ = Listener::new(&path, true).unwrap();
+ let _ = Listener::new(&path, false).is_err();
+ assert!(Master::connect(&path, 1).is_err());
- let listener = Listener::new(UNIX_SOCKET_MASTER2, true).unwrap();
- assert!(Listener::new(UNIX_SOCKET_MASTER2, false).is_err());
+ let listener = Listener::new(&path, true).unwrap();
+ assert!(Listener::new(&path, false).is_err());
listener.set_nonblocking(true).unwrap();
- let _master = Master::connect(UNIX_SOCKET_MASTER2, 1).unwrap();
+ let _master = Master::connect(&path, 1).unwrap();
let _slave = listener.accept().unwrap().unwrap();
}
#[test]
- #[ignore]
fn test_features() {
- let (mut master, mut peer) = create_pair(UNIX_SOCKET_MASTER3);
+ let path = temp_path();
+ let (master, mut peer) = create_pair(&path);
master.set_owner().unwrap();
let (hdr, rfds) = peer.recv_header().unwrap();
@@ -709,6 +723,9 @@ mod tests {
let (_hdr, rfds) = peer.recv_header().unwrap();
assert!(rfds.is_none());
+ let hdr = VhostUserMsgHeader::new(MasterReq::SET_FEATURES, 0x4, 8);
+ let msg = VhostUserU64::new(0x15);
+ peer.send_message(&hdr, &msg, None).unwrap();
master.set_features(0x15).unwrap();
let (_hdr, msg, rfds) = peer.recv_body::<VhostUserU64>().unwrap();
assert!(rfds.is_none());
@@ -722,9 +739,9 @@ mod tests {
}
#[test]
- #[ignore]
fn test_protocol_features() {
- let (mut master, mut peer) = create_pair(UNIX_SOCKET_MASTER4);
+ let path = temp_path();
+ let (mut master, mut peer) = create_pair(&path);
master.set_owner().unwrap();
let (hdr, rfds) = peer.recv_header().unwrap();
@@ -773,12 +790,209 @@ mod tests {
}
#[test]
- fn test_set_mem_table() {
- // TODO
+ fn test_master_set_config_negative() {
+ let path = temp_path();
+ let (mut master, _peer) = create_pair(&path);
+ let buf = vec![0x0; MAX_MSG_SIZE + 1];
+
+ master
+ .set_config(0x100, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .unwrap_err();
+
+ {
+ let mut node = master.node();
+ node.virtio_features = 0xffff_ffff;
+ node.acked_virtio_features = 0xffff_ffff;
+ node.protocol_features = 0xffff_ffff;
+ node.acked_protocol_features = 0xffff_ffff;
+ }
+
+ master
+ .set_config(0x100, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .unwrap();
+ master
+ .set_config(0x0, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .unwrap_err();
+ master
+ .set_config(0x1000, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .unwrap_err();
+ master
+ .set_config(
+ 0x100,
+ unsafe { VhostUserConfigFlags::from_bits_unchecked(0xffff_ffff) },
+ &buf[0..4],
+ )
+ .unwrap_err();
+ master
+ .set_config(0x100, VhostUserConfigFlags::WRITABLE, &buf)
+ .unwrap_err();
+ master
+ .set_config(0x100, VhostUserConfigFlags::WRITABLE, &[])
+ .unwrap_err();
+ }
+
+ fn create_pair2() -> (Master, Endpoint<MasterReq>) {
+ let path = temp_path();
+ let (master, peer) = create_pair(&path);
+
+ {
+ let mut node = master.node();
+ node.virtio_features = 0xffff_ffff;
+ node.acked_virtio_features = 0xffff_ffff;
+ node.protocol_features = 0xffff_ffff;
+ node.acked_protocol_features = 0xffff_ffff;
+ }
+
+ (master, peer)
+ }
+
+ #[test]
+ fn test_master_get_config_negative0() {
+ let (mut master, mut peer) = create_pair2();
+ let buf = vec![0x0; MAX_MSG_SIZE + 1];
+
+ let mut hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16);
+ let msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty());
+ peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
+ .unwrap();
+ assert!(master
+ .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .is_ok());
+
+ hdr.set_code(MasterReq::GET_FEATURES);
+ peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
+ .unwrap();
+ assert!(master
+ .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .is_err());
+ hdr.set_code(MasterReq::GET_CONFIG);
+ }
+
+ #[test]
+ fn test_master_get_config_negative1() {
+ let (mut master, mut peer) = create_pair2();
+ let buf = vec![0x0; MAX_MSG_SIZE + 1];
+
+ let mut hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16);
+ let msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty());
+ peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
+ .unwrap();
+ assert!(master
+ .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .is_ok());
+
+ hdr.set_reply(false);
+ peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
+ .unwrap();
+ assert!(master
+ .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .is_err());
}
#[test]
- fn test_get_ring_num() {
- // TODO
+ fn test_master_get_config_negative2() {
+ let (mut master, mut peer) = create_pair2();
+ let buf = vec![0x0; MAX_MSG_SIZE + 1];
+
+ let hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16);
+ let msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty());
+ peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
+ .unwrap();
+ assert!(master
+ .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .is_ok());
+ }
+
+ #[test]
+ fn test_master_get_config_negative3() {
+ let (mut master, mut peer) = create_pair2();
+ let buf = vec![0x0; MAX_MSG_SIZE + 1];
+
+ let hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16);
+ let mut msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty());
+ peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
+ .unwrap();
+ assert!(master
+ .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .is_ok());
+
+ msg.offset = 0;
+ peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
+ .unwrap();
+ assert!(master
+ .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .is_err());
+ }
+
+ #[test]
+ fn test_master_get_config_negative4() {
+ let (mut master, mut peer) = create_pair2();
+ let buf = vec![0x0; MAX_MSG_SIZE + 1];
+
+ let hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16);
+ let mut msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty());
+ peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
+ .unwrap();
+ assert!(master
+ .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .is_ok());
+
+ msg.offset = 0x101;
+ peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
+ .unwrap();
+ assert!(master
+ .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .is_err());
+ }
+
+ #[test]
+ fn test_master_get_config_negative5() {
+ let (mut master, mut peer) = create_pair2();
+ let buf = vec![0x0; MAX_MSG_SIZE + 1];
+
+ let hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16);
+ let mut msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty());
+ peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
+ .unwrap();
+ assert!(master
+ .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .is_ok());
+
+ msg.offset = (MAX_MSG_SIZE + 1) as u32;
+ peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
+ .unwrap();
+ assert!(master
+ .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .is_err());
+ }
+
+ #[test]
+ fn test_master_get_config_negative6() {
+ let (mut master, mut peer) = create_pair2();
+ let buf = vec![0x0; MAX_MSG_SIZE + 1];
+
+ let hdr = VhostUserMsgHeader::new(MasterReq::GET_CONFIG, 0x4, 16);
+ let mut msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty());
+ peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
+ .unwrap();
+ assert!(master
+ .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .is_ok());
+
+ msg.size = 6;
+ peer.send_message_with_payload(&hdr, &msg, &buf[0..6], None)
+ .unwrap();
+ assert!(master
+ .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
+ .is_err());
+ }
+
+ #[test]
+ fn test_maset_set_mem_table_failure() {
+ let (master, _peer) = create_pair2();
+
+ master.set_mem_table(&[]).unwrap_err();
+ let tables = vec![VhostUserMemoryRegionInfo::default(); MAX_ATTACHED_FD_ENTRIES + 1];
+ master.set_mem_table(&tables).unwrap_err();
}
}
diff --git a/src/vhost_user/master_req_handler.rs b/src/vhost_user/master_req_handler.rs
index aadfeee..8cba188 100644
--- a/src/vhost_user/master_req_handler.rs
+++ b/src/vhost_user/master_req_handler.rs
@@ -1,9 +1,6 @@
-// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved.
+// Copyright (C) 2019-2021 Alibaba Cloud. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
-//! Traits and Structs to handle vhost-user requests from the slave to the master.
-
-use libc;
use std::mem;
use std::os::unix::io::{AsRawFd, RawFd};
use std::os::unix::net::UnixStream;
@@ -13,83 +10,189 @@ use super::connection::Endpoint;
use super::message::*;
use super::{Error, HandlerResult, Result};
-/// Trait to handle vhost-user requests from the slave to the master.
+/// Define services provided by masters for the slave communication channel.
+///
+/// The vhost-user specification defines a slave communication channel, by which slaves could
+/// request services from masters. The [VhostUserMasterReqHandler] trait defines services provided
+/// by masters, and it's used both on the master side and slave side.
+/// - on the slave side, a stub forwarder implementing [VhostUserMasterReqHandler] will proxy
+/// service requests to masters. The [SlaveFsCacheReq] is an example stub forwarder.
+/// - on the master side, the [MasterReqHandler] will forward service requests to a handler
+/// implementing [VhostUserMasterReqHandler].
+///
+/// The [VhostUserMasterReqHandler] trait is design with interior mutability to improve performance
+/// for multi-threading.
+///
+/// [VhostUserMasterReqHandler]: trait.VhostUserMasterReqHandler.html
+/// [MasterReqHandler]: struct.MasterReqHandler.html
+/// [SlaveFsCacheReq]: struct.SlaveFsCacheReq.html
pub trait VhostUserMasterReqHandler {
+ /// Handle device configuration change notifications.
+ fn handle_config_change(&self) -> HandlerResult<u64> {
+ Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
+ }
+
+ /// Handle virtio-fs map file requests.
+ fn fs_slave_map(&self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> {
+ // Safe because we have just received the rawfd from kernel.
+ unsafe { libc::close(fd) };
+ Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
+ }
+
+ /// Handle virtio-fs unmap file requests.
+ fn fs_slave_unmap(&self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
+ Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
+ }
+
+ /// Handle virtio-fs sync file requests.
+ fn fs_slave_sync(&self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
+ Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
+ }
+
+ /// Handle virtio-fs file IO requests.
+ fn fs_slave_io(&self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> {
+ // Safe because we have just received the rawfd from kernel.
+ unsafe { libc::close(fd) };
+ Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
+ }
+
// fn handle_iotlb_msg(&mut self, iotlb: VhostUserIotlb);
// fn handle_vring_host_notifier(&mut self, area: VhostUserVringArea, fd: RawFd);
+}
- /// Handle device configuration change notifications from the slave.
+/// A helper trait mirroring [VhostUserMasterReqHandler] but without interior mutability.
+///
+/// [VhostUserMasterReqHandler]: trait.VhostUserMasterReqHandler.html
+pub trait VhostUserMasterReqHandlerMut {
+ /// Handle device configuration change notifications.
fn handle_config_change(&mut self) -> HandlerResult<u64> {
Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
}
- /// Handle virtio-fs map file requests from the slave.
+ /// Handle virtio-fs map file requests.
fn fs_slave_map(&mut self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> {
// Safe because we have just received the rawfd from kernel.
unsafe { libc::close(fd) };
Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
}
- /// Handle virtio-fs unmap file requests from the slave.
+ /// Handle virtio-fs unmap file requests.
fn fs_slave_unmap(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
}
- /// Handle virtio-fs sync file requests from the slave.
+ /// Handle virtio-fs sync file requests.
fn fs_slave_sync(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
}
- /// Handle virtio-fs file IO requests from the slave.
+ /// Handle virtio-fs file IO requests.
fn fs_slave_io(&mut self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> {
// Safe because we have just received the rawfd from kernel.
unsafe { libc::close(fd) };
Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
}
+
+ // fn handle_iotlb_msg(&mut self, iotlb: VhostUserIotlb);
+ // fn handle_vring_host_notifier(&mut self, area: VhostUserVringArea, fd: RawFd);
+}
+
+impl<S: VhostUserMasterReqHandlerMut> VhostUserMasterReqHandler for Mutex<S> {
+ fn handle_config_change(&self) -> HandlerResult<u64> {
+ self.lock().unwrap().handle_config_change()
+ }
+
+ fn fs_slave_map(&self, fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> {
+ self.lock().unwrap().fs_slave_map(fs, fd)
+ }
+
+ fn fs_slave_unmap(&self, fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
+ self.lock().unwrap().fs_slave_unmap(fs)
+ }
+
+ fn fs_slave_sync(&self, fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
+ self.lock().unwrap().fs_slave_sync(fs)
+ }
+
+ fn fs_slave_io(&self, fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> {
+ self.lock().unwrap().fs_slave_io(fs, fd)
+ }
}
-/// A vhost-user master request endpoint which relays all received requests from the slave to the
-/// provided request handler.
+/// Server to handle service requests from slaves from the slave communication channel.
+///
+/// The [MasterReqHandler] acts as a server on the master side, to handle service requests from
+/// slaves on the slave communication channel. It's actually a proxy invoking the registered
+/// handler implementing [VhostUserMasterReqHandler] to do the real work.
+///
+/// [MasterReqHandler]: struct.MasterReqHandler.html
+/// [VhostUserMasterReqHandler]: trait.VhostUserMasterReqHandler.html
pub struct MasterReqHandler<S: VhostUserMasterReqHandler> {
// underlying Unix domain socket for communication
sub_sock: Endpoint<SlaveReq>,
tx_sock: UnixStream,
+ // Protocol feature VHOST_USER_PROTOCOL_F_REPLY_ACK has been negotiated.
+ reply_ack_negotiated: bool,
// the VirtIO backend device object
- backend: Arc<Mutex<S>>,
+ backend: Arc<S>,
// whether the endpoint has encountered any failure
error: Option<i32>,
}
impl<S: VhostUserMasterReqHandler> MasterReqHandler<S> {
- /// Create a vhost-user slave request handler.
- /// This opens a pair of connected anonymous sockets.
- /// Returns Self and the socket that must be sent to the slave via SET_SLAVE_REQ_FD.
- pub fn new(backend: Arc<Mutex<S>>) -> Result<Self> {
+ /// Create a server to handle service requests from slaves on the slave communication channel.
+ ///
+ /// This opens a pair of connected anonymous sockets to form the slave communication channel.
+ /// The socket fd returned by [Self::get_tx_raw_fd()] should be sent to the slave by
+ /// [VhostUserMaster::set_slave_request_fd()].
+ ///
+ /// [Self::get_tx_raw_fd()]: struct.MasterReqHandler.html#method.get_tx_raw_fd
+ /// [VhostUserMaster::set_slave_request_fd()]: trait.VhostUserMaster.html#tymethod.set_slave_request_fd
+ pub fn new(backend: Arc<S>) -> Result<Self> {
let (tx, rx) = UnixStream::pair().map_err(Error::SocketError)?;
Ok(MasterReqHandler {
sub_sock: Endpoint::<SlaveReq>::from_stream(rx),
tx_sock: tx,
+ reply_ack_negotiated: false,
backend,
error: None,
})
}
- /// Get the raw fd to send to the slave as slave communication channel.
+ /// Get the socket fd for the slave to communication with the master.
+ ///
+ /// The returned fd should be sent to the slave by [VhostUserMaster::set_slave_request_fd()].
+ ///
+ /// [VhostUserMaster::set_slave_request_fd()]: trait.VhostUserMaster.html#tymethod.set_slave_request_fd
pub fn get_tx_raw_fd(&self) -> RawFd {
self.tx_sock.as_raw_fd()
}
- /// Mark endpoint as failed or normal state.
+ /// Set the negotiation state of the `VHOST_USER_PROTOCOL_F_REPLY_ACK` protocol feature.
+ ///
+ /// When the `VHOST_USER_PROTOCOL_F_REPLY_ACK` protocol feature has been negotiated,
+ /// the "REPLY_ACK" flag will be set in the message header for every slave to master request
+ /// message.
+ pub fn set_reply_ack_flag(&mut self, enable: bool) {
+ self.reply_ack_negotiated = enable;
+ }
+
+ /// Mark endpoint as failed or in normal state.
pub fn set_failed(&mut self, error: i32) {
- self.error = Some(error);
+ if error == 0 {
+ self.error = None;
+ } else {
+ self.error = Some(error);
+ }
}
- /// Receive and handle one incoming request message from the slave.
+ /// Main entrance to server slave request from the slave communication channel.
+ ///
/// The caller needs to:
- /// . serialize calls to this function
- /// . decide what to do when errer happens
- /// . optional recover from failure
+ /// - serialize calls to this function
+ /// - decide what to do when errer happens
+ /// - optional recover from failure
pub fn handle_request(&mut self) -> Result<u64> {
// Return error if the endpoint is already in failed state.
self.check_state()?;
@@ -108,6 +211,9 @@ impl<S: VhostUserMasterReqHandler> MasterReqHandler<S> {
let (size, buf) = match hdr.get_size() {
0 => (0, vec![0u8; 0]),
len => {
+ if len as usize > MAX_MSG_SIZE {
+ return Err(Error::InvalidMessage);
+ }
let (size2, rbuf) = self.sub_sock.recv_data(len as usize)?;
if size2 != len as usize {
return Err(Error::InvalidMessage);
@@ -120,41 +226,33 @@ impl<S: VhostUserMasterReqHandler> MasterReqHandler<S> {
SlaveReq::CONFIG_CHANGE_MSG => {
self.check_msg_size(&hdr, size, 0)?;
self.backend
- .lock()
- .unwrap()
.handle_config_change()
.map_err(Error::ReqHandlerError)
}
SlaveReq::FS_MAP => {
let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?;
+ // check_attached_rfds() has validated rfds
self.backend
- .lock()
- .unwrap()
- .fs_slave_map(msg, rfds.unwrap()[0])
+ .fs_slave_map(&msg, rfds.unwrap()[0])
.map_err(Error::ReqHandlerError)
}
SlaveReq::FS_UNMAP => {
let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?;
self.backend
- .lock()
- .unwrap()
- .fs_slave_unmap(msg)
+ .fs_slave_unmap(&msg)
.map_err(Error::ReqHandlerError)
}
SlaveReq::FS_SYNC => {
let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?;
self.backend
- .lock()
- .unwrap()
- .fs_slave_sync(msg)
+ .fs_slave_sync(&msg)
.map_err(Error::ReqHandlerError)
}
SlaveReq::FS_IO => {
let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?;
+ // check_attached_rfds() has validated rfds
self.backend
- .lock()
- .unwrap()
- .fs_slave_io(msg, rfds.unwrap()[0])
+ .fs_slave_io(&msg, rfds.unwrap()[0])
.map_err(Error::ReqHandlerError)
}
_ => Err(Error::InvalidMessage),
@@ -211,7 +309,7 @@ impl<S: VhostUserMasterReqHandler> MasterReqHandler<S> {
_ => {
if rfds.is_some() {
Endpoint::<SlaveReq>::close_rfds(rfds);
- return Err(Error::InvalidMessage);
+ Err(Error::InvalidMessage)
} else {
Ok(rfds)
}
@@ -219,14 +317,14 @@ impl<S: VhostUserMasterReqHandler> MasterReqHandler<S> {
}
}
- fn extract_msg_body<'a, T: Sized + VhostUserMsgValidator>(
+ fn extract_msg_body<T: Sized + VhostUserMsgValidator>(
&self,
hdr: &VhostUserMsgHeader<SlaveReq>,
size: usize,
- buf: &'a [u8],
- ) -> Result<&'a T> {
+ buf: &[u8],
+ ) -> Result<T> {
self.check_msg_size(hdr, size, mem::size_of::<T>())?;
- let msg = unsafe { &*(buf.as_ptr() as *const T) };
+ let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const T) };
if !msg.is_valid() {
return Err(Error::InvalidMessage);
}
@@ -253,7 +351,7 @@ impl<S: VhostUserMasterReqHandler> MasterReqHandler<S> {
req: &VhostUserMsgHeader<SlaveReq>,
res: &Result<u64>,
) -> Result<()> {
- if req.is_need_reply() {
+ if self.reply_ack_negotiated && req.is_need_reply() {
let hdr = self.new_reply_header::<VhostUserU64>(req)?;
let def_err = libc::EINVAL;
let val = match res {
@@ -278,3 +376,102 @@ impl<S: VhostUserMasterReqHandler> AsRawFd for MasterReqHandler<S> {
self.sub_sock.as_raw_fd()
}
}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[cfg(feature = "vhost-user-slave")]
+ use crate::vhost_user::SlaveFsCacheReq;
+ #[cfg(feature = "vhost-user-slave")]
+ use std::os::unix::io::FromRawFd;
+
+ struct MockMasterReqHandler {}
+
+ impl VhostUserMasterReqHandlerMut for MockMasterReqHandler {
+ /// Handle virtio-fs map file requests from the slave.
+ fn fs_slave_map(&mut self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> {
+ // Safe because we have just received the rawfd from kernel.
+ unsafe { libc::close(fd) };
+ Ok(0)
+ }
+
+ /// Handle virtio-fs unmap file requests from the slave.
+ fn fs_slave_unmap(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
+ Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
+ }
+ }
+
+ #[test]
+ fn test_new_master_req_handler() {
+ let backend = Arc::new(Mutex::new(MockMasterReqHandler {}));
+ let mut handler = MasterReqHandler::new(backend).unwrap();
+
+ assert!(handler.get_tx_raw_fd() >= 0);
+ assert!(handler.as_raw_fd() >= 0);
+ handler.check_state().unwrap();
+
+ assert_eq!(handler.error, None);
+ handler.set_failed(libc::EAGAIN);
+ assert_eq!(handler.error, Some(libc::EAGAIN));
+ handler.check_state().unwrap_err();
+ }
+
+ #[cfg(feature = "vhost-user-slave")]
+ #[test]
+ fn test_master_slave_req_handler() {
+ let backend = Arc::new(Mutex::new(MockMasterReqHandler {}));
+ let mut handler = MasterReqHandler::new(backend).unwrap();
+
+ let fd = unsafe { libc::dup(handler.get_tx_raw_fd()) };
+ if fd < 0 {
+ panic!("failed to duplicated tx fd!");
+ }
+ let stream = unsafe { UnixStream::from_raw_fd(fd) };
+ let fs_cache = SlaveFsCacheReq::from_stream(stream);
+
+ std::thread::spawn(move || {
+ let res = handler.handle_request().unwrap();
+ assert_eq!(res, 0);
+ handler.handle_request().unwrap_err();
+ });
+
+ fs_cache
+ .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd)
+ .unwrap();
+ // When REPLY_ACK has not been negotiated, the master has no way to detect failure from
+ // slave side.
+ fs_cache
+ .fs_slave_unmap(&VhostUserFSSlaveMsg::default())
+ .unwrap();
+ }
+
+ #[cfg(feature = "vhost-user-slave")]
+ #[test]
+ fn test_master_slave_req_handler_with_ack() {
+ let backend = Arc::new(Mutex::new(MockMasterReqHandler {}));
+ let mut handler = MasterReqHandler::new(backend).unwrap();
+ handler.set_reply_ack_flag(true);
+
+ let fd = unsafe { libc::dup(handler.get_tx_raw_fd()) };
+ if fd < 0 {
+ panic!("failed to duplicated tx fd!");
+ }
+ let stream = unsafe { UnixStream::from_raw_fd(fd) };
+ let fs_cache = SlaveFsCacheReq::from_stream(stream);
+
+ std::thread::spawn(move || {
+ let res = handler.handle_request().unwrap();
+ assert_eq!(res, 0);
+ handler.handle_request().unwrap_err();
+ });
+
+ fs_cache.set_reply_ack_flag(true);
+ fs_cache
+ .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd)
+ .unwrap();
+ fs_cache
+ .fs_slave_unmap(&VhostUserFSSlaveMsg::default())
+ .unwrap_err();
+ }
+}
diff --git a/src/vhost_user/message.rs b/src/vhost_user/message.rs
index 4109b61..8600410 100644
--- a/src/vhost_user/message.rs
+++ b/src/vhost_user/message.rs
@@ -562,9 +562,9 @@ bitflags! {
/// Flags for the device configuration message.
pub struct VhostUserConfigFlags: u32 {
/// Vhost master messages used for writeable fields.
- const WRITABLE = 0x0;
+ const WRITABLE = 0x1;
/// Vhost master messages used for live migration.
- const LIVE_MIGRATION = 0x1;
+ const LIVE_MIGRATION = 0x2;
}
}
@@ -596,9 +596,11 @@ impl VhostUserMsgValidator for VhostUserConfig {
fn is_valid(&self) -> bool {
if (self.flags & !VhostUserConfigFlags::all().bits()) != 0 {
return false;
+ } else if self.offset < 0x100 {
+ return false;
} else if self.size == 0
|| self.size > VHOST_USER_CONFIG_SIZE
- || self.size + self.offset >= VHOST_USER_CONFIG_SIZE
+ || self.size + self.offset > VHOST_USER_CONFIG_SIZE
{
return false;
}
@@ -656,9 +658,9 @@ pub const VHOST_USER_FS_SLAVE_ENTRIES: usize = 8;
#[repr(packed)]
#[derive(Default)]
pub struct VhostUserFSSlaveMsg {
- /// TODO:
+ /// File offset.
pub fd_offset: [u64; VHOST_USER_FS_SLAVE_ENTRIES],
- /// TODO:
+ /// Offset into the DAX window.
pub cache_offset: [u64; VHOST_USER_FS_SLAVE_ENTRIES],
/// Size of region to map.
pub len: [u64; VHOST_USER_FS_SLAVE_ENTRIES],
@@ -686,13 +688,31 @@ mod tests {
use std::mem;
#[test]
- fn check_request_code() {
+ fn check_master_request_code() {
let code = MasterReq::NOOP;
assert!(!code.is_valid());
let code = MasterReq::MAX_CMD;
assert!(!code.is_valid());
+ assert!(code > MasterReq::NOOP);
let code = MasterReq::GET_FEATURES;
assert!(code.is_valid());
+ assert_eq!(code, code.clone());
+ let code: MasterReq = unsafe { std::mem::transmute::<u32, MasterReq>(10000u32) };
+ assert!(!code.is_valid());
+ }
+
+ #[test]
+ fn check_slave_request_code() {
+ let code = SlaveReq::NOOP;
+ assert!(!code.is_valid());
+ let code = SlaveReq::MAX_CMD;
+ assert!(!code.is_valid());
+ assert!(code > SlaveReq::NOOP);
+ let code = SlaveReq::CONFIG_CHANGE_MSG;
+ assert!(code.is_valid());
+ assert_eq!(code, code.clone());
+ let code: SlaveReq = unsafe { std::mem::transmute::<u32, SlaveReq>(10000u32) };
+ assert!(!code.is_valid());
}
#[test]
@@ -741,6 +761,20 @@ mod tests {
assert!(!hdr.is_valid());
hdr.set_version(0x1);
assert!(hdr.is_valid());
+
+ assert_eq!(hdr, hdr.clone());
+ }
+
+ #[test]
+ fn test_vhost_user_message_u64() {
+ let val = VhostUserU64::default();
+ let val1 = VhostUserU64::new(0);
+
+ let a = val.value;
+ let b = val1.value;
+ assert_eq!(a, b);
+ let a = VhostUserU64::new(1).value;
+ assert_eq!(a, 1);
}
#[test]
@@ -775,6 +809,104 @@ mod tests {
msg.guest_phys_addr = 0xFFFFFFFFFFFF0000;
msg.memory_size = 0;
assert!(!msg.is_valid());
+ let a = msg.guest_phys_addr;
+ let b = msg.guest_phys_addr;
+ assert_eq!(a, b);
+
+ let msg = VhostUserMemoryRegion::default();
+ let a = msg.guest_phys_addr;
+ assert_eq!(a, 0);
+ let a = msg.memory_size;
+ assert_eq!(a, 0);
+ let a = msg.user_addr;
+ assert_eq!(a, 0);
+ let a = msg.mmap_offset;
+ assert_eq!(a, 0);
+ }
+
+ #[test]
+ fn test_vhost_user_state() {
+ let state = VhostUserVringState::new(5, 8);
+
+ let a = state.index;
+ assert_eq!(a, 5);
+ let a = state.num;
+ assert_eq!(a, 8);
+ assert_eq!(state.is_valid(), true);
+
+ let state = VhostUserVringState::default();
+ let a = state.index;
+ assert_eq!(a, 0);
+ let a = state.num;
+ assert_eq!(a, 0);
+ assert_eq!(state.is_valid(), true);
+ }
+
+ #[test]
+ fn test_vhost_user_addr() {
+ let mut addr = VhostUserVringAddr::new(
+ 2,
+ VhostUserVringAddrFlags::VHOST_VRING_F_LOG,
+ 0x1000,
+ 0x2000,
+ 0x3000,
+ 0x4000,
+ );
+
+ let a = addr.index;
+ assert_eq!(a, 2);
+ let a = addr.flags;
+ assert_eq!(a, VhostUserVringAddrFlags::VHOST_VRING_F_LOG.bits());
+ let a = addr.descriptor;
+ assert_eq!(a, 0x1000);
+ let a = addr.used;
+ assert_eq!(a, 0x2000);
+ let a = addr.available;
+ assert_eq!(a, 0x3000);
+ let a = addr.log;
+ assert_eq!(a, 0x4000);
+ assert_eq!(addr.is_valid(), true);
+
+ addr.descriptor = 0x1001;
+ assert_eq!(addr.is_valid(), false);
+ addr.descriptor = 0x1000;
+
+ addr.available = 0x3001;
+ assert_eq!(addr.is_valid(), false);
+ addr.available = 0x3000;
+
+ addr.used = 0x2001;
+ assert_eq!(addr.is_valid(), false);
+ addr.used = 0x2000;
+ assert_eq!(addr.is_valid(), true);
+ }
+
+ #[test]
+ fn test_vhost_user_state_from_config() {
+ let config = VringConfigData {
+ queue_max_size: 256,
+ queue_size: 128,
+ flags: VhostUserVringAddrFlags::VHOST_VRING_F_LOG.bits,
+ desc_table_addr: 0x1000,
+ used_ring_addr: 0x2000,
+ avail_ring_addr: 0x3000,
+ log_addr: Some(0x4000),
+ };
+ let addr = VhostUserVringAddr::from_config_data(2, &config);
+
+ let a = addr.index;
+ assert_eq!(a, 2);
+ let a = addr.flags;
+ assert_eq!(a, VhostUserVringAddrFlags::VHOST_VRING_F_LOG.bits());
+ let a = addr.descriptor;
+ assert_eq!(a, 0x1000);
+ let a = addr.used;
+ assert_eq!(a, 0x2000);
+ let a = addr.available;
+ assert_eq!(a, 0x3000);
+ let a = addr.log;
+ assert_eq!(a, 0x4000);
+ assert_eq!(addr.is_valid(), true);
}
#[test]
@@ -801,7 +933,6 @@ mod tests {
}
#[test]
- #[ignore]
fn check_user_config_msg() {
let mut msg = VhostUserConfig::new(
VHOST_USER_CONFIG_OFFSET,
@@ -828,4 +959,21 @@ mod tests {
msg.flags |= 0x4;
assert!(!msg.is_valid());
}
+
+ #[test]
+ fn test_vhost_user_fs_slave() {
+ let mut fs_slave = VhostUserFSSlaveMsg::default();
+
+ assert_eq!(fs_slave.is_valid(), true);
+
+ fs_slave.fd_offset[0] = 0xffff_ffff_ffff_ffff;
+ fs_slave.len[0] = 0x1;
+ assert_eq!(fs_slave.is_valid(), false);
+
+ assert_ne!(
+ VhostUserFSSlaveMsgFlags::MAP_R,
+ VhostUserFSSlaveMsgFlags::MAP_W
+ );
+ assert_eq!(VhostUserFSSlaveMsgFlags::EMPTY.bits(), 0);
+ }
}
diff --git a/src/vhost_user/mod.rs b/src/vhost_user/mod.rs
index 48a93ff..6a5b6a1 100644
--- a/src/vhost_user/mod.rs
+++ b/src/vhost_user/mod.rs
@@ -18,20 +18,23 @@
//! Most messages that can be sent via the Unix domain socket implementing vhost-user have an
//! equivalent ioctl to the kernel implementation.
-use libc;
use std::io::Error as IOError;
-mod connection;
pub mod message;
+
+mod connection;
pub use self::connection::Listener;
+
#[cfg(feature = "vhost-user-master")]
mod master;
#[cfg(feature = "vhost-user-master")]
pub use self::master::{Master, VhostUserMaster};
-#[cfg(any(feature = "vhost-user-master", feature = "vhost-user-slave"))]
+#[cfg(feature = "vhost-user")]
mod master_req_handler;
-#[cfg(any(feature = "vhost-user-master", feature = "vhost-user-slave"))]
-pub use self::master_req_handler::{MasterReqHandler, VhostUserMasterReqHandler};
+#[cfg(feature = "vhost-user")]
+pub use self::master_req_handler::{
+ MasterReqHandler, VhostUserMasterReqHandler, VhostUserMasterReqHandlerMut,
+};
#[cfg(feature = "vhost-user-slave")]
mod slave;
@@ -40,14 +43,14 @@ pub use self::slave::SlaveListener;
#[cfg(feature = "vhost-user-slave")]
mod slave_req_handler;
#[cfg(feature = "vhost-user-slave")]
-pub use self::slave_req_handler::{SlaveReqHandler, VhostUserSlaveReqHandler};
+pub use self::slave_req_handler::{
+ SlaveReqHandler, VhostUserSlaveReqHandler, VhostUserSlaveReqHandlerMut,
+};
#[cfg(feature = "vhost-user-slave")]
mod slave_fs_cache;
#[cfg(feature = "vhost-user-slave")]
pub use self::slave_fs_cache::SlaveFsCacheReq;
-pub mod sock_ctrl_msg;
-
/// Errors for vhost-user operations
#[derive(Debug)]
pub enum Error {
@@ -102,6 +105,8 @@ impl std::fmt::Display for Error {
}
}
+impl std::error::Error for Error {}
+
impl Error {
/// Determine whether to rebuild the underline communication channel.
pub fn should_reconnect(&self) -> bool {
@@ -170,21 +175,32 @@ pub type Result<T> = std::result::Result<T, Error>;
/// Result of request handler.
pub type HandlerResult<T> = std::result::Result<T, IOError>;
-#[cfg(all(test, feature = "vhost-user-master", feature = "vhost-user-slave"))]
+#[cfg(all(test, feature = "vhost-user-slave"))]
mod dummy_slave;
#[cfg(all(test, feature = "vhost-user-master", feature = "vhost-user-slave"))]
mod tests {
+ use std::os::unix::io::AsRawFd;
+ use std::sync::{Arc, Barrier, Mutex};
+ use std::thread;
+ use vmm_sys_util::rand::rand_alphanumerics;
+
use super::dummy_slave::{DummySlaveReqHandler, VIRTIO_FEATURES};
use super::message::*;
use super::*;
use crate::backend::VhostBackend;
- use std::sync::{Arc, Barrier, Mutex};
- use std::thread;
+ use crate::{VhostUserMemoryRegionInfo, VringConfigData};
+
+ fn temp_path() -> String {
+ format!(
+ "/tmp/vhost_test_{}",
+ rand_alphanumerics(8).to_str().unwrap()
+ )
+ }
fn create_slave<S: VhostUserSlaveReqHandler>(
path: &str,
- backend: Arc<Mutex<S>>,
+ backend: Arc<S>,
) -> (Master, SlaveReqHandler<S>) {
let listener = Listener::new(path, true).unwrap();
let mut slave_listener = SlaveListener::new(listener, backend).unwrap();
@@ -194,7 +210,7 @@ mod tests {
#[test]
fn create_dummy_slave() {
- let mut slave = DummySlaveReqHandler::new();
+ let slave = Arc::new(Mutex::new(DummySlaveReqHandler::new()));
slave.set_owner().unwrap();
assert!(slave.set_owner().is_err());
@@ -203,8 +219,8 @@ mod tests {
#[test]
fn test_set_owner() {
let slave_be = Arc::new(Mutex::new(DummySlaveReqHandler::new()));
- let (mut master, mut slave) =
- create_slave("/tmp/vhost_user_lib_unit_test_owner", slave_be.clone());
+ let path = temp_path();
+ let (master, mut slave) = create_slave(&path, slave_be.clone());
assert_eq!(slave_be.lock().unwrap().owned, false);
master.set_owner().unwrap();
@@ -219,14 +235,60 @@ mod tests {
fn test_set_features() {
let mbar = Arc::new(Barrier::new(2));
let sbar = mbar.clone();
+ let path = temp_path();
+ let slave_be = Arc::new(Mutex::new(DummySlaveReqHandler::new()));
+ let (mut master, mut slave) = create_slave(&path, slave_be.clone());
+
+ thread::spawn(move || {
+ slave.handle_request().unwrap();
+ assert_eq!(slave_be.lock().unwrap().owned, true);
+
+ slave.handle_request().unwrap();
+ slave.handle_request().unwrap();
+ assert_eq!(
+ slave_be.lock().unwrap().acked_features,
+ VIRTIO_FEATURES & !0x1
+ );
+
+ slave.handle_request().unwrap();
+ slave.handle_request().unwrap();
+ assert_eq!(
+ slave_be.lock().unwrap().acked_protocol_features,
+ VhostUserProtocolFeatures::all().bits()
+ );
+
+ sbar.wait();
+ });
+
+ master.set_owner().unwrap();
+
+ // set virtio features
+ let features = master.get_features().unwrap();
+ assert_eq!(features, VIRTIO_FEATURES);
+ master.set_features(VIRTIO_FEATURES & !0x1).unwrap();
+
+ // set vhost protocol features
+ let features = master.get_protocol_features().unwrap();
+ assert_eq!(features.bits(), VhostUserProtocolFeatures::all().bits());
+ master.set_protocol_features(features).unwrap();
+
+ mbar.wait();
+ }
+
+ #[test]
+ fn test_master_slave_process() {
+ let mbar = Arc::new(Barrier::new(2));
+ let sbar = mbar.clone();
+ let path = temp_path();
let slave_be = Arc::new(Mutex::new(DummySlaveReqHandler::new()));
- let (mut master, mut slave) =
- create_slave("/tmp/vhost_user_lib_unit_test_feature", slave_be.clone());
+ let (mut master, mut slave) = create_slave(&path, slave_be.clone());
thread::spawn(move || {
+ // set_own()
slave.handle_request().unwrap();
assert_eq!(slave_be.lock().unwrap().owned, true);
+ // get/set_features()
slave.handle_request().unwrap();
slave.handle_request().unwrap();
assert_eq!(
@@ -241,6 +303,34 @@ mod tests {
VhostUserProtocolFeatures::all().bits()
);
+ // get_queue_num()
+ slave.handle_request().unwrap();
+
+ // set_mem_table()
+ slave.handle_request().unwrap();
+
+ // get/set_config()
+ slave.handle_request().unwrap();
+ slave.handle_request().unwrap();
+
+ // set_slave_request_fd
+ slave.handle_request().unwrap();
+
+ // set_vring_enable
+ slave.handle_request().unwrap();
+
+ // set_log_base,set_log_fd()
+ slave.handle_request().unwrap_err();
+ slave.handle_request().unwrap_err();
+
+ // set_vring_xxx
+ slave.handle_request().unwrap();
+ slave.handle_request().unwrap();
+ slave.handle_request().unwrap();
+ slave.handle_request().unwrap();
+ slave.handle_request().unwrap();
+ slave.handle_request().unwrap();
+
sbar.wait();
});
@@ -256,6 +346,82 @@ mod tests {
assert_eq!(features.bits(), VhostUserProtocolFeatures::all().bits());
master.set_protocol_features(features).unwrap();
+ let num = master.get_queue_num().unwrap();
+ assert_eq!(num, 2);
+
+ let eventfd = vmm_sys_util::eventfd::EventFd::new(0).unwrap();
+ let mem = [VhostUserMemoryRegionInfo {
+ guest_phys_addr: 0,
+ memory_size: 0x10_0000,
+ userspace_addr: 0,
+ mmap_offset: 0,
+ mmap_handle: eventfd.as_raw_fd(),
+ }];
+ master.set_mem_table(&mem).unwrap();
+
+ master
+ .set_config(0x100, VhostUserConfigFlags::WRITABLE, &[0xa5u8])
+ .unwrap();
+ let buf = [0x0u8; 4];
+ let (reply_body, reply_payload) = master
+ .get_config(0x100, 4, VhostUserConfigFlags::empty(), &buf)
+ .unwrap();
+ let offset = reply_body.offset;
+ assert_eq!(offset, 0x100);
+ assert_eq!(reply_payload[0], 0xa5);
+
+ master.set_slave_request_fd(eventfd.as_raw_fd()).unwrap();
+ master.set_vring_enable(0, true).unwrap();
+
+ // unimplemented yet
+ master.set_log_base(0, Some(eventfd.as_raw_fd())).unwrap();
+ master.set_log_fd(eventfd.as_raw_fd()).unwrap();
+
+ master.set_vring_num(0, 256).unwrap();
+ master.set_vring_base(0, 0).unwrap();
+ let config = VringConfigData {
+ queue_max_size: 256,
+ queue_size: 128,
+ flags: VhostUserVringAddrFlags::VHOST_VRING_F_LOG.bits(),
+ desc_table_addr: 0x1000,
+ used_ring_addr: 0x2000,
+ avail_ring_addr: 0x3000,
+ log_addr: Some(0x4000),
+ };
+ master.set_vring_addr(0, &config).unwrap();
+ master.set_vring_call(0, &eventfd).unwrap();
+ master.set_vring_kick(0, &eventfd).unwrap();
+ master.set_vring_err(0, &eventfd).unwrap();
+
mbar.wait();
}
+
+ #[test]
+ fn test_error_display() {
+ assert_eq!(format!("{}", Error::InvalidParam), "invalid parameters");
+ assert_eq!(format!("{}", Error::InvalidOperation), "invalid operation");
+ }
+
+ #[test]
+ fn test_should_reconnect() {
+ assert_eq!(Error::PartialMessage.should_reconnect(), true);
+ assert_eq!(Error::SlaveInternalError.should_reconnect(), true);
+ assert_eq!(Error::MasterInternalError.should_reconnect(), true);
+ assert_eq!(Error::InvalidParam.should_reconnect(), false);
+ assert_eq!(Error::InvalidOperation.should_reconnect(), false);
+ assert_eq!(Error::InvalidMessage.should_reconnect(), false);
+ assert_eq!(Error::IncorrectFds.should_reconnect(), false);
+ assert_eq!(Error::OversizedMsg.should_reconnect(), false);
+ assert_eq!(Error::FeatureMismatch.should_reconnect(), false);
+ }
+
+ #[test]
+ fn test_error_from_sys_util_error() {
+ let e: Error = vmm_sys_util::errno::Error::new(libc::EAGAIN).into();
+ if let Error::SocketRetry(e1) = e {
+ assert_eq!(e1.raw_os_error().unwrap(), libc::EAGAIN);
+ } else {
+ panic!("invalid error code conversion!");
+ }
+ }
}
diff --git a/src/vhost_user/slave.rs b/src/vhost_user/slave.rs
index 5ac99af..fb65c41 100644
--- a/src/vhost_user/slave.rs
+++ b/src/vhost_user/slave.rs
@@ -3,7 +3,7 @@
//! Traits and Structs for vhost-user slave.
-use std::sync::{Arc, Mutex};
+use std::sync::Arc;
use super::connection::{Endpoint, Listener};
use super::message::*;
@@ -12,14 +12,14 @@ use super::{Result, SlaveReqHandler, VhostUserSlaveReqHandler};
/// Vhost-user slave side connection listener.
pub struct SlaveListener<S: VhostUserSlaveReqHandler> {
listener: Listener,
- backend: Option<Arc<Mutex<S>>>,
+ backend: Option<Arc<S>>,
}
/// Sets up a listener for incoming master connections, and handles construction
/// of a Slave on success.
impl<S: VhostUserSlaveReqHandler> SlaveListener<S> {
/// Create a unix domain socket for incoming master connections.
- pub fn new(listener: Listener, backend: Arc<Mutex<S>>) -> Result<Self> {
+ pub fn new(listener: Listener, backend: Arc<S>) -> Result<Self> {
Ok(SlaveListener {
listener,
backend: Some(backend),
@@ -44,3 +44,43 @@ impl<S: VhostUserSlaveReqHandler> SlaveListener<S> {
self.listener.set_nonblocking(block)
}
}
+
+#[cfg(test)]
+mod tests {
+ use std::sync::Mutex;
+
+ use super::*;
+ use crate::vhost_user::dummy_slave::DummySlaveReqHandler;
+
+ #[test]
+ fn test_slave_listener_set_nonblocking() {
+ let backend = Arc::new(Mutex::new(DummySlaveReqHandler::new()));
+ let listener =
+ Listener::new("/tmp/vhost_user_lib_unit_test_slave_nonblocking", true).unwrap();
+ let slave_listener = SlaveListener::new(listener, backend).unwrap();
+
+ slave_listener.set_nonblocking(true).unwrap();
+ slave_listener.set_nonblocking(false).unwrap();
+ slave_listener.set_nonblocking(false).unwrap();
+ slave_listener.set_nonblocking(true).unwrap();
+ slave_listener.set_nonblocking(true).unwrap();
+ }
+
+ #[cfg(feature = "vhost-user-master")]
+ #[test]
+ fn test_slave_listener_accept() {
+ use super::super::Master;
+
+ let path = "/tmp/vhost_user_lib_unit_test_slave_accept";
+ let backend = Arc::new(Mutex::new(DummySlaveReqHandler::new()));
+ let listener = Listener::new(path, true).unwrap();
+ let mut slave_listener = SlaveListener::new(listener, backend).unwrap();
+
+ slave_listener.set_nonblocking(true).unwrap();
+ assert!(slave_listener.accept().unwrap().is_none());
+ assert!(slave_listener.accept().unwrap().is_none());
+
+ let _master = Master::connect(path, 1).unwrap();
+ let _slave = slave_listener.accept().unwrap().unwrap();
+ }
+}
diff --git a/src/vhost_user/slave_fs_cache.rs b/src/vhost_user/slave_fs_cache.rs
index 1804c7a..a9c4ed2 100644
--- a/src/vhost_user/slave_fs_cache.rs
+++ b/src/vhost_user/slave_fs_cache.rs
@@ -1,61 +1,59 @@
-// Copyright (C) 2020 Alibaba Cloud Computing. All rights reserved.
+// Copyright (C) 2020 Alibaba Cloud. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
-use super::connection::Endpoint;
-use super::message::*;
-use super::{Error, HandlerResult, Result, VhostUserMasterReqHandler};
use std::io;
use std::mem;
use std::os::unix::io::RawFd;
use std::os::unix::net::UnixStream;
-use std::sync::{Arc, Mutex};
+use std::sync::{Arc, Mutex, MutexGuard};
+
+use super::connection::Endpoint;
+use super::message::*;
+use super::{Error, HandlerResult, Result, VhostUserMasterReqHandler};
struct SlaveFsCacheReqInternal {
sock: Endpoint<SlaveReq>,
-}
-/// A vhost-user slave endpoint which sends fs cache requests to the master
-#[derive(Clone)]
-pub struct SlaveFsCacheReq {
- // underlying Unix domain socket for communication
- node: Arc<Mutex<SlaveFsCacheReqInternal>>,
+ // Protocol feature VHOST_USER_PROTOCOL_F_REPLY_ACK has been negotiated.
+ reply_ack_negotiated: bool,
// whether the endpoint has encountered any failure
error: Option<i32>,
}
-impl SlaveFsCacheReq {
- fn new(ep: Endpoint<SlaveReq>) -> Self {
- SlaveFsCacheReq {
- node: Arc::new(Mutex::new(SlaveFsCacheReqInternal { sock: ep })),
- error: None,
+impl SlaveFsCacheReqInternal {
+ fn check_state(&self) -> Result<u64> {
+ match self.error {
+ Some(e) => Err(Error::SocketBroken(std::io::Error::from_raw_os_error(e))),
+ None => Ok(0),
}
}
- /// Create a new instance.
- pub fn from_stream(sock: UnixStream) -> Self {
- Self::new(Endpoint::<SlaveReq>::from_stream(sock))
- }
-
fn send_message(
&mut self,
- flags: SlaveReq,
+ request: SlaveReq,
fs: &VhostUserFSSlaveMsg,
fds: Option<&[RawFd]>,
) -> Result<u64> {
self.check_state()?;
let len = mem::size_of::<VhostUserFSSlaveMsg>();
- let mut hdr = VhostUserMsgHeader::new(flags, 0, len as u32);
- hdr.set_need_reply(true);
- self.node.lock().unwrap().sock.send_message(&hdr, fs, fds)?;
+ let mut hdr = VhostUserMsgHeader::new(request, 0, len as u32);
+ if self.reply_ack_negotiated {
+ hdr.set_need_reply(true);
+ }
+ self.sock.send_message(&hdr, fs, fds)?;
self.wait_for_ack(&hdr)
}
fn wait_for_ack(&mut self, hdr: &VhostUserMsgHeader<SlaveReq>) -> Result<u64> {
self.check_state()?;
- let (reply, body, rfds) = self.node.lock().unwrap().sock.recv_body::<VhostUserU64>()?;
+ if !self.reply_ack_negotiated {
+ return Ok(0);
+ }
+
+ let (reply, body, rfds) = self.sock.recv_body::<VhostUserU64>()?;
if !reply.is_reply_for(&hdr) || rfds.is_some() || !body.is_valid() {
Endpoint::<SlaveReq>::close_rfds(rfds);
return Err(Error::InvalidMessage);
@@ -63,32 +61,166 @@ impl SlaveFsCacheReq {
if body.value != 0 {
return Err(Error::MasterInternalError);
}
- Ok(0)
+
+ Ok(body.value)
}
+}
- fn check_state(&self) -> Result<u64> {
- match self.error {
- Some(e) => Err(Error::SocketBroken(std::io::Error::from_raw_os_error(e))),
- None => Ok(0),
+/// Request proxy to send vhost-user-fs slave requests to the master through the slave
+/// communication channel.
+///
+/// The [SlaveFsCacheReq] acts as a message proxy to forward vhost-user-fs slave requests to the
+/// master through the vhost-user slave communication channel. The forwarded messages will be
+/// handled by the [MasterReqHandler] server.
+///
+/// [SlaveFsCacheReq]: struct.SlaveFsCacheReq.html
+/// [MasterReqHandler]: struct.MasterReqHandler.html
+#[derive(Clone)]
+pub struct SlaveFsCacheReq {
+ // underlying Unix domain socket for communication
+ node: Arc<Mutex<SlaveFsCacheReqInternal>>,
+}
+
+impl SlaveFsCacheReq {
+ fn new(ep: Endpoint<SlaveReq>) -> Self {
+ SlaveFsCacheReq {
+ node: Arc::new(Mutex::new(SlaveFsCacheReqInternal {
+ sock: ep,
+ reply_ack_negotiated: false,
+ error: None,
+ })),
}
}
+ fn node(&self) -> MutexGuard<SlaveFsCacheReqInternal> {
+ self.node.lock().unwrap()
+ }
+
+ fn send_message(
+ &self,
+ request: SlaveReq,
+ fs: &VhostUserFSSlaveMsg,
+ fds: Option<&[RawFd]>,
+ ) -> io::Result<u64> {
+ self.node()
+ .send_message(request, fs, fds)
+ .map_err(|e| io::Error::new(io::ErrorKind::Other, format!("{}", e)))
+ }
+
+ /// Create a new instance from a `UnixStream` object.
+ pub fn from_stream(sock: UnixStream) -> Self {
+ Self::new(Endpoint::<SlaveReq>::from_stream(sock))
+ }
+
+ /// Set the negotiation state of the `VHOST_USER_PROTOCOL_F_REPLY_ACK` protocol feature.
+ ///
+ /// When the `VHOST_USER_PROTOCOL_F_REPLY_ACK` protocol feature has been negotiated,
+ /// the "REPLY_ACK" flag will be set in the message header for every slave to master request
+ /// message.
+ pub fn set_reply_ack_flag(&self, enable: bool) {
+ self.node().reply_ack_negotiated = enable;
+ }
+
/// Mark endpoint as failed with specified error code.
- pub fn set_failed(&mut self, error: i32) {
- self.error = Some(error);
+ pub fn set_failed(&self, error: i32) {
+ self.node().error = Some(error);
}
}
impl VhostUserMasterReqHandler for SlaveFsCacheReq {
- /// Handle virtio-fs map file requests from the slave.
- fn fs_slave_map(&mut self, fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> {
+ /// Forward vhost-user-fs map file requests to the slave.
+ fn fs_slave_map(&self, fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> {
self.send_message(SlaveReq::FS_MAP, fs, Some(&[fd]))
- .or_else(|e| Err(io::Error::new(io::ErrorKind::Other, format!("{}", e))))
}
- /// Handle virtio-fs unmap file requests from the slave.
- fn fs_slave_unmap(&mut self, fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
+ /// Forward vhost-user-fs unmap file requests to the master.
+ fn fs_slave_unmap(&self, fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
self.send_message(SlaveReq::FS_UNMAP, fs, None)
- .or_else(|e| Err(io::Error::new(io::ErrorKind::Other, format!("{}", e))))
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use std::os::unix::io::AsRawFd;
+
+ use super::*;
+
+ #[test]
+ fn test_slave_fs_cache_req_set_failed() {
+ let (p1, _p2) = UnixStream::pair().unwrap();
+ let fs_cache = SlaveFsCacheReq::from_stream(p1);
+
+ assert!(fs_cache.node().error.is_none());
+ fs_cache.set_failed(libc::EAGAIN);
+ assert_eq!(fs_cache.node().error, Some(libc::EAGAIN));
+ }
+
+ #[test]
+ fn test_slave_fs_cache_send_failure() {
+ let (p1, p2) = UnixStream::pair().unwrap();
+ let fd = p2.as_raw_fd();
+ let fs_cache = SlaveFsCacheReq::from_stream(p1);
+
+ fs_cache.set_failed(libc::ECONNRESET);
+ fs_cache
+ .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd)
+ .unwrap_err();
+ fs_cache
+ .fs_slave_unmap(&VhostUserFSSlaveMsg::default())
+ .unwrap_err();
+ fs_cache.node().error = None;
+
+ drop(p2);
+ fs_cache
+ .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd)
+ .unwrap_err();
+ fs_cache
+ .fs_slave_unmap(&VhostUserFSSlaveMsg::default())
+ .unwrap_err();
+ }
+
+ #[test]
+ fn test_slave_fs_cache_recv_negative() {
+ let (p1, p2) = UnixStream::pair().unwrap();
+ let fd = p2.as_raw_fd();
+ let fs_cache = SlaveFsCacheReq::from_stream(p1);
+ let mut master = Endpoint::<SlaveReq>::from_stream(p2);
+
+ let len = mem::size_of::<VhostUserFSSlaveMsg>();
+ let mut hdr = VhostUserMsgHeader::new(
+ SlaveReq::FS_MAP,
+ VhostUserHeaderFlag::REPLY.bits(),
+ len as u32,
+ );
+ let body = VhostUserU64::new(0);
+
+ master.send_message(&hdr, &body, Some(&[fd])).unwrap();
+ fs_cache
+ .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd)
+ .unwrap();
+
+ fs_cache.set_reply_ack_flag(true);
+ fs_cache
+ .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd)
+ .unwrap_err();
+
+ hdr.set_code(SlaveReq::FS_UNMAP);
+ master.send_message(&hdr, &body, None).unwrap();
+ fs_cache
+ .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd)
+ .unwrap_err();
+ hdr.set_code(SlaveReq::FS_MAP);
+
+ let body = VhostUserU64::new(1);
+ master.send_message(&hdr, &body, None).unwrap();
+ fs_cache
+ .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd)
+ .unwrap_err();
+
+ let body = VhostUserU64::new(0);
+ master.send_message(&hdr, &body, None).unwrap();
+ fs_cache
+ .fs_slave_map(&VhostUserFSSlaveMsg::default(), fd)
+ .unwrap();
}
}
diff --git a/src/vhost_user/slave_req_handler.rs b/src/vhost_user/slave_req_handler.rs
index f3b0770..3b44e4c 100644
--- a/src/vhost_user/slave_req_handler.rs
+++ b/src/vhost_user/slave_req_handler.rs
@@ -1,8 +1,6 @@
// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
-//! Traits and Structs to handle vhost-user requests from the master to the slave.
-
use std::mem;
use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
use std::os::unix::net::UnixStream;
@@ -14,9 +12,63 @@ use super::message::*;
use super::slave_fs_cache::SlaveFsCacheReq;
use super::{Error, Result};
-/// Trait to handle vhost-user requests from the master to the slave.
+/// Services provided to the master by the slave with interior mutability.
+///
+/// The [VhostUserSlaveReqHandler] trait defines the services provided to the master by the slave.
+/// And the [VhostUserSlaveReqHandlerMut] trait is a helper mirroring [VhostUserSlaveReqHandler],
+/// but without interior mutability.
+/// The vhost-user specification defines a master communication channel, by which masters could
+/// request services from slaves. The [VhostUserSlaveReqHandler] trait defines services provided by
+/// slaves, and it's used both on the master side and slave side.
+///
+/// - on the master side, a stub forwarder implementing [VhostUserSlaveReqHandler] will proxy
+/// service requests to slaves.
+/// - on the slave side, the [SlaveReqHandler] will forward service requests to a handler
+/// implementing [VhostUserSlaveReqHandler].
+///
+/// The [VhostUserSlaveReqHandler] trait is design with interior mutability to improve performance
+/// for multi-threading.
+///
+/// [VhostUserSlaveReqHandler]: trait.VhostUserSlaveReqHandler.html
+/// [VhostUserSlaveReqHandlerMut]: trait.VhostUserSlaveReqHandlerMut.html
+/// [SlaveReqHandler]: struct.SlaveReqHandler.html
#[allow(missing_docs)]
pub trait VhostUserSlaveReqHandler {
+ fn set_owner(&self) -> Result<()>;
+ fn reset_owner(&self) -> Result<()>;
+ fn get_features(&self) -> Result<u64>;
+ fn set_features(&self, features: u64) -> Result<()>;
+ fn set_mem_table(&self, ctx: &[VhostUserMemoryRegion], fds: &[RawFd]) -> Result<()>;
+ fn set_vring_num(&self, index: u32, num: u32) -> Result<()>;
+ fn set_vring_addr(
+ &self,
+ index: u32,
+ flags: VhostUserVringAddrFlags,
+ descriptor: u64,
+ used: u64,
+ available: u64,
+ log: u64,
+ ) -> Result<()>;
+ fn set_vring_base(&self, index: u32, base: u32) -> Result<()>;
+ fn get_vring_base(&self, index: u32) -> Result<VhostUserVringState>;
+ fn set_vring_kick(&self, index: u8, fd: Option<RawFd>) -> Result<()>;
+ fn set_vring_call(&self, index: u8, fd: Option<RawFd>) -> Result<()>;
+ fn set_vring_err(&self, index: u8, fd: Option<RawFd>) -> Result<()>;
+
+ fn get_protocol_features(&self) -> Result<VhostUserProtocolFeatures>;
+ fn set_protocol_features(&self, features: u64) -> Result<()>;
+ fn get_queue_num(&self) -> Result<u64>;
+ fn set_vring_enable(&self, index: u32, enable: bool) -> Result<()>;
+ fn get_config(&self, offset: u32, size: u32, flags: VhostUserConfigFlags) -> Result<Vec<u8>>;
+ fn set_config(&self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>;
+ fn set_slave_req_fd(&self, _vu_req: SlaveFsCacheReq) {}
+}
+
+/// Services provided to the master by the slave without interior mutability.
+///
+/// This is a helper trait mirroring the [VhostUserSlaveReqHandler] trait.
+#[allow(missing_docs)]
+pub trait VhostUserSlaveReqHandlerMut {
fn set_owner(&mut self) -> Result<()>;
fn reset_owner(&mut self) -> Result<()>;
fn get_features(&mut self) -> Result<u64>;
@@ -52,16 +104,110 @@ pub trait VhostUserSlaveReqHandler {
fn set_slave_req_fd(&mut self, _vu_req: SlaveFsCacheReq) {}
}
-/// A vhost-user slave endpoint which relays all received requests from the
-/// master to the virtio backend device object.
+impl<T: VhostUserSlaveReqHandlerMut> VhostUserSlaveReqHandler for Mutex<T> {
+ fn set_owner(&self) -> Result<()> {
+ self.lock().unwrap().set_owner()
+ }
+
+ fn reset_owner(&self) -> Result<()> {
+ self.lock().unwrap().reset_owner()
+ }
+
+ fn get_features(&self) -> Result<u64> {
+ self.lock().unwrap().get_features()
+ }
+
+ fn set_features(&self, features: u64) -> Result<()> {
+ self.lock().unwrap().set_features(features)
+ }
+
+ fn set_mem_table(&self, ctx: &[VhostUserMemoryRegion], fds: &[RawFd]) -> Result<()> {
+ self.lock().unwrap().set_mem_table(ctx, fds)
+ }
+
+ fn set_vring_num(&self, index: u32, num: u32) -> Result<()> {
+ self.lock().unwrap().set_vring_num(index, num)
+ }
+
+ fn set_vring_addr(
+ &self,
+ index: u32,
+ flags: VhostUserVringAddrFlags,
+ descriptor: u64,
+ used: u64,
+ available: u64,
+ log: u64,
+ ) -> Result<()> {
+ self.lock()
+ .unwrap()
+ .set_vring_addr(index, flags, descriptor, used, available, log)
+ }
+
+ fn set_vring_base(&self, index: u32, base: u32) -> Result<()> {
+ self.lock().unwrap().set_vring_base(index, base)
+ }
+
+ fn get_vring_base(&self, index: u32) -> Result<VhostUserVringState> {
+ self.lock().unwrap().get_vring_base(index)
+ }
+
+ fn set_vring_kick(&self, index: u8, fd: Option<RawFd>) -> Result<()> {
+ self.lock().unwrap().set_vring_kick(index, fd)
+ }
+
+ fn set_vring_call(&self, index: u8, fd: Option<RawFd>) -> Result<()> {
+ self.lock().unwrap().set_vring_call(index, fd)
+ }
+
+ fn set_vring_err(&self, index: u8, fd: Option<RawFd>) -> Result<()> {
+ self.lock().unwrap().set_vring_err(index, fd)
+ }
+
+ fn get_protocol_features(&self) -> Result<VhostUserProtocolFeatures> {
+ self.lock().unwrap().get_protocol_features()
+ }
+
+ fn set_protocol_features(&self, features: u64) -> Result<()> {
+ self.lock().unwrap().set_protocol_features(features)
+ }
+
+ fn get_queue_num(&self) -> Result<u64> {
+ self.lock().unwrap().get_queue_num()
+ }
+
+ fn set_vring_enable(&self, index: u32, enable: bool) -> Result<()> {
+ self.lock().unwrap().set_vring_enable(index, enable)
+ }
+
+ fn get_config(&self, offset: u32, size: u32, flags: VhostUserConfigFlags) -> Result<Vec<u8>> {
+ self.lock().unwrap().get_config(offset, size, flags)
+ }
+
+ fn set_config(&self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()> {
+ self.lock().unwrap().set_config(offset, buf, flags)
+ }
+
+ fn set_slave_req_fd(&self, vu_req: SlaveFsCacheReq) {
+ self.lock().unwrap().set_slave_req_fd(vu_req)
+ }
+}
+
+/// Server to handle service requests from masters from the master communication channel.
+///
+/// The [SlaveReqHandler] acts as a server on the slave side, to handle service requests from
+/// masters on the master communication channel. It's actually a proxy invoking the registered
+/// handler implementing [VhostUserSlaveReqHandler] to do the real work.
///
/// The lifetime of the SlaveReqHandler object should be the same as the underline Unix Domain
/// Socket, so it gets simpler to recover from disconnect.
+///
+/// [VhostUserSlaveReqHandler]: trait.VhostUserSlaveReqHandler.html
+/// [SlaveReqHandler]: struct.SlaveReqHandler.html
pub struct SlaveReqHandler<S: VhostUserSlaveReqHandler> {
// underlying Unix domain socket for communication
main_sock: Endpoint<MasterReq>,
// the vhost-user backend device object
- backend: Arc<Mutex<S>>,
+ backend: Arc<S>,
virtio_features: u64,
acked_virtio_features: u64,
@@ -76,7 +222,7 @@ pub struct SlaveReqHandler<S: VhostUserSlaveReqHandler> {
impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
/// Create a vhost-user slave endpoint.
- pub(super) fn new(main_sock: Endpoint<MasterReq>, backend: Arc<Mutex<S>>) -> Self {
+ pub(super) fn new(main_sock: Endpoint<MasterReq>, backend: Arc<S>) -> Self {
SlaveReqHandler {
main_sock,
backend,
@@ -94,7 +240,7 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
/// # Arguments
/// * - `path` - path of Unix domain socket listener to connect to
/// * - `backend` - handler for requests from the master to the slave
- pub fn connect(path: &str, backend: Arc<Mutex<S>>) -> Result<Self> {
+ pub fn connect(path: &str, backend: Arc<S>) -> Result<Self> {
Ok(Self::new(Endpoint::<MasterReq>::connect(path)?, backend))
}
@@ -103,11 +249,12 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
self.error = Some(error);
}
- /// Receive and handle one incoming request message from the master.
- /// The caller needs to:
- /// . serialize calls to this function
- /// . decide what to do when error happens
- /// . optional recover from failure
+ /// Main entrance to server slave request from the slave communication channel.
+ ///
+ /// Receive and handle one incoming request message from the master. The caller needs to:
+ /// - serialize calls to this function
+ /// - decide what to do when error happens
+ /// - optional recover from failure
pub fn handle_request(&mut self) -> Result<()> {
// Return error if the endpoint is already in failed state.
self.check_state()?;
@@ -137,15 +284,15 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
match hdr.get_code() {
MasterReq::SET_OWNER => {
self.check_request_size(&hdr, size, 0)?;
- self.backend.lock().unwrap().set_owner()?;
+ self.backend.set_owner()?;
}
MasterReq::RESET_OWNER => {
self.check_request_size(&hdr, size, 0)?;
- self.backend.lock().unwrap().reset_owner()?;
+ self.backend.reset_owner()?;
}
MasterReq::GET_FEATURES => {
self.check_request_size(&hdr, size, 0)?;
- let features = self.backend.lock().unwrap().get_features()?;
+ let features = self.backend.get_features()?;
let msg = VhostUserU64::new(features);
self.send_reply_message(&hdr, &msg)?;
self.virtio_features = features;
@@ -153,7 +300,7 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
}
MasterReq::SET_FEATURES => {
let msg = self.extract_request_body::<VhostUserU64>(&hdr, size, &buf)?;
- self.backend.lock().unwrap().set_features(msg.value)?;
+ self.backend.set_features(msg.value)?;
self.acked_virtio_features = msg.value;
self.update_reply_ack_flag();
}
@@ -163,11 +310,7 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
}
MasterReq::SET_VRING_NUM => {
let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
- let res = self
- .backend
- .lock()
- .unwrap()
- .set_vring_num(msg.index, msg.num);
+ let res = self.backend.set_vring_num(msg.index, msg.num);
self.send_ack_message(&hdr, res)?;
}
MasterReq::SET_VRING_ADDR => {
@@ -176,7 +319,7 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
Some(val) => val,
None => return Err(Error::InvalidMessage),
};
- let res = self.backend.lock().unwrap().set_vring_addr(
+ let res = self.backend.set_vring_addr(
msg.index,
flags,
msg.descriptor,
@@ -188,39 +331,35 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
}
MasterReq::SET_VRING_BASE => {
let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
- let res = self
- .backend
- .lock()
- .unwrap()
- .set_vring_base(msg.index, msg.num);
+ let res = self.backend.set_vring_base(msg.index, msg.num);
self.send_ack_message(&hdr, res)?;
}
MasterReq::GET_VRING_BASE => {
let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
- let reply = self.backend.lock().unwrap().get_vring_base(msg.index)?;
+ let reply = self.backend.get_vring_base(msg.index)?;
self.send_reply_message(&hdr, &reply)?;
}
MasterReq::SET_VRING_CALL => {
self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?;
let (index, rfds) = self.handle_vring_fd_request(&buf, rfds)?;
- let res = self.backend.lock().unwrap().set_vring_call(index, rfds);
+ let res = self.backend.set_vring_call(index, rfds);
self.send_ack_message(&hdr, res)?;
}
MasterReq::SET_VRING_KICK => {
self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?;
let (index, rfds) = self.handle_vring_fd_request(&buf, rfds)?;
- let res = self.backend.lock().unwrap().set_vring_kick(index, rfds);
+ let res = self.backend.set_vring_kick(index, rfds);
self.send_ack_message(&hdr, res)?;
}
MasterReq::SET_VRING_ERR => {
self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?;
let (index, rfds) = self.handle_vring_fd_request(&buf, rfds)?;
- let res = self.backend.lock().unwrap().set_vring_err(index, rfds);
+ let res = self.backend.set_vring_err(index, rfds);
self.send_ack_message(&hdr, res)?;
}
MasterReq::GET_PROTOCOL_FEATURES => {
self.check_request_size(&hdr, size, 0)?;
- let features = self.backend.lock().unwrap().get_protocol_features()?;
+ let features = self.backend.get_protocol_features()?;
let msg = VhostUserU64::new(features.bits());
self.send_reply_message(&hdr, &msg)?;
self.protocol_features = features;
@@ -228,10 +367,7 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
}
MasterReq::SET_PROTOCOL_FEATURES => {
let msg = self.extract_request_body::<VhostUserU64>(&hdr, size, &buf)?;
- self.backend
- .lock()
- .unwrap()
- .set_protocol_features(msg.value)?;
+ self.backend.set_protocol_features(msg.value)?;
self.acked_protocol_features = msg.value;
self.update_reply_ack_flag();
}
@@ -240,7 +376,7 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
return Err(Error::InvalidOperation);
}
self.check_request_size(&hdr, size, 0)?;
- let num = self.backend.lock().unwrap().get_queue_num()?;
+ let num = self.backend.get_queue_num()?;
let msg = VhostUserU64::new(num);
self.send_reply_message(&hdr, &msg)?;
}
@@ -257,17 +393,14 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
_ => return Err(Error::InvalidParam),
};
- let res = self
- .backend
- .lock()
- .unwrap()
- .set_vring_enable(msg.index, enable);
+ let res = self.backend.set_vring_enable(msg.index, enable);
self.send_ack_message(&hdr, res)?;
}
MasterReq::GET_CONFIG => {
if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 {
return Err(Error::InvalidOperation);
}
+ self.check_request_size(&hdr, size, hdr.get_size() as usize)?;
self.get_config(&hdr, &buf)?;
}
MasterReq::SET_CONFIG => {
@@ -281,6 +414,7 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
if self.acked_protocol_features & VhostUserProtocolFeatures::SLAVE_REQ.bits() == 0 {
return Err(Error::InvalidOperation);
}
+ self.check_request_size(&hdr, size, hdr.get_size() as usize)?;
self.set_slave_req_fd(&hdr, rfds)?;
}
_ => {
@@ -341,15 +475,18 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
}
}
- self.backend.lock().unwrap().set_mem_table(&regions, &fds)
+ self.backend.set_mem_table(&regions, &fds)
}
fn get_config(&mut self, hdr: &VhostUserMsgHeader<MasterReq>, buf: &[u8]) -> Result<()> {
- let msg = unsafe { &*(buf.as_ptr() as *const VhostUserConfig) };
+ let payload_offset = mem::size_of::<VhostUserConfig>();
+ if buf.len() > MAX_MSG_SIZE || buf.len() < payload_offset {
+ return Err(Error::InvalidMessage);
+ }
+ let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const VhostUserConfig) };
if !msg.is_valid() {
return Err(Error::InvalidMessage);
}
- let payload_offset = mem::size_of::<VhostUserConfig>();
if buf.len() - payload_offset != msg.size as usize {
return Err(Error::InvalidMessage);
}
@@ -357,11 +494,7 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
Some(val) => val,
None => return Err(Error::InvalidMessage),
};
- let res = self
- .backend
- .lock()
- .unwrap()
- .get_config(msg.offset, msg.size, flags);
+ let res = self.backend.get_config(msg.offset, msg.size, flags);
// vhost-user slave's payload size MUST match master's request
// on success, uses zero length of payload to indicate an error
@@ -389,10 +522,10 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
size: usize,
buf: &[u8],
) -> Result<()> {
- if size < mem::size_of::<VhostUserConfig>() {
+ if size > MAX_MSG_SIZE || size < mem::size_of::<VhostUserConfig>() {
return Err(Error::InvalidMessage);
}
- let msg = unsafe { &*(buf.as_ptr() as *const VhostUserConfig) };
+ let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const VhostUserConfig) };
if !msg.is_valid() {
return Err(Error::InvalidMessage);
}
@@ -405,11 +538,7 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
None => return Err(Error::InvalidMessage),
}
- let res = self
- .backend
- .lock()
- .unwrap()
- .set_config(msg.offset, buf, flags);
+ let res = self.backend.set_config(msg.offset, buf, flags);
self.send_ack_message(&hdr, res)?;
Ok(())
}
@@ -423,7 +552,7 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
if fds.len() == 1 {
let sock = unsafe { UnixStream::from_raw_fd(fds[0]) };
let vu_req = SlaveFsCacheReq::from_stream(sock);
- self.backend.lock().unwrap().set_slave_req_fd(vu_req);
+ self.backend.set_slave_req_fd(vu_req);
self.send_ack_message(&hdr, Ok(()))
} else {
Err(Error::InvalidMessage)
@@ -438,7 +567,10 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
buf: &[u8],
rfds: Option<Vec<RawFd>>,
) -> Result<(u8, Option<RawFd>)> {
- let msg = unsafe { &*(buf.as_ptr() as *const VhostUserU64) };
+ if buf.len() > MAX_MSG_SIZE || buf.len() < mem::size_of::<VhostUserU64>() {
+ return Err(Error::InvalidMessage);
+ }
+ let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const VhostUserU64) };
if !msg.is_valid() {
return Err(Error::InvalidMessage);
}
@@ -447,10 +579,7 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
// invalid FD flag. This flag is set when there is no file descriptor
// in the ancillary data. This signals that polling will be used
// instead of waiting for the call.
- let nofd = match msg.value & 0x100u64 {
- 0x100u64 => true,
- _ => false,
- };
+ let nofd = (msg.value & 0x100u64) == 0x100u64;
let mut rfd = None;
match rfds {
@@ -519,14 +648,14 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
}
}
- fn extract_request_body<'a, T: Sized + VhostUserMsgValidator>(
+ fn extract_request_body<T: Sized + VhostUserMsgValidator>(
&self,
hdr: &VhostUserMsgHeader<MasterReq>,
size: usize,
- buf: &'a [u8],
- ) -> Result<&'a T> {
+ buf: &[u8],
+ ) -> Result<T> {
self.check_request_size(hdr, size, mem::size_of::<T>())?;
- let msg = unsafe { &*(buf.as_ptr() as *const T) };
+ let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const T) };
if !msg.is_valid() {
return Err(Error::InvalidMessage);
}
@@ -552,7 +681,10 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
req: &VhostUserMsgHeader<MasterReq>,
payload_size: usize,
) -> Result<VhostUserMsgHeader<MasterReq>> {
- if mem::size_of::<T>() > MAX_MSG_SIZE {
+ if mem::size_of::<T>() > MAX_MSG_SIZE
+ || payload_size > MAX_MSG_SIZE
+ || mem::size_of::<T>() + payload_size > MAX_MSG_SIZE
+ {
return Err(Error::InvalidParam);
}
self.check_state()?;
@@ -568,7 +700,7 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
req: &VhostUserMsgHeader<MasterReq>,
res: Result<()>,
) -> Result<()> {
- if self.reply_ack_enabled {
+ if self.reply_ack_enabled && req.is_need_reply() {
let hdr = self.new_reply_header::<VhostUserU64>(req, 0)?;
let val = match res {
Ok(_) => 0,
@@ -590,16 +722,12 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
Ok(())
}
- fn send_reply_with_payload<T, P>(
+ fn send_reply_with_payload<T: Sized>(
&mut self,
req: &VhostUserMsgHeader<MasterReq>,
msg: &T,
- payload: &[P],
- ) -> Result<()>
- where
- T: Sized,
- P: Sized,
- {
+ payload: &[u8],
+ ) -> Result<()> {
let hdr = self.new_reply_header::<T>(req, payload.len())?;
self.main_sock
.send_message_with_payload(&hdr, msg, payload, None)?;
@@ -612,3 +740,24 @@ impl<S: VhostUserSlaveReqHandler> AsRawFd for SlaveReqHandler<S> {
self.main_sock.as_raw_fd()
}
}
+
+#[cfg(test)]
+mod tests {
+ use std::os::unix::io::AsRawFd;
+
+ use super::*;
+ use crate::vhost_user::dummy_slave::DummySlaveReqHandler;
+
+ #[test]
+ fn test_slave_req_handler_new() {
+ let (p1, _p2) = UnixStream::pair().unwrap();
+ let endpoint = Endpoint::<MasterReq>::from_stream(p1);
+ let backend = Arc::new(Mutex::new(DummySlaveReqHandler::new()));
+ let mut handler = SlaveReqHandler::new(endpoint, backend);
+
+ handler.check_state().unwrap();
+ handler.set_failed(libc::EAGAIN);
+ handler.check_state().unwrap_err();
+ assert!(handler.as_raw_fd() >= 0);
+ }
+}
diff --git a/src/vhost_user/sock_ctrl_msg.rs b/src/vhost_user/sock_ctrl_msg.rs
deleted file mode 100644
index db3ec2e..0000000
--- a/src/vhost_user/sock_ctrl_msg.rs
+++ /dev/null
@@ -1,499 +0,0 @@
-// Copyright 2017 The Chromium OS Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style license that can be
-// found in the LICENSE file.
-
-//! Used to send and receive messages with file descriptors on sockets that accept control messages
-//! (e.g. Unix domain sockets).
-
-// TODO: move this file into the vmm-sys-util crate
-
-use std::fs::File;
-use std::mem::size_of;
-use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
-use std::os::unix::net::{UnixDatagram, UnixStream};
-use std::ptr::{copy_nonoverlapping, null_mut, write_unaligned};
-
-use libc::{
- c_long, c_void, cmsghdr, iovec, msghdr, recvmsg, sendmsg, MSG_NOSIGNAL, SCM_RIGHTS, SOL_SOCKET,
-};
-use vmm_sys_util::errno::{Error, Result};
-
-// Each of the following macros performs the same function as their C counterparts. They are each
-// macros because they are used to size statically allocated arrays.
-
-macro_rules! CMSG_ALIGN {
- ($len:expr) => {
- (($len) + size_of::<c_long>() - 1) & !(size_of::<c_long>() - 1)
- };
-}
-
-macro_rules! CMSG_SPACE {
- ($len:expr) => {
- size_of::<cmsghdr>() + CMSG_ALIGN!($len)
- };
-}
-
-#[cfg(not(target_env = "musl"))]
-macro_rules! CMSG_LEN {
- ($len:expr) => {
- size_of::<cmsghdr>() + ($len)
- };
-}
-
-#[cfg(target_env = "musl")]
-macro_rules! CMSG_LEN {
- ($len:expr) => {{
- let sz = size_of::<cmsghdr>() + ($len);
- assert!(sz <= (std::u32::MAX as usize));
- sz as u32
- }};
-}
-
-#[cfg(not(target_env = "musl"))]
-fn new_msghdr(iovecs: &mut [iovec]) -> msghdr {
- msghdr {
- msg_name: null_mut(),
- msg_namelen: 0,
- msg_iov: iovecs.as_mut_ptr(),
- msg_iovlen: iovecs.len(),
- msg_control: null_mut(),
- msg_controllen: 0,
- msg_flags: 0,
- }
-}
-
-#[cfg(target_env = "musl")]
-fn new_msghdr(iovecs: &mut [iovec]) -> msghdr {
- assert!(iovecs.len() <= (std::i32::MAX as usize));
- let mut msg: msghdr = unsafe { std::mem::zeroed() };
- msg.msg_name = null_mut();
- msg.msg_iov = iovecs.as_mut_ptr();
- msg.msg_iovlen = iovecs.len() as i32;
- msg.msg_control = null_mut();
- msg
-}
-
-#[cfg(not(target_env = "musl"))]
-fn set_msg_controllen(msg: &mut msghdr, cmsg_capacity: usize) {
- msg.msg_controllen = cmsg_capacity;
-}
-
-#[cfg(target_env = "musl")]
-fn set_msg_controllen(msg: &mut msghdr, cmsg_capacity: usize) {
- assert!(cmsg_capacity <= (std::u32::MAX as usize));
- msg.msg_controllen = cmsg_capacity as u32;
-}
-
-// This function (macro in the C version) is not used in any compile time constant slots, so is just
-// an ordinary function. The returned pointer is hard coded to be RawFd because that's all that this
-// module supports.
-#[allow(non_snake_case)]
-#[inline(always)]
-fn CMSG_DATA(cmsg_buffer: *mut cmsghdr) -> *mut RawFd {
- // Essentially returns a pointer to just past the header.
- cmsg_buffer.wrapping_offset(1) as *mut RawFd
-}
-
-// This function is like CMSG_NEXT, but safer because it reads only from references, although it
-// does some pointer arithmetic on cmsg_ptr.
-#[cfg_attr(feature = "cargo-clippy", allow(clippy::cast_ptr_alignment))]
-fn get_next_cmsg(msghdr: &msghdr, cmsg: &cmsghdr, cmsg_ptr: *mut cmsghdr) -> *mut cmsghdr {
- let next_cmsg =
- (cmsg_ptr as *mut u8).wrapping_add(CMSG_ALIGN!(cmsg.cmsg_len as usize)) as *mut cmsghdr;
- if next_cmsg
- .wrapping_offset(1)
- .wrapping_sub(msghdr.msg_control as usize) as usize
- > msghdr.msg_controllen as usize
- {
- null_mut()
- } else {
- next_cmsg
- }
-}
-
-const CMSG_BUFFER_INLINE_CAPACITY: usize = CMSG_SPACE!(size_of::<RawFd>() * 32);
-
-enum CmsgBuffer {
- Inline([u64; (CMSG_BUFFER_INLINE_CAPACITY + 7) / 8]),
- Heap(Box<[cmsghdr]>),
-}
-
-impl CmsgBuffer {
- fn with_capacity(capacity: usize) -> CmsgBuffer {
- let cap_in_cmsghdr_units =
- (capacity.checked_add(size_of::<cmsghdr>()).unwrap() - 1) / size_of::<cmsghdr>();
- if capacity <= CMSG_BUFFER_INLINE_CAPACITY {
- CmsgBuffer::Inline([0u64; (CMSG_BUFFER_INLINE_CAPACITY + 7) / 8])
- } else {
- CmsgBuffer::Heap(
- vec![
- cmsghdr {
- cmsg_len: 0,
- cmsg_level: 0,
- cmsg_type: 0,
- #[cfg(target_env = "musl")]
- __pad1: 0,
- };
- cap_in_cmsghdr_units
- ]
- .into_boxed_slice(),
- )
- }
- }
-
- fn as_mut_ptr(&mut self) -> *mut cmsghdr {
- match self {
- CmsgBuffer::Inline(a) => a.as_mut_ptr() as *mut cmsghdr,
- CmsgBuffer::Heap(a) => a.as_mut_ptr(),
- }
- }
-}
-
-fn raw_sendmsg<D: IntoIovec>(fd: RawFd, out_data: &[D], out_fds: &[RawFd]) -> Result<usize> {
- let cmsg_capacity = CMSG_SPACE!(size_of::<RawFd>() * out_fds.len());
- let mut cmsg_buffer = CmsgBuffer::with_capacity(cmsg_capacity);
-
- let mut iovecs = Vec::with_capacity(out_data.len());
- for data in out_data {
- iovecs.push(iovec {
- iov_base: data.as_ptr() as *mut c_void,
- iov_len: data.size(),
- });
- }
-
- let mut msg = new_msghdr(&mut iovecs);
-
- if !out_fds.is_empty() {
- let cmsg = cmsghdr {
- cmsg_len: CMSG_LEN!(size_of::<RawFd>() * out_fds.len()),
- cmsg_level: SOL_SOCKET,
- cmsg_type: SCM_RIGHTS,
- #[cfg(target_env = "musl")]
- __pad1: 0,
- };
- unsafe {
- // Safe because cmsg_buffer was allocated to be large enough to contain cmsghdr.
- write_unaligned(cmsg_buffer.as_mut_ptr() as *mut cmsghdr, cmsg);
- // Safe because the cmsg_buffer was allocated to be large enough to hold out_fds.len()
- // file descriptors.
- copy_nonoverlapping(
- out_fds.as_ptr(),
- CMSG_DATA(cmsg_buffer.as_mut_ptr()),
- out_fds.len(),
- );
- }
-
- msg.msg_control = cmsg_buffer.as_mut_ptr() as *mut c_void;
- set_msg_controllen(&mut msg, cmsg_capacity);
- }
-
- // Safe because the msghdr was properly constructed from valid (or null) pointers of the
- // indicated length and we check the return value.
- let write_count = unsafe { sendmsg(fd, &msg, MSG_NOSIGNAL) };
-
- if write_count == -1 {
- Err(Error::last())
- } else {
- Ok(write_count as usize)
- }
-}
-
-fn raw_recvmsg(fd: RawFd, iovecs: &mut [iovec], in_fds: &mut [RawFd]) -> Result<(usize, usize)> {
- let cmsg_capacity = CMSG_SPACE!(size_of::<RawFd>() * in_fds.len());
- let mut cmsg_buffer = CmsgBuffer::with_capacity(cmsg_capacity);
- let mut msg = new_msghdr(iovecs);
-
- if !in_fds.is_empty() {
- msg.msg_control = cmsg_buffer.as_mut_ptr() as *mut c_void;
- set_msg_controllen(&mut msg, cmsg_capacity);
- }
-
- // Safe because the msghdr was properly constructed from valid (or null) pointers of the
- // indicated length and we check the return value.
- let total_read = unsafe { recvmsg(fd, &mut msg, libc::MSG_WAITALL) };
-
- if total_read == -1 {
- return Err(Error::last());
- }
-
- // When the connection is closed recvmsg() doesn't give an explicit error
- if total_read == 0 && (msg.msg_controllen as usize) < size_of::<cmsghdr>() {
- return Err(Error::new(libc::ECONNRESET));
- }
-
- let mut cmsg_ptr = msg.msg_control as *mut cmsghdr;
- let mut in_fds_count = 0;
- while !cmsg_ptr.is_null() {
- // Safe because we checked that cmsg_ptr was non-null, and the loop is constructed such that
- // that only happens when there is at least sizeof(cmsghdr) space after the pointer to read.
- let cmsg = unsafe { (cmsg_ptr as *mut cmsghdr).read_unaligned() };
-
- if cmsg.cmsg_level == SOL_SOCKET && cmsg.cmsg_type == SCM_RIGHTS {
- let fd_count = (cmsg.cmsg_len - CMSG_LEN!(0)) as usize / size_of::<RawFd>();
- unsafe {
- copy_nonoverlapping(
- CMSG_DATA(cmsg_ptr),
- in_fds[in_fds_count..(in_fds_count + fd_count)].as_mut_ptr(),
- fd_count,
- );
- }
- in_fds_count += fd_count;
- }
-
- cmsg_ptr = get_next_cmsg(&msg, &cmsg, cmsg_ptr);
- }
-
- Ok((total_read as usize, in_fds_count))
-}
-
-/// Trait for file descriptors can send and receive socket control messages via `sendmsg` and
-/// `recvmsg`.
-pub trait ScmSocket {
- /// Gets the file descriptor of this socket.
- fn socket_fd(&self) -> RawFd;
-
- /// Sends the given data and file descriptor over the socket.
- ///
- /// On success, returns the number of bytes sent.
- ///
- /// # Arguments
- ///
- /// * `buf` - A buffer of data to send on the `socket`.
- /// * `fd` - A file descriptors to be sent.
- fn send_with_fd<D: IntoIovec>(&self, buf: D, fd: RawFd) -> Result<usize> {
- self.send_with_fds(&[buf], &[fd])
- }
-
- /// Sends the given data and file descriptors over the socket.
- ///
- /// On success, returns the number of bytes sent.
- ///
- /// # Arguments
- ///
- /// * `bufs` - A list of data buffer to send on the `socket`.
- /// * `fds` - A list of file descriptors to be sent.
- fn send_with_fds<D: IntoIovec>(&self, bufs: &[D], fds: &[RawFd]) -> Result<usize> {
- raw_sendmsg(self.socket_fd(), bufs, fds)
- }
-
- /// Receives data and potentially a file descriptor from the socket.
- ///
- /// On success, returns the number of bytes and an optional file descriptor.
- ///
- /// # Arguments
- ///
- /// * `buf` - A buffer to receive data from the socket.
- fn recv_with_fd(&self, buf: &mut [u8]) -> Result<(usize, Option<File>)> {
- let mut fd = [0];
- let mut iovecs = [iovec {
- iov_base: buf.as_mut_ptr() as *mut c_void,
- iov_len: buf.len(),
- }];
-
- let (read_count, fd_count) = self.recv_with_fds(&mut iovecs[..], &mut fd)?;
- let file = if fd_count == 0 {
- None
- } else {
- // Safe because the first fd from recv_with_fds is owned by us and valid because this
- // branch was taken.
- Some(unsafe { File::from_raw_fd(fd[0]) })
- };
- Ok((read_count, file))
- }
-
- /// Receives data and file descriptors from the socket.
- ///
- /// On success, returns the number of bytes and file descriptors received as a tuple
- /// `(bytes count, files count)`.
- ///
- /// # Arguments
- ///
- /// * `iovecs` - A list of iovec to receive data from the socket.
- /// * `fds` - A slice of `RawFd`s to put the received file descriptors into. On success, the
- /// number of valid file descriptors is indicated by the second element of the
- /// returned tuple. The caller owns these file descriptors, but they will not be
- /// closed on drop like a `File`-like type would be. It is recommended that each valid
- /// file descriptor gets wrapped in a drop type that closes it after this returns.
- fn recv_with_fds(&self, iovecs: &mut [iovec], fds: &mut [RawFd]) -> Result<(usize, usize)> {
- raw_recvmsg(self.socket_fd(), iovecs, fds)
- }
-}
-
-impl ScmSocket for UnixDatagram {
- fn socket_fd(&self) -> RawFd {
- self.as_raw_fd()
- }
-}
-
-impl ScmSocket for UnixStream {
- fn socket_fd(&self) -> RawFd {
- self.as_raw_fd()
- }
-}
-
-/// Trait for types that can be converted into an `iovec` that can be referenced by a syscall for
-/// the lifetime of this object.
-///
-/// This trait is unsafe because interfaces that use this trait depend on the base pointer and size
-/// being accurate.
-pub unsafe trait IntoIovec {
- /// Gets the base pointer of this `iovec`.
- fn as_ptr(&self) -> *const c_void;
-
- /// Gets the size in bytes of this `iovec`.
- fn size(&self) -> usize;
-}
-
-// Safe because this slice can not have another mutable reference and it's pointer and size are
-// guaranteed to be valid.
-unsafe impl<'a> IntoIovec for &'a [u8] {
- // Clippy false positive: https://github.com/rust-lang/rust-clippy/issues/3480
- #[cfg_attr(feature = "cargo-clippy", allow(clippy::useless_asref))]
- fn as_ptr(&self) -> *const c_void {
- self.as_ref().as_ptr() as *const c_void
- }
-
- fn size(&self) -> usize {
- self.len()
- }
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
-
- use std::io::Write;
- use std::mem::size_of;
- use std::os::raw::c_long;
- use std::os::unix::net::UnixDatagram;
- use std::slice::from_raw_parts;
-
- use libc::cmsghdr;
-
- use vmm_sys_util::eventfd::EventFd;
-
- #[test]
- fn buffer_len() {
- assert_eq!(CMSG_SPACE!(0 * size_of::<RawFd>()), size_of::<cmsghdr>());
- assert_eq!(
- CMSG_SPACE!(1 * size_of::<RawFd>()),
- size_of::<cmsghdr>() + size_of::<c_long>()
- );
- if size_of::<RawFd>() == 4 {
- assert_eq!(
- CMSG_SPACE!(2 * size_of::<RawFd>()),
- size_of::<cmsghdr>() + size_of::<c_long>()
- );
- assert_eq!(
- CMSG_SPACE!(3 * size_of::<RawFd>()),
- size_of::<cmsghdr>() + size_of::<c_long>() * 2
- );
- assert_eq!(
- CMSG_SPACE!(4 * size_of::<RawFd>()),
- size_of::<cmsghdr>() + size_of::<c_long>() * 2
- );
- } else if size_of::<RawFd>() == 8 {
- assert_eq!(
- CMSG_SPACE!(2 * size_of::<RawFd>()),
- size_of::<cmsghdr>() + size_of::<c_long>() * 2
- );
- assert_eq!(
- CMSG_SPACE!(3 * size_of::<RawFd>()),
- size_of::<cmsghdr>() + size_of::<c_long>() * 3
- );
- assert_eq!(
- CMSG_SPACE!(4 * size_of::<RawFd>()),
- size_of::<cmsghdr>() + size_of::<c_long>() * 4
- );
- }
- }
-
- #[test]
- fn send_recv_no_fd() {
- let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair");
-
- let write_count = s1
- .send_with_fds(&[[1u8, 1, 2].as_ref(), [21u8, 34, 55].as_ref()], &[])
- .expect("failed to send data");
-
- assert_eq!(write_count, 6);
-
- let mut buf = [0u8; 6];
- let mut files = [0; 1];
- let mut iovecs = [iovec {
- iov_base: buf.as_mut_ptr() as *mut c_void,
- iov_len: buf.len(),
- }];
- let (read_count, file_count) = s2
- .recv_with_fds(&mut iovecs[..], &mut files)
- .expect("failed to recv data");
-
- assert_eq!(read_count, 6);
- assert_eq!(file_count, 0);
- assert_eq!(buf, [1, 1, 2, 21, 34, 55]);
- }
-
- #[test]
- fn send_recv_only_fd() {
- let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair");
-
- let evt = EventFd::new(0).expect("failed to create eventfd");
- let write_count = s1
- .send_with_fd([].as_ref(), evt.as_raw_fd())
- .expect("failed to send fd");
-
- assert_eq!(write_count, 0);
-
- let (read_count, file_opt) = s2.recv_with_fd(&mut []).expect("failed to recv fd");
-
- let mut file = file_opt.unwrap();
-
- assert_eq!(read_count, 0);
- assert!(file.as_raw_fd() >= 0);
- assert_ne!(file.as_raw_fd(), s1.as_raw_fd());
- assert_ne!(file.as_raw_fd(), s2.as_raw_fd());
- assert_ne!(file.as_raw_fd(), evt.as_raw_fd());
-
- file.write(unsafe { from_raw_parts(&1203u64 as *const u64 as *const u8, 8) })
- .expect("failed to write to sent fd");
-
- assert_eq!(evt.read().expect("failed to read from eventfd"), 1203);
- }
-
- #[test]
- fn send_recv_with_fd() {
- let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair");
-
- let evt = EventFd::new(0).expect("failed to create eventfd");
- let write_count = s1
- .send_with_fds(&[[237].as_ref()], &[evt.as_raw_fd()])
- .expect("failed to send fd");
-
- assert_eq!(write_count, 1);
-
- let mut files = [0; 2];
- let mut buf = [0u8];
- let mut iovecs = [iovec {
- iov_base: buf.as_mut_ptr() as *mut c_void,
- iov_len: buf.len(),
- }];
- let (read_count, file_count) = s2
- .recv_with_fds(&mut iovecs[..], &mut files)
- .expect("failed to recv fd");
-
- assert_eq!(read_count, 1);
- assert_eq!(buf[0], 237);
- assert_eq!(file_count, 1);
- assert!(files[0] >= 0);
- assert_ne!(files[0], s1.as_raw_fd());
- assert_ne!(files[0], s2.as_raw_fd());
- assert_ne!(files[0], evt.as_raw_fd());
-
- let mut file = unsafe { File::from_raw_fd(files[0]) };
-
- file.write(unsafe { from_raw_parts(&1203u64 as *const u64 as *const u8, 8) })
- .expect("failed to write to sent fd");
-
- assert_eq!(evt.read().expect("failed to read from eventfd"), 1203);
- }
-}
diff --git a/src/vsock.rs b/src/vsock.rs
index 4fb75f5..1e1b0b9 100644
--- a/src/vsock.rs
+++ b/src/vsock.rs
@@ -20,11 +20,11 @@ pub trait VhostVsock: VhostBackend {
///
/// # Arguments
/// * `cid` - CID to assign to the guest
- fn set_guest_cid(&mut self, cid: u64) -> Result<()>;
+ fn set_guest_cid(&self, cid: u64) -> Result<()>;
/// Tell the VHOST driver to start performing data transfer.
- fn start(&mut self) -> Result<()>;
+ fn start(&self) -> Result<()>;
/// Tell the VHOST driver to stop performing data transfer.
- fn stop(&mut self) -> Result<()>;
+ fn stop(&self) -> Result<()>;
}