diff options
Diffstat (limited to 'src/buf_read_iter.rs')
-rw-r--r-- | src/buf_read_iter.rs | 191 |
1 files changed, 99 insertions, 92 deletions
diff --git a/src/buf_read_iter.rs b/src/buf_read_iter.rs index 37bc353..ff1fce2 100644 --- a/src/buf_read_iter.rs +++ b/src/buf_read_iter.rs @@ -3,6 +3,7 @@ use std::io::BufRead; use std::io::BufReader; use std::io::Read; use std::mem; +use std::mem::MaybeUninit; use std::u64; #[cfg(feature = "bytes")] @@ -14,8 +15,11 @@ use bytes::Bytes; #[cfg(feature = "bytes")] use bytes::BytesMut; +use crate::buf_read_or_reader::BufReadOrReader; use crate::coded_input_stream::READ_RAW_BYTES_MAX_ALLOC; use crate::error::WireError; +use crate::misc::maybe_uninit_write_slice; +use crate::misc::vec_spare_capacity_mut; use crate::ProtobufError; use crate::ProtobufResult; @@ -29,8 +33,7 @@ const NO_LIMIT: u64 = u64::MAX; /// Hold all possible combinations of input source enum InputSource<'a> { - BufRead(&'a mut dyn BufRead), - Read(BufReader<&'a mut dyn Read>), + Read(BufReadOrReader<'a>), Slice(&'a [u8]), #[cfg(feature = "bytes")] Bytes(&'a Bytes), @@ -50,7 +53,7 @@ enum InputSource<'a> { /// It is important for `CodedInputStream` performance that small reads /// (e. g. 4 bytes reads) do not involve virtual calls or switches. /// This is achievable with `BufReadIter`. -pub struct BufReadIter<'a> { +pub(crate) struct BufReadIter<'a> { input_source: InputSource<'a>, buf: &'a [u8], pos_within_buf: usize, @@ -62,22 +65,19 @@ pub struct BufReadIter<'a> { impl<'a> Drop for BufReadIter<'a> { fn drop(&mut self) { match self.input_source { - InputSource::BufRead(ref mut buf_read) => buf_read.consume(self.pos_within_buf), - InputSource::Read(_) => { - // Nothing to flush, because we own BufReader - } + InputSource::Read(ref mut buf_read) => buf_read.consume(self.pos_within_buf), _ => {} } } } impl<'ignore> BufReadIter<'ignore> { - pub fn from_read<'a>(read: &'a mut dyn Read) -> BufReadIter<'a> { + pub(crate) fn from_read<'a>(read: &'a mut dyn Read) -> BufReadIter<'a> { BufReadIter { - input_source: InputSource::Read(BufReader::with_capacity( + input_source: InputSource::Read(BufReadOrReader::BufReader(BufReader::with_capacity( INPUT_STREAM_BUFFER_SIZE, read, - )), + ))), buf: &[], pos_within_buf: 0, limit_within_buf: 0, @@ -86,9 +86,9 @@ impl<'ignore> BufReadIter<'ignore> { } } - pub fn from_buf_read<'a>(buf_read: &'a mut dyn BufRead) -> BufReadIter<'a> { + pub(crate) fn from_buf_read<'a>(buf_read: &'a mut dyn BufRead) -> BufReadIter<'a> { BufReadIter { - input_source: InputSource::BufRead(buf_read), + input_source: InputSource::Read(BufReadOrReader::BufRead(buf_read)), buf: &[], pos_within_buf: 0, limit_within_buf: 0, @@ -97,7 +97,7 @@ impl<'ignore> BufReadIter<'ignore> { } } - pub fn from_byte_slice<'a>(bytes: &'a [u8]) -> BufReadIter<'a> { + pub(crate) fn from_byte_slice<'a>(bytes: &'a [u8]) -> BufReadIter<'a> { BufReadIter { input_source: InputSource::Slice(bytes), buf: bytes, @@ -109,7 +109,7 @@ impl<'ignore> BufReadIter<'ignore> { } #[cfg(feature = "bytes")] - pub fn from_bytes<'a>(bytes: &'a Bytes) -> BufReadIter<'a> { + pub(crate) fn from_bytes<'a>(bytes: &'a Bytes) -> BufReadIter<'a> { BufReadIter { input_source: InputSource::Bytes(bytes), buf: &bytes, @@ -128,7 +128,7 @@ impl<'ignore> BufReadIter<'ignore> { } #[inline(always)] - pub fn pos(&self) -> u64 { + pub(crate) fn pos(&self) -> u64 { self.pos_of_buf_start + self.pos_within_buf as u64 } @@ -144,7 +144,7 @@ impl<'ignore> BufReadIter<'ignore> { self.assertions(); } - pub fn push_limit(&mut self, limit: u64) -> ProtobufResult<u64> { + pub(crate) fn push_limit(&mut self, limit: u64) -> ProtobufResult<u64> { let new_limit = match self.pos().checked_add(limit) { Some(new_limit) => new_limit, None => return Err(ProtobufError::WireError(WireError::Other)), @@ -162,7 +162,7 @@ impl<'ignore> BufReadIter<'ignore> { } #[inline] - pub fn pop_limit(&mut self, limit: u64) { + pub(crate) fn pop_limit(&mut self, limit: u64) { assert!(limit >= self.limit); self.limit = limit; @@ -171,7 +171,7 @@ impl<'ignore> BufReadIter<'ignore> { } #[inline] - pub fn remaining_in_buf(&self) -> &[u8] { + pub(crate) fn remaining_in_buf(&self) -> &[u8] { if USE_UNSAFE_FOR_SPEED { unsafe { &self @@ -184,12 +184,12 @@ impl<'ignore> BufReadIter<'ignore> { } #[inline(always)] - pub fn remaining_in_buf_len(&self) -> usize { + pub(crate) fn remaining_in_buf_len(&self) -> usize { self.limit_within_buf - self.pos_within_buf } #[inline(always)] - pub fn bytes_until_limit(&self) -> u64 { + pub(crate) fn bytes_until_limit(&self) -> u64 { if self.limit == NO_LIMIT { NO_LIMIT } else { @@ -198,7 +198,7 @@ impl<'ignore> BufReadIter<'ignore> { } #[inline(always)] - pub fn eof(&mut self) -> ProtobufResult<bool> { + pub(crate) fn eof(&mut self) -> ProtobufResult<bool> { if self.pos_within_buf == self.limit_within_buf { Ok(self.fill_buf()?.is_empty()) } else { @@ -207,7 +207,7 @@ impl<'ignore> BufReadIter<'ignore> { } #[inline(always)] - pub fn read_byte(&mut self) -> ProtobufResult<u8> { + pub(crate) fn read_byte(&mut self) -> ProtobufResult<u8> { if self.pos_within_buf == self.limit_within_buf { self.do_fill_buf()?; if self.remaining_in_buf_len() == 0 { @@ -239,52 +239,8 @@ impl<'ignore> BufReadIter<'ignore> { Ok(len) } - /// Read exact number of bytes into `Vec`. - /// - /// `Vec` is cleared in the beginning. - pub fn read_exact_to_vec(&mut self, count: usize, target: &mut Vec<u8>) -> ProtobufResult<()> { - // TODO: also do some limits when reading from unlimited source - if count as u64 > self.bytes_until_limit() { - return Err(ProtobufError::WireError(WireError::TruncatedMessage)); - } - - target.clear(); - - if count >= READ_RAW_BYTES_MAX_ALLOC && count > target.capacity() { - // avoid calling `reserve` on buf with very large buffer: could be a malformed message - - target.reserve(READ_RAW_BYTES_MAX_ALLOC); - - while target.len() < count { - let need_to_read = count - target.len(); - if need_to_read <= target.len() { - target.reserve_exact(need_to_read); - } else { - target.reserve(1); - } - - let max = cmp::min(target.capacity() - target.len(), need_to_read); - let read = self.read_to_vec(target, max)?; - if read == 0 { - return Err(ProtobufError::WireError(WireError::TruncatedMessage)); - } - } - } else { - target.reserve_exact(count); - - unsafe { - self.read_exact(&mut target.get_unchecked_mut(..count))?; - target.set_len(count); - } - } - - debug_assert_eq!(count, target.len()); - - Ok(()) - } - #[cfg(feature = "bytes")] - pub fn read_exact_bytes(&mut self, len: usize) -> ProtobufResult<Bytes> { + pub(crate) fn read_exact_bytes(&mut self, len: usize) -> ProtobufResult<Bytes> { if let InputSource::Bytes(bytes) = self.input_source { let end = match self.pos_within_buf.checked_add(len) { Some(end) => end, @@ -318,13 +274,13 @@ impl<'ignore> BufReadIter<'ignore> { } #[cfg(feature = "bytes")] - unsafe fn uninit_slice_as_mut_slice(slice: &mut UninitSlice) -> &mut [u8] { + unsafe fn uninit_slice_as_mut_slice(slice: &mut UninitSlice) -> &mut [MaybeUninit<u8>] { use std::slice; - slice::from_raw_parts_mut(slice.as_mut_ptr(), slice.len()) + slice::from_raw_parts_mut(slice.as_mut_ptr() as *mut MaybeUninit<u8>, slice.len()) } /// Returns 0 when EOF or limit reached. - pub fn read(&mut self, buf: &mut [u8]) -> ProtobufResult<usize> { + pub(crate) fn read(&mut self, buf: &mut [u8]) -> ProtobufResult<usize> { self.fill_buf()?; let rem = &self.buf[self.pos_within_buf..self.limit_within_buf]; @@ -335,14 +291,7 @@ impl<'ignore> BufReadIter<'ignore> { Ok(len) } - pub fn read_exact(&mut self, buf: &mut [u8]) -> ProtobufResult<()> { - if self.remaining_in_buf_len() >= buf.len() { - let buf_len = buf.len(); - buf.copy_from_slice(&self.buf[self.pos_within_buf..self.pos_within_buf + buf_len]); - self.pos_within_buf += buf_len; - return Ok(()); - } - + fn read_exact_slow(&mut self, buf: &mut [MaybeUninit<u8>]) -> ProtobufResult<()> { if self.bytes_until_limit() < buf.len() as u64 { return Err(ProtobufError::WireError(WireError::UnexpectedEof)); } @@ -356,11 +305,7 @@ impl<'ignore> BufReadIter<'ignore> { match self.input_source { InputSource::Read(ref mut buf_read) => { buf_read.consume(consume); - buf_read.read_exact(buf)?; - } - InputSource::BufRead(ref mut buf_read) => { - buf_read.consume(consume); - buf_read.read_exact(buf)?; + buf_read.read_exact_uninit(buf)?; } _ => { return Err(ProtobufError::WireError(WireError::UnexpectedEof)); @@ -374,6 +319,65 @@ impl<'ignore> BufReadIter<'ignore> { Ok(()) } + #[inline] + pub(crate) fn read_exact(&mut self, buf: &mut [MaybeUninit<u8>]) -> ProtobufResult<()> { + if self.remaining_in_buf_len() >= buf.len() { + let buf_len = buf.len(); + maybe_uninit_write_slice( + buf, + &self.buf[self.pos_within_buf..self.pos_within_buf + buf_len], + ); + self.pos_within_buf += buf_len; + return Ok(()); + } + + self.read_exact_slow(buf) + } + + /// Read exact number of bytes into `Vec`. + /// + /// `Vec` is cleared in the beginning. + pub fn read_exact_to_vec(&mut self, count: usize, target: &mut Vec<u8>) -> ProtobufResult<()> { + // TODO: also do some limits when reading from unlimited source + if count as u64 > self.bytes_until_limit() { + return Err(ProtobufError::WireError(WireError::TruncatedMessage)); + } + + target.clear(); + + if count >= READ_RAW_BYTES_MAX_ALLOC && count > target.capacity() { + // avoid calling `reserve` on buf with very large buffer: could be a malformed message + + target.reserve(READ_RAW_BYTES_MAX_ALLOC); + + while target.len() < count { + let need_to_read = count - target.len(); + if need_to_read <= target.len() { + target.reserve_exact(need_to_read); + } else { + target.reserve(1); + } + + let max = cmp::min(target.capacity() - target.len(), need_to_read); + let read = self.read_to_vec(target, max)?; + if read == 0 { + return Err(ProtobufError::WireError(WireError::TruncatedMessage)); + } + } + } else { + target.reserve_exact(count); + + unsafe { + self.read_exact(&mut vec_spare_capacity_mut(target)[..count])?; + target.set_len(count); + } + } + + debug_assert_eq!(count, target.len()); + + Ok(()) + } + fn do_fill_buf(&mut self) -> ProtobufResult<()> { debug_assert!(self.pos_within_buf == self.limit_within_buf); @@ -394,10 +398,6 @@ impl<'ignore> BufReadIter<'ignore> { buf_read.consume(consume); self.buf = unsafe { mem::transmute(buf_read.fill_buf()?) }; } - InputSource::BufRead(ref mut buf_read) => { - buf_read.consume(consume); - self.buf = unsafe { mem::transmute(buf_read.fill_buf()?) }; - } _ => { return Ok(()); } @@ -409,7 +409,7 @@ impl<'ignore> BufReadIter<'ignore> { } #[inline(always)] - pub fn fill_buf(&mut self) -> ProtobufResult<&[u8]> { + pub(crate) fn fill_buf(&mut self) -> ProtobufResult<&[u8]> { if self.pos_within_buf == self.limit_within_buf { self.do_fill_buf()?; } @@ -425,7 +425,7 @@ impl<'ignore> BufReadIter<'ignore> { } #[inline(always)] - pub fn consume(&mut self, amt: usize) { + pub(crate) fn consume(&mut self, amt: usize) { assert!(amt <= self.limit_within_buf - self.pos_within_buf); self.pos_within_buf += amt; } @@ -433,9 +433,10 @@ impl<'ignore> BufReadIter<'ignore> { #[cfg(all(test, feature = "bytes"))] mod test_bytes { - use super::*; use std::io::Write; + use super::*; + fn make_long_string(len: usize) -> Vec<u8> { let mut s = Vec::new(); while s.len() < len { @@ -467,11 +468,12 @@ mod test_bytes { #[cfg(test)] mod test { - use super::*; use std::io; use std::io::BufRead; use std::io::Read; + use super::*; + #[test] fn eof_at_limit() { struct Read5ThenPanic { @@ -509,7 +511,12 @@ mod test { let _prev_limit = buf_read_iter.push_limit(5); buf_read_iter.read_byte().expect("read_byte"); buf_read_iter - .read_exact(&mut [1, 2, 3, 4]) + .read_exact(&mut [ + MaybeUninit::uninit(), + MaybeUninit::uninit(), + MaybeUninit::uninit(), + MaybeUninit::uninit(), + ]) .expect("read_exact"); assert!(buf_read_iter.eof().expect("eof")); } |