diff options
-rw-r--r-- | pw_protobuf/BUILD.bazel | 21 | ||||
-rw-r--r-- | pw_protobuf/BUILD.gn | 21 | ||||
-rw-r--r-- | pw_protobuf/docs.rst | 141 | ||||
-rw-r--r-- | pw_protobuf/map_utils.cc (renamed from pw_protobuf/helpers.cc) | 6 | ||||
-rw-r--r-- | pw_protobuf/map_utils_test.cc (renamed from pw_protobuf/helpers_test.cc) | 24 | ||||
-rw-r--r-- | pw_protobuf/message.cc | 184 | ||||
-rw-r--r-- | pw_protobuf/message_test.cc | 573 | ||||
-rw-r--r-- | pw_protobuf/public/pw_protobuf/internal/proto_integer_base.h | 39 | ||||
-rw-r--r-- | pw_protobuf/public/pw_protobuf/map_utils.h (renamed from pw_protobuf/public/pw_protobuf/helpers.h) | 4 | ||||
-rw-r--r-- | pw_protobuf/public/pw_protobuf/message.h | 554 | ||||
-rw-r--r-- | pw_protobuf/public/pw_protobuf/stream_decoder.h | 21 | ||||
-rw-r--r-- | pw_protobuf/stream_decoder.cc | 7 | ||||
-rw-r--r-- | pw_protobuf/stream_decoder_test.cc | 32 | ||||
-rw-r--r-- | pw_stream/BUILD.bazel | 16 | ||||
-rw-r--r-- | pw_stream/BUILD.gn | 17 | ||||
-rw-r--r-- | pw_stream/interval_reader.cc | 83 | ||||
-rw-r--r-- | pw_stream/interval_reader_test.cc | 89 | ||||
-rw-r--r-- | pw_stream/public/pw_stream/interval_reader.h | 93 |
18 files changed, 1896 insertions, 29 deletions
diff --git a/pw_protobuf/BUILD.bazel b/pw_protobuf/BUILD.bazel index c343ee44a..5aca132a6 100644 --- a/pw_protobuf/BUILD.bazel +++ b/pw_protobuf/BUILD.bazel @@ -38,14 +38,17 @@ pw_cc_library( "decoder.cc", "encoder.cc", "find.cc", - "helpers.cc", + "map_utils.cc", + "message.cc", "stream_decoder.cc", ], hdrs = [ "public/pw_protobuf/decoder.h", "public/pw_protobuf/encoder.h", "public/pw_protobuf/find.h", - "public/pw_protobuf/helpers.h", + "public/pw_protobuf/internal/proto_integer_base.h", + "public/pw_protobuf/map_utils.h", + "public/pw_protobuf/message.h", "public/pw_protobuf/serialized_size.h", "public/pw_protobuf/stream_decoder.h", "public/pw_protobuf/wire_format.h", @@ -59,6 +62,7 @@ pw_cc_library( "//pw_span", "//pw_status", "//pw_stream", + "//pw_stream:interval_reader", "//pw_varint", ], ) @@ -111,8 +115,17 @@ pw_cc_test( ) pw_cc_test( - name = "helpers_test", - srcs = ["helpers_test.cc"], + name = "map_utils_test", + srcs = ["map_utils_test.cc"], + deps = [ + ":pw_protobuf", + "//pw_unit_test", + ], +) + +pw_cc_test( + name = "message_test", + srcs = ["message_test.cc"], deps = [ ":pw_protobuf", "//pw_unit_test", diff --git a/pw_protobuf/BUILD.gn b/pw_protobuf/BUILD.gn index a35d2a6c3..f5597f0ee 100644 --- a/pw_protobuf/BUILD.gn +++ b/pw_protobuf/BUILD.gn @@ -45,8 +45,10 @@ pw_source_set("pw_protobuf") { public_configs = [ ":public_include_path" ] public_deps = [ ":config", + "$dir_pw_stream:interval_reader", dir_pw_assert, dir_pw_bytes, + dir_pw_log, dir_pw_result, dir_pw_status, dir_pw_stream, @@ -56,7 +58,9 @@ pw_source_set("pw_protobuf") { "public/pw_protobuf/decoder.h", "public/pw_protobuf/encoder.h", "public/pw_protobuf/find.h", - "public/pw_protobuf/helpers.h", + "public/pw_protobuf/internal/proto_integer_base.h", + "public/pw_protobuf/map_utils.h", + "public/pw_protobuf/message.h", "public/pw_protobuf/serialized_size.h", "public/pw_protobuf/stream_decoder.h", "public/pw_protobuf/wire_format.h", @@ -65,7 +69,8 @@ pw_source_set("pw_protobuf") { "decoder.cc", "encoder.cc", "find.cc", - "helpers.cc", + "map_utils.cc", + "message.cc", "stream_decoder.cc", ] } @@ -85,8 +90,9 @@ pw_test_group("tests") { ":encoder_test", ":encoder_fuzzer", ":find_test", + ":map_utils_test", + ":message_test", ":stream_decoder_test", - ":helpers_test", ":varint_size_test", ] } @@ -116,9 +122,14 @@ pw_test("stream_decoder_test") { sources = [ "stream_decoder_test.cc" ] } -pw_test("helpers_test") { +pw_test("map_utils_test") { deps = [ ":pw_protobuf" ] - sources = [ "helpers_test.cc" ] + sources = [ "map_utils_test.cc" ] +} + +pw_test("message_test") { + deps = [ ":pw_protobuf" ] + sources = [ "message_test.cc" ] } config("one_byte_varint") { diff --git a/pw_protobuf/docs.rst b/pw_protobuf/docs.rst index 7e2232aaf..0b3387c20 100644 --- a/pw_protobuf/docs.rst +++ b/pw_protobuf/docs.rst @@ -409,12 +409,147 @@ its parent decoder cannot be used. // parent decoder can be used again. } -Protobuf helpers -================ +Proto map encoding utils +======================== Some additional helpers for encoding more complex but common protobuf submessages (e.g. map<string, bytes>) are provided in -``pw_protobuf/helpers.h``. +``pw_protobuf/map_utils.h``. + +.. Note:: + The helper API are currently in-development and may not remain stable. + +Message +======= + +The module implements a message parsing class ``Message``, in +``pw_protobuf/message.h``, to faciliate proto message parsing and field access. +The class provides interfaces for searching fields in a proto message and +creating helper classes for it according to its interpreted field type, i.e. +uint32, bytes, string, map<>, repeated etc. The class works on top of +``StreamDecoder`` and thus requires a ``pw::stream::SeekableReader`` for proto +message access. The following gives examples for using the class to process +different fields in a proto message: + +.. code-block:: c++ + + // Consider the proto messages defined as follows: + // + // message Nested { + // string nested_str = 1; + // bytes nested_bytes = 2; + // } + // + // message { + // uint32 integer = 1; + // string str = 2; + // bytes bytes = 3; + // Nested nested = 4; + // repeated string rep_str = 5; + // repeated Nested rep_nested = 6; + // map<string, bytes> str_to_bytes = 7; + // map<string, Nested> str_to_nested = 8; + // } + + // Given a seekable `reader` that reads the top-level proto message, and + // a <proto_size> that gives the size of the proto message: + Message message(reader, proto_size); + + // Parse a proto integer field + Uint32 integer = messasge_parser.AsUint32(1); + if (!integer.ok()) { + // handle parsing error. i.e. return integer.status(). + } + uint32_t integer_value = integer.value(); // obtained the value + + // Parse a string field + String str = message.AsString(2); + if (!str.ok()) { + // handle parsing error. i.e. return str.status(); + } + + // check string equal + Result<bool> str_check = str.Equal("foo"); + + // Parse a bytes field + Bytes bytes = message.AsBytes(3); + if (!bytes.ok()) { + // handle parsing error. i.e. return bytes.status(); + } + + // Get a reader to the bytes. + stream::IntervalReader bytes_reader = bytes.GetBytesReader(); + + // Parse nested message `Nested nested = 4;` + Message nested = message.AsMessage(4). + // Get the fields in the nested message. + String nested_str = nested.AsString(1); + Bytes nested_bytes = nested.AsBytes(2); + + // Parse repeated field `repeated string rep_str = 5;` + RepeatedStrings rep_str = message.AsRepeatedString(5); + // Iterate through the entries. For iteration + for (String element : rep_str) { + // Process str + } + + // Parse repeated field `repeated Nested rep_nested = 6;` + RepeatedStrings rep_str = message.AsRepeatedString(6); + // Iterate through the entries. For iteration + for (Message element : rep_rep_nestedstr) { + // Process element + } + + // Parse map field `map<string, bytes> str_to_bytes = 7;` + StringToBytesMap str_to_bytes = message.AsStringToBytesMap(7); + // Access the entry by a given key value + Bytes bytes_for_key = str_to_bytes["key"]; + // Or iterate through map entries + for (StringToBytesMapEntry entry : str_to_bytes) { + String key = entry.Key(); + Bytes value = entry.Value(); + // process entry + } + + // Parse map field `map<string, Nested> str_to_nested = 8;` + StringToMessageMap str_to_nested = message.AsStringToBytesMap(8); + // Access the entry by a given key value + Message nested_for_key = str_to_nested["key"]; + // Or iterate through map entries + for (StringToMessageMapEntry entry : str_to_nested) { + String key = entry.Key(); + Message value = entry.Value(); + // process entry + } + +The methods in ``Message`` for parsing a single field, i.e. everty `AsXXX()` +method except AsRepeatedXXX() and AsStringMapXXX(), internally performs a +linear scan of the entire proto message to find the field with the given +field number. This can be expensive if performed multiple times, especially +on slow reader. The same applies to the ``operator[]`` of StringToXXXXMap +helper class. Therefore, for performance consideration, whenever possible, it +is recommended to use the following for-range style to iterate and process +single fields directly. + + +.. code-block:: c++ + + for (Message::Field field : message) { + if (field.field_number() == 1) { + Uint32 integer = field.As<Uint32>(); + ... + } else if (field.field_number() == 2) { + String str = field.As<String>(); + ... + } else if (field.field_number() == 3) { + Bytes bytes = field.As<Bytes>(); + ... + } else if (field.field_number() == 4) { + Message nested = field.As<Message>(); + ... + } + } + .. Note:: The helper API are currently in-development and may not remain stable. diff --git a/pw_protobuf/helpers.cc b/pw_protobuf/map_utils.cc index 8bf074c14..27b505bcf 100644 --- a/pw_protobuf/helpers.cc +++ b/pw_protobuf/map_utils.cc @@ -12,17 +12,17 @@ // License for the specific language governing permissions and limitations under // the License. -#include "pw_protobuf/helpers.h" +#include "pw_protobuf/map_utils.h" #include <cstddef> +#include "pw_bytes/span.h" #include "pw_protobuf/encoder.h" #include "pw_protobuf/serialized_size.h" +#include "pw_stream/stream.h" namespace pw::protobuf { -// TODO(pwbug/456): Generalize and move this helper to pw_protobuf -// // Note that a map<string, bytes> is essentially // // message Entry { diff --git a/pw_protobuf/helpers_test.cc b/pw_protobuf/map_utils_test.cc index 6d2f8cf7b..d1da109ca 100644 --- a/pw_protobuf/helpers_test.cc +++ b/pw_protobuf/map_utils_test.cc @@ -12,10 +12,15 @@ // License for the specific language governing permissions and limitations under // the License. -#include "pw_protobuf/helpers.h" +#include "pw_protobuf/map_utils.h" + +#include <string_view> #include "gtest/gtest.h" -#include "pw_protobuf/encoder.h" +#include "pw_stream/memory_stream.h" +#include "pw_stream/stream.h" + +#define ASSERT_OK(status) ASSERT_EQ(OkStatus(), status) namespace pw::protobuf { @@ -78,14 +83,13 @@ TEST(ProtoHelper, WriteProtoStringToBytesMapEntry) { for (auto ele : kMapData) { stream::MemoryReader key_reader(std::as_bytes(std::span{ele.key})); stream::MemoryReader value_reader(std::as_bytes(std::span{ele.value})); - ASSERT_TRUE(WriteProtoStringToBytesMapEntry(ele.field_number, - key_reader, - ele.key.size(), - value_reader, - ele.value.size(), - stream_pipe_buffer, - writer) - .ok()); + ASSERT_OK(WriteProtoStringToBytesMapEntry(ele.field_number, + key_reader, + ele.key.size(), + value_reader, + ele.value.size(), + stream_pipe_buffer, + writer)); } ASSERT_EQ(memcmp(dst_buffer, encoded_proto, sizeof(dst_buffer)), 0); diff --git a/pw_protobuf/message.cc b/pw_protobuf/message.cc new file mode 100644 index 000000000..fa4c69f63 --- /dev/null +++ b/pw_protobuf/message.cc @@ -0,0 +1,184 @@ +// Copyright 2021 The Pigweed Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. + +#include "pw_protobuf/message.h" + +#include <cstddef> + +#include "pw_protobuf/serialized_size.h" +#include "pw_protobuf/stream_decoder.h" +#include "pw_result/result.h" +#include "pw_status/status_with_size.h" +#include "pw_stream/interval_reader.h" +#include "pw_stream/stream.h" + +namespace pw::protobuf { + +template <> +Uint32 Message::Field::As<Uint32>() { + protobuf::StreamDecoder decoder(field_reader_.Reset()); + PW_TRY(decoder.Next()); + return decoder.ReadUint32(); +} + +template <> +Int32 Message::Field::As<Int32>() { + protobuf::StreamDecoder decoder(field_reader_.Reset()); + PW_TRY(decoder.Next()); + return decoder.ReadInt32(); +} + +template <> +Sint32 Message::Field::As<Sint32>() { + protobuf::StreamDecoder decoder(field_reader_.Reset()); + PW_TRY(decoder.Next()); + return decoder.ReadSint32(); +} + +template <> +Fixed32 Message::Field::As<Fixed32>() { + protobuf::StreamDecoder decoder(field_reader_.Reset()); + PW_TRY(decoder.Next()); + return decoder.ReadFixed32(); +} + +template <> +Sfixed32 Message::Field::As<Sfixed32>() { + protobuf::StreamDecoder decoder(field_reader_.Reset()); + PW_TRY(decoder.Next()); + return decoder.ReadSfixed32(); +} + +template <> +Uint64 Message::Field::As<Uint64>() { + protobuf::StreamDecoder decoder(field_reader_.Reset()); + PW_TRY(decoder.Next()); + return decoder.ReadUint64(); +} + +template <> +Int64 Message::Field::As<Int64>() { + protobuf::StreamDecoder decoder(field_reader_.Reset()); + PW_TRY(decoder.Next()); + return decoder.ReadInt64(); +} + +template <> +Sint64 Message::Field::As<Sint64>() { + protobuf::StreamDecoder decoder(field_reader_.Reset()); + PW_TRY(decoder.Next()); + return decoder.ReadSint64(); +} + +template <> +Fixed64 Message::Field::As<Fixed64>() { + protobuf::StreamDecoder decoder(field_reader_.Reset()); + PW_TRY(decoder.Next()); + return decoder.ReadFixed64(); +} + +template <> +Sfixed64 Message::Field::As<Sfixed64>() { + protobuf::StreamDecoder decoder(field_reader_.Reset()); + PW_TRY(decoder.Next()); + return decoder.ReadSfixed64(); +} + +Result<bool> Bytes::Equal(ConstByteSpan bytes) { + stream::IntervalReader bytes_reader = GetBytesReader(); + if (bytes_reader.interval_size() != bytes.size()) { + return false; + } + + std::byte buf[1]; + for (size_t i = 0; i < bytes.size();) { + Result<ByteSpan> res = bytes_reader.Read(buf); + PW_TRY(res.status()); + if (res.value().size() == 1) { + if (buf[0] != bytes[i++]) + return false; + } + } + + return true; +} + +Result<bool> String::Equal(std::string_view str) { + return Bytes::Equal(std::as_bytes(std::span{str})); +} + +Message::iterator& Message::iterator::operator++() { + // Store the starting offset of the field. + size_t field_start = reader_.current(); + protobuf::StreamDecoder decoder(reader_); + Status status = decoder.Next(); + if (status.IsOutOfRange()) { + eof_ = true; + return *this; + } + + PW_CHECK(status.ok()); + Result<uint32_t> field_number = decoder.FieldNumber(); + // Consume the field so that the reader will be pointing to the start + // of the next field, which is equivalent to the end offset of the + // current field. + PW_CHECK(ConsumeCurrentField(decoder).ok()); + + // Create a Field object with the field interval. + current_ = Field(stream::IntervalReader( + reader_.source_reader(), field_start, reader_.current()), + field_number.value()); + return *this; +} + +Message::iterator Message::begin() { + PW_CHECK(ok()); + return iterator(reader_.Reset()); +} + +Message::iterator Message::end() { + PW_CHECK(ok()); + // The end iterator is created by using an exahusted stream::IntervalReader, + // i.e. the reader is pointing at the internval end. + stream::IntervalReader reader_end = reader_; + PW_CHECK(reader_end.Seek(0, stream::Stream::Whence::kEnd).ok()); + return iterator(reader_end); +} + +RepeatedBytes Message::AsRepeatedBytes(uint32_t field_number) { + return AsRepeated<Bytes>(field_number); +} + +RepeatedFieldParser<String> Message::AsRepeatedStrings(uint32_t field_number) { + return AsRepeated<String>(field_number); +} + +RepeatedFieldParser<Message> Message::AsRepeatedMessages( + uint32_t field_number) { + return AsRepeated<Message>(field_number); +} + +StringMapParser<Message> Message::AsStringToMessageMap(uint32_t field_number) { + return AsStringMap<Message>(field_number); +} + +StringMapParser<Bytes> Message::AsStringToBytesMap(uint32_t field_number) { + return AsStringMap<Bytes>(field_number); +} + +StringMapParser<String> Message::AsStringToStringMap(uint32_t field_number) { + return AsStringMap<String>(field_number); +} + +} // namespace pw::protobuf diff --git a/pw_protobuf/message_test.cc b/pw_protobuf/message_test.cc new file mode 100644 index 000000000..390e1ed77 --- /dev/null +++ b/pw_protobuf/message_test.cc @@ -0,0 +1,573 @@ +// Copyright 2021 The Pigweed Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. + +#include "pw_protobuf/message.h" + +#include "gtest/gtest.h" +#include "pw_stream/memory_stream.h" + +#define ASSERT_OK(status) ASSERT_EQ(OkStatus(), status) + +namespace pw::protobuf { + +TEST(ProtoHelper, IterateMessage) { + // clang-format off + constexpr uint8_t encoded_proto[] = { + // type=uint32, k=1, v=1 + 0x08, 0x01, + // type=uint32, k=2, v=2 + 0x10, 0x02, + // type=uint32, k=3, v=3 + 0x18, 0x03, + }; + // clang-format on + + stream::MemoryReader reader(std::as_bytes(std::span(encoded_proto))); + Message parser = Message(reader, sizeof(encoded_proto)); + + uint32_t count = 0; + for (Message::Field field : parser) { + ++count; + EXPECT_EQ(field.field_number(), count); + Uint32 value = field.As<Uint32>(); + ASSERT_OK(value.status()); + EXPECT_EQ(value.value(), count); + } + + EXPECT_EQ(count, static_cast<uint32_t>(3)); +} + +TEST(ProtoHelper, MessageIterator) { + // clang-format off + std::uint8_t encoded_proto[] = { + // key = 1, str = "foo 1" + 0x0a, 0x05, 'f', 'o', 'o', ' ', '1', + // type=uint32, k=2, v=2 + 0x10, 0x02, + }; + // clang-format on + + stream::MemoryReader reader(std::as_bytes(std::span(encoded_proto))); + Message parser = Message(reader, sizeof(encoded_proto)); + + Message::iterator iter = parser.begin(); + + Message::iterator first = iter++; + ASSERT_EQ(first, first); + ASSERT_EQ(first->field_number(), static_cast<uint32_t>(1)); + String str = first->As<String>(); + ASSERT_OK(str.status()); + Result<bool> cmp = str.Equal("foo 1"); + ASSERT_OK(cmp.status()); + ASSERT_TRUE(cmp.value()); + + Message::iterator second = iter++; + ASSERT_EQ(second, second); + ASSERT_EQ(second->field_number(), static_cast<uint32_t>(2)); + Uint32 uint32_val = second->As<Uint32>(); + ASSERT_OK(uint32_val.status()); + ASSERT_EQ(uint32_val.value(), static_cast<uint32_t>(2)); + + ASSERT_NE(first, second); + ASSERT_NE(first, iter); + ASSERT_NE(second, iter); + ASSERT_EQ(iter, parser.end()); +} + +TEST(ProtoHelper, AsProtoInteger) { + // clang-format off + std::uint8_t encoded_proto[] = { + // type: int32, k = 1, val = -123 + 0x08, 0x85, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x01, + // type: uint32, k = 2, val = 123 + 0x10, 0x7b, + // type: sint32, k = 3, val = -456 + 0x18, 0x8f, 0x07, + // type: fixed32, k = 4, val = 268435457 + 0x25, 0x01, 0x00, 0x00, 0x10, + // type: sfixed32, k = 5, val = -268435457 + 0x2d, 0xff, 0xff, 0xff, 0xef, + // type: int64, k = 6, val = -1099511627776 + 0x30, 0x80, 0x80, 0x80, 0x80, 0x80, 0xe0, 0xff, 0xff, 0xff, 0x01, + // type: uint64, k = 7, val = 1099511627776 + 0x38, 0x80, 0x80, 0x80, 0x80, 0x80, 0x20, + // type: sint64, k = 8, val = -2199023255552 + 0x40, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f, + // type: fixed64, k = 9, val = 72057594037927937 + 0x49, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, + // type: sfixed64, k = 10, val = -72057594037927937 + 0x51, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe, + }; + // clang-format on + + stream::MemoryReader reader(std::as_bytes(std::span(encoded_proto))); + Message parser = Message(reader, sizeof(encoded_proto)); + + { + Int32 value = parser.AsInt32(1); + ASSERT_OK(value.status()); + ASSERT_EQ(value.value(), static_cast<int32_t>(-123)); + } + + { + Uint32 value = parser.AsUint32(2); + ASSERT_OK(value.status()); + ASSERT_EQ(value.value(), static_cast<uint32_t>(123)); + } + + { + Sint32 value = parser.AsSint32(3); + ASSERT_OK(value.status()); + ASSERT_EQ(value.value(), static_cast<int32_t>(-456)); + } + + { + Fixed32 value = parser.AsFixed32(4); + ASSERT_OK(value.status()); + ASSERT_EQ(value.value(), static_cast<uint32_t>(268435457)); + } + + { + Sfixed32 value = parser.AsSfixed32(5); + ASSERT_OK(value.status()); + ASSERT_EQ(value.value(), static_cast<int32_t>(-268435457)); + } + + { + Int64 value = parser.AsInt64(6); + ASSERT_OK(value.status()); + ASSERT_EQ(value.value(), static_cast<int64_t>(-1099511627776)); + } + + { + Uint64 value = parser.AsUint64(7); + ASSERT_OK(value.status()); + ASSERT_EQ(value.value(), static_cast<uint64_t>(1099511627776)); + } + + { + Sint64 value = parser.AsSint64(8); + ASSERT_OK(value.status()); + ASSERT_EQ(value.value(), static_cast<int64_t>(-2199023255552)); + } + + { + Fixed64 value = parser.AsFixed64(9); + ASSERT_OK(value.status()); + ASSERT_EQ(value.value(), static_cast<uint64_t>(72057594037927937)); + } + + { + Sfixed64 value = parser.AsSfixed64(10); + ASSERT_OK(value.status()); + ASSERT_EQ(value.value(), static_cast<int64_t>(-72057594037927937)); + } +} + +TEST(ProtoHelper, AsString) { + // message { + // string str = 1; + // } + // clang-format off + std::uint8_t encoded_proto[] = { + // `str`, k = 1, "string" + 0x0a, 0x06, 's', 't', 'r', 'i', 'n', 'g', + }; + // clang-format on + + stream::MemoryReader reader(std::as_bytes(std::span(encoded_proto))); + Message parser = Message(reader, sizeof(encoded_proto)); + + constexpr uint32_t kFieldNumber = 1; + String value = parser.AsString(kFieldNumber); + ASSERT_OK(value.status()); + Result<bool> cmp = value.Equal("string"); + ASSERT_OK(cmp.status()); + ASSERT_TRUE(cmp.value()); + + cmp = value.Equal("other"); + ASSERT_OK(cmp.status()); + ASSERT_FALSE(cmp.value()); + + // The string is a prefix of the target string to compare. + cmp = value.Equal("string and more"); + ASSERT_OK(cmp.status()); + ASSERT_FALSE(cmp.value()); + + // The target string to compare is a sub prefix of this string + cmp = value.Equal("str"); + ASSERT_OK(cmp.status()); + ASSERT_FALSE(cmp.value()); +} + +TEST(ProtoHelper, AsRepeatedStrings) { + // Repeated field of string i.e. + // + // message RepeatedString { + // repeated string msg_a = 1; + // repeated string msg_b = 2; + // } + // clang-format off + std::uint8_t encoded_proto[] = { + // key = 1, str = "foo 1" + 0x0a, 0x05, 'f', 'o', 'o', ' ', '1', + // key = 2, str = "foo 2" + 0x12, 0x05, 'f', 'o', 'o', ' ', '2', + // key = 1, str = "bar 1" + 0x0a, 0x05, 'b', 'a', 'r', ' ', '1', + // key = 2, str = "bar 2" + 0x12, 0x05, 'b', 'a', 'r', ' ', '2', + }; + // clang-format on + + constexpr uint32_t kMsgAFieldNumber = 1; + constexpr uint32_t kMsgBFieldNumber = 2; + constexpr uint32_t kNonExistFieldNumber = 3; + + stream::MemoryReader reader(std::as_bytes(std::span(encoded_proto))); + Message parser = Message(reader, sizeof(encoded_proto)); + + // Field 'msg_a' + { + RepeatedStrings msg = parser.AsRepeatedStrings(kMsgAFieldNumber); + std::string_view expected[] = { + "foo 1", + "bar 1", + }; + + size_t count = 0; + for (String ele : msg) { + ASSERT_OK(ele.status()); + Result<bool> res = ele.Equal(expected[count++]); + ASSERT_OK(res.status()); + ASSERT_TRUE(res.value()); + } + + ASSERT_EQ(count, static_cast<size_t>(2)); + } + + // Field `msg_b` + { + RepeatedStrings msg = parser.AsRepeatedStrings(kMsgBFieldNumber); + std::string_view expected[] = { + "foo 2", + "bar 2", + }; + + size_t count = 0; + for (String ele : msg) { + ASSERT_OK(ele.status()); + Result<bool> res = ele.Equal(expected[count++]); + ASSERT_OK(res.status()); + ASSERT_TRUE(res.value()); + } + + ASSERT_EQ(count, static_cast<size_t>(2)); + } + + // non-existing field + { + RepeatedStrings msg = parser.AsRepeatedStrings(kNonExistFieldNumber); + size_t count = 0; + for ([[maybe_unused]] String ele : msg) { + count++; + } + + ASSERT_EQ(count, static_cast<size_t>(0)); + } +} + +TEST(ProtoHelper, RepeatedFieldIterator) { + // Repeated field of string i.e. + // + // message RepeatedString { + // repeated string msg = 1; + // } + // clang-format off + std::uint8_t encoded_proto[] = { + // key = 1, str = "foo 1" + 0x0a, 0x05, 'f', 'o', 'o', ' ', '1', + // key = 1, str = "bar 1" + 0x0a, 0x05, 'b', 'a', 'r', ' ', '1', + }; + // clang-format on + + constexpr uint32_t kFieldNumber = 1; + stream::MemoryReader reader(std::as_bytes(std::span(encoded_proto))); + Message parser = Message(reader, sizeof(encoded_proto)); + RepeatedStrings repeated_str = parser.AsRepeatedStrings(kFieldNumber); + + RepeatedStrings::iterator iter = repeated_str.begin(); + + RepeatedStrings::iterator first = iter++; + ASSERT_EQ(first, first); + Result<bool> cmp = first->Equal("foo 1"); + ASSERT_OK(cmp.status()); + ASSERT_TRUE(cmp.value()); + + RepeatedStrings::iterator second = iter++; + ASSERT_EQ(second, second); + cmp = second->Equal("bar 1"); + ASSERT_OK(cmp.status()); + ASSERT_TRUE(cmp.value()); + + ASSERT_NE(first, second); + ASSERT_NE(first, iter); + ASSERT_NE(second, iter); + ASSERT_EQ(iter, repeated_str.end()); +} + +TEST(ProtoHelper, AsMessage) { + // A nested message: + // + // message Contact { + // string number = 1; + // string email = 2; + // } + // + // message Person { + // Contact info = 2; + // } + // clang-format off + std::uint8_t encoded_proto[] = { + // Person.info.number = "123456", .email = "foo@email.com" + 0x12, 0x17, + 0x0a, 0x06, '1', '2', '3', '4', '5', '6', + 0x12, 0x0d, 'f', 'o', 'o', '@', 'e', 'm', 'a', 'i', 'l', '.', 'c', 'o', 'm', + }; + // clang-format on + + constexpr uint32_t kInfoFieldNumber = 2; + constexpr uint32_t kNumberFieldNumber = 1; + constexpr uint32_t kEmailFieldNumber = 2; + + stream::MemoryReader reader(std::as_bytes(std::span(encoded_proto))); + Message parser = Message(reader, sizeof(encoded_proto)); + + Message info = parser.AsMessage(kInfoFieldNumber); + ASSERT_OK(info.status()); + + String number = info.AsString(kNumberFieldNumber); + ASSERT_OK(number.status()); + Result<bool> cmp = number.Equal("123456"); + ASSERT_OK(cmp.status()); + ASSERT_TRUE(cmp.value()); + + String email = info.AsString(kEmailFieldNumber); + ASSERT_OK(email.status()); + cmp = email.Equal("foo@email.com"); + ASSERT_OK(cmp.status()); + ASSERT_TRUE(cmp.value()); +} + +TEST(ProtoHelper, AsRepeatedMessages) { + // message Contact { + // string number = 1; + // string email = 2; + // } + // + // message Person { + // repeated Contact info = 1; + // } + // clang-format off + std::uint8_t encoded_proto[] = { + // Person.Contact.number = "12345", .email = "foo@email.com" + 0x0a, 0x16, + 0x0a, 0x05, '1', '2', '3', '4', '5', + 0x12, 0x0d, 'f', 'o', 'o', '@', 'e', 'm', 'a', 'i', 'l', '.', 'c', 'o', 'm', + + // Person.Contact.number = "67890", .email = "bar@email.com" + 0x0a, 0x16, + 0x0a, 0x05, '6', '7', '8', '9', '0', + 0x12, 0x0d, 'b', 'a', 'r', '@', 'e', 'm', 'a', 'i', 'l', '.', 'c', 'o', 'm', + }; + // clang-format on + + constexpr uint32_t kInfoFieldNumber = 1; + constexpr uint32_t kNumberFieldNumber = 1; + constexpr uint32_t kEmailFieldNumber = 2; + + stream::MemoryReader reader(std::as_bytes(std::span(encoded_proto))); + Message parser = Message(reader, sizeof(encoded_proto)); + + RepeatedMessages messages = parser.AsRepeatedMessages(kInfoFieldNumber); + ASSERT_OK(messages.status()); + + struct { + std::string_view number; + std::string_view email; + } expected[] = { + {"12345", "foo@email.com"}, + {"67890", "bar@email.com"}, + }; + + size_t count = 0; + for (Message message : messages) { + String number = message.AsString(kNumberFieldNumber); + ASSERT_OK(number.status()); + Result<bool> cmp = number.Equal(expected[count].number); + ASSERT_OK(cmp.status()); + ASSERT_TRUE(cmp.value()); + + String email = message.AsString(kEmailFieldNumber); + ASSERT_OK(email.status()); + cmp = email.Equal(expected[count].email); + ASSERT_OK(cmp.status()); + ASSERT_TRUE(cmp.value()); + + count++; + } + + ASSERT_EQ(count, static_cast<size_t>(2)); +} + +TEST(ProtoHelper, AsStringToBytesMap) { + // message Maps { + // map<string, string> map_a = 1; + // map<string, string> map_b = 2; + // } + // clang-format off + std::uint8_t encoded_proto[] = { + // map_a["key_bar"] = "bar_a", key = 1 + 0x0a, 0x10, + 0x0a, 0x07, 'k', 'e', 'y', '_', 'b', 'a', 'r', // map key + 0x12, 0x05, 'b', 'a', 'r', '_', 'a', // map value + + // map_a["key_foo"] = "foo_a", key = 1 + 0x0a, 0x10, + 0x0a, 0x07, 'k', 'e', 'y', '_', 'f', 'o', 'o', + 0x12, 0x05, 'f', 'o', 'o', '_', 'a', + + // map_b["key_foo"] = "foo_b", key = 2 + 0x12, 0x10, + 0x0a, 0x07, 'k', 'e', 'y', '_', 'f', 'o', 'o', + 0x12, 0x05, 'f', 'o', 'o', 0x5f, 0x62, + + // map_b["key_bar"] = "bar_b", key = 2 + 0x12, 0x10, + 0x0a, 0x07, 'k', 'e', 'y', '_', 'b', 'a', 'r', + 0x12, 0x05, 'b', 'a', 'r', 0x5f, 0x62, + }; + // clang-format on + + stream::MemoryReader reader(std::as_bytes(std::span(encoded_proto))); + Message parser = Message(reader, sizeof(encoded_proto)); + + { + // Parse field 'map_a' + constexpr uint32_t kFieldNumber = 1; + StringMapParser<String> string_map = + parser.AsStringToStringMap(kFieldNumber); + + String value = string_map["key_foo"]; + ASSERT_OK(value.status()); + Result<bool> cmp = value.Equal("foo_a"); + ASSERT_OK(cmp.status()); + ASSERT_TRUE(cmp.value()); + + value = string_map["key_bar"]; + ASSERT_OK(value.status()); + cmp = value.Equal("bar_a"); + ASSERT_OK(cmp.status()); + ASSERT_TRUE(cmp.value()); + + // Non-existing key + value = string_map["non-existing"]; + ASSERT_EQ(value.status(), Status::NotFound()); + } + + { + // Parse field 'map_b' + constexpr uint32_t kFieldNumber = 2; + StringMapParser<String> string_map = + parser.AsStringToStringMap(kFieldNumber); + + String value = string_map["key_foo"]; + ASSERT_OK(value.status()); + Result<bool> cmp = value.Equal("foo_b"); + ASSERT_OK(cmp.status()); + ASSERT_TRUE(cmp.value()); + + value = string_map["key_bar"]; + ASSERT_OK(value.status()); + cmp = value.Equal("bar_b"); + ASSERT_OK(cmp.status()); + ASSERT_TRUE(cmp.value()); + + // Non-existing key + value = string_map["non-existing"]; + ASSERT_EQ(value.status(), Status::NotFound()); + } +} + +TEST(ProtoHelper, AsStringToMessageMap) { + // message Contact { + // string number = 1; + // string email = 2; + // } + // + // message Contacts { + // map<string, Contact> staffs = 1; + // } + // clang-format off + std::uint8_t encoded_proto[] = { + // staffs['bar'] = {.number = '456, .email = "bar@email.com"} + 0x0a, 0x1b, + 0x0a, 0x03, 0x62, 0x61, 0x72, + 0x12, 0x14, 0x0a, 0x03, 0x34, 0x35, 0x36, 0x12, 0x0d, 0x62, 0x61, 0x72, 0x40, 0x65, 0x6d, 0x61, 0x69, 0x6c, 0x2e, 0x63, 0x6f, 0x6d, + + // staffs['foo'] = {.number = '123', .email = "foo@email.com"} + 0x0a, 0x1b, + 0x0a, 0x03, 0x66, 0x6f, 0x6f, + 0x12, 0x14, 0x0a, 0x03, 0x31, 0x32, 0x33, 0x12, 0x0d, 0x66, 0x6f, 0x6f, 0x40, 0x65, 0x6d, 0x61, 0x69, 0x6c, 0x2e, 0x63, 0x6f, 0x6d, + }; + // clang-format on + constexpr uint32_t kStaffsFieldId = 1; + constexpr uint32_t kNumberFieldId = 1; + constexpr uint32_t kEmailFieldId = 2; + + stream::MemoryReader reader(std::as_bytes(std::span(encoded_proto))); + Message parser = Message(reader, sizeof(encoded_proto)); + + StringMapParser<Message> staffs = parser.AsStringToMessageMap(kStaffsFieldId); + ASSERT_OK(staffs.status()); + + Message foo_staff = staffs["foo"]; + ASSERT_OK(foo_staff.status()); + String foo_number = foo_staff.AsString(kNumberFieldId); + ASSERT_OK(foo_number.status()); + Result<bool> foo_number_cmp = foo_number.Equal("123"); + ASSERT_OK(foo_number_cmp.status()); + ASSERT_TRUE(foo_number_cmp.value()); + String foo_email = foo_staff.AsString(kEmailFieldId); + ASSERT_OK(foo_email.status()); + Result<bool> foo_email_cmp = foo_email.Equal("foo@email.com"); + ASSERT_OK(foo_email_cmp.status()); + ASSERT_TRUE(foo_email_cmp.value()); + + Message bar_staff = staffs["bar"]; + ASSERT_OK(bar_staff.status()); + String bar_number = bar_staff.AsString(kNumberFieldId); + ASSERT_OK(bar_number.status()); + Result<bool> bar_number_cmp = bar_number.Equal("456"); + ASSERT_OK(bar_number_cmp.status()); + ASSERT_TRUE(bar_number_cmp.value()); + String bar_email = bar_staff.AsString(kEmailFieldId); + ASSERT_OK(bar_email.status()); + Result<bool> bar_email_cmp = bar_email.Equal("bar@email.com"); + ASSERT_OK(bar_email_cmp.status()); + ASSERT_TRUE(bar_email_cmp.value()); +} + +} // namespace pw::protobuf diff --git a/pw_protobuf/public/pw_protobuf/internal/proto_integer_base.h b/pw_protobuf/public/pw_protobuf/internal/proto_integer_base.h new file mode 100644 index 000000000..c50a15d52 --- /dev/null +++ b/pw_protobuf/public/pw_protobuf/internal/proto_integer_base.h @@ -0,0 +1,39 @@ +// Copyright 2021 The Pigweed Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. +// +// The header provides a set of helper utils for protobuf related operations. +// The APIs may not be finalized yet. + +#pragma once + +#include "pw_result/result.h" + +namespace pw::protobuf::internal { + +// A base class for representing parsed proto integer types or an error code +// to indicate parsing failure. +template <typename Integer> +class ProtoIntegerBase { + public: + constexpr ProtoIntegerBase(Result<Integer> value) : value_(value) {} + constexpr ProtoIntegerBase(Status status) : value_(status) {} + bool ok() { return value_.ok(); } + Status status() { return value_.status(); } + Integer value() { return value_.value(); } + + private: + Result<Integer> value_ = 0; +}; + +} // namespace pw::protobuf::internal diff --git a/pw_protobuf/public/pw_protobuf/helpers.h b/pw_protobuf/public/pw_protobuf/map_utils.h index 3c4c74f98..258a1b218 100644 --- a/pw_protobuf/public/pw_protobuf/helpers.h +++ b/pw_protobuf/public/pw_protobuf/map_utils.h @@ -18,8 +18,12 @@ #pragma once #include <cstddef> +#include <string_view> +#include "pw_assert/check.h" +#include "pw_protobuf/stream_decoder.h" #include "pw_status/status.h" +#include "pw_status/try.h" #include "pw_stream/stream.h" namespace pw::protobuf { diff --git a/pw_protobuf/public/pw_protobuf/message.h b/pw_protobuf/public/pw_protobuf/message.h new file mode 100644 index 000000000..dedf24c01 --- /dev/null +++ b/pw_protobuf/public/pw_protobuf/message.h @@ -0,0 +1,554 @@ +// Copyright 2021 The Pigweed Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. +// +// The header provides a set of helper utils for protobuf related operations. +// The APIs may not be finalized yet. + +#pragma once + +#include <cstddef> +#include <string_view> + +#include "pw_assert/check.h" +#include "pw_protobuf/internal/proto_integer_base.h" +#include "pw_protobuf/stream_decoder.h" +#include "pw_status/status.h" +#include "pw_status/try.h" +#include "pw_stream/interval_reader.h" +#include "pw_stream/stream.h" + +namespace pw::protobuf { + +// The following defines classes that represent various parsed proto integer +// types or an error code to indicate parsing failure. +// +// For normal uses, the class should be created from `class Message`. See +// comment for `class Message` for usage. + +class Uint32 : public internal::ProtoIntegerBase<uint32_t> { + public: + using ProtoIntegerBase<uint32_t>::ProtoIntegerBase; +}; + +class Int32 : public internal::ProtoIntegerBase<int32_t> { + public: + using ProtoIntegerBase<int32_t>::ProtoIntegerBase; +}; + +class Sint32 : public internal::ProtoIntegerBase<int32_t> { + public: + using ProtoIntegerBase<int32_t>::ProtoIntegerBase; +}; + +class Fixed32 : public internal::ProtoIntegerBase<uint32_t> { + public: + using ProtoIntegerBase<uint32_t>::ProtoIntegerBase; +}; + +class Sfixed32 : public internal::ProtoIntegerBase<int32_t> { + public: + using ProtoIntegerBase<int32_t>::ProtoIntegerBase; +}; + +class Uint64 : public internal::ProtoIntegerBase<uint64_t> { + public: + using ProtoIntegerBase<uint64_t>::ProtoIntegerBase; +}; + +class Int64 : public internal::ProtoIntegerBase<int64_t> { + public: + using ProtoIntegerBase<int64_t>::ProtoIntegerBase; +}; + +class Sint64 : public internal::ProtoIntegerBase<int64_t> { + public: + using ProtoIntegerBase<int64_t>::ProtoIntegerBase; +}; + +class Fixed64 : public internal::ProtoIntegerBase<uint64_t> { + public: + using ProtoIntegerBase<uint64_t>::ProtoIntegerBase; +}; + +class Sfixed64 : public internal::ProtoIntegerBase<int64_t> { + public: + using ProtoIntegerBase<int64_t>::ProtoIntegerBase; +}; + +// An object that represents a parsed `bytes` field or an error code. The +// bytes are available via an stream::IntervalReader by GetBytesReader(). +// +// For normal uses, the class should be created from `class Message`. See +// comment for `class Message` for usage. +class Bytes { + public: + Bytes() = default; + Bytes(Status status) : reader_(status) {} + Bytes(stream::IntervalReader reader) : reader_(reader) {} + const stream::IntervalReader& GetBytesReader() { return reader_; } + bool ok() { return reader_.ok(); } + Status status() { return reader_.status(); } + + // Check whether the bytes value equals the given `bytes`. + Result<bool> Equal(ConstByteSpan bytes); + + private: + stream::IntervalReader reader_; +}; + +// An object that represents a parsed `string` field or an error code. The +// string value is available via an stream::IntervalReader by +// GetBytesReader(). +// +// For normal uses, the class should be created from `class Message`. See +// comment for `class Message` for usage. +class String : public Bytes { + public: + using Bytes::Bytes; + + // Check whether the string value equals the given `str` + Result<bool> Equal(std::string_view str); +}; + +// Forward declaration of parser classes. +template <typename FieldType> +class RepeatedFieldParser; +template <typename FieldType> +class StringMapEntryParser; +template <typename FieldType> +class StringMapParser; +class Message; + +using RepeatedBytes = RepeatedFieldParser<Bytes>; +using RepeatedStrings = RepeatedFieldParser<String>; +using RepeatedMessages = RepeatedFieldParser<Message>; +using StringToBytesMapEntry = StringMapEntryParser<Bytes>; +using StringToStringMapEntry = StringMapEntryParser<String>; +using StringToMessageMapEntry = StringMapEntryParser<Message>; +using StringToBytesMap = StringMapParser<Bytes>; +using StringToStringMap = StringMapParser<String>; +using StringToMessageMap = StringMapParser<Message>; + +// Message - A helper class for parsing a proto message. +// +// Examples: +// +// message Nested { +// string nested_str = 1; +// bytes nested_bytes = 2; +// } +// +// message { +// string str = 1; +// bytes bytes = 2; +// uint32 integer = 3 +// repeated string rep_str = 4; +// map<string, bytes> str_to_bytes = 5; +// Nested nested = 6; +// } +// +// // Given a seekable `reader` that reads the top-level proto message, and +// // a <size> that gives the size of the proto message: +// Message message(reader, <size>); +// +// // Prase simple basic value fields +// String str = message.AsString(1); // string +// Bytes bytes = message.AsBytes(2); // bytes +// Uint32 integer = messasge_parser.AsUint32(3); // uint32 integer +// +// // Parse repeated field `repeated string rep_str = 4;` +// RepeatedStrings rep_str = message.AsRepeatedString(4); +// // Iterate through the entries +// for (String str : rep_str) { +// ... +// } +// +// // Parse map field `map<string, bytes> str_to_bytes = 5;` +// StringToBytesMap str_to_bytes = message.AsStringToBytesMap(5); +// +// // Access the entry by a given key value +// Bytes bytes_for_key = str_to_bytes["key"]; +// +// // Or iterate through map entries +// for (StringToBytesMapEntry entry : str_to_bytes) { +// String key = entry.Key(); +// Bytes value = entry.Value(); +// ... +// } +// +// // Parse nested message `Nested nested = 6;` +// Message nested = message.AsMessage(6). +// String nested_str = nested.AsString(1); +// Bytes nested_bytes = nested.AsBytes(2); +// +// // The `AsXXX()` methods above internally traverse all the fields to find +// // the one with the give field number. This can be expensive if called +// // multiple times. Therefore, whenever possible, it is recommended to use +// // the following iteration to iterate and process each field directly. +// for (Message::Field field : message) { +// if (field.field_number() == 1) { +// String str = field.As<String>(); +// ... +// } else if (field.field_number() == 2) { +// Bytes bytes = field.As<Bytes>(); +// ... +// } else if (field.field_number() == 6) { +// Message nested = field.As<Message>(); +// ... +// } +// } +// +// All parser objects created above internally hold the same reference +// to `reader`. Therefore it needs to maintain valid lifespan throughout the +// operations. The parser objects can work independently and without blocking +// each other. All method calls and for-iterations above are re-enterable. +class Message { + public: + class Field { + public: + uint32_t field_number() { return field_number_; } + const stream::IntervalReader& field_reader() { return field_reader_; } + + // Create a helper parser type of `FieldType` for the field. + // The default implementation below assumes the field is a length-delimited + // field. Other cases such as primitive integer uint32 will be handled by + // template specialization. + template <typename FieldType> + FieldType As() { + protobuf::StreamDecoder decoder(field_reader_.Reset()); + PW_TRY(decoder.Next()); + Result<protobuf::StreamDecoder::Bounds> payload_bounds = + decoder.GetLengthDelimitedPayloadBounds(); + PW_TRY(payload_bounds.status()); + // The bounds is relative to the given stream::IntervalReader. Convert + // it to the interval relative to the source_reader. + return FieldType(stream::IntervalReader( + field_reader_.source_reader(), + payload_bounds.value().low + field_reader_.start(), + payload_bounds.value().high + field_reader_.start())); + } + + private: + Field() = default; + Field(stream::IntervalReader reader, uint32_t field_number) + : field_reader_(reader), field_number_(field_number) {} + + stream::IntervalReader field_reader_; + uint32_t field_number_; + + friend class Message; + }; + + class iterator { + public: + iterator& operator++(); + + iterator operator++(int) { + iterator iter = *this; + this->operator++(); + return iter; + } + + Field operator*() { return current_; } + Field* operator->() { return ¤t_; } + bool operator!=(const iterator& other) const { return !(*this == other); } + + bool operator==(const iterator& other) const { + return eof_ == other.eof_ && reader_ == other.reader_; + } + + private: + stream::IntervalReader reader_; + bool eof_ = false; + Field current_; + + iterator(stream::IntervalReader reader) : reader_(reader) { + this->operator++(); + } + + friend class Message; + }; + + Message() = default; + Message(Status status) : reader_(status) {} + Message(stream::IntervalReader reader) : reader_(reader) {} + Message(stream::SeekableReader& proto_source, size_t size) + : reader_(proto_source, 0, size) {} + + // Parse a sub-field in the message given by `field_number` as bytes. + Bytes AsBytes(uint32_t field_number) { return As<Bytes>(field_number); } + + // Parse a sub-field in the message given by `field_number` as string. + String AsString(uint32_t field_number) { return As<String>(field_number); } + + // Parse a sub-field in the message given by `field_number` as one of the + // proto integer type. + Int32 AsInt32(uint32_t field_number) { return As<Int32>(field_number); } + Sint32 AsSint32(uint32_t field_number) { return As<Sint32>(field_number); } + Uint32 AsUint32(uint32_t field_number) { return As<Uint32>(field_number); } + Fixed32 AsFixed32(uint32_t field_number) { return As<Fixed32>(field_number); } + Int64 AsInt64(uint32_t field_number) { return As<Int64>(field_number); } + Sint64 AsSint64(uint32_t field_number) { return As<Sint64>(field_number); } + Uint64 AsUint64(uint32_t field_number) { return As<Uint64>(field_number); } + Fixed64 AsFixed64(uint32_t field_number) { return As<Fixed64>(field_number); } + + Sfixed32 AsSfixed32(uint32_t field_number) { + return As<Sfixed32>(field_number); + } + + Sfixed64 AsSfixed64(uint32_t field_number) { + return As<Sfixed64>(field_number); + } + + // Parse a sub-field in the message given by `field_number` as another + // message. + Message AsMessage(uint32_t field_number) { return As<Message>(field_number); } + + // Parse a sub-field in the message given by `field_number` as `repeated + // string`. + RepeatedBytes AsRepeatedBytes(uint32_t field_number); + + // Parse a sub-field in the message given by `field_number` as `repeated + // string`. + RepeatedStrings AsRepeatedStrings(uint32_t field_number); + + // Parse a sub-field in the message given by `field_number` as `repeated + // message`. + RepeatedMessages AsRepeatedMessages(uint32_t field_number); + + // Parse a sub-field in the message given by `field_number` as `map<string, + // message>`. + StringToMessageMap AsStringToMessageMap(uint32_t field_number); + + // Parse a sub-field in the message given by `field_number` as + // `map<string, bytes>`. + StringToBytesMap AsStringToBytesMap(uint32_t field_number); + + // Parse a sub-field in the message given by `field_number` as + // `map<string, string>`. + StringToStringMap AsStringToStringMap(uint32_t field_number); + + // Convert the message to a Bytes that represents the raw bytes of this + // message. This can be used to obatained the serialized wire-format of the + // message. + Bytes ToBytes() { return Bytes(reader_.Reset()); } + + bool ok() { return reader_.ok(); } + Status status() { return reader_.status(); } + iterator begin(); + iterator end(); + + // Parse a field given by `field_number` as the target parser type + // `FieldType`. + // + // Note: This method assumes that the message has only 1 field with the given + // <field_number>. It returns the first matching it find. It does not perform + // value overridding or string concatenation for multiple fields with the same + // <field_number>. + // + // Since the method needs to traverse all fields, it can be inefficient if + // called multiple times exepcially on slow reader. + template <typename FieldType> + FieldType As(uint32_t field_number) { + for (Field field : *this) { + if (field.field_number() == field_number) { + return field.As<FieldType>(); + } + } + + return FieldType(Status::NotFound()); + } + + template <typename FieldType> + RepeatedFieldParser<FieldType> AsRepeated(uint32_t field_number) { + return RepeatedFieldParser<FieldType>(*this, field_number); + } + + template <typename FieldParser> + StringMapParser<FieldParser> AsStringMap(uint32_t field_number) { + return StringMapParser<FieldParser>(*this, field_number); + } + + private: + stream::IntervalReader reader_; + + // Consume the current field. If the field has already been processed, i.e. + // by calling one of the Read..() method, nothing is done. After calling this + // method, the reader will be pointing either to the start of the next + // field (i.e. the starting offset of the field key), or the end of the + // stream. The method is for use by Message for computing field interval. + static Status ConsumeCurrentField(StreamDecoder& decoder) { + return decoder.field_consumed_ ? OkStatus() : decoder.SkipField(); + } +}; + +// The following are template specialization for proto integer types. +template <> +Uint32 Message::Field::As<Uint32>(); + +template <> +Int32 Message::Field::As<Int32>(); + +template <> +Sint32 Message::Field::As<Sint32>(); + +template <> +Fixed32 Message::Field::As<Fixed32>(); + +template <> +Sfixed32 Message::Field::As<Sfixed32>(); + +template <> +Uint64 Message::Field::As<Uint64>(); + +template <> +Int64 Message::Field::As<Int64>(); + +template <> +Sint64 Message::Field::As<Sint64>(); + +template <> +Fixed64 Message::Field::As<Fixed64>(); + +template <> +Sfixed64 Message::Field::As<Sfixed64>(); + +// A helper for parsing `repeated` field. It implements an iterator interface +// that only iterates through the fields of a given `field_number`. +// +// For normal uses, the class should be created from `class Message`. See +// comment for `class Message` for usage. +template <typename FieldType> +class RepeatedFieldParser { + public: + class iterator { + public: + // Precondition: iter_ is not pointing to the end. + iterator& operator++() { + iter_++; + MoveToNext(); + return *this; + } + + iterator operator++(int) { + iterator iter = *this; + this->operator++(); + return iter; + } + + FieldType operator*() { return current_; } + FieldType* operator->() { return ¤t_; } + bool operator!=(const iterator& other) const { return !(*this == other); } + bool operator==(const iterator& other) const { + return &host_ == &other.host_ && iter_ == other.iter_; + } + + private: + RepeatedFieldParser& host_; + Message::iterator iter_; + bool eof_ = false; + FieldType current_ = FieldType(Status::Unavailable()); + + iterator(RepeatedFieldParser& host, Message::iterator init_iter) + : host_(host), iter_(init_iter), current_(Status::Unavailable()) { + // Move to the first element of the target field number. + MoveToNext(); + } + + void MoveToNext() { + // Move the iterator to the next element with the target field number + for (; iter_ != host_.message_.end(); ++iter_) { + if (iter_->field_number() == host_.field_number_) { + current_ = iter_->As<FieldType>(); + break; + } + } + } + + friend class RepeatedFieldParser; + }; + + // `message` -- The containing message. + // `field_number` -- The field number of the repeated field. + RepeatedFieldParser(Message& message, uint32_t field_number) + : message_(message), field_number_(field_number) {} + + bool ok() { return message_.ok(); } + Status status() { return message_.status(); } + iterator begin() { return iterator(*this, message_.begin()); } + iterator end() { return iterator(*this, message_.end()); } + + private: + Message& message_; + uint32_t field_number_ = 0; +}; + +// A helper for pasring the entry type of map<string, <value>>. +// An entry for a proto map is essentially a message of a key(k=1) and +// value(k=2) field, i.e.: +// +// message Entry { +// string key = 1; +// bytes value = 2; +// } +// +// For normal uses, the class should be created from `class Message`. See +// comment for `class Message` for usage. +template <typename ValueParser> +class StringMapEntryParser { + public: + StringMapEntryParser(Status status) : entry_(status) {} + StringMapEntryParser(stream::IntervalReader reader) : entry_(reader) {} + String Key() { return entry_.AsString(kMapKeyFieldNumber); } + ValueParser Value() { return entry_.As<ValueParser>(kMapValueFieldNumber); } + + private: + static constexpr uint32_t kMapKeyFieldNumber = 1; + static constexpr uint32_t kMapValueFieldNumber = 2; + Message entry_; +}; + +// A helper class for parsing a string-keyed map field. i.e. map<string, +// <value>>. The template argument `ValueParser` indicates the type the value +// will be parsed as, i.e. String, Bytes, Uint32, Message etc. +// +// For normal uses, the class should be created from `class Message`. See +// comment for `class Message` for usage. +template <typename ValueParser> +class StringMapParser + : public RepeatedFieldParser<StringMapEntryParser<ValueParser>> { + public: + using RepeatedFieldParser< + StringMapEntryParser<ValueParser>>::RepeatedFieldParser; + + // Operator overload for value access of a given key. + ValueParser operator[](std::string_view target) { + // Iterate over all entries and find the one whose key matches `target` + for (StringMapEntryParser<ValueParser> entry : *this) { + String key = entry.Key(); + PW_TRY(key.status()); + + // Compare key value with the given string + Result<bool> cmp_res = key.Equal(target); + PW_TRY(cmp_res.status()); + if (cmp_res.value()) { + return entry.Value(); + } + } + + return ValueParser(Status::NotFound()); + } +}; + +} // namespace pw::protobuf diff --git a/pw_protobuf/public/pw_protobuf/stream_decoder.h b/pw_protobuf/public/pw_protobuf/stream_decoder.h index dffd6ffc4..41987efec 100644 --- a/pw_protobuf/public/pw_protobuf/stream_decoder.h +++ b/pw_protobuf/public/pw_protobuf/stream_decoder.h @@ -101,6 +101,7 @@ class StreamDecoder { stream_bounds_({0, std::numeric_limits<size_t>::max()}), current_field_(kInitialFieldKey), delimited_field_size_(0), + delimited_field_offset_(0), parent_(nullptr), field_consumed_(true), nested_reader_open_(false), @@ -283,6 +284,16 @@ class StreamDecoder { // See the example in GetBytesReader() above for RAII semantics and usage. StreamDecoder GetNestedDecoder(); + struct Bounds { + size_t low; + size_t high; + }; + + // Get the interval of the payload part of a length-delimited field. That is, + // the interval exluding the field key and the length prefix. The bounds are + // relative to the given reader. + Result<Bounds> GetLengthDelimitedPayloadBounds(); + private: friend class BytesReader; @@ -301,6 +312,7 @@ class StreamDecoder { stream_bounds_({low, high}), current_field_(kInitialFieldKey), delimited_field_size_(0), + delimited_field_offset_(0), parent_(parent), field_consumed_(true), nested_reader_open_(false), @@ -315,6 +327,7 @@ class StreamDecoder { stream_bounds_({0, std::numeric_limits<size_t>::max()}), current_field_(kInitialFieldKey), delimited_field_size_(0), + delimited_field_offset_(0), parent_(parent), field_consumed_(true), nested_reader_open_(false), @@ -352,13 +365,11 @@ class StreamDecoder { Status CheckOkToRead(WireType type); stream::SeekableReader& reader_; - struct { - size_t low; - size_t high; - } stream_bounds_; + Bounds stream_bounds_; FieldKey current_field_; size_t delimited_field_size_; + size_t delimited_field_offset_; StreamDecoder* parent_; @@ -366,6 +377,8 @@ class StreamDecoder { bool nested_reader_open_; Status status_; + + friend class Message; }; } // namespace pw::protobuf diff --git a/pw_protobuf/stream_decoder.cc b/pw_protobuf/stream_decoder.cc index f9a8cefcb..6670dcf02 100644 --- a/pw_protobuf/stream_decoder.cc +++ b/pw_protobuf/stream_decoder.cc @@ -241,12 +241,19 @@ Status StreamDecoder::ReadFieldKey() { } delimited_field_size_ = varint; + delimited_field_offset_ = reader_.Tell(); } field_consumed_ = false; return OkStatus(); } +Result<StreamDecoder::Bounds> StreamDecoder::GetLengthDelimitedPayloadBounds() { + PW_TRY(CheckOkToRead(WireType::kDelimited)); + return StreamDecoder::Bounds{delimited_field_offset_, + delimited_field_size_ + delimited_field_offset_}; +} + // Consumes the current protobuf field, advancing the stream to the key of the // next field (if one exists). Status StreamDecoder::SkipField() { diff --git a/pw_protobuf/stream_decoder_test.cc b/pw_protobuf/stream_decoder_test.cc index 8a93306a0..6f4804479 100644 --- a/pw_protobuf/stream_decoder_test.cc +++ b/pw_protobuf/stream_decoder_test.cc @@ -514,5 +514,37 @@ TEST(StreamDecoder, Decode_BytesReader_InvalidField) { EXPECT_EQ(decoder.Next(), Status::DataLoss()); } +TEST(StreamDecoder, GetLengthDelimitedPayloadBounds) { + // clang-format off + constexpr uint8_t encoded_proto[] = { + // bytes key=1, length=14 + 0x0a, 0x0e, + + 0x00, 0x01, 0x02, 0x03, + 0x04, 0x05, 0x06, 0x07, + 0x08, 0x09, 0x0a, 0x0b, + 0x0c, 0x0d, + // End bytes + + // type=sint32, k=2, v=-13 + 0x10, 0x19, + }; + // clang-format on + + stream::MemoryReader reader(std::as_bytes(std::span(encoded_proto))); + StreamDecoder decoder(reader); + + ASSERT_EQ(OkStatus(), decoder.Next()); + Result<StreamDecoder::Bounds> field_bound = + decoder.GetLengthDelimitedPayloadBounds(); + ASSERT_EQ(OkStatus(), field_bound.status()); + ASSERT_EQ(field_bound.value().low, 2ULL); + ASSERT_EQ(field_bound.value().high, 16ULL); + + ASSERT_EQ(OkStatus(), decoder.Next()); + ASSERT_EQ(Status::NotFound(), + decoder.GetLengthDelimitedPayloadBounds().status()); +} + } // namespace } // namespace pw::protobuf diff --git a/pw_stream/BUILD.bazel b/pw_stream/BUILD.bazel index 938103643..a6ce477fe 100644 --- a/pw_stream/BUILD.bazel +++ b/pw_stream/BUILD.bazel @@ -70,6 +70,13 @@ pw_cc_library( deps = [":pw_stream"], ) +pw_cc_library( + name = "interval_reader", + srcs = ["interval_reader.cc"], + hdrs = ["public/pw_stream/interval_reader.h"], + deps = [":pw_stream"], +) + pw_cc_test( name = "memory_stream_test", srcs = ["memory_stream_test.cc"], @@ -96,3 +103,12 @@ pw_cc_test( "//pw_unit_test", ], ) + +pw_cc_test( + name = "interval_reader_test", + srcs = ["interval_reader_test.cc"], + deps = [ + ":interval_reader", + "//pw_unit_test", + ], +) diff --git a/pw_stream/BUILD.gn b/pw_stream/BUILD.gn index 5e34c4b5d..1f6a9da06 100644 --- a/pw_stream/BUILD.gn +++ b/pw_stream/BUILD.gn @@ -65,12 +65,24 @@ pw_source_set("std_file_stream") { sources = [ "std_file_stream.cc" ] } +pw_source_set("interval_reader") { + public_configs = [ ":public_include_path" ] + public_deps = [ + ":pw_stream", + dir_pw_assert, + dir_pw_status, + ] + public = [ "public/pw_stream/interval_reader.h" ] + sources = [ "interval_reader.cc" ] +} + pw_doc_group("docs") { sources = [ "docs.rst" ] } pw_test_group("tests") { tests = [ + ":interval_reader_test", ":memory_stream_test", ":seek_test", ":stream_test", @@ -91,3 +103,8 @@ pw_test("stream_test") { sources = [ "stream_test.cc" ] deps = [ ":pw_stream" ] } + +pw_test("interval_reader_test") { + sources = [ "interval_reader_test.cc" ] + deps = [ ":interval_reader" ] +} diff --git a/pw_stream/interval_reader.cc b/pw_stream/interval_reader.cc new file mode 100644 index 000000000..06abf6e4c --- /dev/null +++ b/pw_stream/interval_reader.cc @@ -0,0 +1,83 @@ +// Copyright 2020 The Pigweed Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. + +#include "pw_stream/interval_reader.h" + +#include "pw_assert/check.h" + +namespace pw::stream { + +void IntervalReader::Check() { + PW_CHECK(ok(), "internal::IntervalReader is in an invalid state"); + PW_CHECK_NOTNULL(source_reader_); +} + +StatusWithSize IntervalReader::DoRead(ByteSpan destination) { + Check(); + if (current_ == end_) { + return StatusWithSize::OutOfRange(); + } + + // Seek the source reader to the `current_` offset of this IntervalReader + // before reading. + Status status = source_reader_->Seek(current_, Whence::kBeginning); + if (!status.ok()) { + return StatusWithSize(status, 0); + } + + size_t to_read = std::min(destination.size(), end_ - current_); + Result<ByteSpan> res = source_reader_->Read(destination.first(to_read)); + if (!res.ok()) { + return StatusWithSize(res.status(), 0); + } + + current_ += res.value().size(); + return StatusWithSize(res.value().size()); +} + +Status IntervalReader::DoSeek(ssize_t offset, Whence origin) { + Check(); + + ssize_t absolute_position = std::numeric_limits<size_t>::max(); + + // Convert from the position within the interval to the position within the + // source reader stream. + switch (origin) { + case Whence::kBeginning: + absolute_position = offset + start_; + break; + + case Whence::kCurrent: + absolute_position = current_ + offset; + break; + + case Whence::kEnd: + absolute_position = end_ + offset; + break; + } + + if (absolute_position < 0) { + return Status::InvalidArgument(); + } + + if (static_cast<size_t>(absolute_position) < start_ || + static_cast<size_t>(absolute_position) > end_) { + return Status::InvalidArgument(); + } + + current_ = absolute_position; + return OkStatus(); +} + +}; // namespace pw::stream diff --git a/pw_stream/interval_reader_test.cc b/pw_stream/interval_reader_test.cc new file mode 100644 index 000000000..663e45ba8 --- /dev/null +++ b/pw_stream/interval_reader_test.cc @@ -0,0 +1,89 @@ +// Copyright 2020 The Pigweed Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. + +#include "pw_stream/interval_reader.h" + +#include "gtest/gtest.h" +#include "pw_result/result.h" +#include "pw_stream/memory_stream.h" + +namespace pw::stream { +namespace { + +TEST(IntervalReader, IntervalReaderRead) { + std::uint8_t data[] = {0, 1, 2, 3, 4, 5, 6, 7, 9, 10}; + stream::MemoryReader reader(std::as_bytes(std::span(data))); + IntervalReader reader_first_half(reader, 0, 5); + IntervalReader reader_second_half(reader, 5, 10); + + // Read second half + std::byte read_buf[5]; + Result<ByteSpan> res = reader_second_half.Read(read_buf); + ASSERT_EQ(res.status(), OkStatus()); + ASSERT_EQ(res.value().size(), sizeof(read_buf)); + ASSERT_EQ(memcmp(read_buf, data + 5, 5), 0); + ASSERT_EQ(reader_second_half.Read(read_buf).status(), Status::OutOfRange()); + + // Read first half. They should be independent and do not interfere each + // other. + res = reader_first_half.Read(read_buf); + ASSERT_EQ(res.status(), OkStatus()); + ASSERT_EQ(res.value().size(), sizeof(read_buf)); + ASSERT_EQ(memcmp(read_buf, data, 5), 0); + ASSERT_EQ(reader_first_half.Read(read_buf).status(), Status::OutOfRange()); + + // Reset the cursor and the reader should read from the beginning. + res = reader_second_half.Reset().Read(read_buf); + ASSERT_EQ(res.status(), OkStatus()); + ASSERT_EQ(res.value().size(), sizeof(read_buf)); + ASSERT_EQ(memcmp(read_buf, data + 5, 5), 0); + ASSERT_EQ(reader_second_half.Read(read_buf).status(), Status::OutOfRange()); +} + +TEST(IntervalReader, IntervalReaderSeek) { + std::uint8_t data[] = {0, 1, 2, 3, 4, 5, 6, 7, 9, 10}; + stream::MemoryReader reader(std::as_bytes(std::span(data))); + IntervalReader interval_reader(reader, 0, 10); + + // Absolute seeking. + std::byte read_buf[5]; + ASSERT_EQ(interval_reader.Seek(5), OkStatus()); + Result<ByteSpan> res = interval_reader.Read(read_buf); + ASSERT_EQ(res.status(), OkStatus()); + ASSERT_EQ(res.value().size(), sizeof(read_buf)); + ASSERT_EQ(memcmp(read_buf, data + 5, 5), 0); + ASSERT_EQ(interval_reader.Read(read_buf).status(), Status::OutOfRange()); + + // Relative seek. + ASSERT_EQ(interval_reader.Seek(-10, stream::Stream::kCurrent), OkStatus()); + res = interval_reader.Read(read_buf); + ASSERT_EQ(res.status(), OkStatus()); + ASSERT_EQ(res.value().size(), sizeof(read_buf)); + ASSERT_EQ(memcmp(read_buf, data, 5), 0); + + // Seeking from the end. + ASSERT_EQ(interval_reader.Seek(-5, stream::Stream::kEnd), OkStatus()); + res = interval_reader.Read(read_buf); + ASSERT_EQ(res.status(), OkStatus()); + ASSERT_EQ(res.value().size(), sizeof(read_buf)); + ASSERT_EQ(memcmp(read_buf, data + 5, 5), 0); + ASSERT_EQ(interval_reader.Read(read_buf).status(), Status::OutOfRange()); + + // Seeking to the end is allowed + ASSERT_EQ(interval_reader.Seek(0, stream::Stream::kEnd), OkStatus()); + ASSERT_EQ(interval_reader.Read(read_buf).status(), Status::OutOfRange()); +} + +} // namespace +} // namespace pw::stream diff --git a/pw_stream/public/pw_stream/interval_reader.h b/pw_stream/public/pw_stream/interval_reader.h new file mode 100644 index 000000000..c1ef2b089 --- /dev/null +++ b/pw_stream/public/pw_stream/interval_reader.h @@ -0,0 +1,93 @@ +// Copyright 2021 The Pigweed Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. +// +// The header provides a set of helper utils for protobuf related operations. +// The APIs may not be finalized yet. + +#pragma once + +#include <cstddef> +#include <string_view> + +#include "pw_assert/assert.h" +#include "pw_status/status.h" +#include "pw_stream/stream.h" + +namespace pw::stream { + +// A reader wrapper that reads from a sub-interval of a given seekable +// source reader. The IntervalReader tracks and maintains its own read offset. +// It seeks the source reader to its current read offset before reading. In +// this way, multiple IntervalReader can share the same source reader without +// interfereing each other. +// +// The reader additionally embedds a `Status` to indicate whether itself +// is valid. This is a workaround for Reader not being compatibile with +// Result<>. +class IntervalReader : public SeekableReader { + public: + constexpr IntervalReader() : status_(Status::Unavailable()) {} + + // Create an IntervalReader with an error status. + constexpr IntervalReader(Status status) : status_(status) { + PW_ASSERT(!status.ok()); + } + + // source_reader -- The source reader to read from. + // start -- starting offset to read in `source_reader` + // end -- ending offset in `source_reader`. + constexpr IntervalReader(SeekableReader& source_reader, + size_t start, + size_t end) + : source_reader_(&source_reader), + start_(start), + end_(end), + current_(start) {} + + // Reset the read offset to the start of the interval + IntervalReader& Reset() { + Check(); + current_ = start_; + return *this; + } + + // Get a reference to the source reader. + SeekableReader& source_reader() { return *source_reader_; } + size_t start() const { return start_; } + size_t end() const { return end_; } + size_t current() const { return current_; } + size_t interval_size() const { return end_ - start_; } + bool ok() const { return status_.ok(); } + Status status() const { return status_; } + + // For iterator comparison in Message. + bool operator==(const IntervalReader& other) const { + return source_reader_ == other.source_reader_ && start_ == other.start_ && + end_ == other.end_ && current_ == other.current_; + } + + private: + StatusWithSize DoRead(ByteSpan destination) final; + Status DoSeek(ssize_t offset, Whence origin) final; + size_t DoTell() const final { return current_ - start_; } + void Check(); + + SeekableReader* source_reader_ = nullptr; + size_t start_ = 0; + size_t end_ = 0; + size_t current_ = 0; + Status status_ = OkStatus(); +}; + +} // namespace pw::stream |