aboutsummaryrefslogtreecommitdiff
path: root/src/coded_input_stream.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/coded_input_stream.rs')
-rw-r--r--src/coded_input_stream.rs84
1 files changed, 28 insertions, 56 deletions
diff --git a/src/coded_input_stream.rs b/src/coded_input_stream.rs
index 52a13a6..a49563c 100644
--- a/src/coded_input_stream.rs
+++ b/src/coded_input_stream.rs
@@ -6,19 +6,21 @@ use std::io;
use std::io::BufRead;
use std::io::Read;
use std::mem;
+use std::mem::MaybeUninit;
use std::slice;
#[cfg(feature = "bytes")]
-use crate::chars::Chars;
-#[cfg(feature = "bytes")]
use bytes::Bytes;
use crate::buf_read_iter::BufReadIter;
+#[cfg(feature = "bytes")]
+use crate::chars::Chars;
use crate::enums::ProtobufEnum;
use crate::error::ProtobufError;
use crate::error::ProtobufResult;
use crate::error::WireError;
use crate::message::Message;
+use crate::misc::maybe_ununit_array_assume_init;
use crate::unknown::UnknownValue;
use crate::wire_format;
use crate::zigzag::decode_zig_zag_32;
@@ -105,12 +107,21 @@ impl<'a> CodedInputStream<'a> {
}
/// Read bytes into given `buf`.
+ #[inline]
+ fn read_exact_uninit(&mut self, buf: &mut [MaybeUninit<u8>]) -> ProtobufResult<()> {
+ self.source.read_exact(buf)
+ }
+
+ /// Read bytes into given `buf`.
///
/// Return `0` on EOF.
// TODO: overload with `Read::read`
pub fn read(&mut self, buf: &mut [u8]) -> ProtobufResult<()> {
- self.source.read_exact(buf)?;
- Ok(())
+ // SAFETY: same layout
+ let buf = unsafe {
+ slice::from_raw_parts_mut(buf.as_mut_ptr() as *mut MaybeUninit<u8>, buf.len())
+ };
+ self.read_exact_uninit(buf)
}
/// Read exact number of bytes as `Bytes` object.
@@ -248,24 +259,20 @@ impl<'a> CodedInputStream<'a> {
/// Read little-endian 32-bit integer
pub fn read_raw_little_endian32(&mut self) -> ProtobufResult<u32> {
- let mut r = 0u32;
- let bytes: &mut [u8] = unsafe {
- let p: *mut u8 = mem::transmute(&mut r);
- slice::from_raw_parts_mut(p, mem::size_of::<u32>())
- };
- self.read(bytes)?;
- Ok(r.to_le())
+ let mut bytes = [MaybeUninit::uninit(); 4];
+ self.read_exact_uninit(&mut bytes)?;
+ // SAFETY: `read_exact` guarantees that the buffer is filled.
+ let bytes = unsafe { maybe_ununit_array_assume_init(bytes) };
+ Ok(u32::from_le_bytes(bytes))
}
/// Read little-endian 64-bit integer
pub fn read_raw_little_endian64(&mut self) -> ProtobufResult<u64> {
- let mut r = 0u64;
- let bytes: &mut [u8] = unsafe {
- let p: *mut u8 = mem::transmute(&mut r);
- slice::from_raw_parts_mut(p, mem::size_of::<u64>())
- };
- self.read(bytes)?;
- Ok(r.to_le())
+ let mut bytes = [MaybeUninit::uninit(); 8];
+ self.read_exact_uninit(&mut bytes)?;
+ // SAFETY: `read_exact` guarantees that the buffer is filled.
+ let bytes = unsafe { maybe_ununit_array_assume_init(bytes) };
+ Ok(u64::from_le_bytes(bytes))
}
/// Read tag
@@ -596,41 +603,7 @@ impl<'a> CodedInputStream<'a> {
/// Read raw bytes into the supplied vector. The vector will be resized as needed and
/// overwritten.
pub fn read_raw_bytes_into(&mut self, count: u32, target: &mut Vec<u8>) -> ProtobufResult<()> {
- if false {
- // Master uses this version, but keep existing version for a while
- // to avoid possible breakages.
- return self.source.read_exact_to_vec(count as usize, target);
- }
-
- let count = count as usize;
-
- // TODO: also do some limits when reading from unlimited source
- if count as u64 > self.source.bytes_until_limit() {
- return Err(ProtobufError::WireError(WireError::TruncatedMessage));
- }
-
- unsafe {
- target.set_len(0);
- }
-
- if count >= READ_RAW_BYTES_MAX_ALLOC {
- // avoid calling `reserve` on buf with very large buffer: could be a malformed message
-
- let mut take = self.by_ref().take(count as u64);
- take.read_to_end(target)?;
-
- if target.len() != count {
- return Err(ProtobufError::WireError(WireError::TruncatedMessage));
- }
- } else {
- target.reserve(count);
- unsafe {
- target.set_len(count);
- }
-
- self.source.read_exact(target)?;
- }
- Ok(())
+ self.source.read_exact_to_vec(count as usize, target)
}
/// Read exact number of bytes
@@ -795,13 +768,12 @@ mod test {
use std::io::BufRead;
use std::io::Read;
+ use super::CodedInputStream;
+ use super::READ_RAW_BYTES_MAX_ALLOC;
use crate::error::ProtobufError;
use crate::error::ProtobufResult;
use crate::hex::decode_hex;
- use super::CodedInputStream;
- use super::READ_RAW_BYTES_MAX_ALLOC;
-
fn test_read_partial<F>(hex: &str, mut callback: F)
where
F: FnMut(&mut CodedInputStream),