summaryrefslogtreecommitdiff
path: root/src/vhost_user/slave_fs_cache.rs
blob: 1804c7a87d3e2773bd809815865d4c0d372ece5d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
// Copyright (C) 2020 Alibaba Cloud Computing. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

use super::connection::Endpoint;
use super::message::*;
use super::{Error, HandlerResult, Result, VhostUserMasterReqHandler};
use std::io;
use std::mem;
use std::os::unix::io::RawFd;
use std::os::unix::net::UnixStream;
use std::sync::{Arc, Mutex};

struct SlaveFsCacheReqInternal {
    sock: Endpoint<SlaveReq>,
}

/// A vhost-user slave endpoint which sends fs cache requests to the master
#[derive(Clone)]
pub struct SlaveFsCacheReq {
    // underlying Unix domain socket for communication
    node: Arc<Mutex<SlaveFsCacheReqInternal>>,

    // whether the endpoint has encountered any failure
    error: Option<i32>,
}

impl SlaveFsCacheReq {
    fn new(ep: Endpoint<SlaveReq>) -> Self {
        SlaveFsCacheReq {
            node: Arc::new(Mutex::new(SlaveFsCacheReqInternal { sock: ep })),
            error: None,
        }
    }

    /// Create a new instance.
    pub fn from_stream(sock: UnixStream) -> Self {
        Self::new(Endpoint::<SlaveReq>::from_stream(sock))
    }

    fn send_message(
        &mut self,
        flags: SlaveReq,
        fs: &VhostUserFSSlaveMsg,
        fds: Option<&[RawFd]>,
    ) -> Result<u64> {
        self.check_state()?;

        let len = mem::size_of::<VhostUserFSSlaveMsg>();
        let mut hdr = VhostUserMsgHeader::new(flags, 0, len as u32);
        hdr.set_need_reply(true);
        self.node.lock().unwrap().sock.send_message(&hdr, fs, fds)?;

        self.wait_for_ack(&hdr)
    }

    fn wait_for_ack(&mut self, hdr: &VhostUserMsgHeader<SlaveReq>) -> Result<u64> {
        self.check_state()?;
        let (reply, body, rfds) = self.node.lock().unwrap().sock.recv_body::<VhostUserU64>()?;
        if !reply.is_reply_for(&hdr) || rfds.is_some() || !body.is_valid() {
            Endpoint::<SlaveReq>::close_rfds(rfds);
            return Err(Error::InvalidMessage);
        }
        if body.value != 0 {
            return Err(Error::MasterInternalError);
        }
        Ok(0)
    }

    fn check_state(&self) -> Result<u64> {
        match self.error {
            Some(e) => Err(Error::SocketBroken(std::io::Error::from_raw_os_error(e))),
            None => Ok(0),
        }
    }

    /// Mark endpoint as failed with specified error code.
    pub fn set_failed(&mut self, error: i32) {
        self.error = Some(error);
    }
}

impl VhostUserMasterReqHandler for SlaveFsCacheReq {
    /// Handle virtio-fs map file requests from the slave.
    fn fs_slave_map(&mut self, fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<u64> {
        self.send_message(SlaveReq::FS_MAP, fs, Some(&[fd]))
            .or_else(|e| Err(io::Error::new(io::ErrorKind::Other, format!("{}", e))))
    }

    /// Handle virtio-fs unmap file requests from the slave.
    fn fs_slave_unmap(&mut self, fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
        self.send_message(SlaveReq::FS_UNMAP, fs, None)
            .or_else(|e| Err(io::Error::new(io::ErrorKind::Other, format!("{}", e))))
    }
}