diff options
Diffstat (limited to 'src/coded_input_stream.rs')
-rw-r--r-- | src/coded_input_stream.rs | 84 |
1 files changed, 28 insertions, 56 deletions
diff --git a/src/coded_input_stream.rs b/src/coded_input_stream.rs index 52a13a6..a49563c 100644 --- a/src/coded_input_stream.rs +++ b/src/coded_input_stream.rs @@ -6,19 +6,21 @@ use std::io; use std::io::BufRead; use std::io::Read; use std::mem; +use std::mem::MaybeUninit; use std::slice; #[cfg(feature = "bytes")] -use crate::chars::Chars; -#[cfg(feature = "bytes")] use bytes::Bytes; use crate::buf_read_iter::BufReadIter; +#[cfg(feature = "bytes")] +use crate::chars::Chars; use crate::enums::ProtobufEnum; use crate::error::ProtobufError; use crate::error::ProtobufResult; use crate::error::WireError; use crate::message::Message; +use crate::misc::maybe_ununit_array_assume_init; use crate::unknown::UnknownValue; use crate::wire_format; use crate::zigzag::decode_zig_zag_32; @@ -105,12 +107,21 @@ impl<'a> CodedInputStream<'a> { } /// Read bytes into given `buf`. + #[inline] + fn read_exact_uninit(&mut self, buf: &mut [MaybeUninit<u8>]) -> ProtobufResult<()> { + self.source.read_exact(buf) + } + + /// Read bytes into given `buf`. /// /// Return `0` on EOF. // TODO: overload with `Read::read` pub fn read(&mut self, buf: &mut [u8]) -> ProtobufResult<()> { - self.source.read_exact(buf)?; - Ok(()) + // SAFETY: same layout + let buf = unsafe { + slice::from_raw_parts_mut(buf.as_mut_ptr() as *mut MaybeUninit<u8>, buf.len()) + }; + self.read_exact_uninit(buf) } /// Read exact number of bytes as `Bytes` object. @@ -248,24 +259,20 @@ impl<'a> CodedInputStream<'a> { /// Read little-endian 32-bit integer pub fn read_raw_little_endian32(&mut self) -> ProtobufResult<u32> { - let mut r = 0u32; - let bytes: &mut [u8] = unsafe { - let p: *mut u8 = mem::transmute(&mut r); - slice::from_raw_parts_mut(p, mem::size_of::<u32>()) - }; - self.read(bytes)?; - Ok(r.to_le()) + let mut bytes = [MaybeUninit::uninit(); 4]; + self.read_exact_uninit(&mut bytes)?; + // SAFETY: `read_exact` guarantees that the buffer is filled. + let bytes = unsafe { maybe_ununit_array_assume_init(bytes) }; + Ok(u32::from_le_bytes(bytes)) } /// Read little-endian 64-bit integer pub fn read_raw_little_endian64(&mut self) -> ProtobufResult<u64> { - let mut r = 0u64; - let bytes: &mut [u8] = unsafe { - let p: *mut u8 = mem::transmute(&mut r); - slice::from_raw_parts_mut(p, mem::size_of::<u64>()) - }; - self.read(bytes)?; - Ok(r.to_le()) + let mut bytes = [MaybeUninit::uninit(); 8]; + self.read_exact_uninit(&mut bytes)?; + // SAFETY: `read_exact` guarantees that the buffer is filled. + let bytes = unsafe { maybe_ununit_array_assume_init(bytes) }; + Ok(u64::from_le_bytes(bytes)) } /// Read tag @@ -596,41 +603,7 @@ impl<'a> CodedInputStream<'a> { /// Read raw bytes into the supplied vector. The vector will be resized as needed and /// overwritten. pub fn read_raw_bytes_into(&mut self, count: u32, target: &mut Vec<u8>) -> ProtobufResult<()> { - if false { - // Master uses this version, but keep existing version for a while - // to avoid possible breakages. - return self.source.read_exact_to_vec(count as usize, target); - } - - let count = count as usize; - - // TODO: also do some limits when reading from unlimited source - if count as u64 > self.source.bytes_until_limit() { - return Err(ProtobufError::WireError(WireError::TruncatedMessage)); - } - - unsafe { - target.set_len(0); - } - - if count >= READ_RAW_BYTES_MAX_ALLOC { - // avoid calling `reserve` on buf with very large buffer: could be a malformed message - - let mut take = self.by_ref().take(count as u64); - take.read_to_end(target)?; - - if target.len() != count { - return Err(ProtobufError::WireError(WireError::TruncatedMessage)); - } - } else { - target.reserve(count); - unsafe { - target.set_len(count); - } - - self.source.read_exact(target)?; - } - Ok(()) + self.source.read_exact_to_vec(count as usize, target) } /// Read exact number of bytes @@ -795,13 +768,12 @@ mod test { use std::io::BufRead; use std::io::Read; + use super::CodedInputStream; + use super::READ_RAW_BYTES_MAX_ALLOC; use crate::error::ProtobufError; use crate::error::ProtobufResult; use crate::hex::decode_hex; - use super::CodedInputStream; - use super::READ_RAW_BYTES_MAX_ALLOC; - fn test_read_partial<F>(hex: &str, mut callback: F) where F: FnMut(&mut CodedInputStream), |