diff options
-rw-r--r-- | src/vhost_user/slave_req_handler.rs | 26 |
1 files changed, 17 insertions, 9 deletions
diff --git a/src/vhost_user/slave_req_handler.rs b/src/vhost_user/slave_req_handler.rs index 190501b..344afff 100644 --- a/src/vhost_user/slave_req_handler.rs +++ b/src/vhost_user/slave_req_handler.rs @@ -400,6 +400,7 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 { return Err(Error::InvalidOperation); } + self.check_request_size(&hdr, size, hdr.get_size() as usize)?; self.get_config(&hdr, &buf)?; } MasterReq::SET_CONFIG => { @@ -413,6 +414,7 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { if self.acked_protocol_features & VhostUserProtocolFeatures::SLAVE_REQ.bits() == 0 { return Err(Error::InvalidOperation); } + self.check_request_size(&hdr, size, hdr.get_size() as usize)?; self.set_slave_req_fd(&hdr, rfds)?; } _ => { @@ -477,11 +479,14 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { } fn get_config(&mut self, hdr: &VhostUserMsgHeader<MasterReq>, buf: &[u8]) -> Result<()> { - let msg = unsafe { &*(buf.as_ptr() as *const VhostUserConfig) }; + let payload_offset = mem::size_of::<VhostUserConfig>(); + if buf.len() > MAX_MSG_SIZE || buf.len() < payload_offset { + return Err(Error::InvalidMessage); + } + let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const VhostUserConfig) }; if !msg.is_valid() { return Err(Error::InvalidMessage); } - let payload_offset = mem::size_of::<VhostUserConfig>(); if buf.len() - payload_offset != msg.size as usize { return Err(Error::InvalidMessage); } @@ -517,10 +522,10 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { size: usize, buf: &[u8], ) -> Result<()> { - if size < mem::size_of::<VhostUserConfig>() { + if size > MAX_MSG_SIZE || size < mem::size_of::<VhostUserConfig>() { return Err(Error::InvalidMessage); } - let msg = unsafe { &*(buf.as_ptr() as *const VhostUserConfig) }; + let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const VhostUserConfig) }; if !msg.is_valid() { return Err(Error::InvalidMessage); } @@ -562,7 +567,10 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { buf: &[u8], rfds: Option<Vec<RawFd>>, ) -> Result<(u8, Option<RawFd>)> { - let msg = unsafe { &*(buf.as_ptr() as *const VhostUserU64) }; + if buf.len() > MAX_MSG_SIZE || buf.len() < mem::size_of::<VhostUserU64>() { + return Err(Error::InvalidMessage); + } + let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const VhostUserU64) }; if !msg.is_valid() { return Err(Error::InvalidMessage); } @@ -643,14 +651,14 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { } } - fn extract_request_body<'a, T: Sized + VhostUserMsgValidator>( + fn extract_request_body<T: Sized + VhostUserMsgValidator>( &self, hdr: &VhostUserMsgHeader<MasterReq>, size: usize, - buf: &'a [u8], - ) -> Result<&'a T> { + buf: &[u8], + ) -> Result<T> { self.check_request_size(hdr, size, mem::size_of::<T>())?; - let msg = unsafe { &*(buf.as_ptr() as *const T) }; + let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const T) }; if !msg.is_valid() { return Err(Error::InvalidMessage); } |