aboutsummaryrefslogtreecommitdiff
path: root/src/buf_read_iter.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/buf_read_iter.rs')
-rw-r--r--src/buf_read_iter.rs191
1 files changed, 99 insertions, 92 deletions
diff --git a/src/buf_read_iter.rs b/src/buf_read_iter.rs
index 37bc353..ff1fce2 100644
--- a/src/buf_read_iter.rs
+++ b/src/buf_read_iter.rs
@@ -3,6 +3,7 @@ use std::io::BufRead;
use std::io::BufReader;
use std::io::Read;
use std::mem;
+use std::mem::MaybeUninit;
use std::u64;
#[cfg(feature = "bytes")]
@@ -14,8 +15,11 @@ use bytes::Bytes;
#[cfg(feature = "bytes")]
use bytes::BytesMut;
+use crate::buf_read_or_reader::BufReadOrReader;
use crate::coded_input_stream::READ_RAW_BYTES_MAX_ALLOC;
use crate::error::WireError;
+use crate::misc::maybe_uninit_write_slice;
+use crate::misc::vec_spare_capacity_mut;
use crate::ProtobufError;
use crate::ProtobufResult;
@@ -29,8 +33,7 @@ const NO_LIMIT: u64 = u64::MAX;
/// Hold all possible combinations of input source
enum InputSource<'a> {
- BufRead(&'a mut dyn BufRead),
- Read(BufReader<&'a mut dyn Read>),
+ Read(BufReadOrReader<'a>),
Slice(&'a [u8]),
#[cfg(feature = "bytes")]
Bytes(&'a Bytes),
@@ -50,7 +53,7 @@ enum InputSource<'a> {
/// It is important for `CodedInputStream` performance that small reads
/// (e. g. 4 bytes reads) do not involve virtual calls or switches.
/// This is achievable with `BufReadIter`.
-pub struct BufReadIter<'a> {
+pub(crate) struct BufReadIter<'a> {
input_source: InputSource<'a>,
buf: &'a [u8],
pos_within_buf: usize,
@@ -62,22 +65,19 @@ pub struct BufReadIter<'a> {
impl<'a> Drop for BufReadIter<'a> {
fn drop(&mut self) {
match self.input_source {
- InputSource::BufRead(ref mut buf_read) => buf_read.consume(self.pos_within_buf),
- InputSource::Read(_) => {
- // Nothing to flush, because we own BufReader
- }
+ InputSource::Read(ref mut buf_read) => buf_read.consume(self.pos_within_buf),
_ => {}
}
}
}
impl<'ignore> BufReadIter<'ignore> {
- pub fn from_read<'a>(read: &'a mut dyn Read) -> BufReadIter<'a> {
+ pub(crate) fn from_read<'a>(read: &'a mut dyn Read) -> BufReadIter<'a> {
BufReadIter {
- input_source: InputSource::Read(BufReader::with_capacity(
+ input_source: InputSource::Read(BufReadOrReader::BufReader(BufReader::with_capacity(
INPUT_STREAM_BUFFER_SIZE,
read,
- )),
+ ))),
buf: &[],
pos_within_buf: 0,
limit_within_buf: 0,
@@ -86,9 +86,9 @@ impl<'ignore> BufReadIter<'ignore> {
}
}
- pub fn from_buf_read<'a>(buf_read: &'a mut dyn BufRead) -> BufReadIter<'a> {
+ pub(crate) fn from_buf_read<'a>(buf_read: &'a mut dyn BufRead) -> BufReadIter<'a> {
BufReadIter {
- input_source: InputSource::BufRead(buf_read),
+ input_source: InputSource::Read(BufReadOrReader::BufRead(buf_read)),
buf: &[],
pos_within_buf: 0,
limit_within_buf: 0,
@@ -97,7 +97,7 @@ impl<'ignore> BufReadIter<'ignore> {
}
}
- pub fn from_byte_slice<'a>(bytes: &'a [u8]) -> BufReadIter<'a> {
+ pub(crate) fn from_byte_slice<'a>(bytes: &'a [u8]) -> BufReadIter<'a> {
BufReadIter {
input_source: InputSource::Slice(bytes),
buf: bytes,
@@ -109,7 +109,7 @@ impl<'ignore> BufReadIter<'ignore> {
}
#[cfg(feature = "bytes")]
- pub fn from_bytes<'a>(bytes: &'a Bytes) -> BufReadIter<'a> {
+ pub(crate) fn from_bytes<'a>(bytes: &'a Bytes) -> BufReadIter<'a> {
BufReadIter {
input_source: InputSource::Bytes(bytes),
buf: &bytes,
@@ -128,7 +128,7 @@ impl<'ignore> BufReadIter<'ignore> {
}
#[inline(always)]
- pub fn pos(&self) -> u64 {
+ pub(crate) fn pos(&self) -> u64 {
self.pos_of_buf_start + self.pos_within_buf as u64
}
@@ -144,7 +144,7 @@ impl<'ignore> BufReadIter<'ignore> {
self.assertions();
}
- pub fn push_limit(&mut self, limit: u64) -> ProtobufResult<u64> {
+ pub(crate) fn push_limit(&mut self, limit: u64) -> ProtobufResult<u64> {
let new_limit = match self.pos().checked_add(limit) {
Some(new_limit) => new_limit,
None => return Err(ProtobufError::WireError(WireError::Other)),
@@ -162,7 +162,7 @@ impl<'ignore> BufReadIter<'ignore> {
}
#[inline]
- pub fn pop_limit(&mut self, limit: u64) {
+ pub(crate) fn pop_limit(&mut self, limit: u64) {
assert!(limit >= self.limit);
self.limit = limit;
@@ -171,7 +171,7 @@ impl<'ignore> BufReadIter<'ignore> {
}
#[inline]
- pub fn remaining_in_buf(&self) -> &[u8] {
+ pub(crate) fn remaining_in_buf(&self) -> &[u8] {
if USE_UNSAFE_FOR_SPEED {
unsafe {
&self
@@ -184,12 +184,12 @@ impl<'ignore> BufReadIter<'ignore> {
}
#[inline(always)]
- pub fn remaining_in_buf_len(&self) -> usize {
+ pub(crate) fn remaining_in_buf_len(&self) -> usize {
self.limit_within_buf - self.pos_within_buf
}
#[inline(always)]
- pub fn bytes_until_limit(&self) -> u64 {
+ pub(crate) fn bytes_until_limit(&self) -> u64 {
if self.limit == NO_LIMIT {
NO_LIMIT
} else {
@@ -198,7 +198,7 @@ impl<'ignore> BufReadIter<'ignore> {
}
#[inline(always)]
- pub fn eof(&mut self) -> ProtobufResult<bool> {
+ pub(crate) fn eof(&mut self) -> ProtobufResult<bool> {
if self.pos_within_buf == self.limit_within_buf {
Ok(self.fill_buf()?.is_empty())
} else {
@@ -207,7 +207,7 @@ impl<'ignore> BufReadIter<'ignore> {
}
#[inline(always)]
- pub fn read_byte(&mut self) -> ProtobufResult<u8> {
+ pub(crate) fn read_byte(&mut self) -> ProtobufResult<u8> {
if self.pos_within_buf == self.limit_within_buf {
self.do_fill_buf()?;
if self.remaining_in_buf_len() == 0 {
@@ -239,52 +239,8 @@ impl<'ignore> BufReadIter<'ignore> {
Ok(len)
}
- /// Read exact number of bytes into `Vec`.
- ///
- /// `Vec` is cleared in the beginning.
- pub fn read_exact_to_vec(&mut self, count: usize, target: &mut Vec<u8>) -> ProtobufResult<()> {
- // TODO: also do some limits when reading from unlimited source
- if count as u64 > self.bytes_until_limit() {
- return Err(ProtobufError::WireError(WireError::TruncatedMessage));
- }
-
- target.clear();
-
- if count >= READ_RAW_BYTES_MAX_ALLOC && count > target.capacity() {
- // avoid calling `reserve` on buf with very large buffer: could be a malformed message
-
- target.reserve(READ_RAW_BYTES_MAX_ALLOC);
-
- while target.len() < count {
- let need_to_read = count - target.len();
- if need_to_read <= target.len() {
- target.reserve_exact(need_to_read);
- } else {
- target.reserve(1);
- }
-
- let max = cmp::min(target.capacity() - target.len(), need_to_read);
- let read = self.read_to_vec(target, max)?;
- if read == 0 {
- return Err(ProtobufError::WireError(WireError::TruncatedMessage));
- }
- }
- } else {
- target.reserve_exact(count);
-
- unsafe {
- self.read_exact(&mut target.get_unchecked_mut(..count))?;
- target.set_len(count);
- }
- }
-
- debug_assert_eq!(count, target.len());
-
- Ok(())
- }
-
#[cfg(feature = "bytes")]
- pub fn read_exact_bytes(&mut self, len: usize) -> ProtobufResult<Bytes> {
+ pub(crate) fn read_exact_bytes(&mut self, len: usize) -> ProtobufResult<Bytes> {
if let InputSource::Bytes(bytes) = self.input_source {
let end = match self.pos_within_buf.checked_add(len) {
Some(end) => end,
@@ -318,13 +274,13 @@ impl<'ignore> BufReadIter<'ignore> {
}
#[cfg(feature = "bytes")]
- unsafe fn uninit_slice_as_mut_slice(slice: &mut UninitSlice) -> &mut [u8] {
+ unsafe fn uninit_slice_as_mut_slice(slice: &mut UninitSlice) -> &mut [MaybeUninit<u8>] {
use std::slice;
- slice::from_raw_parts_mut(slice.as_mut_ptr(), slice.len())
+ slice::from_raw_parts_mut(slice.as_mut_ptr() as *mut MaybeUninit<u8>, slice.len())
}
/// Returns 0 when EOF or limit reached.
- pub fn read(&mut self, buf: &mut [u8]) -> ProtobufResult<usize> {
+ pub(crate) fn read(&mut self, buf: &mut [u8]) -> ProtobufResult<usize> {
self.fill_buf()?;
let rem = &self.buf[self.pos_within_buf..self.limit_within_buf];
@@ -335,14 +291,7 @@ impl<'ignore> BufReadIter<'ignore> {
Ok(len)
}
- pub fn read_exact(&mut self, buf: &mut [u8]) -> ProtobufResult<()> {
- if self.remaining_in_buf_len() >= buf.len() {
- let buf_len = buf.len();
- buf.copy_from_slice(&self.buf[self.pos_within_buf..self.pos_within_buf + buf_len]);
- self.pos_within_buf += buf_len;
- return Ok(());
- }
-
+ fn read_exact_slow(&mut self, buf: &mut [MaybeUninit<u8>]) -> ProtobufResult<()> {
if self.bytes_until_limit() < buf.len() as u64 {
return Err(ProtobufError::WireError(WireError::UnexpectedEof));
}
@@ -356,11 +305,7 @@ impl<'ignore> BufReadIter<'ignore> {
match self.input_source {
InputSource::Read(ref mut buf_read) => {
buf_read.consume(consume);
- buf_read.read_exact(buf)?;
- }
- InputSource::BufRead(ref mut buf_read) => {
- buf_read.consume(consume);
- buf_read.read_exact(buf)?;
+ buf_read.read_exact_uninit(buf)?;
}
_ => {
return Err(ProtobufError::WireError(WireError::UnexpectedEof));
@@ -374,6 +319,65 @@ impl<'ignore> BufReadIter<'ignore> {
Ok(())
}
+ #[inline]
+ pub(crate) fn read_exact(&mut self, buf: &mut [MaybeUninit<u8>]) -> ProtobufResult<()> {
+ if self.remaining_in_buf_len() >= buf.len() {
+ let buf_len = buf.len();
+ maybe_uninit_write_slice(
+ buf,
+ &self.buf[self.pos_within_buf..self.pos_within_buf + buf_len],
+ );
+ self.pos_within_buf += buf_len;
+ return Ok(());
+ }
+
+ self.read_exact_slow(buf)
+ }
+
+ /// Read exact number of bytes into `Vec`.
+ ///
+ /// `Vec` is cleared in the beginning.
+ pub fn read_exact_to_vec(&mut self, count: usize, target: &mut Vec<u8>) -> ProtobufResult<()> {
+ // TODO: also do some limits when reading from unlimited source
+ if count as u64 > self.bytes_until_limit() {
+ return Err(ProtobufError::WireError(WireError::TruncatedMessage));
+ }
+
+ target.clear();
+
+ if count >= READ_RAW_BYTES_MAX_ALLOC && count > target.capacity() {
+ // avoid calling `reserve` on buf with very large buffer: could be a malformed message
+
+ target.reserve(READ_RAW_BYTES_MAX_ALLOC);
+
+ while target.len() < count {
+ let need_to_read = count - target.len();
+ if need_to_read <= target.len() {
+ target.reserve_exact(need_to_read);
+ } else {
+ target.reserve(1);
+ }
+
+ let max = cmp::min(target.capacity() - target.len(), need_to_read);
+ let read = self.read_to_vec(target, max)?;
+ if read == 0 {
+ return Err(ProtobufError::WireError(WireError::TruncatedMessage));
+ }
+ }
+ } else {
+ target.reserve_exact(count);
+
+ unsafe {
+ self.read_exact(&mut vec_spare_capacity_mut(target)[..count])?;
+ target.set_len(count);
+ }
+ }
+
+ debug_assert_eq!(count, target.len());
+
+ Ok(())
+ }
+
fn do_fill_buf(&mut self) -> ProtobufResult<()> {
debug_assert!(self.pos_within_buf == self.limit_within_buf);
@@ -394,10 +398,6 @@ impl<'ignore> BufReadIter<'ignore> {
buf_read.consume(consume);
self.buf = unsafe { mem::transmute(buf_read.fill_buf()?) };
}
- InputSource::BufRead(ref mut buf_read) => {
- buf_read.consume(consume);
- self.buf = unsafe { mem::transmute(buf_read.fill_buf()?) };
- }
_ => {
return Ok(());
}
@@ -409,7 +409,7 @@ impl<'ignore> BufReadIter<'ignore> {
}
#[inline(always)]
- pub fn fill_buf(&mut self) -> ProtobufResult<&[u8]> {
+ pub(crate) fn fill_buf(&mut self) -> ProtobufResult<&[u8]> {
if self.pos_within_buf == self.limit_within_buf {
self.do_fill_buf()?;
}
@@ -425,7 +425,7 @@ impl<'ignore> BufReadIter<'ignore> {
}
#[inline(always)]
- pub fn consume(&mut self, amt: usize) {
+ pub(crate) fn consume(&mut self, amt: usize) {
assert!(amt <= self.limit_within_buf - self.pos_within_buf);
self.pos_within_buf += amt;
}
@@ -433,9 +433,10 @@ impl<'ignore> BufReadIter<'ignore> {
#[cfg(all(test, feature = "bytes"))]
mod test_bytes {
- use super::*;
use std::io::Write;
+ use super::*;
+
fn make_long_string(len: usize) -> Vec<u8> {
let mut s = Vec::new();
while s.len() < len {
@@ -467,11 +468,12 @@ mod test_bytes {
#[cfg(test)]
mod test {
- use super::*;
use std::io;
use std::io::BufRead;
use std::io::Read;
+ use super::*;
+
#[test]
fn eof_at_limit() {
struct Read5ThenPanic {
@@ -509,7 +511,12 @@ mod test {
let _prev_limit = buf_read_iter.push_limit(5);
buf_read_iter.read_byte().expect("read_byte");
buf_read_iter
- .read_exact(&mut [1, 2, 3, 4])
+ .read_exact(&mut [
+ MaybeUninit::uninit(),
+ MaybeUninit::uninit(),
+ MaybeUninit::uninit(),
+ MaybeUninit::uninit(),
+ ])
.expect("read_exact");
assert!(buf_read_iter.eof().expect("eof"));
}