aboutsummaryrefslogtreecommitdiff
path: root/src/buf.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/buf.rs')
-rw-r--r--src/buf.rs573
1 files changed, 573 insertions, 0 deletions
diff --git a/src/buf.rs b/src/buf.rs
new file mode 100644
index 0000000..7731fd5
--- /dev/null
+++ b/src/buf.rs
@@ -0,0 +1,573 @@
+// Copyright 2019 TiKV Project Authors. Licensed under Apache-2.0.
+
+use grpcio_sys::*;
+use std::cell::UnsafeCell;
+use std::ffi::{c_void, CStr, CString};
+use std::fmt::{self, Debug, Formatter};
+use std::io::{self, BufRead, Read};
+use std::mem::{self, ManuallyDrop, MaybeUninit};
+
+/// A convenient rust wrapper for the type `grpc_slice`.
+///
+/// It's expected that the slice should be initialized.
+#[repr(C)]
+pub struct GrpcSlice(grpc_slice);
+
+impl GrpcSlice {
+ /// Get the length of the data.
+ pub fn len(&self) -> usize {
+ unsafe {
+ if !self.0.refcount.is_null() {
+ self.0.data.refcounted.length
+ } else {
+ self.0.data.inlined.length as usize
+ }
+ }
+ }
+
+ /// Returns a slice of inner buffer.
+ pub fn as_slice(&self) -> &[u8] {
+ unsafe {
+ if !self.0.refcount.is_null() {
+ let start = self.0.data.refcounted.bytes;
+ let len = self.0.data.refcounted.length;
+ std::slice::from_raw_parts(start, len)
+ } else {
+ let len = self.0.data.inlined.length;
+ &self.0.data.inlined.bytes[..len as usize]
+ }
+ }
+ }
+
+ pub fn is_empty(&self) -> bool {
+ self.len() == 0
+ }
+
+ /// Creates a slice from static rust slice.
+ ///
+ /// Same as `From<&[u8]>` but without copying the buffer.
+ #[inline]
+ pub fn from_static_slice(s: &'static [u8]) -> GrpcSlice {
+ GrpcSlice(unsafe { grpc_slice_from_static_buffer(s.as_ptr() as _, s.len()) })
+ }
+
+ /// Creates a `GrpcSlice` from static rust str.
+ ///
+ /// Same as `from_str` but without allocation.
+ #[inline]
+ pub fn from_static_str(s: &'static str) -> GrpcSlice {
+ GrpcSlice::from_static_slice(s.as_bytes())
+ }
+}
+
+impl Clone for GrpcSlice {
+ /// Clone the slice.
+ ///
+ /// If the slice is not inlined, the reference count will be increased
+ /// instead of copy.
+ fn clone(&self) -> Self {
+ GrpcSlice(unsafe { grpc_slice_ref(self.0) })
+ }
+}
+
+impl Default for GrpcSlice {
+ /// Returns a default slice, which is empty.
+ fn default() -> Self {
+ GrpcSlice(unsafe { grpc_empty_slice() })
+ }
+}
+
+impl Debug for GrpcSlice {
+ fn fmt(&self, f: &mut Formatter) -> fmt::Result {
+ self.as_slice().fmt(f)
+ }
+}
+
+impl Drop for GrpcSlice {
+ fn drop(&mut self) {
+ unsafe {
+ grpc_slice_unref(self.0);
+ }
+ }
+}
+
+impl PartialEq<[u8]> for GrpcSlice {
+ fn eq(&self, r: &[u8]) -> bool {
+ // Technically, the equal function inside vtable should be used.
+ // But it's not cheap or safe to create a grpc_slice from rust slice.
+ self.as_slice() == r
+ }
+}
+
+impl PartialEq<GrpcSlice> for GrpcSlice {
+ fn eq(&self, r: &GrpcSlice) -> bool {
+ unsafe { grpc_slice_eq(self.0, r.0) != 0 }
+ }
+}
+
+unsafe extern "C" fn drop_vec(ptr: *mut c_void, len: usize) {
+ Vec::from_raw_parts(ptr as *mut u8, len, len);
+}
+
+impl From<Vec<u8>> for GrpcSlice {
+ /// Converts a `Vec<u8>` into `GrpcSlice`.
+ ///
+ /// If v can't fit inline, there will be allocations.
+ #[inline]
+ fn from(mut v: Vec<u8>) -> GrpcSlice {
+ if v.is_empty() {
+ return GrpcSlice::default();
+ }
+
+ if v.len() == v.capacity() {
+ let slice = unsafe {
+ grpcio_sys::grpc_slice_new_with_len(v.as_mut_ptr() as _, v.len(), Some(drop_vec))
+ };
+ mem::forget(v);
+ return GrpcSlice(slice);
+ }
+
+ unsafe {
+ GrpcSlice(grpcio_sys::grpc_slice_from_copied_buffer(
+ v.as_mut_ptr() as _,
+ v.len(),
+ ))
+ }
+ }
+}
+
+/// Creates a `GrpcSlice` from rust string.
+///
+/// If the string can't fit inline, there will be allocations.
+impl From<String> for GrpcSlice {
+ #[inline]
+ fn from(s: String) -> GrpcSlice {
+ GrpcSlice::from(s.into_bytes())
+ }
+}
+
+/// Creates a `GrpcSlice` from rust cstring.
+///
+/// If the cstring can't fit inline, there will be allocations.
+impl From<CString> for GrpcSlice {
+ #[inline]
+ fn from(s: CString) -> GrpcSlice {
+ GrpcSlice::from(s.into_bytes())
+ }
+}
+
+/// Creates a `GrpcSlice` from rust slice.
+///
+/// The data inside slice will be cloned. If the data can't fit inline,
+/// necessary buffer will be allocated.
+impl From<&'_ [u8]> for GrpcSlice {
+ #[inline]
+ fn from(s: &'_ [u8]) -> GrpcSlice {
+ GrpcSlice(unsafe { grpc_slice_from_copied_buffer(s.as_ptr() as _, s.len()) })
+ }
+}
+
+/// Creates a `GrpcSlice` from rust str.
+///
+/// The data inside str will be cloned. If the data can't fit inline,
+/// necessary buffer will be allocated.
+impl From<&'_ str> for GrpcSlice {
+ #[inline]
+ fn from(s: &'_ str) -> GrpcSlice {
+ GrpcSlice::from(s.as_bytes())
+ }
+}
+
+/// Creates a `GrpcSlice` from rust `CStr`.
+///
+/// The data inside `CStr` will be cloned. If the data can't fit inline,
+/// necessary buffer will be allocated.
+impl From<&'_ CStr> for GrpcSlice {
+ #[inline]
+ fn from(s: &'_ CStr) -> GrpcSlice {
+ GrpcSlice::from(s.to_bytes())
+ }
+}
+
+/// A collection of `GrpcBytes`.
+#[repr(C)]
+pub struct GrpcByteBuffer(*mut grpc_byte_buffer);
+
+impl GrpcByteBuffer {
+ #[inline]
+ pub unsafe fn from_raw(ptr: *mut grpc_byte_buffer) -> GrpcByteBuffer {
+ GrpcByteBuffer(ptr)
+ }
+}
+
+impl<'a> From<&'a [GrpcSlice]> for GrpcByteBuffer {
+ /// Create a buffer from the given slice array.
+ ///
+ /// A buffer is allocated for the whole slice array, and every slice will
+ /// be `Clone::clone` into the buffer.
+ fn from(slice: &'a [GrpcSlice]) -> Self {
+ let len = slice.len();
+ unsafe {
+ let s = slice.as_ptr() as *const grpc_slice as *const UnsafeCell<grpc_slice>;
+ // hack: see From<&GrpcSlice>.
+ GrpcByteBuffer(grpc_raw_byte_buffer_create((*s).get(), len))
+ }
+ }
+}
+
+impl<'a> From<&'a GrpcSlice> for GrpcByteBuffer {
+ /// Create a buffer from the given single slice.
+ ///
+ /// A buffer, which length is 1, is allocated for the slice.
+ #[allow(clippy::cast_ref_to_mut)]
+ fn from(s: &'a GrpcSlice) -> GrpcByteBuffer {
+ unsafe {
+ // hack: buffer_create accepts an mutable pointer to indicate it mutate
+ // ref count. Ref count is recorded by atomic variable, which is considered
+ // `Sync` in rust. This is an interesting difference in what is *mutable*
+ // between C++ and rust.
+ // Using `UnsafeCell` to avoid raw cast, which is UB.
+ let s = &*(s as *const GrpcSlice as *const grpc_slice as *const UnsafeCell<grpc_slice>);
+ GrpcByteBuffer(grpc_raw_byte_buffer_create((*s).get(), 1))
+ }
+ }
+}
+
+impl Clone for GrpcByteBuffer {
+ fn clone(&self) -> Self {
+ unsafe { GrpcByteBuffer(grpc_byte_buffer_copy(self.0)) }
+ }
+}
+
+impl Drop for GrpcByteBuffer {
+ fn drop(&mut self) {
+ unsafe { grpc_byte_buffer_destroy(self.0) }
+ }
+}
+
+/// A zero-copy reader for the message payload.
+///
+/// To achieve zero-copy, use the BufRead API `fill_buf` and `consume`
+/// to operate the reader.
+#[repr(C)]
+pub struct GrpcByteBufferReader {
+ reader: grpc_byte_buffer_reader,
+ /// Current reading buffer.
+ // This is a temporary buffer that may need to be dropped before every
+ // iteration. So use `ManuallyDrop` to control the behavior more clean
+ // and precisely.
+ slice: ManuallyDrop<GrpcSlice>,
+ /// The offset of `slice` that has not been read.
+ offset: usize,
+ /// How many bytes pending for reading.
+ remain: usize,
+}
+
+impl GrpcByteBufferReader {
+ /// Creates a reader for the `GrpcByteBuffer`.
+ ///
+ /// `buf` is stored inside the reader, and dropped when the reader is dropped.
+ pub fn new(buf: GrpcByteBuffer) -> GrpcByteBufferReader {
+ let mut reader = MaybeUninit::uninit();
+ let mut s = MaybeUninit::uninit();
+ unsafe {
+ let code = grpc_byte_buffer_reader_init(reader.as_mut_ptr(), buf.0);
+ assert_eq!(code, 1);
+ if 0 == grpc_byte_buffer_reader_next(reader.as_mut_ptr(), s.as_mut_ptr()) {
+ s.as_mut_ptr().write(grpc_empty_slice());
+ }
+ let remain = grpc_byte_buffer_length((*reader.as_mut_ptr()).buffer_out);
+ // buf is stored inside `reader` as `buffer_in`, so do not drop it.
+ mem::forget(buf);
+
+ GrpcByteBufferReader {
+ reader: reader.assume_init(),
+ slice: ManuallyDrop::new(GrpcSlice(s.assume_init())),
+ offset: 0,
+ remain,
+ }
+ }
+ }
+
+ /// Get the next slice from reader.
+ fn load_next_slice(&mut self) {
+ unsafe {
+ ManuallyDrop::drop(&mut self.slice);
+ if 0 == grpc_byte_buffer_reader_next(&mut self.reader, &mut self.slice.0) {
+ self.slice = ManuallyDrop::new(GrpcSlice::default());
+ }
+ }
+ self.offset = 0;
+ }
+
+ #[inline]
+ pub fn len(&self) -> usize {
+ self.remain
+ }
+
+ #[inline]
+ pub fn is_empty(&self) -> bool {
+ self.remain == 0
+ }
+}
+
+impl Read for GrpcByteBufferReader {
+ fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
+ let read = self.fill_buf()?.read(buf)?;
+ self.consume(read);
+ Ok(read)
+ }
+
+ fn read_to_end(&mut self, buf: &mut Vec<u8>) -> io::Result<usize> {
+ let cap = self.remain;
+ buf.reserve(cap);
+ let old_len = buf.len();
+ while self.remain > 0 {
+ let read = {
+ let s = match self.fill_buf() {
+ Ok(s) => s,
+ Err(e) => {
+ unsafe {
+ buf.set_len(old_len);
+ }
+ return Err(e);
+ }
+ };
+ buf.extend_from_slice(s);
+ s.len()
+ };
+ self.consume(read);
+ }
+ Ok(cap)
+ }
+}
+
+impl BufRead for GrpcByteBufferReader {
+ #[inline]
+ fn fill_buf(&mut self) -> io::Result<&[u8]> {
+ if self.slice.is_empty() {
+ return Ok(&[]);
+ }
+ Ok(unsafe { self.slice.as_slice().get_unchecked(self.offset..) })
+ }
+
+ fn consume(&mut self, mut amt: usize) {
+ if amt > self.remain {
+ amt = self.remain;
+ }
+ self.remain -= amt;
+ let mut offset = self.offset + amt;
+ while offset >= self.slice.len() && offset > 0 {
+ offset -= self.slice.len();
+ self.load_next_slice();
+ }
+ self.offset = offset;
+ }
+}
+
+impl Drop for GrpcByteBufferReader {
+ fn drop(&mut self) {
+ unsafe {
+ grpc_byte_buffer_reader_destroy(&mut self.reader);
+ ManuallyDrop::drop(&mut self.slice);
+ grpc_byte_buffer_destroy(self.reader.buffer_in);
+ }
+ }
+}
+
+unsafe impl Sync for GrpcByteBufferReader {}
+unsafe impl Send for GrpcByteBufferReader {}
+
+#[cfg(feature = "prost-codec")]
+impl bytes::Buf for GrpcByteBufferReader {
+ fn remaining(&self) -> usize {
+ self.remain
+ }
+
+ fn bytes(&self) -> &[u8] {
+ // This is similar but not identical to `BuffRead::fill_buf`, since `self`
+ // is not mutable, we can only return bytes up to the end of the current
+ // slice.
+
+ // Optimization for empty slice
+ if self.slice.is_empty() {
+ return &[];
+ }
+
+ unsafe { self.slice.as_slice().get_unchecked(self.offset..) }
+ }
+
+ fn advance(&mut self, cnt: usize) {
+ self.consume(cnt);
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ fn new_message_reader(seed: Vec<u8>, copy_count: usize) -> GrpcByteBufferReader {
+ let data = vec![GrpcSlice::from(seed); copy_count];
+ let buf = GrpcByteBuffer::from(data.as_slice());
+ GrpcByteBufferReader::new(buf)
+ }
+
+ #[test]
+ fn test_grpc_slice() {
+ let empty = GrpcSlice::default();
+ assert!(empty.is_empty());
+ assert_eq!(empty.len(), 0);
+ assert!(empty.as_slice().is_empty());
+
+ let a = vec![0, 2, 1, 3, 8];
+ let slice = GrpcSlice::from(a.clone());
+ assert_eq!(a.as_slice(), slice.as_slice());
+ assert_eq!(a.len(), slice.len());
+ assert_eq!(&slice, &*a);
+
+ let a = vec![5; 64];
+ let slice = GrpcSlice::from(a.clone());
+ assert_eq!(a.as_slice(), slice.as_slice());
+ assert_eq!(a.len(), slice.len());
+ assert_eq!(&slice, &*a);
+
+ let a = vec![];
+ let slice = GrpcSlice::from(a);
+ assert_eq!(empty, slice);
+ }
+
+ #[test]
+ // Old code crashes under a very weird circumstance, due to a typo in `MessageReader::consume`
+ fn test_typo_len_offset() {
+ let data = vec![1, 2, 3, 4, 5, 6, 7, 8];
+ // half of the size of `data`
+ let half_size = data.len() / 2;
+ let slice = GrpcSlice::from(data.clone());
+ let buffer = GrpcByteBuffer::from(&slice);
+ let mut reader = GrpcByteBufferReader::new(buffer);
+ assert_eq!(reader.len(), data.len());
+ // first 3 elements of `data`
+ let mut buf = vec![0; half_size];
+ reader.read(buf.as_mut_slice()).unwrap();
+ assert_eq!(data[..half_size], *buf.as_slice());
+ assert_eq!(reader.len(), data.len() - half_size);
+ assert!(!reader.is_empty());
+ reader.read(&mut buf).unwrap();
+ assert_eq!(data[half_size..], *buf.as_slice());
+ assert_eq!(reader.len(), 0);
+ assert!(reader.is_empty());
+ }
+
+ #[test]
+ fn test_message_reader() {
+ for len in 0..=1024 {
+ for n_slice in 1..=4 {
+ let source = vec![len as u8; len];
+ let expect = vec![len as u8; len * n_slice];
+ // Test read.
+ let mut reader = new_message_reader(source.clone(), n_slice);
+ let mut dest = [0; 7];
+ let amt = reader.read(&mut dest).unwrap();
+
+ assert_eq!(
+ dest[..amt],
+ expect[..amt],
+ "len: {}, nslice: {}",
+ len,
+ n_slice
+ );
+
+ // Read after move.
+ let mut box_reader = Box::new(reader);
+ let amt = box_reader.read(&mut dest).unwrap();
+ assert_eq!(
+ dest[..amt],
+ expect[..amt],
+ "len: {}, nslice: {}",
+ len,
+ n_slice
+ );
+
+ // Test read_to_end.
+ let mut reader = new_message_reader(source.clone(), n_slice);
+ let mut dest = vec![];
+ reader.read_to_end(&mut dest).unwrap();
+ assert_eq!(dest, expect, "len: {}, nslice: {}", len, n_slice);
+
+ assert_eq!(0, reader.len());
+ assert_eq!(0, reader.read(&mut [1]).unwrap());
+
+ // Test arbitrary consuming.
+ let mut reader = new_message_reader(source.clone(), n_slice);
+ reader.consume(source.len() * (n_slice - 1));
+ let mut dest = vec![];
+ reader.read_to_end(&mut dest).unwrap();
+ assert_eq!(
+ dest.len(),
+ source.len(),
+ "len: {}, nslice: {}",
+ len,
+ n_slice
+ );
+ assert_eq!(
+ *dest,
+ expect[expect.len() - source.len()..],
+ "len: {}, nslice: {}",
+ len,
+ n_slice
+ );
+ assert_eq!(0, reader.len());
+ assert_eq!(0, reader.read(&mut [1]).unwrap());
+ }
+ }
+ }
+
+ #[test]
+ fn test_converter() {
+ let a = vec![1, 2, 3, 0];
+ assert_eq!(GrpcSlice::from(a.clone()).as_slice(), a.as_slice());
+ assert_eq!(GrpcSlice::from(a.as_slice()).as_slice(), a.as_slice());
+
+ let s = "abcd".to_owned();
+ assert_eq!(GrpcSlice::from(s.clone()).as_slice(), s.as_bytes());
+ assert_eq!(GrpcSlice::from(s.as_str()).as_slice(), s.as_bytes());
+
+ let cs = CString::new(s.clone()).unwrap();
+ assert_eq!(GrpcSlice::from(cs.clone()).as_slice(), s.as_bytes());
+ assert_eq!(GrpcSlice::from(cs.as_c_str()).as_slice(), s.as_bytes());
+ }
+
+ #[cfg(feature = "prost-codec")]
+ #[test]
+ fn test_buf_impl() {
+ use bytes::Buf;
+
+ for len in 0..1024 + 1 {
+ for n_slice in 1..4 {
+ let source = vec![len as u8; len];
+
+ let mut reader = new_message_reader(source.clone(), n_slice);
+
+ let mut remaining = len * n_slice;
+ let mut count = 100;
+ while reader.remaining() > 0 {
+ assert_eq!(remaining, reader.remaining());
+ let bytes = Buf::bytes(&reader);
+ bytes.iter().for_each(|b| assert_eq!(*b, len as u8));
+ let mut read = bytes.len();
+ // We don't have to advance by the whole amount we read.
+ if read > 5 && len % 2 == 0 {
+ read -= 5;
+ }
+ reader.advance(read);
+ remaining -= read;
+ count -= 1;
+ assert!(count > 0);
+ }
+
+ assert_eq!(0, remaining);
+ assert_eq!(0, reader.remaining());
+ }
+ }
+ }
+}