aboutsummaryrefslogtreecommitdiff
path: root/pw_protobuf/stream_decoder.cc
diff options
context:
space:
mode:
Diffstat (limited to 'pw_protobuf/stream_decoder.cc')
-rw-r--r--pw_protobuf/stream_decoder.cc148
1 files changed, 89 insertions, 59 deletions
diff --git a/pw_protobuf/stream_decoder.cc b/pw_protobuf/stream_decoder.cc
index 9990b0862..e6a88aea0 100644
--- a/pw_protobuf/stream_decoder.cc
+++ b/pw_protobuf/stream_decoder.cc
@@ -16,6 +16,8 @@
#include <algorithm>
#include <bit>
+#include <cstdint>
+#include <cstring>
#include <limits>
#include "pw_assert/check.h"
@@ -114,59 +116,6 @@ Status StreamDecoder::Next() {
return status_;
}
-Result<int32_t> StreamDecoder::ReadInt32() {
- uint64_t varint = 0;
- PW_TRY(ReadVarintField(&varint));
-
- int64_t signed_value = static_cast<int64_t>(varint);
- if (signed_value > std::numeric_limits<int32_t>::max() ||
- signed_value < std::numeric_limits<int32_t>::min()) {
- return Status::OutOfRange();
- }
-
- return signed_value;
-}
-
-Result<uint32_t> StreamDecoder::ReadUint32() {
- uint64_t varint = 0;
- PW_TRY(ReadVarintField(&varint));
-
- if (varint > std::numeric_limits<uint32_t>::max()) {
- return Status::OutOfRange();
- }
- return varint;
-}
-
-Result<int64_t> StreamDecoder::ReadInt64() {
- uint64_t varint = 0;
- PW_TRY(ReadVarintField(&varint));
- return varint;
-}
-
-Result<int32_t> StreamDecoder::ReadSint32() {
- uint64_t varint = 0;
- PW_TRY(ReadVarintField(&varint));
-
- int64_t signed_value = varint::ZigZagDecode(varint);
- if (signed_value > std::numeric_limits<int32_t>::max() ||
- signed_value < std::numeric_limits<int32_t>::min()) {
- return Status::OutOfRange();
- }
- return signed_value;
-}
-
-Result<int64_t> StreamDecoder::ReadSint64() {
- uint64_t varint = 0;
- PW_TRY(ReadVarintField(&varint));
- return varint::ZigZagDecode(varint);
-}
-
-Result<bool> StreamDecoder::ReadBool() {
- uint64_t varint = 0;
- PW_TRY(ReadVarintField(&varint));
- return varint;
-}
-
StreamDecoder::BytesReader StreamDecoder::GetBytesReader() {
Status status = CheckOkToRead(WireType::kDelimited);
@@ -330,24 +279,68 @@ Status StreamDecoder::SkipField() {
return OkStatus();
}
-Status StreamDecoder::ReadVarintField(uint64_t* out) {
+Status StreamDecoder::ReadVarintField(std::span<std::byte> out,
+ VarintDecodeType decode_type) {
+ PW_CHECK(out.size() == sizeof(bool) || out.size() == sizeof(uint32_t) ||
+ out.size() == sizeof(uint64_t),
+ "Protobuf varints must only be used with bool, int32_t, uint32_t, "
+ "int64_t, or uint64_t");
PW_TRY(CheckOkToRead(WireType::kVarint));
+ const StatusWithSize sws = ReadOneVarint(out, decode_type);
+ if (sws.status() != Status::DataLoss())
+ field_consumed_ = true;
+ return sws.status();
+}
+
+StatusWithSize StreamDecoder::ReadOneVarint(std::span<std::byte> out,
+ VarintDecodeType decode_type) {
uint64_t value;
StatusWithSize sws = varint::Read(reader_, &value);
if (sws.IsOutOfRange()) {
// Out of range indicates the end of the stream. As a value is expected
// here, report it as a data loss and terminate the decode operation.
status_ = Status::DataLoss();
- return status_;
+ return StatusWithSize(status_, sws.size());
+ }
+ if (!sws.ok()) {
+ return sws;
}
- PW_TRY(sws);
position_ += sws.size();
- field_consumed_ = true;
- *out = value;
- return OkStatus();
+ if (out.size() == sizeof(uint64_t)) {
+ if (decode_type == VarintDecodeType::kUnsigned) {
+ std::memcpy(out.data(), &value, out.size());
+ } else {
+ const int64_t signed_value = decode_type == VarintDecodeType::kZigZag
+ ? varint::ZigZagDecode(value)
+ : static_cast<int64_t>(value);
+ std::memcpy(out.data(), &signed_value, out.size());
+ }
+ } else if (out.size() == sizeof(uint32_t)) {
+ if (decode_type == VarintDecodeType::kUnsigned) {
+ if (value > std::numeric_limits<uint32_t>::max()) {
+ return StatusWithSize(Status::OutOfRange(), sws.size());
+ }
+ std::memcpy(out.data(), &value, out.size());
+ } else {
+ const int64_t signed_value = decode_type == VarintDecodeType::kZigZag
+ ? varint::ZigZagDecode(value)
+ : static_cast<int64_t>(value);
+ if (signed_value > std::numeric_limits<int32_t>::max() ||
+ signed_value < std::numeric_limits<int32_t>::min()) {
+ return StatusWithSize(Status::OutOfRange(), sws.size());
+ }
+ std::memcpy(out.data(), &signed_value, out.size());
+ }
+ } else if (out.size() == sizeof(bool)) {
+ PW_CHECK(decode_type == VarintDecodeType::kUnsigned,
+ "Protobuf bool can never be signed");
+ std::memcpy(out.data(), &value, out.size());
+ }
+
+ return sws;
}
Status StreamDecoder::ReadFixedField(std::span<std::byte> out) {
@@ -435,6 +428,43 @@ StatusWithSize StreamDecoder::ReadPackedFixedField(std::span<std::byte> out,
return StatusWithSize(result.value().size() / elem_size);
}
+StatusWithSize StreamDecoder::ReadPackedVarintField(
+ std::span<std::byte> out, size_t elem_size, VarintDecodeType decode_type) {
+ PW_CHECK(elem_size == sizeof(bool) || elem_size == sizeof(uint32_t) ||
+ elem_size == sizeof(uint64_t),
+ "Protobuf varints must only be used with bool, int32_t, uint32_t, "
+ "int64_t, or uint64_t");
+
+ if (Status status = CheckOkToRead(WireType::kDelimited); !status.ok()) {
+ return StatusWithSize(status, 0);
+ }
+
+ if (reader_.ConservativeReadLimit() < delimited_field_size_) {
+ status_ = Status::DataLoss();
+ return StatusWithSize(status_, 0);
+ }
+
+ size_t bytes_read = 0;
+ size_t number_out = 0;
+ while (bytes_read < delimited_field_size_ && !out.empty()) {
+ const StatusWithSize sws = ReadOneVarint(out.first(elem_size), decode_type);
+ if (!sws.ok()) {
+ return StatusWithSize(sws.status(), number_out);
+ }
+
+ bytes_read += sws.size();
+ out = out.subspan(elem_size);
+ ++number_out;
+ }
+
+ if (bytes_read < delimited_field_size_) {
+ return StatusWithSize(Status::ResourceExhausted(), number_out);
+ }
+
+ field_consumed_ = true;
+ return StatusWithSize(OkStatus(), number_out);
+}
+
Status StreamDecoder::CheckOkToRead(WireType type) {
PW_CHECK(!nested_reader_open_,
"Cannot read from a decoder while a nested decoder is open");