diff options
-rw-r--r-- | pw_protobuf/public/pw_protobuf/stream_decoder.h | 160 | ||||
-rw-r--r-- | pw_protobuf/stream_decoder.cc | 148 | ||||
-rw-r--r-- | pw_protobuf/stream_decoder_test.cc | 90 |
3 files changed, 233 insertions, 165 deletions
diff --git a/pw_protobuf/public/pw_protobuf/stream_decoder.h b/pw_protobuf/public/pw_protobuf/stream_decoder.h index 4b8a0ae0e..1e1a4f741 100644 --- a/pw_protobuf/public/pw_protobuf/stream_decoder.h +++ b/pw_protobuf/public/pw_protobuf/stream_decoder.h @@ -168,7 +168,9 @@ class StreamDecoder { // // Reads a proto int32 value from the current position. - Result<int32_t> ReadInt32(); + Result<int32_t> ReadInt32() { + return ReadVarintField<int32_t>(VarintDecodeType::kNormal); + } // Reads repeated int32 values from the current position using packed // encoding. @@ -176,11 +178,15 @@ class StreamDecoder { // Returns the number of values read. In the case of error, the return value // indicates the number of values successfully read, in addition to the error. StatusWithSize ReadPackedInt32(std::span<int32_t> out) { - return ReadPackedVarintField(out, VarintDecodeType::kNormal); + return ReadPackedVarintField(std::as_writable_bytes(out), + sizeof(int32_t), + VarintDecodeType::kNormal); } // Reads a proto uint32 value from the current position. - Result<uint32_t> ReadUint32(); + Result<uint32_t> ReadUint32() { + return ReadVarintField<uint32_t>(VarintDecodeType::kUnsigned); + } // Reads repeated uint32 values from the current position using packed // encoding. @@ -188,11 +194,15 @@ class StreamDecoder { // Returns the number of values read. In the case of error, the return value // indicates the number of values successfully read, in addition to the error. StatusWithSize ReadPackedUint32(std::span<uint32_t> out) { - return ReadPackedVarintField(out, VarintDecodeType::kNormal); + return ReadPackedVarintField(std::as_writable_bytes(out), + sizeof(uint32_t), + VarintDecodeType::kUnsigned); } // Reads a proto int64 value from the current position. - Result<int64_t> ReadInt64(); + Result<int64_t> ReadInt64() { + return ReadVarintField<int64_t>(VarintDecodeType::kNormal); + } // Reads repeated int64 values from the current position using packed // encoding. @@ -200,16 +210,14 @@ class StreamDecoder { // Returns the number of values read. In the case of error, the return value // indicates the number of values successfully read, in addition to the error. StatusWithSize ReadPackedInt64(std::span<int64_t> out) { - return ReadPackedVarintField(out, VarintDecodeType::kNormal); + return ReadPackedVarintField(std::as_writable_bytes(out), + sizeof(int64_t), + VarintDecodeType::kNormal); } // Reads a proto uint64 value from the current position. Result<uint64_t> ReadUint64() { - uint64_t varint; - if (Status status = ReadVarintField(&varint); !status.ok()) { - return status; - } - return varint; + return ReadVarintField<uint64_t>(VarintDecodeType::kUnsigned); } // Reads repeated uint64 values from the current position using packed @@ -218,11 +226,15 @@ class StreamDecoder { // Returns the number of values read. In the case of error, the return value // indicates the number of values successfully read, in addition to the error. StatusWithSize ReadPackedUint64(std::span<uint64_t> out) { - return ReadPackedVarintField(out, VarintDecodeType::kNormal); + return ReadPackedVarintField(std::as_writable_bytes(out), + sizeof(uint64_t), + VarintDecodeType::kUnsigned); } // Reads a proto sint32 value from the current position. - Result<int32_t> ReadSint32(); + Result<int32_t> ReadSint32() { + return ReadVarintField<int32_t>(VarintDecodeType::kZigZag); + } // Reads repeated sint32 values from the current position using packed // encoding. @@ -230,11 +242,15 @@ class StreamDecoder { // Returns the number of values read. In the case of error, the return value // indicates the number of values successfully read, in addition to the error. StatusWithSize ReadPackedSint32(std::span<int32_t> out) { - return ReadPackedVarintField(out, VarintDecodeType::kZigZag); + return ReadPackedVarintField(std::as_writable_bytes(out), + sizeof(int32_t), + VarintDecodeType::kZigZag); } // Reads a proto sint64 value from the current position. - Result<int64_t> ReadSint64(); + Result<int64_t> ReadSint64() { + return ReadVarintField<int64_t>(VarintDecodeType::kZigZag); + } // Reads repeated int64 values from the current position using packed // encoding. @@ -242,11 +258,15 @@ class StreamDecoder { // Returns the number of values read. In the case of error, the return value // indicates the number of values successfully read, in addition to the error. StatusWithSize ReadPackedSint64(std::span<int64_t> out) { - return ReadPackedVarintField(out, VarintDecodeType::kZigZag); + return ReadPackedVarintField(std::as_writable_bytes(out), + sizeof(int64_t), + VarintDecodeType::kZigZag); } // Reads a proto bool value from the current position. - Result<bool> ReadBool(); + Result<bool> ReadBool() { + return ReadVarintField<bool>(VarintDecodeType::kUnsigned); + } // Reads repeated bool values from the current position using packed // encoding. @@ -254,7 +274,8 @@ class StreamDecoder { // Returns the number of values read. In the case of error, the return value // indicates the number of values successfully read, in addition to the error. StatusWithSize ReadPackedBool(std::span<bool> out) { - return ReadPackedVarintField(out, VarintDecodeType::kNormal); + return ReadPackedVarintField( + std::as_writable_bytes(out), sizeof(bool), VarintDecodeType::kUnsigned); } // Reads a proto fixed32 value from the current position. @@ -280,13 +301,7 @@ class StreamDecoder { } // Reads a proto sfixed32 value from the current position. - Result<int32_t> ReadSfixed32() { - Result<uint32_t> fixed32 = ReadFixed32(); - if (!fixed32.ok()) { - return fixed32.status(); - } - return fixed32.value(); - } + Result<int32_t> ReadSfixed32() { return ReadFixedField<int32_t>(); } // Reads repeated sfixed32 values from the current position using packed // encoding. @@ -297,13 +312,7 @@ class StreamDecoder { } // Reads a proto sfixed64 value from the current position. - Result<int64_t> ReadSfixed64() { - Result<uint64_t> fixed64 = ReadFixed64(); - if (!fixed64.ok()) { - return fixed64.status(); - } - return fixed64.value(); - } + Result<int64_t> ReadSfixed64() { return ReadFixedField<int64_t>(); } // Reads repeated sfixed64 values from the current position using packed // encoding. @@ -451,6 +460,7 @@ class StreamDecoder { friend class BytesReader; enum class VarintDecodeType { + kUnsigned, kNormal, kZigZag, }; @@ -503,87 +513,29 @@ class StreamDecoder { Status ReadFieldKey(); Status SkipField(); - Status ReadVarintField(uint64_t* out); + Status ReadVarintField(std::span<std::byte> out, + VarintDecodeType decode_type); - template <typename T> - typename std::enable_if_t<std::is_signed_v<T>, Status> StoreCheckedValue( - uint64_t varint, VarintDecodeType decode_type, T& value) { - int64_t signed_value = static_cast<int64_t>(varint); - if (decode_type == VarintDecodeType::kZigZag) { - signed_value = varint::ZigZagDecode(varint); - } - - if (signed_value > std::numeric_limits<T>::max() || - signed_value < std::numeric_limits<T>::min()) { - return Status::OutOfRange(); - } - - value = signed_value; - return OkStatus(); - } + StatusWithSize ReadOneVarint(std::span<std::byte> out, + VarintDecodeType decode_type); template <typename T> - std::enable_if_t<std::is_unsigned_v<T>, Status> StoreCheckedValue( - uint64_t varint, VarintDecodeType /*decode_type*/, T& value) { - if (varint > std::numeric_limits<T>::max()) { - return Status::OutOfRange(); - } - - value = varint; - return OkStatus(); - } - - // Reads repeated varint values from the current position using packed - // encoding. - // - // Returns the number of values read. In the case of error, the return value - // indicates the number of values successfully read, in addition to the error. - template <typename T, typename = std::enable_if_t<std::is_integral_v<T>>> - StatusWithSize ReadPackedVarintField(std::span<T> out, - VarintDecodeType decode_type) { + Result<T> ReadVarintField(VarintDecodeType decode_type) { static_assert( std::is_same_v<T, bool> || std::is_same_v<T, uint32_t> || std::is_same_v<T, int32_t> || std::is_same_v<T, uint64_t> || std::is_same_v<T, int64_t>, - "Packed varints must be of type bool, uint32_t, int32_t, uint64_t, " + "Protobuf varints must be of type bool, uint32_t, int32_t, uint64_t, " "or int64_t"); - if (Status status = CheckOkToRead(WireType::kDelimited); !status.ok()) { - return StatusWithSize(status, 0); - } - - size_t bytes_read = 0; - size_t number_out = 0; - for (T& val : out) { - uint64_t varint; - StatusWithSize sws = varint::Read(reader_, &varint); - if (sws.IsOutOfRange()) { - // Out of range indicates the end of the stream. As a value is expected - // here, report it as a data loss and terminate the decode operation. - return StatusWithSize(Status::DataLoss(), number_out); - } - if (!sws.ok()) { - return StatusWithSize(sws.status(), number_out); - } - - bytes_read += sws.size(); - if (const auto status = StoreCheckedValue(varint, decode_type, val); - !status.ok()) { - return StatusWithSize(status, number_out); - } - ++number_out; - - if (bytes_read == delimited_field_size_) { - break; - } - } - - if (bytes_read < delimited_field_size_) { - return StatusWithSize(Status::ResourceExhausted(), number_out); + T result; + if (Status status = ReadVarintField( + std::as_writable_bytes(std::span(&result, 1)), decode_type); + !status.ok()) { + return status; } - field_consumed_ = true; - return StatusWithSize(OkStatus(), number_out); + return result; } Status ReadFixedField(std::span<std::byte> out); @@ -609,6 +561,10 @@ class StreamDecoder { StatusWithSize ReadPackedFixedField(std::span<std::byte> out, size_t elem_size); + StatusWithSize ReadPackedVarintField(std::span<std::byte> out, + size_t elem_size, + VarintDecodeType decode_type); + Status CheckOkToRead(WireType type); stream::Reader& reader_; diff --git a/pw_protobuf/stream_decoder.cc b/pw_protobuf/stream_decoder.cc index 9990b0862..e6a88aea0 100644 --- a/pw_protobuf/stream_decoder.cc +++ b/pw_protobuf/stream_decoder.cc @@ -16,6 +16,8 @@ #include <algorithm> #include <bit> +#include <cstdint> +#include <cstring> #include <limits> #include "pw_assert/check.h" @@ -114,59 +116,6 @@ Status StreamDecoder::Next() { return status_; } -Result<int32_t> StreamDecoder::ReadInt32() { - uint64_t varint = 0; - PW_TRY(ReadVarintField(&varint)); - - int64_t signed_value = static_cast<int64_t>(varint); - if (signed_value > std::numeric_limits<int32_t>::max() || - signed_value < std::numeric_limits<int32_t>::min()) { - return Status::OutOfRange(); - } - - return signed_value; -} - -Result<uint32_t> StreamDecoder::ReadUint32() { - uint64_t varint = 0; - PW_TRY(ReadVarintField(&varint)); - - if (varint > std::numeric_limits<uint32_t>::max()) { - return Status::OutOfRange(); - } - return varint; -} - -Result<int64_t> StreamDecoder::ReadInt64() { - uint64_t varint = 0; - PW_TRY(ReadVarintField(&varint)); - return varint; -} - -Result<int32_t> StreamDecoder::ReadSint32() { - uint64_t varint = 0; - PW_TRY(ReadVarintField(&varint)); - - int64_t signed_value = varint::ZigZagDecode(varint); - if (signed_value > std::numeric_limits<int32_t>::max() || - signed_value < std::numeric_limits<int32_t>::min()) { - return Status::OutOfRange(); - } - return signed_value; -} - -Result<int64_t> StreamDecoder::ReadSint64() { - uint64_t varint = 0; - PW_TRY(ReadVarintField(&varint)); - return varint::ZigZagDecode(varint); -} - -Result<bool> StreamDecoder::ReadBool() { - uint64_t varint = 0; - PW_TRY(ReadVarintField(&varint)); - return varint; -} - StreamDecoder::BytesReader StreamDecoder::GetBytesReader() { Status status = CheckOkToRead(WireType::kDelimited); @@ -330,24 +279,68 @@ Status StreamDecoder::SkipField() { return OkStatus(); } -Status StreamDecoder::ReadVarintField(uint64_t* out) { +Status StreamDecoder::ReadVarintField(std::span<std::byte> out, + VarintDecodeType decode_type) { + PW_CHECK(out.size() == sizeof(bool) || out.size() == sizeof(uint32_t) || + out.size() == sizeof(uint64_t), + "Protobuf varints must only be used with bool, int32_t, uint32_t, " + "int64_t, or uint64_t"); PW_TRY(CheckOkToRead(WireType::kVarint)); + const StatusWithSize sws = ReadOneVarint(out, decode_type); + if (sws.status() != Status::DataLoss()) + field_consumed_ = true; + return sws.status(); +} + +StatusWithSize StreamDecoder::ReadOneVarint(std::span<std::byte> out, + VarintDecodeType decode_type) { uint64_t value; StatusWithSize sws = varint::Read(reader_, &value); if (sws.IsOutOfRange()) { // Out of range indicates the end of the stream. As a value is expected // here, report it as a data loss and terminate the decode operation. status_ = Status::DataLoss(); - return status_; + return StatusWithSize(status_, sws.size()); + } + if (!sws.ok()) { + return sws; } - PW_TRY(sws); position_ += sws.size(); - field_consumed_ = true; - *out = value; - return OkStatus(); + if (out.size() == sizeof(uint64_t)) { + if (decode_type == VarintDecodeType::kUnsigned) { + std::memcpy(out.data(), &value, out.size()); + } else { + const int64_t signed_value = decode_type == VarintDecodeType::kZigZag + ? varint::ZigZagDecode(value) + : static_cast<int64_t>(value); + std::memcpy(out.data(), &signed_value, out.size()); + } + } else if (out.size() == sizeof(uint32_t)) { + if (decode_type == VarintDecodeType::kUnsigned) { + if (value > std::numeric_limits<uint32_t>::max()) { + return StatusWithSize(Status::OutOfRange(), sws.size()); + } + std::memcpy(out.data(), &value, out.size()); + } else { + const int64_t signed_value = decode_type == VarintDecodeType::kZigZag + ? varint::ZigZagDecode(value) + : static_cast<int64_t>(value); + if (signed_value > std::numeric_limits<int32_t>::max() || + signed_value < std::numeric_limits<int32_t>::min()) { + return StatusWithSize(Status::OutOfRange(), sws.size()); + } + std::memcpy(out.data(), &signed_value, out.size()); + } + } else if (out.size() == sizeof(bool)) { + PW_CHECK(decode_type == VarintDecodeType::kUnsigned, + "Protobuf bool can never be signed"); + std::memcpy(out.data(), &value, out.size()); + } + + return sws; } Status StreamDecoder::ReadFixedField(std::span<std::byte> out) { @@ -435,6 +428,43 @@ StatusWithSize StreamDecoder::ReadPackedFixedField(std::span<std::byte> out, return StatusWithSize(result.value().size() / elem_size); } +StatusWithSize StreamDecoder::ReadPackedVarintField( + std::span<std::byte> out, size_t elem_size, VarintDecodeType decode_type) { + PW_CHECK(elem_size == sizeof(bool) || elem_size == sizeof(uint32_t) || + elem_size == sizeof(uint64_t), + "Protobuf varints must only be used with bool, int32_t, uint32_t, " + "int64_t, or uint64_t"); + + if (Status status = CheckOkToRead(WireType::kDelimited); !status.ok()) { + return StatusWithSize(status, 0); + } + + if (reader_.ConservativeReadLimit() < delimited_field_size_) { + status_ = Status::DataLoss(); + return StatusWithSize(status_, 0); + } + + size_t bytes_read = 0; + size_t number_out = 0; + while (bytes_read < delimited_field_size_ && !out.empty()) { + const StatusWithSize sws = ReadOneVarint(out.first(elem_size), decode_type); + if (!sws.ok()) { + return StatusWithSize(sws.status(), number_out); + } + + bytes_read += sws.size(); + out = out.subspan(elem_size); + ++number_out; + } + + if (bytes_read < delimited_field_size_) { + return StatusWithSize(Status::ResourceExhausted(), number_out); + } + + field_consumed_ = true; + return StatusWithSize(OkStatus(), number_out); +} + Status StreamDecoder::CheckOkToRead(WireType type) { PW_CHECK(!nested_reader_open_, "Cannot read from a decoder while a nested decoder is open"); diff --git a/pw_protobuf/stream_decoder_test.cc b/pw_protobuf/stream_decoder_test.cc index 302d134b9..db90d3382 100644 --- a/pw_protobuf/stream_decoder_test.cc +++ b/pw_protobuf/stream_decoder_test.cc @@ -61,6 +61,12 @@ TEST(StreamDecoder, Decode) { 0x2d, 0xef, 0xbe, 0xad, 0xde, // type=string, k=6, v="Hello world" 0x32, 0x0b, 'H', 'e', 'l', 'l', 'o', ' ', 'w', 'o', 'r', 'l', 'd', + // type=sfixed32, k=7, v=-50 + 0x3d, 0xce, 0xff, 0xff, 0xff, + // type=sfixed64, k=8, v=-1647993274 + 0x41, 0x46, 0x9e, 0xc5, 0x9d, 0xff, 0xff, 0xff, 0xff, + // type=float, k=9, v=2.718 + 0x4d, 0xb6, 0xf3, 0x2d, 0x40, }; // clang-format on @@ -105,6 +111,24 @@ TEST(StreamDecoder, Decode) { buffer[sws.size()] = '\0'; EXPECT_STREQ(buffer, "Hello world"); + EXPECT_EQ(decoder.Next(), OkStatus()); + ASSERT_EQ(decoder.FieldNumber().value(), 7u); + Result<int32_t> sfixed32 = decoder.ReadSfixed32(); + ASSERT_EQ(sfixed32.status(), OkStatus()); + EXPECT_EQ(sfixed32.value(), -50); + + EXPECT_EQ(decoder.Next(), OkStatus()); + ASSERT_EQ(decoder.FieldNumber().value(), 8u); + Result<int64_t> sfixed64 = decoder.ReadSfixed64(); + ASSERT_EQ(sfixed64.status(), OkStatus()); + EXPECT_EQ(sfixed64.value(), -1647993274); + + EXPECT_EQ(decoder.Next(), OkStatus()); + ASSERT_EQ(decoder.FieldNumber().value(), 9u); + Result<float> flt = decoder.ReadFloat(); + ASSERT_EQ(flt.status(), OkStatus()); + EXPECT_EQ(flt.value(), 2.718f); + EXPECT_EQ(decoder.Next(), Status::OutOfRange()); } @@ -917,7 +941,23 @@ TEST(StreamDecoder, PackedFixed) { 0xc8, 0x00, 0x00, 0x00, // type=fixed64[], v=2, v={0x0102030405060708} 0x12, 0x08, - 0x08, 0x07, 0x06, 0x05, 0x04, 0x03, 0x02, 0x01 + 0x08, 0x07, 0x06, 0x05, 0x04, 0x03, 0x02, 0x01, + // type=sfixed32[], k=3, v={0, -50, 100, -150, 200} + 0x1a, 0x14, + 0x00, 0x00, 0x00, 0x00, + 0xce, 0xff, 0xff, 0xff, + 0x64, 0x00, 0x00, 0x00, + 0x6a, 0xff, 0xff, 0xff, + 0xc8, 0x00, 0x00, 0x00, + // type=sfixed64[], v=4, v={-1647993274} + 0x22, 0x08, + 0x46, 0x9e, 0xc5, 0x9d, 0xff, 0xff, 0xff, 0xff, + // type=double[], k=5, v=3.14159 + 0x2a, 0x08, + 0x6e, 0x86, 0x1b, 0xf0, 0xf9, 0x21, 0x09, 0x40, + // type=float[], k=6, v=2.718 + 0x32, 0x04, + 0xb6, 0xf3, 0x2d, 0x40, }; // clang-format on @@ -945,12 +985,54 @@ TEST(StreamDecoder, PackedFixed) { EXPECT_EQ(size.size(), 1u); EXPECT_EQ(fixed64[0], 0x0102030405060708u); + + EXPECT_EQ(decoder.Next(), OkStatus()); + ASSERT_EQ(decoder.FieldNumber().value(), 3u); + std::array<int32_t, 8> sfixed32{}; + size = decoder.ReadPackedSfixed32(sfixed32); + ASSERT_EQ(size.status(), OkStatus()); + EXPECT_EQ(size.size(), 5u); + + EXPECT_EQ(sfixed32[0], 0); + EXPECT_EQ(sfixed32[1], -50); + EXPECT_EQ(sfixed32[2], 100); + EXPECT_EQ(sfixed32[3], -150); + EXPECT_EQ(sfixed32[4], 200); + + EXPECT_EQ(decoder.Next(), OkStatus()); + ASSERT_EQ(decoder.FieldNumber().value(), 4u); + std::array<int64_t, 8> sfixed64{}; + size = decoder.ReadPackedSfixed64(sfixed64); + ASSERT_EQ(size.status(), OkStatus()); + EXPECT_EQ(size.size(), 1u); + + EXPECT_EQ(sfixed64[0], -1647993274); + + EXPECT_EQ(decoder.Next(), OkStatus()); + ASSERT_EQ(decoder.FieldNumber().value(), 5u); + std::array<double, 8> dbl{}; + size = decoder.ReadPackedDouble(dbl); + ASSERT_EQ(size.status(), OkStatus()); + EXPECT_EQ(size.size(), 1u); + + EXPECT_EQ(dbl[0], 3.14159); + + EXPECT_EQ(decoder.Next(), OkStatus()); + ASSERT_EQ(decoder.FieldNumber().value(), 6u); + std::array<float, 8> flt{}; + size = decoder.ReadPackedFloat(flt); + ASSERT_EQ(size.status(), OkStatus()); + EXPECT_EQ(size.size(), 1u); + + EXPECT_EQ(flt[0], 2.718f); + + EXPECT_EQ(decoder.Next(), Status::OutOfRange()); } TEST(StreamDecoder, PackedFixedInsufficientSpace) { // clang-format off constexpr uint8_t encoded_proto[] = { - // type=sfixed32[], k=1, v={0, 50, 100, 150, 200} + // type=fixed32[], k=1, v={0, 50, 100, 150, 200} 0x0a, 0x14, 0x00, 0x00, 0x00, 0x00, 0x32, 0x00, 0x00, 0x00, @@ -965,8 +1047,8 @@ TEST(StreamDecoder, PackedFixedInsufficientSpace) { EXPECT_EQ(decoder.Next(), OkStatus()); ASSERT_EQ(decoder.FieldNumber().value(), 1u); - std::array<uint32_t, 2> sfixed32{}; - StatusWithSize size = decoder.ReadPackedFixed32(sfixed32); + std::array<uint32_t, 2> fixed32{}; + StatusWithSize size = decoder.ReadPackedFixed32(fixed32); ASSERT_EQ(size.status(), Status::ResourceExhausted()); } |