diff options
author | Alexei Frolov <frolv@google.com> | 2020-10-07 10:02:00 -0700 |
---|---|---|
committer | CQ Bot Account <pigweed-scoped@luci-project-accounts.iam.gserviceaccount.com> | 2020-10-21 17:47:44 +0000 |
commit | b70597117a2080b478ab959c1754889ff7345ec0 (patch) | |
tree | 2bffa55915e2ce7f749f606a99b14d45b12664eb /pw_rpc | |
parent | 62de81b268e81929b3fc5e5b38b24228d7332bf9 (diff) | |
download | pigweed-b70597117a2080b478ab959c1754889ff7345ec0.tar.gz |
pw_rpc: Raw method implementation
This change adds a new RPC method implementation which calls methods
with raw binary protobuf data. The structure largely follows that of
nanopb methods, with a similar API for generated code.
Change-Id: Ia3284f62a21b4c8c467109c9577b67bef1fc1cce
Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/20120
Commit-Queue: Alexei Frolov <frolv@google.com>
Reviewed-by: Wyatt Hepler <hepler@google.com>
Diffstat (limited to 'pw_rpc')
-rw-r--r-- | pw_rpc/BUILD.gn | 5 | ||||
-rw-r--r-- | pw_rpc/nanopb/public/pw_rpc/internal/nanopb_method.h | 4 | ||||
-rw-r--r-- | pw_rpc/public/pw_rpc/internal/base_server_writer.h | 2 | ||||
-rw-r--r-- | pw_rpc/public/pw_rpc/internal/channel.h | 5 | ||||
-rw-r--r-- | pw_rpc/raw/BUILD | 49 | ||||
-rw-r--r-- | pw_rpc/raw/BUILD.gn | 49 | ||||
-rw-r--r-- | pw_rpc/raw/public/pw_rpc/internal/raw_method.h | 105 | ||||
-rw-r--r-- | pw_rpc/raw/raw_method.cc | 68 | ||||
-rw-r--r-- | pw_rpc/raw/raw_method_test.cc | 215 |
9 files changed, 499 insertions, 3 deletions
diff --git a/pw_rpc/BUILD.gn b/pw_rpc/BUILD.gn index ef1884f37..8c6779722 100644 --- a/pw_rpc/BUILD.gn +++ b/pw_rpc/BUILD.gn @@ -166,7 +166,10 @@ pw_test_group("tests") { ":server_test", ":service_test", ] - group_deps = [ "nanopb:tests" ] + group_deps = [ + "nanopb:tests", + "raw:tests", + ] } pw_proto_library("test_protos") { diff --git a/pw_rpc/nanopb/public/pw_rpc/internal/nanopb_method.h b/pw_rpc/nanopb/public/pw_rpc/internal/nanopb_method.h index 3a97cff08..52b8dc5d3 100644 --- a/pw_rpc/nanopb/public/pw_rpc/internal/nanopb_method.h +++ b/pw_rpc/nanopb/public/pw_rpc/internal/nanopb_method.h @@ -164,10 +164,10 @@ class NanopbMethod : public Method { id, ServerStreamingInvoker<AllocateSpaceFor<Request<method>>()>, {.server_streaming = - [](ServerCall& call, const void* req, BaseServerWriter& resp) { + [](ServerCall& call, const void* req, BaseServerWriter& writer) { method(call, *static_cast<const Request<method>*>(req), - static_cast<ServerWriter<Response<method>>&>(resp)); + static_cast<ServerWriter<Response<method>>&>(writer)); }}, request, response); diff --git a/pw_rpc/public/pw_rpc/internal/base_server_writer.h b/pw_rpc/public/pw_rpc/internal/base_server_writer.h index 5c7e4fd21..ef36d38d2 100644 --- a/pw_rpc/public/pw_rpc/internal/base_server_writer.h +++ b/pw_rpc/public/pw_rpc/internal/base_server_writer.h @@ -66,6 +66,8 @@ class BaseServerWriter : public IntrusiveList<BaseServerWriter>::Item { const Channel& channel() const { return call_.channel(); } + constexpr const Channel::OutputBuffer& buffer() const { return response_; } + std::span<std::byte> AcquirePayloadBuffer(); Status ReleasePayloadBuffer(std::span<const std::byte> payload); diff --git a/pw_rpc/public/pw_rpc/internal/channel.h b/pw_rpc/public/pw_rpc/internal/channel.h index ba03f071f..870e42595 100644 --- a/pw_rpc/public/pw_rpc/internal/channel.h +++ b/pw_rpc/public/pw_rpc/internal/channel.h @@ -51,6 +51,11 @@ class Channel : public rpc::Channel { // Returns a portion of this OutputBuffer to use as the packet payload. std::span<std::byte> payload(const Packet& packet) const; + bool Contains(std::span<const std::byte> buffer) const { + return buffer.data() >= buffer_.data() && + buffer.data() + buffer.size() <= buffer_.data() + buffer_.size(); + } + private: friend class Channel; diff --git a/pw_rpc/raw/BUILD b/pw_rpc/raw/BUILD new file mode 100644 index 000000000..0d53511c6 --- /dev/null +++ b/pw_rpc/raw/BUILD @@ -0,0 +1,49 @@ +# Copyright 2020 The Pigweed Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not +# use this file except in compliance with the License. You may obtain a copy of +# the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations under +# the License. + +load( + "//pw_build:pigweed.bzl", + "pw_cc_library", + "pw_cc_test", +) + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache License 2.0 + +pw_cc_library( + name = "method", + srcs = [ + "raw_method.cc", + ], + hdrs = [ + "public/pw_rpc/internal/raw_method.h", + ], + deps = [ + "//pw_bytes", + "//pw_rpc:server", + ] +) + +pw_cc_test( + name = "raw_method_test", + srcs = [ + "raw_method_test.cc", + ], + deps = [ + ":method", + "//pw_protobuf", + "//pw_rpc:internal_test_utils", + ], +) diff --git a/pw_rpc/raw/BUILD.gn b/pw_rpc/raw/BUILD.gn new file mode 100644 index 000000000..1bf926e91 --- /dev/null +++ b/pw_rpc/raw/BUILD.gn @@ -0,0 +1,49 @@ +# Copyright 2020 The Pigweed Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not +# use this file except in compliance with the License. You may obtain a copy of +# the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations under +# the License. + +# gn-format disable +import("//build_overrides/pigweed.gni") + +import("$dir_pw_build/target_types.gni") +import("$dir_pw_docgen/docs.gni") +import("$dir_pw_unit_test/test.gni") +config("public") { + include_dirs = [ "public" ] + visibility = [ ":*" ] +} + +pw_source_set("method") { + public_configs = [ ":public" ] + public = [ "public/pw_rpc/internal/raw_method.h" ] + sources = [ "raw_method.cc" ] + public_deps = [ + "..:server", + dir_pw_bytes, + ] + deps = [ dir_pw_log ] +} + +pw_test_group("tests") { + tests = [ ":raw_method_test" ] +} + +pw_test("raw_method_test") { + deps = [ + ":method", + "..:test_protos_pwpb", + "..:test_utils", + dir_pw_protobuf, + ] + sources = [ "raw_method_test.cc" ] +} diff --git a/pw_rpc/raw/public/pw_rpc/internal/raw_method.h b/pw_rpc/raw/public/pw_rpc/internal/raw_method.h new file mode 100644 index 000000000..e0a514ff7 --- /dev/null +++ b/pw_rpc/raw/public/pw_rpc/internal/raw_method.h @@ -0,0 +1,105 @@ +// Copyright 2020 The Pigweed Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. +#pragma once + +#include "pw_bytes/span.h" +#include "pw_rpc/internal/base_server_writer.h" +#include "pw_rpc/internal/method.h" +#include "pw_rpc/internal/method_type.h" +#include "pw_status/status_with_size.h" + +namespace pw::rpc { + +class RawServerWriter : public internal::BaseServerWriter { + public: + RawServerWriter() = default; + + // Returns a buffer in which a response payload can be built. + ByteSpan PayloadBuffer() { return AcquirePayloadBuffer(); } + + // Sends a response packet with the given raw payload. The payload can either + // be in the buffer previously acquired from PayloadBuffer(), or an arbitrary + // external buffer. + Status Write(ConstByteSpan response); +}; + +namespace internal { + +// A RawMethod is a method invoker which does not perform any automatic protobuf +// serialization or deserialization. The implementer is given the raw binary +// payload of incoming requests, and is responsible for encoding responses to a +// provided buffer. This is intended for use in methods which would have large +// protobuf data structure overhead to lower stack usage, or in methods packing +// responses up to a channel's MTU. +class RawMethod : public Method { + public: + template <auto method> + constexpr static RawMethod Unary(uint32_t id) { + return RawMethod( + id, + UnaryInvoker, + {.unary = [](ServerCall& call, ConstByteSpan req, ByteSpan res) { + return method(call, req, res); + }}); + } + + template <auto method> + constexpr static RawMethod ServerStreaming(uint32_t id) { + return RawMethod(id, + ServerStreamingInvoker, + Function{.server_streaming = [](ServerCall& call, + ConstByteSpan req, + BaseServerWriter& writer) { + method(call, req, static_cast<RawServerWriter&>(writer)); + }}); + } + + private: + using UnaryFunction = StatusWithSize (*)(ServerCall&, + ConstByteSpan, + ByteSpan); + + using ServerStreamingFunction = void (*)(ServerCall&, + ConstByteSpan, + BaseServerWriter&); + union Function { + UnaryFunction unary; + ServerStreamingFunction server_streaming; + // TODO(frolv): Support client and bidirectional streaming. + }; + + constexpr RawMethod(uint32_t id, Invoker invoker, Function function) + : Method(id, invoker), function_(function) {} + + static void UnaryInvoker(const Method& method, + ServerCall& call, + const Packet& request) { + static_cast<const RawMethod&>(method).CallUnary(call, request); + } + + static void ServerStreamingInvoker(const Method& method, + ServerCall& call, + const Packet& request) { + static_cast<const RawMethod&>(method).CallServerStreaming(call, request); + } + + void CallUnary(ServerCall& call, const Packet& request) const; + void CallServerStreaming(ServerCall& call, const Packet& request) const; + + // Stores the user-defined RPC in a generic wrapper. + Function function_; +}; + +} // namespace internal +} // namespace pw::rpc diff --git a/pw_rpc/raw/raw_method.cc b/pw_rpc/raw/raw_method.cc new file mode 100644 index 000000000..1b2075f63 --- /dev/null +++ b/pw_rpc/raw/raw_method.cc @@ -0,0 +1,68 @@ +// Copyright 2020 The Pigweed Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. + +#include "pw_rpc/internal/raw_method.h" + +#include <cstring> + +#include "pw_log/log.h" +#include "pw_rpc/internal/packet.h" + +namespace pw::rpc { + +Status RawServerWriter::Write(ConstByteSpan response) { + if (buffer().Contains(response)) { + return ReleasePayloadBuffer(response); + } + + std::span<std::byte> buffer = AcquirePayloadBuffer(); + + if (response.size() > buffer.size()) { + ReleasePayloadBuffer({}); + return Status::OutOfRange(); + } + + std::memcpy(buffer.data(), response.data(), response.size()); + return ReleasePayloadBuffer(buffer.first(response.size())); +} + +namespace internal { + +void RawMethod::CallUnary(ServerCall& call, const Packet& request) const { + Channel::OutputBuffer response_buffer = call.channel().AcquireBuffer(); + std::span payload_buffer = response_buffer.payload(request); + + StatusWithSize sws = function_.unary(call, request.payload(), payload_buffer); + Packet response = Packet::Response(request); + + response.set_payload(payload_buffer.first(sws.size())); + response.set_status(sws.status()); + if (call.channel().Send(response_buffer, response).ok()) { + return; + } + + PW_LOG_WARN("Failed to send response packet for channel %u", + unsigned(call.channel().id())); + call.channel().Send(response_buffer, + Packet::ServerError(request, Status::Internal())); +} + +void RawMethod::CallServerStreaming(ServerCall& call, + const Packet& request) const { + internal::BaseServerWriter server_writer(call); + function_.server_streaming(call, request.payload(), server_writer); +} + +} // namespace internal +} // namespace pw::rpc diff --git a/pw_rpc/raw/raw_method_test.cc b/pw_rpc/raw/raw_method_test.cc new file mode 100644 index 000000000..a96f931ed --- /dev/null +++ b/pw_rpc/raw/raw_method_test.cc @@ -0,0 +1,215 @@ +// Copyright 2020 The Pigweed Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. + +#include "pw_rpc/internal/raw_method.h" + +#include <array> + +#include "gtest/gtest.h" +#include "pw_bytes/array.h" +#include "pw_protobuf/decoder.h" +#include "pw_protobuf/encoder.h" +#include "pw_rpc/server_context.h" +#include "pw_rpc/service.h" +#include "pw_rpc_private/internal_test_utils.h" +#include "pw_rpc_test_protos/test.pwpb.h" + +namespace pw::rpc::internal { +namespace { + +template <typename Implementation> +class FakeGeneratedService : public Service { + public: + constexpr FakeGeneratedService(uint32_t id) : Service(id, kMethods) {} + + static StatusWithSize Invoke_DoNothing(ServerCall& call, + ConstByteSpan request, + ByteSpan response) { + return static_cast<Implementation&>(call.service()) + .DoNothing(call.context(), request, response); + } + + static StatusWithSize Invoke_AddFive(ServerCall& call, + ConstByteSpan request, + ByteSpan response) { + return static_cast<Implementation&>(call.service()) + .AddFive(call.context(), request, response); + } + + static void Invoke_StartStream(ServerCall& call, + ConstByteSpan request, + RawServerWriter& writer) { + static_cast<Implementation&>(call.service()) + .StartStream(call.context(), request, writer); + } + + static constexpr std::array<RawMethod, 3> kMethods = { + RawMethod::Unary<Invoke_DoNothing>(10u), + RawMethod::Unary<Invoke_AddFive>(11u), + RawMethod::ServerStreaming<Invoke_StartStream>(12u), + }; +}; + +struct { + int64_t integer; + uint32_t status_code; +} last_request; +RawServerWriter last_writer; + +class FakeGeneratedServiceImpl + : public FakeGeneratedService<FakeGeneratedServiceImpl> { + public: + FakeGeneratedServiceImpl(uint32_t id) : FakeGeneratedService(id) {} + + StatusWithSize DoNothing(ServerContext&, ConstByteSpan, ByteSpan) { + return StatusWithSize::Unknown(); + } + + StatusWithSize AddFive(ServerContext&, + ConstByteSpan request, + ByteSpan response) { + DecodeRawTestRequest(request); + + protobuf::NestedEncoder encoder(response); + test::TestResponse::Encoder test_response(&encoder); + test_response.WriteValue(last_request.integer + 5); + ConstByteSpan payload; + encoder.Encode(&payload); + + return StatusWithSize::Unauthenticated(payload.size()); + } + + void StartStream(ServerContext&, + ConstByteSpan request, + RawServerWriter& writer) { + DecodeRawTestRequest(request); + last_writer = std::move(writer); + } + + private: + void DecodeRawTestRequest(ConstByteSpan request) { + protobuf::Decoder decoder(request); + + while (decoder.Next().ok()) { + test::TestRequest::Fields field = + static_cast<test::TestRequest::Fields>(decoder.FieldNumber()); + + switch (field) { + case test::TestRequest::Fields::INTEGER: + decoder.ReadInt64(&last_request.integer); + break; + case test::TestRequest::Fields::STATUS_CODE: + decoder.ReadUint32(&last_request.status_code); + break; + } + } + } +}; + +TEST(RawMethod, UnaryRpc_SendsResponse) { + std::byte buffer[16]; + protobuf::NestedEncoder encoder(buffer); + test::TestRequest::Encoder test_request(&encoder); + test_request.WriteInteger(456); + test_request.WriteStatusCode(7); + + const RawMethod& method = std::get<1>(FakeGeneratedServiceImpl::kMethods); + ServerContextForTest<FakeGeneratedServiceImpl> context(method); + method.Invoke(context.get(), context.packet(encoder.Encode().value())); + + EXPECT_EQ(last_request.integer, 456); + EXPECT_EQ(last_request.status_code, 7u); + + const Packet& response = context.output().sent_packet(); + EXPECT_EQ(response.status(), Status::Unauthenticated()); + + protobuf::Decoder decoder(response.payload()); + ASSERT_TRUE(decoder.Next().ok()); + int64_t value; + EXPECT_EQ(decoder.ReadInt64(&value), Status::Ok()); + EXPECT_EQ(value, 461); +} + +TEST(RawMethod, ServerStreamingRpc_SendsNothingWhenInitiallyCalled) { + std::byte buffer[16]; + protobuf::NestedEncoder encoder(buffer); + test::TestRequest::Encoder test_request(&encoder); + test_request.WriteInteger(777); + test_request.WriteStatusCode(2); + + const RawMethod& method = std::get<2>(FakeGeneratedServiceImpl::kMethods); + ServerContextForTest<FakeGeneratedServiceImpl> context(method); + + method.Invoke(context.get(), context.packet(encoder.Encode().value())); + + EXPECT_EQ(0u, context.output().packet_count()); + EXPECT_EQ(777, last_request.integer); + EXPECT_EQ(2u, last_request.status_code); + EXPECT_TRUE(last_writer.open()); + last_writer.Finish(); +} + +TEST(RawServerWriter, Write_SendsPreviouslyAcquiredBuffer) { + const RawMethod& method = std::get<2>(FakeGeneratedServiceImpl::kMethods); + ServerContextForTest<FakeGeneratedServiceImpl> context(method); + + method.Invoke(context.get(), context.packet({})); + + auto buffer = last_writer.PayloadBuffer(); + + constexpr auto data = bytes::Array<0x0d, 0x06, 0xf0, 0x0d>(); + std::memcpy(buffer.data(), data.data(), data.size()); + + EXPECT_EQ(last_writer.Write(buffer.first(data.size())), Status::Ok()); + + const internal::Packet& packet = context.output().sent_packet(); + EXPECT_EQ(packet.type(), internal::PacketType::RESPONSE); + EXPECT_EQ(packet.channel_id(), context.kChannelId); + EXPECT_EQ(packet.service_id(), context.kServiceId); + EXPECT_EQ(packet.method_id(), context.get().method().id()); + EXPECT_EQ(std::memcmp(packet.payload().data(), data.data(), data.size()), 0); + EXPECT_EQ(packet.status(), Status::Ok()); +} + +TEST(RawServerWriter, Write_SendsExternalBuffer) { + const RawMethod& method = std::get<2>(FakeGeneratedServiceImpl::kMethods); + ServerContextForTest<FakeGeneratedServiceImpl> context(method); + + method.Invoke(context.get(), context.packet({})); + + constexpr auto data = bytes::Array<0x0d, 0x06, 0xf0, 0x0d>(); + EXPECT_EQ(last_writer.Write(data), Status::Ok()); + + const internal::Packet& packet = context.output().sent_packet(); + EXPECT_EQ(packet.type(), internal::PacketType::RESPONSE); + EXPECT_EQ(packet.channel_id(), context.kChannelId); + EXPECT_EQ(packet.service_id(), context.kServiceId); + EXPECT_EQ(packet.method_id(), context.get().method().id()); + EXPECT_EQ(std::memcmp(packet.payload().data(), data.data(), data.size()), 0); + EXPECT_EQ(packet.status(), Status::Ok()); +} + +TEST(RawServerWriter, Write_BufferTooSmall_ReturnsOutOfRange) { + const RawMethod& method = std::get<2>(FakeGeneratedServiceImpl::kMethods); + ServerContextForTest<FakeGeneratedServiceImpl, 16> context(method); + + method.Invoke(context.get(), context.packet({})); + + constexpr auto data = + bytes::Array<0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16>(); + EXPECT_EQ(last_writer.Write(data), Status::OutOfRange()); +} + +} // namespace +} // namespace pw::rpc::internal |