diff options
Diffstat (limited to 'src/coded_output_stream.rs')
-rw-r--r-- | src/coded_output_stream.rs | 121 |
1 files changed, 86 insertions, 35 deletions
diff --git a/src/coded_output_stream.rs b/src/coded_output_stream.rs index bbfe228..2bbe0a3 100644 --- a/src/coded_output_stream.rs +++ b/src/coded_output_stream.rs @@ -1,17 +1,23 @@ -use crate::misc::remaining_capacity_as_slice_mut; -use crate::misc::remove_lifetime_mut; +use std::io; +use std::io::Write; +use std::mem; +use std::mem::MaybeUninit; +use std::ptr; +use std::slice; + +use crate::misc::maybe_uninit_write; +use crate::misc::maybe_uninit_write_slice; +use crate::misc::vec_spare_capacity_mut; use crate::varint; use crate::wire_format; use crate::zigzag::encode_zig_zag_32; use crate::zigzag::encode_zig_zag_64; use crate::Message; use crate::ProtobufEnum; +use crate::ProtobufError; use crate::ProtobufResult; use crate::UnknownFields; use crate::UnknownValueRef; -use std::io; -use std::io::Write; -use std::mem; /// Equal to the default buffer size of `BufWriter`, so when /// `CodedOutputStream` wraps `BufWriter`, it often skips double buffering. @@ -58,17 +64,27 @@ where Ok(v) } +/// Output buffer/writer for `CodedOutputStream`. enum OutputTarget<'a> { Write(&'a mut dyn Write, Vec<u8>), Vec(&'a mut Vec<u8>), + /// The buffer is passed as `&[u8]` to `CodedOutputStream` constructor + /// and immediately converted to `buffer` field of `CodedOutputStream`, + /// it is not needed to be stored here. + /// Lifetime parameter of `CodedOutputStream` guarantees the buffer is valid + /// during the lifetime of `CodedOutputStream`. Bytes, } /// Buffered write with handy utilities pub struct CodedOutputStream<'a> { target: OutputTarget<'a>, - // alias to buf from target - buffer: &'a mut [u8], + // Actual buffer is owned by `OutputTarget`, + // and here we alias the buffer so access to the buffer is branchless: + // access does not require switch by actual target type: `&[], `Vec`, `Write` etc. + // We don't access the actual buffer in `OutputTarget` except when + // we initialize `buffer` field here. + buffer: *mut [MaybeUninit<u8>], // within buffer position: usize, } @@ -81,15 +97,16 @@ impl<'a> CodedOutputStream<'a> { let buffer_len = OUTPUT_STREAM_BUFFER_SIZE; let mut buffer_storage = Vec::with_capacity(buffer_len); - unsafe { - buffer_storage.set_len(buffer_len); - } - let buffer = unsafe { remove_lifetime_mut(&mut buffer_storage as &mut [u8]) }; + // SAFETY: we are not using the `buffer_storage` + // except for initializing the `buffer` field. + // See `buffer` field documentation. + let buffer = vec_spare_capacity_mut(&mut buffer_storage); + let buffer: *mut [MaybeUninit<u8>] = buffer; CodedOutputStream { target: OutputTarget::Write(writer, buffer_storage), - buffer: buffer, + buffer, position: 0, } } @@ -98,9 +115,12 @@ impl<'a> CodedOutputStream<'a> { /// /// Attempt to write more than bytes capacity results in error. pub fn bytes(bytes: &'a mut [u8]) -> CodedOutputStream<'a> { + // SAFETY: it is safe to cast from &mut [u8] to &mut [MaybeUninit<u8>]. + let buffer = + ptr::slice_from_raw_parts_mut(bytes.as_mut_ptr() as *mut MaybeUninit<u8>, bytes.len()); CodedOutputStream { target: OutputTarget::Bytes, - buffer: bytes, + buffer, position: 0, } } @@ -110,9 +130,10 @@ impl<'a> CodedOutputStream<'a> { /// Caller should call `flush` at the end to guarantee vec contains /// all written data. pub fn vec(vec: &'a mut Vec<u8>) -> CodedOutputStream<'a> { + let buffer: *mut [MaybeUninit<u8>] = &mut []; CodedOutputStream { target: OutputTarget::Vec(vec), - buffer: &mut [], + buffer, position: 0, } } @@ -125,7 +146,7 @@ impl<'a> CodedOutputStream<'a> { pub fn check_eof(&self) { match self.target { OutputTarget::Bytes => { - assert_eq!(self.buffer.len() as u64, self.position as u64); + assert_eq!(self.buffer().len() as u64, self.position as u64); } OutputTarget::Write(..) | OutputTarget::Vec(..) => { panic!("must not be called with Writer or Vec"); @@ -133,10 +154,25 @@ impl<'a> CodedOutputStream<'a> { } } + #[inline(always)] + fn buffer(&self) -> &[MaybeUninit<u8>] { + // SAFETY: see the `buffer` field documentation about invariants. + unsafe { &*(self.buffer as *mut [MaybeUninit<u8>]) } + } + + #[inline(always)] + fn filled_buffer_impl<'s>(buffer: *mut [MaybeUninit<u8>], position: usize) -> &'s [u8] { + // SAFETY: this function is safe assuming `buffer` and `position` + // are `self.buffer` and `safe.position`: + // * `CodedOutputStream` has invariant that `position <= buffer.len()`. + // * `buffer` is filled up to `position`. + unsafe { slice::from_raw_parts_mut(buffer as *mut u8, position) } + } + fn refresh_buffer(&mut self) -> ProtobufResult<()> { match self.target { OutputTarget::Write(ref mut write, _) => { - write.write_all(&self.buffer[0..self.position as usize])?; + write.write_all(Self::filled_buffer_impl(self.buffer, self.position))?; self.position = 0; } OutputTarget::Vec(ref mut vec) => unsafe { @@ -144,11 +180,14 @@ impl<'a> CodedOutputStream<'a> { assert!(vec_len + self.position <= vec.capacity()); vec.set_len(vec_len + self.position); vec.reserve(1); - self.buffer = remove_lifetime_mut(remaining_capacity_as_slice_mut(vec)); + self.buffer = vec_spare_capacity_mut(vec); self.position = 0; }, OutputTarget::Bytes => { - panic!("refresh_buffer must not be called on CodedOutputStream create from slice"); + return Err(ProtobufError::IoError(io::Error::new( + io::ErrorKind::Other, + "given slice is too small to serialize the message", + ))); } } Ok(()) @@ -167,20 +206,22 @@ impl<'a> CodedOutputStream<'a> { /// Write a byte pub fn write_raw_byte(&mut self, byte: u8) -> ProtobufResult<()> { - if self.position as usize == self.buffer.len() { + if self.position as usize == self.buffer().len() { self.refresh_buffer()?; } - self.buffer[self.position as usize] = byte; + unsafe { maybe_uninit_write(&mut (&mut *self.buffer)[self.position as usize], byte) }; self.position += 1; Ok(()) } /// Write bytes pub fn write_raw_bytes(&mut self, bytes: &[u8]) -> ProtobufResult<()> { - if bytes.len() <= self.buffer.len() - self.position { + if bytes.len() <= self.buffer().len() - self.position { let bottom = self.position as usize; let top = bottom + (bytes.len() as usize); - self.buffer[bottom..top].copy_from_slice(bytes); + // SAFETY: see the `buffer` field documentation about invariants. + let buffer = unsafe { &mut (&mut *self.buffer)[bottom..top] }; + maybe_uninit_write_slice(buffer, bytes); self.position += bytes.len(); return Ok(()); } @@ -189,8 +230,11 @@ impl<'a> CodedOutputStream<'a> { assert!(self.position == 0); - if self.position + bytes.len() < self.buffer.len() { - self.buffer[self.position..self.position + bytes.len()].copy_from_slice(bytes); + if self.position + bytes.len() < self.buffer().len() { + // SAFETY: see the `buffer` field documentation about invariants. + let buffer = + unsafe { &mut (&mut *self.buffer)[self.position..self.position + bytes.len()] }; + maybe_uninit_write_slice(buffer, bytes); self.position += bytes.len(); return Ok(()); } @@ -204,9 +248,7 @@ impl<'a> CodedOutputStream<'a> { } OutputTarget::Vec(ref mut vec) => { vec.extend(bytes); - unsafe { - self.buffer = remove_lifetime_mut(remaining_capacity_as_slice_mut(vec)); - } + self.buffer = vec_spare_capacity_mut(vec) } } Ok(()) @@ -223,30 +265,38 @@ impl<'a> CodedOutputStream<'a> { /// Write varint pub fn write_raw_varint32(&mut self, value: u32) -> ProtobufResult<()> { - if self.buffer.len() - self.position >= 5 { + if self.buffer().len() - self.position >= 5 { // fast path - let len = varint::encode_varint32(value, &mut self.buffer[self.position..]); + let len = unsafe { + varint::encode_varint32(value, &mut (&mut *self.buffer)[self.position..]) + }; self.position += len; Ok(()) } else { // slow path let buf = &mut [0u8; 5]; - let len = varint::encode_varint32(value, buf); + let len = varint::encode_varint32(value, unsafe { + slice::from_raw_parts_mut(buf.as_mut_ptr() as *mut MaybeUninit<u8>, buf.len()) + }); self.write_raw_bytes(&buf[..len]) } } /// Write varint pub fn write_raw_varint64(&mut self, value: u64) -> ProtobufResult<()> { - if self.buffer.len() - self.position >= 10 { + if self.buffer().len() - self.position >= 10 { // fast path - let len = varint::encode_varint64(value, &mut self.buffer[self.position..]); + let len = unsafe { + varint::encode_varint64(value, &mut (&mut *self.buffer)[self.position..]) + }; self.position += len; Ok(()) } else { // slow path let buf = &mut [0u8; 10]; - let len = varint::encode_varint64(value, buf); + let len = varint::encode_varint64(value, unsafe { + slice::from_raw_parts_mut(buf.as_mut_ptr() as *mut MaybeUninit<u8>, buf.len()) + }); self.write_raw_bytes(&buf[..len]) } } @@ -532,13 +582,14 @@ impl<'a> Write for CodedOutputStream<'a> { #[cfg(test)] mod test { + use std::io::Write; + use std::iter; + use crate::coded_output_stream::CodedOutputStream; use crate::hex::decode_hex; use crate::hex::encode_hex; use crate::wire_format; use crate::ProtobufResult; - use std::io::Write; - use std::iter; fn test_write<F>(expected: &str, mut gen: F) where |