aboutsummaryrefslogtreecommitdiff
path: root/src/coded_output_stream.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/coded_output_stream.rs')
-rw-r--r--src/coded_output_stream.rs121
1 files changed, 86 insertions, 35 deletions
diff --git a/src/coded_output_stream.rs b/src/coded_output_stream.rs
index bbfe228..2bbe0a3 100644
--- a/src/coded_output_stream.rs
+++ b/src/coded_output_stream.rs
@@ -1,17 +1,23 @@
-use crate::misc::remaining_capacity_as_slice_mut;
-use crate::misc::remove_lifetime_mut;
+use std::io;
+use std::io::Write;
+use std::mem;
+use std::mem::MaybeUninit;
+use std::ptr;
+use std::slice;
+
+use crate::misc::maybe_uninit_write;
+use crate::misc::maybe_uninit_write_slice;
+use crate::misc::vec_spare_capacity_mut;
use crate::varint;
use crate::wire_format;
use crate::zigzag::encode_zig_zag_32;
use crate::zigzag::encode_zig_zag_64;
use crate::Message;
use crate::ProtobufEnum;
+use crate::ProtobufError;
use crate::ProtobufResult;
use crate::UnknownFields;
use crate::UnknownValueRef;
-use std::io;
-use std::io::Write;
-use std::mem;
/// Equal to the default buffer size of `BufWriter`, so when
/// `CodedOutputStream` wraps `BufWriter`, it often skips double buffering.
@@ -58,17 +64,27 @@ where
Ok(v)
}
+/// Output buffer/writer for `CodedOutputStream`.
enum OutputTarget<'a> {
Write(&'a mut dyn Write, Vec<u8>),
Vec(&'a mut Vec<u8>),
+ /// The buffer is passed as `&[u8]` to `CodedOutputStream` constructor
+ /// and immediately converted to `buffer` field of `CodedOutputStream`,
+ /// it is not needed to be stored here.
+ /// Lifetime parameter of `CodedOutputStream` guarantees the buffer is valid
+ /// during the lifetime of `CodedOutputStream`.
Bytes,
}
/// Buffered write with handy utilities
pub struct CodedOutputStream<'a> {
target: OutputTarget<'a>,
- // alias to buf from target
- buffer: &'a mut [u8],
+ // Actual buffer is owned by `OutputTarget`,
+ // and here we alias the buffer so access to the buffer is branchless:
+ // access does not require switch by actual target type: `&[], `Vec`, `Write` etc.
+ // We don't access the actual buffer in `OutputTarget` except when
+ // we initialize `buffer` field here.
+ buffer: *mut [MaybeUninit<u8>],
// within buffer
position: usize,
}
@@ -81,15 +97,16 @@ impl<'a> CodedOutputStream<'a> {
let buffer_len = OUTPUT_STREAM_BUFFER_SIZE;
let mut buffer_storage = Vec::with_capacity(buffer_len);
- unsafe {
- buffer_storage.set_len(buffer_len);
- }
- let buffer = unsafe { remove_lifetime_mut(&mut buffer_storage as &mut [u8]) };
+ // SAFETY: we are not using the `buffer_storage`
+ // except for initializing the `buffer` field.
+ // See `buffer` field documentation.
+ let buffer = vec_spare_capacity_mut(&mut buffer_storage);
+ let buffer: *mut [MaybeUninit<u8>] = buffer;
CodedOutputStream {
target: OutputTarget::Write(writer, buffer_storage),
- buffer: buffer,
+ buffer,
position: 0,
}
}
@@ -98,9 +115,12 @@ impl<'a> CodedOutputStream<'a> {
///
/// Attempt to write more than bytes capacity results in error.
pub fn bytes(bytes: &'a mut [u8]) -> CodedOutputStream<'a> {
+ // SAFETY: it is safe to cast from &mut [u8] to &mut [MaybeUninit<u8>].
+ let buffer =
+ ptr::slice_from_raw_parts_mut(bytes.as_mut_ptr() as *mut MaybeUninit<u8>, bytes.len());
CodedOutputStream {
target: OutputTarget::Bytes,
- buffer: bytes,
+ buffer,
position: 0,
}
}
@@ -110,9 +130,10 @@ impl<'a> CodedOutputStream<'a> {
/// Caller should call `flush` at the end to guarantee vec contains
/// all written data.
pub fn vec(vec: &'a mut Vec<u8>) -> CodedOutputStream<'a> {
+ let buffer: *mut [MaybeUninit<u8>] = &mut [];
CodedOutputStream {
target: OutputTarget::Vec(vec),
- buffer: &mut [],
+ buffer,
position: 0,
}
}
@@ -125,7 +146,7 @@ impl<'a> CodedOutputStream<'a> {
pub fn check_eof(&self) {
match self.target {
OutputTarget::Bytes => {
- assert_eq!(self.buffer.len() as u64, self.position as u64);
+ assert_eq!(self.buffer().len() as u64, self.position as u64);
}
OutputTarget::Write(..) | OutputTarget::Vec(..) => {
panic!("must not be called with Writer or Vec");
@@ -133,10 +154,25 @@ impl<'a> CodedOutputStream<'a> {
}
}
+ #[inline(always)]
+ fn buffer(&self) -> &[MaybeUninit<u8>] {
+ // SAFETY: see the `buffer` field documentation about invariants.
+ unsafe { &*(self.buffer as *mut [MaybeUninit<u8>]) }
+ }
+
+ #[inline(always)]
+ fn filled_buffer_impl<'s>(buffer: *mut [MaybeUninit<u8>], position: usize) -> &'s [u8] {
+ // SAFETY: this function is safe assuming `buffer` and `position`
+ // are `self.buffer` and `safe.position`:
+ // * `CodedOutputStream` has invariant that `position <= buffer.len()`.
+ // * `buffer` is filled up to `position`.
+ unsafe { slice::from_raw_parts_mut(buffer as *mut u8, position) }
+ }
+
fn refresh_buffer(&mut self) -> ProtobufResult<()> {
match self.target {
OutputTarget::Write(ref mut write, _) => {
- write.write_all(&self.buffer[0..self.position as usize])?;
+ write.write_all(Self::filled_buffer_impl(self.buffer, self.position))?;
self.position = 0;
}
OutputTarget::Vec(ref mut vec) => unsafe {
@@ -144,11 +180,14 @@ impl<'a> CodedOutputStream<'a> {
assert!(vec_len + self.position <= vec.capacity());
vec.set_len(vec_len + self.position);
vec.reserve(1);
- self.buffer = remove_lifetime_mut(remaining_capacity_as_slice_mut(vec));
+ self.buffer = vec_spare_capacity_mut(vec);
self.position = 0;
},
OutputTarget::Bytes => {
- panic!("refresh_buffer must not be called on CodedOutputStream create from slice");
+ return Err(ProtobufError::IoError(io::Error::new(
+ io::ErrorKind::Other,
+ "given slice is too small to serialize the message",
+ )));
}
}
Ok(())
@@ -167,20 +206,22 @@ impl<'a> CodedOutputStream<'a> {
/// Write a byte
pub fn write_raw_byte(&mut self, byte: u8) -> ProtobufResult<()> {
- if self.position as usize == self.buffer.len() {
+ if self.position as usize == self.buffer().len() {
self.refresh_buffer()?;
}
- self.buffer[self.position as usize] = byte;
+ unsafe { maybe_uninit_write(&mut (&mut *self.buffer)[self.position as usize], byte) };
self.position += 1;
Ok(())
}
/// Write bytes
pub fn write_raw_bytes(&mut self, bytes: &[u8]) -> ProtobufResult<()> {
- if bytes.len() <= self.buffer.len() - self.position {
+ if bytes.len() <= self.buffer().len() - self.position {
let bottom = self.position as usize;
let top = bottom + (bytes.len() as usize);
- self.buffer[bottom..top].copy_from_slice(bytes);
+ // SAFETY: see the `buffer` field documentation about invariants.
+ let buffer = unsafe { &mut (&mut *self.buffer)[bottom..top] };
+ maybe_uninit_write_slice(buffer, bytes);
self.position += bytes.len();
return Ok(());
}
@@ -189,8 +230,11 @@ impl<'a> CodedOutputStream<'a> {
assert!(self.position == 0);
- if self.position + bytes.len() < self.buffer.len() {
- self.buffer[self.position..self.position + bytes.len()].copy_from_slice(bytes);
+ if self.position + bytes.len() < self.buffer().len() {
+ // SAFETY: see the `buffer` field documentation about invariants.
+ let buffer =
+ unsafe { &mut (&mut *self.buffer)[self.position..self.position + bytes.len()] };
+ maybe_uninit_write_slice(buffer, bytes);
self.position += bytes.len();
return Ok(());
}
@@ -204,9 +248,7 @@ impl<'a> CodedOutputStream<'a> {
}
OutputTarget::Vec(ref mut vec) => {
vec.extend(bytes);
- unsafe {
- self.buffer = remove_lifetime_mut(remaining_capacity_as_slice_mut(vec));
- }
+ self.buffer = vec_spare_capacity_mut(vec)
}
}
Ok(())
@@ -223,30 +265,38 @@ impl<'a> CodedOutputStream<'a> {
/// Write varint
pub fn write_raw_varint32(&mut self, value: u32) -> ProtobufResult<()> {
- if self.buffer.len() - self.position >= 5 {
+ if self.buffer().len() - self.position >= 5 {
// fast path
- let len = varint::encode_varint32(value, &mut self.buffer[self.position..]);
+ let len = unsafe {
+ varint::encode_varint32(value, &mut (&mut *self.buffer)[self.position..])
+ };
self.position += len;
Ok(())
} else {
// slow path
let buf = &mut [0u8; 5];
- let len = varint::encode_varint32(value, buf);
+ let len = varint::encode_varint32(value, unsafe {
+ slice::from_raw_parts_mut(buf.as_mut_ptr() as *mut MaybeUninit<u8>, buf.len())
+ });
self.write_raw_bytes(&buf[..len])
}
}
/// Write varint
pub fn write_raw_varint64(&mut self, value: u64) -> ProtobufResult<()> {
- if self.buffer.len() - self.position >= 10 {
+ if self.buffer().len() - self.position >= 10 {
// fast path
- let len = varint::encode_varint64(value, &mut self.buffer[self.position..]);
+ let len = unsafe {
+ varint::encode_varint64(value, &mut (&mut *self.buffer)[self.position..])
+ };
self.position += len;
Ok(())
} else {
// slow path
let buf = &mut [0u8; 10];
- let len = varint::encode_varint64(value, buf);
+ let len = varint::encode_varint64(value, unsafe {
+ slice::from_raw_parts_mut(buf.as_mut_ptr() as *mut MaybeUninit<u8>, buf.len())
+ });
self.write_raw_bytes(&buf[..len])
}
}
@@ -532,13 +582,14 @@ impl<'a> Write for CodedOutputStream<'a> {
#[cfg(test)]
mod test {
+ use std::io::Write;
+ use std::iter;
+
use crate::coded_output_stream::CodedOutputStream;
use crate::hex::decode_hex;
use crate::hex::encode_hex;
use crate::wire_format;
use crate::ProtobufResult;
- use std::io::Write;
- use std::iter;
fn test_write<F>(expected: &str, mut gen: F)
where