diff options
Diffstat (limited to 'pw_protobuf/stream_decoder.cc')
-rw-r--r-- | pw_protobuf/stream_decoder.cc | 148 |
1 files changed, 89 insertions, 59 deletions
diff --git a/pw_protobuf/stream_decoder.cc b/pw_protobuf/stream_decoder.cc index 9990b0862..e6a88aea0 100644 --- a/pw_protobuf/stream_decoder.cc +++ b/pw_protobuf/stream_decoder.cc @@ -16,6 +16,8 @@ #include <algorithm> #include <bit> +#include <cstdint> +#include <cstring> #include <limits> #include "pw_assert/check.h" @@ -114,59 +116,6 @@ Status StreamDecoder::Next() { return status_; } -Result<int32_t> StreamDecoder::ReadInt32() { - uint64_t varint = 0; - PW_TRY(ReadVarintField(&varint)); - - int64_t signed_value = static_cast<int64_t>(varint); - if (signed_value > std::numeric_limits<int32_t>::max() || - signed_value < std::numeric_limits<int32_t>::min()) { - return Status::OutOfRange(); - } - - return signed_value; -} - -Result<uint32_t> StreamDecoder::ReadUint32() { - uint64_t varint = 0; - PW_TRY(ReadVarintField(&varint)); - - if (varint > std::numeric_limits<uint32_t>::max()) { - return Status::OutOfRange(); - } - return varint; -} - -Result<int64_t> StreamDecoder::ReadInt64() { - uint64_t varint = 0; - PW_TRY(ReadVarintField(&varint)); - return varint; -} - -Result<int32_t> StreamDecoder::ReadSint32() { - uint64_t varint = 0; - PW_TRY(ReadVarintField(&varint)); - - int64_t signed_value = varint::ZigZagDecode(varint); - if (signed_value > std::numeric_limits<int32_t>::max() || - signed_value < std::numeric_limits<int32_t>::min()) { - return Status::OutOfRange(); - } - return signed_value; -} - -Result<int64_t> StreamDecoder::ReadSint64() { - uint64_t varint = 0; - PW_TRY(ReadVarintField(&varint)); - return varint::ZigZagDecode(varint); -} - -Result<bool> StreamDecoder::ReadBool() { - uint64_t varint = 0; - PW_TRY(ReadVarintField(&varint)); - return varint; -} - StreamDecoder::BytesReader StreamDecoder::GetBytesReader() { Status status = CheckOkToRead(WireType::kDelimited); @@ -330,24 +279,68 @@ Status StreamDecoder::SkipField() { return OkStatus(); } -Status StreamDecoder::ReadVarintField(uint64_t* out) { +Status StreamDecoder::ReadVarintField(std::span<std::byte> out, + VarintDecodeType decode_type) { + PW_CHECK(out.size() == sizeof(bool) || out.size() == sizeof(uint32_t) || + out.size() == sizeof(uint64_t), + "Protobuf varints must only be used with bool, int32_t, uint32_t, " + "int64_t, or uint64_t"); PW_TRY(CheckOkToRead(WireType::kVarint)); + const StatusWithSize sws = ReadOneVarint(out, decode_type); + if (sws.status() != Status::DataLoss()) + field_consumed_ = true; + return sws.status(); +} + +StatusWithSize StreamDecoder::ReadOneVarint(std::span<std::byte> out, + VarintDecodeType decode_type) { uint64_t value; StatusWithSize sws = varint::Read(reader_, &value); if (sws.IsOutOfRange()) { // Out of range indicates the end of the stream. As a value is expected // here, report it as a data loss and terminate the decode operation. status_ = Status::DataLoss(); - return status_; + return StatusWithSize(status_, sws.size()); + } + if (!sws.ok()) { + return sws; } - PW_TRY(sws); position_ += sws.size(); - field_consumed_ = true; - *out = value; - return OkStatus(); + if (out.size() == sizeof(uint64_t)) { + if (decode_type == VarintDecodeType::kUnsigned) { + std::memcpy(out.data(), &value, out.size()); + } else { + const int64_t signed_value = decode_type == VarintDecodeType::kZigZag + ? varint::ZigZagDecode(value) + : static_cast<int64_t>(value); + std::memcpy(out.data(), &signed_value, out.size()); + } + } else if (out.size() == sizeof(uint32_t)) { + if (decode_type == VarintDecodeType::kUnsigned) { + if (value > std::numeric_limits<uint32_t>::max()) { + return StatusWithSize(Status::OutOfRange(), sws.size()); + } + std::memcpy(out.data(), &value, out.size()); + } else { + const int64_t signed_value = decode_type == VarintDecodeType::kZigZag + ? varint::ZigZagDecode(value) + : static_cast<int64_t>(value); + if (signed_value > std::numeric_limits<int32_t>::max() || + signed_value < std::numeric_limits<int32_t>::min()) { + return StatusWithSize(Status::OutOfRange(), sws.size()); + } + std::memcpy(out.data(), &signed_value, out.size()); + } + } else if (out.size() == sizeof(bool)) { + PW_CHECK(decode_type == VarintDecodeType::kUnsigned, + "Protobuf bool can never be signed"); + std::memcpy(out.data(), &value, out.size()); + } + + return sws; } Status StreamDecoder::ReadFixedField(std::span<std::byte> out) { @@ -435,6 +428,43 @@ StatusWithSize StreamDecoder::ReadPackedFixedField(std::span<std::byte> out, return StatusWithSize(result.value().size() / elem_size); } +StatusWithSize StreamDecoder::ReadPackedVarintField( + std::span<std::byte> out, size_t elem_size, VarintDecodeType decode_type) { + PW_CHECK(elem_size == sizeof(bool) || elem_size == sizeof(uint32_t) || + elem_size == sizeof(uint64_t), + "Protobuf varints must only be used with bool, int32_t, uint32_t, " + "int64_t, or uint64_t"); + + if (Status status = CheckOkToRead(WireType::kDelimited); !status.ok()) { + return StatusWithSize(status, 0); + } + + if (reader_.ConservativeReadLimit() < delimited_field_size_) { + status_ = Status::DataLoss(); + return StatusWithSize(status_, 0); + } + + size_t bytes_read = 0; + size_t number_out = 0; + while (bytes_read < delimited_field_size_ && !out.empty()) { + const StatusWithSize sws = ReadOneVarint(out.first(elem_size), decode_type); + if (!sws.ok()) { + return StatusWithSize(sws.status(), number_out); + } + + bytes_read += sws.size(); + out = out.subspan(elem_size); + ++number_out; + } + + if (bytes_read < delimited_field_size_) { + return StatusWithSize(Status::ResourceExhausted(), number_out); + } + + field_consumed_ = true; + return StatusWithSize(OkStatus(), number_out); +} + Status StreamDecoder::CheckOkToRead(WireType type) { PW_CHECK(!nested_reader_open_, "Cannot read from a decoder while a nested decoder is open"); |