diff options
author | Alexei Frolov <frolv@google.com> | 2020-09-28 16:23:06 -0700 |
---|---|---|
committer | CQ Bot Account <pigweed-scoped@luci-project-accounts.iam.gserviceaccount.com> | 2020-10-02 16:21:48 +0000 |
commit | 062ed18abbf7d2555c84bd6ffe2399c207a250ed (patch) | |
tree | 01b28871736d3462c8a1255c764c5bbe37813548 /pw_rpc | |
parent | da651f169f97f85b63c675b5f5479a7868179da1 (diff) | |
download | pigweed-062ed18abbf7d2555c84bd6ffe2399c207a250ed.tar.gz |
pw_rpc: Send CLIENT_ERROR on unexpected packet
This updates the RPC client to send back a CLIENT_ERROR to the server if
it receives a packet it was not expecting. The server is updated to
handle client errors by ending the active RPC, if applicable.
Change-Id: Ia9a2f36571fd3f91e28a50393c531df66ba11f09
Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/19180
Commit-Queue: Alexei Frolov <frolv@google.com>
Reviewed-by: Wyatt Hepler <hepler@google.com>
Diffstat (limited to 'pw_rpc')
-rw-r--r-- | pw_rpc/base_server_writer.cc | 12 | ||||
-rw-r--r-- | pw_rpc/client.cc | 10 | ||||
-rw-r--r-- | pw_rpc/client_test.cc | 12 | ||||
-rw-r--r-- | pw_rpc/public/pw_rpc/internal/base_server_writer.h | 13 | ||||
-rw-r--r-- | pw_rpc/public/pw_rpc/internal/packet.h | 15 | ||||
-rw-r--r-- | pw_rpc/public/pw_rpc/server.h | 1 | ||||
-rw-r--r-- | pw_rpc/server.cc | 18 | ||||
-rw-r--r-- | pw_rpc/server_test.cc | 10 |
8 files changed, 82 insertions, 9 deletions
diff --git a/pw_rpc/base_server_writer.cc b/pw_rpc/base_server_writer.cc index 1e3c39680..942cf326e 100644 --- a/pw_rpc/base_server_writer.cc +++ b/pw_rpc/base_server_writer.cc @@ -50,8 +50,7 @@ void BaseServerWriter::Finish(Status status) { return; } - call_.server().RemoveWriter(*this); - state_ = kClosed; + Close(); // Send a control packet indicating that the stream (and RPC) has terminated. call_.channel().Send(Packet(PacketType::SERVER_STREAM_END, @@ -79,6 +78,15 @@ Status BaseServerWriter::ReleasePayloadBuffer( return call_.channel().Send(response_, ResponsePacket(payload)); } +void BaseServerWriter::Close() { + if (!open()) { + return; + } + + call_.server().RemoveWriter(*this); + state_ = kClosed; +} + Packet BaseServerWriter::ResponsePacket( std::span<const std::byte> payload) const { return Packet(PacketType::RESPONSE, diff --git a/pw_rpc/client.cc b/pw_rpc/client.cc index 73b2f0010..214a2085e 100644 --- a/pw_rpc/client.cc +++ b/pw_rpc/client.cc @@ -50,8 +50,18 @@ Status Client::ProcessPacket(ConstByteSpan data) { c.method_id() == packet.method_id(); }); + auto channel = std::find_if(channels_.begin(), channels_.end(), [&](auto& c) { + return c.id() == packet.channel_id(); + }); + + if (channel == channels_.end()) { + PW_LOG_WARN("RPC client received a packet for an unregistered channel"); + return Status::NotFound(); + } + if (call == calls_.end()) { PW_LOG_WARN("RPC client received a packet for a request it did not make"); + channel->Send(Packet::ClientError(packet, Status::FailedPrecondition())); return Status::NotFound(); } diff --git a/pw_rpc/client_test.cc b/pw_rpc/client_test.cc index 13a7cbe42..3174cfc0e 100644 --- a/pw_rpc/client_test.cc +++ b/pw_rpc/client_test.cc @@ -54,9 +54,19 @@ TEST(Client, ProcessPacket_InvokesARegisteredClientCall) { EXPECT_TRUE(call.invoked()); } -TEST(Client, ProcessPacket_ReturnsNotFoundOnUnregisteredCall) { +TEST(Client, ProcessPacket_SendsClientErrorOnUnregisteredCall) { ClientContextForTest context; + EXPECT_EQ(context.SendResponse(Status::OK, {}), Status::NotFound()); + + ASSERT_EQ(context.output().packet_count(), 1u); + const Packet& packet = context.output().sent_packet(); + EXPECT_EQ(packet.type(), PacketType::CLIENT_ERROR); + EXPECT_EQ(packet.channel_id(), context.kChannelId); + EXPECT_EQ(packet.service_id(), context.kServiceId); + EXPECT_EQ(packet.method_id(), context.kMethodId); + EXPECT_TRUE(packet.payload().empty()); + EXPECT_EQ(packet.status(), Status::FailedPrecondition()); } TEST(Client, ProcessPacket_ReturnsDataLossOnBadPacket) { diff --git a/pw_rpc/public/pw_rpc/internal/base_server_writer.h b/pw_rpc/public/pw_rpc/internal/base_server_writer.h index 295cd4cea..5c7e4fd21 100644 --- a/pw_rpc/public/pw_rpc/internal/base_server_writer.h +++ b/pw_rpc/public/pw_rpc/internal/base_server_writer.h @@ -24,7 +24,11 @@ #include "pw_rpc/service.h" #include "pw_status/status.h" -namespace pw::rpc::internal { +namespace pw::rpc { + +class Server; + +namespace internal { class Packet; @@ -67,6 +71,10 @@ class BaseServerWriter : public IntrusiveList<BaseServerWriter>::Item { Status ReleasePayloadBuffer(std::span<const std::byte> payload); private: + friend class rpc::Server; + + void Close(); + Packet ResponsePacket(std::span<const std::byte> payload = {}) const; ServerCall call_; @@ -74,4 +82,5 @@ class BaseServerWriter : public IntrusiveList<BaseServerWriter>::Item { enum { kClosed, kOpen } state_; }; -} // namespace pw::rpc::internal +} // namespace internal +} // namespace pw::rpc diff --git a/pw_rpc/public/pw_rpc/internal/packet.h b/pw_rpc/public/pw_rpc/internal/packet.h index f88118744..db7ce57c1 100644 --- a/pw_rpc/public/pw_rpc/internal/packet.h +++ b/pw_rpc/public/pw_rpc/internal/packet.h @@ -42,8 +42,8 @@ class Packet { status); } - // Creates an ERROR packet with the channel, service, and method ID of the - // provided packet. + // Creates a SERVER_ERROR packet with the channel, service, and method ID of + // the provided packet. static constexpr Packet ServerError(const Packet& packet, Status status) { return Packet(PacketType::SERVER_ERROR, packet.channel_id(), @@ -53,6 +53,17 @@ class Packet { status); } + // Creates a CLIENT_ERROR packet with the channel, service, and method ID of + // the provided packet. + static constexpr Packet ClientError(const Packet& packet, Status status) { + return Packet(PacketType::CLIENT_ERROR, + packet.channel_id(), + packet.service_id(), + packet.method_id(), + {}, + status); + } + // Creates an empty packet. constexpr Packet() : Packet(PacketType{}, kUnassignedId, kUnassignedId, kUnassignedId) {} diff --git a/pw_rpc/public/pw_rpc/server.h b/pw_rpc/public/pw_rpc/server.h index b88bdae89..51e3f365e 100644 --- a/pw_rpc/public/pw_rpc/server.h +++ b/pw_rpc/public/pw_rpc/server.h @@ -61,6 +61,7 @@ class Server { void HandleCancelPacket(const internal::Packet& request, internal::Channel& channel); + void HandleClientError(const internal::Packet& packet); internal::Channel* FindChannel(uint32_t id) const; internal::Channel* AssignChannel(uint32_t id, ChannelOutput& interface); diff --git a/pw_rpc/server.cc b/pw_rpc/server.cc index f7afeeafc..e288bfbc1 100644 --- a/pw_rpc/server.cc +++ b/pw_rpc/server.cc @@ -105,8 +105,7 @@ Status Server::ProcessPacket(std::span<const byte> data, // TODO(hepler): Support client streaming RPCs. break; case PacketType::CLIENT_ERROR: - // TODO(hepler): Handle errors from the client. If the client wasn't - // expecting a response for a streaming RPC, cancel that RPC. + HandleClientError(packet); break; case PacketType::CANCEL_SERVER_STREAM: HandleCancelPacket(packet, *channel); @@ -149,6 +148,21 @@ void Server::HandleCancelPacket(const Packet& packet, } } +void Server::HandleClientError(const Packet& packet) { + // A client error indicates that the client received a packet that it did not + // expect. If the packet belongs to a streaming RPC, cancel the stream without + // sending a final SERVER_STREAM_END packet. + auto writer = std::find_if(writers_.begin(), writers_.end(), [&](auto& w) { + return w.channel_id() == packet.channel_id() && + w.service_id() == packet.service_id() && + w.method_id() == packet.method_id(); + }); + + if (writer != writers_.end()) { + writer->Close(); + } +} + internal::Channel* Server::FindChannel(uint32_t id) const { for (internal::Channel& c : channels_) { if (c.id() == id) { diff --git a/pw_rpc/server_test.cc b/pw_rpc/server_test.cc index 44a7888d7..81c9866e1 100644 --- a/pw_rpc/server_test.cc +++ b/pw_rpc/server_test.cc @@ -263,6 +263,16 @@ TEST_F(MethodPending, ProcessPacket_Cancel_SendsStreamEndPacket) { EXPECT_EQ(packet.status(), Status::Cancelled()); } +TEST_F(MethodPending, + ProcessPacket_ClientError_ClosesServerWriterWithoutStreamEnd) { + EXPECT_EQ(Status::OK, + server_.ProcessPacket( + EncodeRequest(PacketType::CLIENT_ERROR, 1, 42, 100), output_)); + + EXPECT_FALSE(writer_.open()); + EXPECT_EQ(output_.packet_count(), 0u); +} + TEST_F(MethodPending, ProcessPacket_Cancel_IncorrectChannel) { EXPECT_EQ(Status::Ok(), server_.ProcessPacket( |