aboutsummaryrefslogtreecommitdiff
path: root/pw_rpc
diff options
context:
space:
mode:
authorAlexei Frolov <frolv@google.com>2020-09-28 16:23:06 -0700
committerCQ Bot Account <pigweed-scoped@luci-project-accounts.iam.gserviceaccount.com>2020-10-02 16:21:48 +0000
commit062ed18abbf7d2555c84bd6ffe2399c207a250ed (patch)
tree01b28871736d3462c8a1255c764c5bbe37813548 /pw_rpc
parentda651f169f97f85b63c675b5f5479a7868179da1 (diff)
downloadpigweed-062ed18abbf7d2555c84bd6ffe2399c207a250ed.tar.gz
pw_rpc: Send CLIENT_ERROR on unexpected packet
This updates the RPC client to send back a CLIENT_ERROR to the server if it receives a packet it was not expecting. The server is updated to handle client errors by ending the active RPC, if applicable. Change-Id: Ia9a2f36571fd3f91e28a50393c531df66ba11f09 Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/19180 Commit-Queue: Alexei Frolov <frolv@google.com> Reviewed-by: Wyatt Hepler <hepler@google.com>
Diffstat (limited to 'pw_rpc')
-rw-r--r--pw_rpc/base_server_writer.cc12
-rw-r--r--pw_rpc/client.cc10
-rw-r--r--pw_rpc/client_test.cc12
-rw-r--r--pw_rpc/public/pw_rpc/internal/base_server_writer.h13
-rw-r--r--pw_rpc/public/pw_rpc/internal/packet.h15
-rw-r--r--pw_rpc/public/pw_rpc/server.h1
-rw-r--r--pw_rpc/server.cc18
-rw-r--r--pw_rpc/server_test.cc10
8 files changed, 82 insertions, 9 deletions
diff --git a/pw_rpc/base_server_writer.cc b/pw_rpc/base_server_writer.cc
index 1e3c39680..942cf326e 100644
--- a/pw_rpc/base_server_writer.cc
+++ b/pw_rpc/base_server_writer.cc
@@ -50,8 +50,7 @@ void BaseServerWriter::Finish(Status status) {
return;
}
- call_.server().RemoveWriter(*this);
- state_ = kClosed;
+ Close();
// Send a control packet indicating that the stream (and RPC) has terminated.
call_.channel().Send(Packet(PacketType::SERVER_STREAM_END,
@@ -79,6 +78,15 @@ Status BaseServerWriter::ReleasePayloadBuffer(
return call_.channel().Send(response_, ResponsePacket(payload));
}
+void BaseServerWriter::Close() {
+ if (!open()) {
+ return;
+ }
+
+ call_.server().RemoveWriter(*this);
+ state_ = kClosed;
+}
+
Packet BaseServerWriter::ResponsePacket(
std::span<const std::byte> payload) const {
return Packet(PacketType::RESPONSE,
diff --git a/pw_rpc/client.cc b/pw_rpc/client.cc
index 73b2f0010..214a2085e 100644
--- a/pw_rpc/client.cc
+++ b/pw_rpc/client.cc
@@ -50,8 +50,18 @@ Status Client::ProcessPacket(ConstByteSpan data) {
c.method_id() == packet.method_id();
});
+ auto channel = std::find_if(channels_.begin(), channels_.end(), [&](auto& c) {
+ return c.id() == packet.channel_id();
+ });
+
+ if (channel == channels_.end()) {
+ PW_LOG_WARN("RPC client received a packet for an unregistered channel");
+ return Status::NotFound();
+ }
+
if (call == calls_.end()) {
PW_LOG_WARN("RPC client received a packet for a request it did not make");
+ channel->Send(Packet::ClientError(packet, Status::FailedPrecondition()));
return Status::NotFound();
}
diff --git a/pw_rpc/client_test.cc b/pw_rpc/client_test.cc
index 13a7cbe42..3174cfc0e 100644
--- a/pw_rpc/client_test.cc
+++ b/pw_rpc/client_test.cc
@@ -54,9 +54,19 @@ TEST(Client, ProcessPacket_InvokesARegisteredClientCall) {
EXPECT_TRUE(call.invoked());
}
-TEST(Client, ProcessPacket_ReturnsNotFoundOnUnregisteredCall) {
+TEST(Client, ProcessPacket_SendsClientErrorOnUnregisteredCall) {
ClientContextForTest context;
+
EXPECT_EQ(context.SendResponse(Status::OK, {}), Status::NotFound());
+
+ ASSERT_EQ(context.output().packet_count(), 1u);
+ const Packet& packet = context.output().sent_packet();
+ EXPECT_EQ(packet.type(), PacketType::CLIENT_ERROR);
+ EXPECT_EQ(packet.channel_id(), context.kChannelId);
+ EXPECT_EQ(packet.service_id(), context.kServiceId);
+ EXPECT_EQ(packet.method_id(), context.kMethodId);
+ EXPECT_TRUE(packet.payload().empty());
+ EXPECT_EQ(packet.status(), Status::FailedPrecondition());
}
TEST(Client, ProcessPacket_ReturnsDataLossOnBadPacket) {
diff --git a/pw_rpc/public/pw_rpc/internal/base_server_writer.h b/pw_rpc/public/pw_rpc/internal/base_server_writer.h
index 295cd4cea..5c7e4fd21 100644
--- a/pw_rpc/public/pw_rpc/internal/base_server_writer.h
+++ b/pw_rpc/public/pw_rpc/internal/base_server_writer.h
@@ -24,7 +24,11 @@
#include "pw_rpc/service.h"
#include "pw_status/status.h"
-namespace pw::rpc::internal {
+namespace pw::rpc {
+
+class Server;
+
+namespace internal {
class Packet;
@@ -67,6 +71,10 @@ class BaseServerWriter : public IntrusiveList<BaseServerWriter>::Item {
Status ReleasePayloadBuffer(std::span<const std::byte> payload);
private:
+ friend class rpc::Server;
+
+ void Close();
+
Packet ResponsePacket(std::span<const std::byte> payload = {}) const;
ServerCall call_;
@@ -74,4 +82,5 @@ class BaseServerWriter : public IntrusiveList<BaseServerWriter>::Item {
enum { kClosed, kOpen } state_;
};
-} // namespace pw::rpc::internal
+} // namespace internal
+} // namespace pw::rpc
diff --git a/pw_rpc/public/pw_rpc/internal/packet.h b/pw_rpc/public/pw_rpc/internal/packet.h
index f88118744..db7ce57c1 100644
--- a/pw_rpc/public/pw_rpc/internal/packet.h
+++ b/pw_rpc/public/pw_rpc/internal/packet.h
@@ -42,8 +42,8 @@ class Packet {
status);
}
- // Creates an ERROR packet with the channel, service, and method ID of the
- // provided packet.
+ // Creates a SERVER_ERROR packet with the channel, service, and method ID of
+ // the provided packet.
static constexpr Packet ServerError(const Packet& packet, Status status) {
return Packet(PacketType::SERVER_ERROR,
packet.channel_id(),
@@ -53,6 +53,17 @@ class Packet {
status);
}
+ // Creates a CLIENT_ERROR packet with the channel, service, and method ID of
+ // the provided packet.
+ static constexpr Packet ClientError(const Packet& packet, Status status) {
+ return Packet(PacketType::CLIENT_ERROR,
+ packet.channel_id(),
+ packet.service_id(),
+ packet.method_id(),
+ {},
+ status);
+ }
+
// Creates an empty packet.
constexpr Packet()
: Packet(PacketType{}, kUnassignedId, kUnassignedId, kUnassignedId) {}
diff --git a/pw_rpc/public/pw_rpc/server.h b/pw_rpc/public/pw_rpc/server.h
index b88bdae89..51e3f365e 100644
--- a/pw_rpc/public/pw_rpc/server.h
+++ b/pw_rpc/public/pw_rpc/server.h
@@ -61,6 +61,7 @@ class Server {
void HandleCancelPacket(const internal::Packet& request,
internal::Channel& channel);
+ void HandleClientError(const internal::Packet& packet);
internal::Channel* FindChannel(uint32_t id) const;
internal::Channel* AssignChannel(uint32_t id, ChannelOutput& interface);
diff --git a/pw_rpc/server.cc b/pw_rpc/server.cc
index f7afeeafc..e288bfbc1 100644
--- a/pw_rpc/server.cc
+++ b/pw_rpc/server.cc
@@ -105,8 +105,7 @@ Status Server::ProcessPacket(std::span<const byte> data,
// TODO(hepler): Support client streaming RPCs.
break;
case PacketType::CLIENT_ERROR:
- // TODO(hepler): Handle errors from the client. If the client wasn't
- // expecting a response for a streaming RPC, cancel that RPC.
+ HandleClientError(packet);
break;
case PacketType::CANCEL_SERVER_STREAM:
HandleCancelPacket(packet, *channel);
@@ -149,6 +148,21 @@ void Server::HandleCancelPacket(const Packet& packet,
}
}
+void Server::HandleClientError(const Packet& packet) {
+ // A client error indicates that the client received a packet that it did not
+ // expect. If the packet belongs to a streaming RPC, cancel the stream without
+ // sending a final SERVER_STREAM_END packet.
+ auto writer = std::find_if(writers_.begin(), writers_.end(), [&](auto& w) {
+ return w.channel_id() == packet.channel_id() &&
+ w.service_id() == packet.service_id() &&
+ w.method_id() == packet.method_id();
+ });
+
+ if (writer != writers_.end()) {
+ writer->Close();
+ }
+}
+
internal::Channel* Server::FindChannel(uint32_t id) const {
for (internal::Channel& c : channels_) {
if (c.id() == id) {
diff --git a/pw_rpc/server_test.cc b/pw_rpc/server_test.cc
index 44a7888d7..81c9866e1 100644
--- a/pw_rpc/server_test.cc
+++ b/pw_rpc/server_test.cc
@@ -263,6 +263,16 @@ TEST_F(MethodPending, ProcessPacket_Cancel_SendsStreamEndPacket) {
EXPECT_EQ(packet.status(), Status::Cancelled());
}
+TEST_F(MethodPending,
+ ProcessPacket_ClientError_ClosesServerWriterWithoutStreamEnd) {
+ EXPECT_EQ(Status::OK,
+ server_.ProcessPacket(
+ EncodeRequest(PacketType::CLIENT_ERROR, 1, 42, 100), output_));
+
+ EXPECT_FALSE(writer_.open());
+ EXPECT_EQ(output_.packet_count(), 0u);
+}
+
TEST_F(MethodPending, ProcessPacket_Cancel_IncorrectChannel) {
EXPECT_EQ(Status::Ok(),
server_.ProcessPacket(