aboutsummaryrefslogtreecommitdiff
path: root/pw_rpc/java/test
diff options
context:
space:
mode:
Diffstat (limited to 'pw_rpc/java/test')
-rw-r--r--pw_rpc/java/test/dev/pigweed/pw_rpc/BUILD.bazel47
-rw-r--r--pw_rpc/java/test/dev/pigweed/pw_rpc/ClientTest.java104
-rw-r--r--pw_rpc/java/test/dev/pigweed/pw_rpc/EndpointTest.java209
-rw-r--r--pw_rpc/java/test/dev/pigweed/pw_rpc/FutureCallTest.java217
-rw-r--r--pw_rpc/java/test/dev/pigweed/pw_rpc/IdsTest.java1
-rw-r--r--pw_rpc/java/test/dev/pigweed/pw_rpc/PacketsTest.java3
-rw-r--r--pw_rpc/java/test/dev/pigweed/pw_rpc/RpcManagerTest.java149
-rw-r--r--pw_rpc/java/test/dev/pigweed/pw_rpc/StreamObserverCallTest.java148
-rw-r--r--pw_rpc/java/test/dev/pigweed/pw_rpc/StreamObserverMethodClientTest.java164
-rw-r--r--pw_rpc/java/test/dev/pigweed/pw_rpc/TestClient.java98
10 files changed, 741 insertions, 399 deletions
diff --git a/pw_rpc/java/test/dev/pigweed/pw_rpc/BUILD.bazel b/pw_rpc/java/test/dev/pigweed/pw_rpc/BUILD.bazel
index 3f63f7a7b..532255da9 100644
--- a/pw_rpc/java/test/dev/pigweed/pw_rpc/BUILD.bazel
+++ b/pw_rpc/java/test/dev/pigweed/pw_rpc/BUILD.bazel
@@ -20,7 +20,10 @@ java_library(
name = "test_client",
testonly = True,
srcs = ["TestClient.java"],
- visibility = ["__pkg__"],
+ visibility = [
+ "__pkg__",
+ "//pw_transfer/java/test/dev/pigweed/pw_transfer:__pkg__",
+ ],
deps = [
"//pw_rpc:packet_proto_java_lite",
"//pw_rpc/java/main/dev/pigweed/pw_rpc:client",
@@ -48,44 +51,60 @@ java_test(
)
java_test(
- name = "IdsTest",
+ name = "EndpointTest",
size = "small",
- srcs = ["IdsTest.java"],
- test_class = "dev.pigweed.pw_rpc.IdsTest",
+ srcs = ["EndpointTest.java"],
+ test_class = "dev.pigweed.pw_rpc.EndpointTest",
deps = [
+ ":test_proto_java_proto_lite",
+ "//pw_rpc:packet_proto_java_lite",
"//pw_rpc/java/main/dev/pigweed/pw_rpc:client",
+ "@com_google_protobuf//java/lite",
"@maven//:com_google_flogger_flogger_system_backend",
"@maven//:com_google_truth_truth",
+ "@maven//:org_mockito_mockito_core",
],
)
java_test(
- name = "PacketsTest",
+ name = "FutureCallTest",
size = "small",
- srcs = ["PacketsTest.java"],
- test_class = "dev.pigweed.pw_rpc.PacketsTest",
+ srcs = ["FutureCallTest.java"],
+ test_class = "dev.pigweed.pw_rpc.FutureCallTest",
deps = [
+ ":test_proto_java_proto_lite",
"//pw_rpc:packet_proto_java_lite",
"//pw_rpc/java/main/dev/pigweed/pw_rpc:client",
"@com_google_protobuf//java/lite",
"@maven//:com_google_flogger_flogger_system_backend",
"@maven//:com_google_truth_truth",
+ "@maven//:org_mockito_mockito_core",
],
)
java_test(
- name = "RpcManagerTest",
+ name = "IdsTest",
size = "small",
- srcs = ["RpcManagerTest.java"],
- test_class = "dev.pigweed.pw_rpc.RpcManagerTest",
+ srcs = ["IdsTest.java"],
+ test_class = "dev.pigweed.pw_rpc.IdsTest",
+ deps = [
+ "//pw_rpc/java/main/dev/pigweed/pw_rpc:client",
+ "@maven//:com_google_flogger_flogger_system_backend",
+ "@maven//:com_google_truth_truth",
+ ],
+)
+
+java_test(
+ name = "PacketsTest",
+ size = "small",
+ srcs = ["PacketsTest.java"],
+ test_class = "dev.pigweed.pw_rpc.PacketsTest",
deps = [
- ":test_proto_java_proto_lite",
"//pw_rpc:packet_proto_java_lite",
"//pw_rpc/java/main/dev/pigweed/pw_rpc:client",
"@com_google_protobuf//java/lite",
"@maven//:com_google_flogger_flogger_system_backend",
"@maven//:com_google_truth_truth",
- "@maven//:org_mockito_mockito_core",
],
)
@@ -111,6 +130,7 @@ java_test(
test_class = "dev.pigweed.pw_rpc.StreamObserverMethodClientTest",
deps = [
":test_proto_java_proto_lite",
+ "//pw_rpc:packet_proto_java_lite",
"//pw_rpc/java/main/dev/pigweed/pw_rpc:client",
"@com_google_protobuf//java/lite",
"@maven//:com_google_flogger_flogger_system_backend",
@@ -123,9 +143,10 @@ test_suite(
name = "pw_rpc",
tests = [
":ClientTest",
+ ":EndpointTest",
+ ":FutureCallTest",
":IdsTest",
":PacketsTest",
- ":RpcManagerTest",
":StreamObserverCallTest",
":StreamObserverMethodClientTest",
],
diff --git a/pw_rpc/java/test/dev/pigweed/pw_rpc/ClientTest.java b/pw_rpc/java/test/dev/pigweed/pw_rpc/ClientTest.java
index affffa598..b3e31984b 100644
--- a/pw_rpc/java/test/dev/pigweed/pw_rpc/ClientTest.java
+++ b/pw_rpc/java/test/dev/pigweed/pw_rpc/ClientTest.java
@@ -20,6 +20,7 @@ import static org.junit.Assert.assertThrows;
import static org.junit.Assert.fail;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import com.google.common.collect.ImmutableList;
@@ -33,7 +34,6 @@ import java.util.List;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
-import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.junit.MockitoJUnit;
@@ -123,22 +123,40 @@ public final class ClientTest {
}
@Test
- public void method_unknownMethod() {
+ public void method_invalidFormat() {
assertThrows(IllegalArgumentException.class, () -> client.method(CHANNEL_ID, ""));
assertThrows(IllegalArgumentException.class, () -> client.method(CHANNEL_ID, "one"));
assertThrows(IllegalArgumentException.class, () -> client.method(CHANNEL_ID, "hello"));
+ }
+
+ @Test
+ public void method_unknownService() {
assertThrows(
- IllegalArgumentException.class, () -> client.method(CHANNEL_ID, "abc.Service/Method"));
- assertThrows(IllegalArgumentException.class,
- () -> client.method(CHANNEL_ID, "pw.rpc.test1.TheTestService/NotAnRpc").method());
+ InvalidRpcServiceException.class, () -> client.method(CHANNEL_ID, "abc.Service/Method"));
+
+ Service service = new Service("throwaway.NotRealService",
+ Service.unaryMethod("NotAnRpc", SomeMessage.class, AnotherMessage.class));
+ assertThrows(InvalidRpcServiceException.class,
+ () -> client.method(CHANNEL_ID, service.method("NotAnRpc")));
+ }
+
+ @Test
+ public void method_unknownMethodInKnownService() {
+ assertThrows(InvalidRpcServiceMethodException.class,
+ () -> client.method(CHANNEL_ID, "pw.rpc.test1.TheTestService/NotAnRpc"));
+ assertThrows(InvalidRpcServiceMethodException.class,
+ () -> client.method(CHANNEL_ID, "pw.rpc.test1.TheTestService", "NotAnRpc"));
}
@Test
public void method_unknownChannel() {
- assertThrows(IllegalArgumentException.class,
- () -> client.method(0, "pw.rpc.test1.TheTestService/SomeUnary"));
- assertThrows(IllegalArgumentException.class,
- () -> client.method(999, "pw.rpc.test1.TheTestService/SomeUnary"));
+ MethodClient methodClient0 = client.method(0, "pw.rpc.test1.TheTestService/SomeUnary");
+ assertThrows(InvalidRpcChannelException.class,
+ () -> methodClient0.invokeUnary(SomeMessage.getDefaultInstance()));
+
+ MethodClient methodClient999 = client.method(999, "pw.rpc.test1.TheTestService/SomeUnary");
+ assertThrows(InvalidRpcChannelException.class,
+ () -> methodClient999.invokeUnary(SomeMessage.getDefaultInstance()));
}
@Test
@@ -178,13 +196,22 @@ public final class ClientTest {
}
@Test
+ public void method_accessFromMethodInstance() {
+ assertThat(client.method(CHANNEL_ID, UNARY_METHOD).method()).isSameInstanceAs(UNARY_METHOD);
+ assertThat(client.method(CHANNEL_ID, SERVER_STREAMING_METHOD).method())
+ .isSameInstanceAs(SERVER_STREAMING_METHOD);
+ assertThat(client.method(CHANNEL_ID, CLIENT_STREAMING_METHOD).method())
+ .isSameInstanceAs(CLIENT_STREAMING_METHOD);
+ }
+
+ @Test
public void processPacket_emptyPacket_isNotProcessed() {
assertThat(client.processPacket(new byte[] {})).isFalse();
}
@Test
public void processPacket_invalidPacket_isNotProcessed() {
- assertThat(client.processPacket("This is definitely not a packet!".getBytes(UTF_8))).isFalse();
+ assertThat(client.processPacket("\uffff\uffff\uffffNot a packet!".getBytes(UTF_8))).isFalse();
}
@Test
@@ -323,7 +350,7 @@ public final class ClientTest {
}
@Test
- @SuppressWarnings("unchecked") // No idea why, but this test causes "unchecked" warnings
+ @SuppressWarnings("unchecked")
public void streamObserverClient_create_invokeMethod() throws Exception {
Channel.Output mockChannelOutput = Mockito.mock(Channel.Output.class);
Client client = Client.create(ImmutableList.of(new Channel(1, mockChannelOutput)),
@@ -335,4 +362,59 @@ public final class ClientTest {
verify(mockChannelOutput)
.send(requestPacket("pw.rpc.test1.TheTestService", "SomeUnary", payload).toByteArray());
}
+
+ @Test
+ public void closeChannel_abortsExisting() throws Exception {
+ MethodClient serverStreamMethod =
+ client.method(CHANNEL_ID, "pw.rpc.test1.TheTestService", "SomeServerStreaming");
+
+ Call call1 = serverStreamMethod.invokeServerStreaming(REQUEST_PAYLOAD, observer);
+ Call call2 = client.method(CHANNEL_ID, "pw.rpc.test1.TheTestService", "SomeClientStreaming")
+ .invokeClientStreaming(observer);
+ assertThat(call1.active()).isTrue();
+ assertThat(call2.active()).isTrue();
+
+ assertThat(client.closeChannel(CHANNEL_ID)).isTrue();
+
+ assertThat(call1.active()).isFalse();
+ assertThat(call2.active()).isFalse();
+
+ verify(observer, times(2)).onError(Status.ABORTED);
+
+ assertThrows(InvalidRpcChannelException.class,
+ () -> serverStreamMethod.invokeServerStreaming(REQUEST_PAYLOAD, observer));
+ }
+
+ @Test
+ public void closeChannel_noCalls() {
+ assertThat(client.closeChannel(CHANNEL_ID)).isTrue();
+ }
+
+ @Test
+ public void closeChannel_knownChannel() {
+ assertThat(client.closeChannel(CHANNEL_ID + 100)).isFalse();
+ }
+
+ @Test
+ public void openChannel_uniqueChannel() throws Exception {
+ int newChannelId = CHANNEL_ID + 100;
+ Channel.Output channelOutput = Mockito.mock(Channel.Output.class);
+ client.openChannel(new Channel(newChannelId, channelOutput));
+
+ client.method(newChannelId, "pw.rpc.test1.TheTestService", "SomeUnary")
+ .invokeUnary(REQUEST_PAYLOAD, observer);
+
+ verify(channelOutput)
+ .send(requestPacket("pw.rpc.test1.TheTestService", "SomeUnary", REQUEST_PAYLOAD)
+ .toBuilder()
+ .setChannelId(newChannelId)
+ .build()
+ .toByteArray());
+ }
+
+ @Test
+ public void openChannel_alreadyExists_throwsException() {
+ assertThrows(InvalidRpcChannelException.class,
+ () -> client.openChannel(new Channel(CHANNEL_ID, packet -> {})));
+ }
}
diff --git a/pw_rpc/java/test/dev/pigweed/pw_rpc/EndpointTest.java b/pw_rpc/java/test/dev/pigweed/pw_rpc/EndpointTest.java
new file mode 100644
index 000000000..91722ee70
--- /dev/null
+++ b/pw_rpc/java/test/dev/pigweed/pw_rpc/EndpointTest.java
@@ -0,0 +1,209 @@
+// Copyright 2021 The Pigweed Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License"); you may not
+// use this file except in compliance with the License. You may obtain a copy of
+// the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+// License for the specific language governing permissions and limitations under
+// the License.
+
+package dev.pigweed.pw_rpc;
+
+import static com.google.common.truth.Truth.assertThat;
+import static org.junit.Assert.assertThrows;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.doThrow;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyNoInteractions;
+import static org.mockito.Mockito.verifyNoMoreInteractions;
+
+import com.google.common.collect.ImmutableList;
+import com.google.protobuf.MessageLite;
+import dev.pigweed.pw_rpc.internal.Packet.PacketType;
+import dev.pigweed.pw_rpc.internal.Packet.RpcPacket;
+import org.junit.Rule;
+import org.junit.Test;
+import org.mockito.Mock;
+import org.mockito.junit.MockitoJUnit;
+import org.mockito.junit.MockitoRule;
+
+public final class EndpointTest {
+ @Rule public final MockitoRule mockito = MockitoJUnit.rule();
+
+ private static final Service SERVICE = new Service("pw.rpc.test1.TheTestService",
+ Service.unaryMethod("SomeUnary", SomeMessage.class, SomeMessage.class),
+ Service.serverStreamingMethod("SomeServerStreaming", SomeMessage.class, SomeMessage.class),
+ Service.clientStreamingMethod("SomeClientStreaming", SomeMessage.class, SomeMessage.class),
+ Service.bidirectionalStreamingMethod(
+ "SomeBidiStreaming", SomeMessage.class, SomeMessage.class));
+
+ private static final Method METHOD = SERVICE.method("SomeUnary");
+
+ private static final SomeMessage REQUEST_PAYLOAD =
+ SomeMessage.newBuilder().setMagicNumber(1337).build();
+ private static final byte[] REQUEST = request(REQUEST_PAYLOAD);
+ private static final AnotherMessage RESPONSE_PAYLOAD =
+ AnotherMessage.newBuilder().setPayload("hello").build();
+ private static final int CHANNEL_ID = 555;
+
+ @Mock private Channel.Output mockOutput;
+ @Mock private StreamObserver<MessageLite> callEvents;
+
+ private final Channel channel = new Channel(CHANNEL_ID, bytes -> mockOutput.send(bytes));
+ private final Endpoint endpoint = new Endpoint(ImmutableList.of(channel));
+
+ private static byte[] request(MessageLite payload) {
+ return packetBuilder()
+ .setType(PacketType.REQUEST)
+ .setPayload(payload.toByteString())
+ .build()
+ .toByteArray();
+ }
+
+ private static byte[] cancel() {
+ return packetBuilder()
+ .setType(PacketType.CLIENT_ERROR)
+ .setStatus(Status.CANCELLED.code())
+ .build()
+ .toByteArray();
+ }
+
+ private static RpcPacket.Builder packetBuilder() {
+ return RpcPacket.newBuilder()
+ .setChannelId(CHANNEL_ID)
+ .setServiceId(SERVICE.id())
+ .setMethodId(METHOD.id());
+ }
+
+ private AbstractCall<MessageLite, MessageLite> createCall(Endpoint endpoint, PendingRpc rpc) {
+ return StreamObserverCall.getFactory(callEvents).apply(endpoint, rpc);
+ }
+
+ @Test
+ public void start_succeeds_rpcIsPending() throws Exception {
+ AbstractCall<MessageLite, MessageLite> call =
+ endpoint.invokeRpc(CHANNEL_ID, METHOD, this::createCall, REQUEST_PAYLOAD);
+
+ verify(mockOutput).send(REQUEST);
+ assertThat(endpoint.abandon(call)).isTrue();
+ }
+
+ @Test
+ public void start_sendingFails_callsHandleError() throws Exception {
+ doThrow(new ChannelOutputException()).when(mockOutput).send(any());
+
+ assertThrows(ChannelOutputException.class,
+ () -> endpoint.invokeRpc(CHANNEL_ID, METHOD, this::createCall, REQUEST_PAYLOAD));
+
+ verify(mockOutput).send(REQUEST);
+ }
+
+ @Test
+ public void abandon_rpcNoLongerPending() throws Exception {
+ AbstractCall<MessageLite, MessageLite> call =
+ endpoint.invokeRpc(CHANNEL_ID, METHOD, this::createCall, REQUEST_PAYLOAD);
+ assertThat(endpoint.abandon(call)).isTrue();
+
+ assertThat(endpoint.abandon(call)).isFalse();
+ }
+
+ @Test
+ public void abandon_sendsNoPackets() throws Exception {
+ AbstractCall<MessageLite, MessageLite> call =
+ endpoint.invokeRpc(CHANNEL_ID, METHOD, this::createCall, REQUEST_PAYLOAD);
+ verify(mockOutput).send(REQUEST);
+ verifyNoMoreInteractions(mockOutput);
+
+ assertThat(endpoint.abandon(call)).isTrue();
+ }
+
+ @Test
+ public void cancel_rpcNoLongerPending() throws Exception {
+ AbstractCall<MessageLite, MessageLite> call =
+ endpoint.invokeRpc(CHANNEL_ID, METHOD, this::createCall, REQUEST_PAYLOAD);
+ assertThat(endpoint.cancel(call)).isTrue();
+
+ assertThat(endpoint.abandon(call)).isFalse();
+ }
+
+ @Test
+ public void cancel_sendsCancelPacket() throws Exception {
+ AbstractCall<MessageLite, MessageLite> call =
+ endpoint.invokeRpc(CHANNEL_ID, METHOD, this::createCall, REQUEST_PAYLOAD);
+ assertThat(endpoint.cancel(call)).isTrue();
+
+ verify(mockOutput).send(cancel());
+ }
+
+ @Test
+ public void open_sendsNoPacketsButRpcIsPending() {
+ AbstractCall<MessageLite, MessageLite> call =
+ endpoint.openRpc(CHANNEL_ID, METHOD, this::createCall);
+
+ assertThat(call.active()).isTrue();
+ assertThat(endpoint.abandon(call)).isTrue();
+ verifyNoInteractions(mockOutput);
+ }
+
+ @Test
+ public void ignoresActionsIfCallIsNotPending() throws Exception {
+ AbstractCall<MessageLite, MessageLite> call =
+ createCall(endpoint, PendingRpc.create(channel, METHOD));
+
+ assertThat(endpoint.cancel(call)).isFalse();
+ assertThat(endpoint.abandon(call)).isFalse();
+ assertThat(endpoint.clientStream(call, REQUEST_PAYLOAD)).isFalse();
+ assertThat(endpoint.clientStreamEnd(call)).isFalse();
+ }
+
+ @Test
+ public void ignoresPacketsIfCallIsNotPending() throws Exception {
+ AbstractCall<MessageLite, MessageLite> call =
+ createCall(endpoint, PendingRpc.create(channel, METHOD));
+
+ assertThat(endpoint.cancel(call)).isFalse();
+ assertThat(endpoint.abandon(call)).isFalse();
+
+ assertThat(endpoint.processClientPacket(call.rpc().method(),
+ packetBuilder()
+ .setType(PacketType.SERVER_STREAM)
+ .setPayload(RESPONSE_PAYLOAD.toByteString())
+ .build()))
+ .isTrue();
+ assertThat(endpoint.processClientPacket(call.rpc().method(),
+ packetBuilder()
+ .setType(PacketType.RESPONSE)
+ .setPayload(RESPONSE_PAYLOAD.toByteString())
+ .build()))
+ .isTrue();
+ assertThat(endpoint.processClientPacket(call.rpc().method(),
+ packetBuilder()
+ .setType(PacketType.SERVER_ERROR)
+ .setStatus(Status.ABORTED.code())
+ .build()))
+ .isTrue();
+
+ assertThat(endpoint.processClientPacket(call.rpc().method(),
+ packetBuilder()
+ .setType(PacketType.CLIENT_STREAM)
+ .setPayload(REQUEST_PAYLOAD.toByteString())
+ .build()))
+ .isTrue();
+ assertThat(endpoint.processClientPacket(call.rpc().method(),
+ packetBuilder().setType(PacketType.CLIENT_STREAM_END).build()))
+ .isTrue();
+ assertThat(endpoint.processClientPacket(call.rpc().method(),
+ packetBuilder()
+ .setType(PacketType.CLIENT_ERROR)
+ .setStatus(Status.ABORTED.code())
+ .build()))
+ .isTrue();
+
+ verifyNoInteractions(callEvents);
+ }
+}
diff --git a/pw_rpc/java/test/dev/pigweed/pw_rpc/FutureCallTest.java b/pw_rpc/java/test/dev/pigweed/pw_rpc/FutureCallTest.java
new file mode 100644
index 000000000..a1f11997a
--- /dev/null
+++ b/pw_rpc/java/test/dev/pigweed/pw_rpc/FutureCallTest.java
@@ -0,0 +1,217 @@
+// Copyright 2021 The Pigweed Authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License"); you may not
+// use this file except in compliance with the License. You may obtain a copy of
+// the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+// License for the specific language governing permissions and limitations under
+// the License.
+
+package dev.pigweed.pw_rpc;
+
+import static com.google.common.truth.Truth.assertThat;
+import static com.google.common.util.concurrent.MoreExecutors.directExecutor;
+import static org.junit.Assert.assertThrows;
+import static org.mockito.Mockito.any;
+import static org.mockito.Mockito.doThrow;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+
+import com.google.common.collect.ImmutableList;
+import com.google.protobuf.ByteString;
+import dev.pigweed.pw_rpc.Call.UnaryFuture;
+import dev.pigweed.pw_rpc.FutureCall.StreamResponseFuture;
+import dev.pigweed.pw_rpc.FutureCall.UnaryResponseFuture;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.ExecutionException;
+import org.junit.Rule;
+import org.junit.Test;
+import org.mockito.Mock;
+import org.mockito.junit.MockitoJUnit;
+import org.mockito.junit.MockitoRule;
+
+public final class FutureCallTest {
+ @Rule public final MockitoRule mockito = MockitoJUnit.rule();
+
+ private static final Service SERVICE = new Service("pw.rpc.test1.TheTestService",
+ Service.unaryMethod("SomeUnary", SomeMessage.class, AnotherMessage.class),
+ Service.clientStreamingMethod("SomeClient", SomeMessage.class, AnotherMessage.class),
+ Service.bidirectionalStreamingMethod(
+ "SomeBidirectional", SomeMessage.class, AnotherMessage.class));
+ private static final Method METHOD = SERVICE.method("SomeUnary");
+ private static final int CHANNEL_ID = 555;
+
+ @Mock private Channel.Output mockOutput;
+
+ private final Channel channel = new Channel(CHANNEL_ID, packet -> mockOutput.send(packet));
+ private final Endpoint endpoint = new Endpoint(ImmutableList.of(channel));
+ private final PendingRpc rpc = PendingRpc.create(channel, METHOD);
+
+ @Test
+ public void unaryFuture_response_setsValue() throws Exception {
+ UnaryResponseFuture<SomeMessage, AnotherMessage> call = endpoint.invokeRpc(
+ CHANNEL_ID, METHOD, UnaryResponseFuture::new, SomeMessage.getDefaultInstance());
+
+ AnotherMessage response = AnotherMessage.newBuilder().setResultValue(1138).build();
+ call.handleUnaryCompleted(response.toByteString(), Status.CANCELLED);
+
+ assertThat(call.isDone()).isTrue();
+ assertThat(call.get()).isEqualTo(UnaryResult.create(response, Status.CANCELLED));
+ }
+
+ @Test
+ public void unaryFuture_serverError_setsException() throws Exception {
+ UnaryResponseFuture<SomeMessage, AnotherMessage> call = endpoint.invokeRpc(
+ CHANNEL_ID, METHOD, UnaryResponseFuture::new, SomeMessage.getDefaultInstance());
+
+ call.handleError(Status.NOT_FOUND);
+
+ assertThat(call.isDone()).isTrue();
+ ExecutionException exception = assertThrows(ExecutionException.class, call::get);
+ assertThat(exception).hasCauseThat().isInstanceOf(RpcError.class);
+
+ RpcError error = (RpcError) exception.getCause();
+ assertThat(error).isNotNull();
+ assertThat(error.rpc()).isEqualTo(rpc);
+ assertThat(error.status()).isEqualTo(Status.NOT_FOUND);
+ }
+
+ @Test
+ public void unaryFuture_cancelOnCall_cancelsTheCallAndFuture() throws Exception {
+ UnaryFuture<SomeMessage> call = endpoint.invokeRpc(
+ CHANNEL_ID, METHOD, UnaryResponseFuture::new, SomeMessage.getDefaultInstance());
+ assertThat(call.cancel()).isTrue();
+ assertThat(call.isCancelled()).isTrue();
+
+ ExecutionException exception = assertThrows(ExecutionException.class, call::get);
+ assertThat(exception).hasCauseThat().isInstanceOf(RpcError.class);
+
+ RpcError error = (RpcError) exception.getCause();
+ assertThat(error).isNotNull();
+ assertThat(error.rpc()).isEqualTo(rpc);
+ assertThat(error.status()).isEqualTo(Status.CANCELLED);
+ }
+
+ @Test
+ public void unaryFuture_cancelOnFuture_cancelsTheCallAndFuture() throws Exception {
+ UnaryFuture<SomeMessage> call = endpoint.invokeRpc(
+ CHANNEL_ID, METHOD, UnaryResponseFuture::new, SomeMessage.getDefaultInstance());
+ assertThat(call.cancel(true)).isTrue();
+ assertThat(call.isCancelled()).isTrue();
+
+ ExecutionException exception = assertThrows(ExecutionException.class, call::get);
+ assertThat(exception).hasCauseThat().isInstanceOf(RpcError.class);
+
+ RpcError error = (RpcError) exception.getCause();
+ assertThat(error).isNotNull();
+ assertThat(error.rpc()).isEqualTo(rpc);
+ assertThat(error.status()).isEqualTo(Status.CANCELLED);
+ }
+
+ @Test
+ public void unaryFuture_cancelOnFutureSendFails_cancelsTheCallAndFuture() throws Exception {
+ UnaryFuture<SomeMessage> call = endpoint.invokeRpc(
+ CHANNEL_ID, METHOD, UnaryResponseFuture::new, SomeMessage.getDefaultInstance());
+
+ doThrow(new ChannelOutputException()).when(mockOutput).send(any());
+
+ assertThat(call.cancel(true)).isTrue();
+ assertThat(call.isCancelled()).isTrue();
+
+ ExecutionException exception = assertThrows(ExecutionException.class, call::get);
+ assertThat(exception).hasCauseThat().isInstanceOf(RpcError.class);
+
+ RpcError error = (RpcError) exception.getCause();
+ assertThat(error).isNotNull();
+ assertThat(error.rpc()).isEqualTo(rpc);
+ assertThat(error.status()).isEqualTo(Status.CANCELLED);
+ }
+
+ @Test
+ public void unaryFuture_multipleResponses_setsException() throws Exception {
+ UnaryResponseFuture<SomeMessage, AnotherMessage> call = endpoint.invokeRpc(
+ CHANNEL_ID, METHOD, UnaryResponseFuture::new, SomeMessage.getDefaultInstance());
+
+ AnotherMessage response = AnotherMessage.newBuilder().setResultValue(1138).build();
+ call.doHandleNext(response);
+ call.handleUnaryCompleted(ByteString.EMPTY, Status.OK);
+
+ assertThat(call.isDone()).isTrue();
+ ExecutionException exception = assertThrows(ExecutionException.class, call::get);
+ assertThat(exception).hasCauseThat().isInstanceOf(IllegalStateException.class);
+ }
+
+ @Test
+ public void unaryFuture_addListener_calledOnCompletion() throws Exception {
+ UnaryResponseFuture<SomeMessage, AnotherMessage> call = endpoint.invokeRpc(
+ CHANNEL_ID, METHOD, UnaryResponseFuture::new, SomeMessage.getDefaultInstance());
+
+ Runnable listener = mock(Runnable.class);
+ call.addListener(listener, directExecutor());
+
+ AnotherMessage response = AnotherMessage.newBuilder().setResultValue(1138).build();
+ call.handleUnaryCompleted(response.toByteString(), Status.OK);
+
+ verify(listener, times(1)).run();
+ }
+
+ @Test
+ public void unaryFuture_exceptionDuringStart() throws Exception {
+ ChannelOutputException exceptionToThrow = new ChannelOutputException();
+ doThrow(exceptionToThrow).when(mockOutput).send(any());
+
+ UnaryResponseFuture<SomeMessage, AnotherMessage> call = endpoint.invokeRpc(
+ CHANNEL_ID, METHOD, UnaryResponseFuture::new, SomeMessage.getDefaultInstance());
+
+ assertThat(call.error()).isEqualTo(Status.ABORTED);
+ ExecutionException exception = assertThrows(ExecutionException.class, call::get);
+ assertThat(exception).hasCauseThat().isInstanceOf(ChannelOutputException.class);
+
+ assertThat(exception.getCause()).isSameInstanceAs(exceptionToThrow);
+ }
+
+ @Test
+ public void bidirectionalStreamingFuture_responses_setsValue() throws Exception {
+ List<AnotherMessage> responses = new ArrayList<>();
+ StreamResponseFuture<SomeMessage, AnotherMessage> call = endpoint.invokeRpc(CHANNEL_ID,
+ METHOD,
+ StreamResponseFuture.getFactory(responses::add),
+ SomeMessage.getDefaultInstance());
+
+ AnotherMessage message = AnotherMessage.newBuilder().setResultValue(1138).build();
+ call.doHandleNext(message);
+ call.doHandleNext(message);
+ assertThat(call.isDone()).isFalse();
+ call.handleStreamCompleted(Status.OK);
+
+ assertThat(call.isDone()).isTrue();
+ assertThat(call.get()).isEqualTo(Status.OK);
+ assertThat(responses).containsExactly(message, message);
+ }
+
+ @Test
+ public void bidirectionalStreamingFuture_serverError_setsException() throws Exception {
+ StreamResponseFuture<SomeMessage, AnotherMessage> call = endpoint.invokeRpc(CHANNEL_ID,
+ METHOD,
+ StreamResponseFuture.getFactory(msg -> {}),
+ SomeMessage.getDefaultInstance());
+
+ call.handleError(Status.NOT_FOUND);
+
+ assertThat(call.isDone()).isTrue();
+ ExecutionException exception = assertThrows(ExecutionException.class, call::get);
+ assertThat(exception).hasCauseThat().isInstanceOf(RpcError.class);
+
+ RpcError error = (RpcError) exception.getCause();
+ assertThat(error).isNotNull();
+ assertThat(error.rpc()).isEqualTo(rpc);
+ assertThat(error.status()).isEqualTo(Status.NOT_FOUND);
+ }
+}
diff --git a/pw_rpc/java/test/dev/pigweed/pw_rpc/IdsTest.java b/pw_rpc/java/test/dev/pigweed/pw_rpc/IdsTest.java
index 1440c6dbf..f8f99762b 100644
--- a/pw_rpc/java/test/dev/pigweed/pw_rpc/IdsTest.java
+++ b/pw_rpc/java/test/dev/pigweed/pw_rpc/IdsTest.java
@@ -17,7 +17,6 @@ package dev.pigweed.pw_rpc;
import static com.google.common.truth.Truth.assertThat;
import org.junit.Test;
-import org.junit.runner.RunWith;
public final class IdsTest {
@Test
diff --git a/pw_rpc/java/test/dev/pigweed/pw_rpc/PacketsTest.java b/pw_rpc/java/test/dev/pigweed/pw_rpc/PacketsTest.java
index 6a1e771f8..5f75aacf1 100644
--- a/pw_rpc/java/test/dev/pigweed/pw_rpc/PacketsTest.java
+++ b/pw_rpc/java/test/dev/pigweed/pw_rpc/PacketsTest.java
@@ -20,14 +20,13 @@ import com.google.protobuf.ExtensionRegistryLite;
import dev.pigweed.pw_rpc.internal.Packet.PacketType;
import dev.pigweed.pw_rpc.internal.Packet.RpcPacket;
import org.junit.Test;
-import org.junit.runner.RunWith;
public final class PacketsTest {
private static final Service SERVICE =
new Service("Greetings", Service.unaryMethod("Hello", RpcPacket.class, RpcPacket.class));
private static final PendingRpc RPC =
- PendingRpc.create(new Channel(123, null), SERVICE, SERVICE.method("Hello"));
+ PendingRpc.create(new Channel(123, null), SERVICE.method("Hello"));
private static final RpcPacket PACKET = RpcPacket.newBuilder()
.setChannelId(123)
diff --git a/pw_rpc/java/test/dev/pigweed/pw_rpc/RpcManagerTest.java b/pw_rpc/java/test/dev/pigweed/pw_rpc/RpcManagerTest.java
deleted file mode 100644
index f25c66283..000000000
--- a/pw_rpc/java/test/dev/pigweed/pw_rpc/RpcManagerTest.java
+++ /dev/null
@@ -1,149 +0,0 @@
-// Copyright 2021 The Pigweed Authors
-//
-// Licensed under the Apache License, Version 2.0 (the "License"); you may not
-// use this file except in compliance with the License. You may obtain a copy of
-// the License at
-//
-// https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-// License for the specific language governing permissions and limitations under
-// the License.
-
-package dev.pigweed.pw_rpc;
-
-import static com.google.common.truth.Truth.assertThat;
-import static org.junit.Assert.assertThrows;
-import static org.mockito.ArgumentMatchers.any;
-import static org.mockito.Mockito.doThrow;
-import static org.mockito.Mockito.verify;
-import static org.mockito.Mockito.verifyNoMoreInteractions;
-
-import com.google.protobuf.MessageLite;
-import dev.pigweed.pw_rpc.internal.Packet.PacketType;
-import dev.pigweed.pw_rpc.internal.Packet.RpcPacket;
-import org.junit.Before;
-import org.junit.Rule;
-import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.mockito.Mock;
-import org.mockito.junit.MockitoJUnit;
-import org.mockito.junit.MockitoRule;
-
-public final class RpcManagerTest {
- @Rule public final MockitoRule mockito = MockitoJUnit.rule();
-
- private static final Service SERVICE = new Service("pw.rpc.test1.TheTestService",
- Service.unaryMethod("SomeUnary", SomeMessage.class, SomeMessage.class),
- Service.serverStreamingMethod("SomeServerStreaming", SomeMessage.class, SomeMessage.class),
- Service.clientStreamingMethod("SomeClientStreaming", SomeMessage.class, SomeMessage.class),
- Service.bidirectionalStreamingMethod(
- "SomeBidiStreaming", SomeMessage.class, SomeMessage.class));
-
- private static final Method METHOD = SERVICE.method("SomeUnary");
-
- private static final SomeMessage REQUEST_PAYLOAD =
- SomeMessage.newBuilder().setMagicNumber(1337).build();
- private static final byte[] REQUEST = request(REQUEST_PAYLOAD);
- private static final int CHANNEL_ID = 555;
-
- @Mock private Channel.Output mockOutput;
- @Mock private StreamObserverCall<MessageLite, MessageLite> call;
-
- private PendingRpc rpc;
- private RpcManager manager;
-
- @Before
- public void setup() {
- rpc = PendingRpc.create(new Channel(CHANNEL_ID, mockOutput), SERVICE, METHOD);
- manager = new RpcManager();
- }
-
- private static byte[] request(MessageLite payload) {
- return packetBuilder()
- .setType(PacketType.REQUEST)
- .setPayload(payload.toByteString())
- .build()
- .toByteArray();
- }
-
- private static byte[] cancel() {
- return packetBuilder()
- .setType(PacketType.CLIENT_ERROR)
- .setStatus(Status.CANCELLED.code())
- .build()
- .toByteArray();
- }
-
- private static RpcPacket.Builder packetBuilder() {
- return RpcPacket.newBuilder()
- .setChannelId(CHANNEL_ID)
- .setServiceId(SERVICE.id())
- .setMethodId(METHOD.id());
- }
-
- @Test
- public void start_sendingFails_rpcNotPending() throws Exception {
- doThrow(new ChannelOutputException()).when(mockOutput).send(any());
-
- assertThrows(ChannelOutputException.class, () -> manager.start(rpc, call, REQUEST_PAYLOAD));
-
- verify(mockOutput).send(REQUEST);
- assertThat(manager.getPending(rpc)).isNull();
- }
-
- @Test
- public void start_succeeds_rpcIsPending() throws Exception {
- assertThat(manager.start(rpc, call, REQUEST_PAYLOAD)).isNull();
-
- assertThat(manager.getPending(rpc)).isSameInstanceAs(call);
- }
-
- @Test
- public void startThenCancel_rpcNotPending() throws Exception {
- assertThat(manager.start(rpc, call, REQUEST_PAYLOAD)).isNull();
- assertThat(manager.cancel(rpc)).isSameInstanceAs(call);
-
- assertThat(manager.getPending(rpc)).isNull();
- }
-
- @Test
- public void startThenCancel_sendsCancelPacket() throws Exception {
- assertThat(manager.start(rpc, call, REQUEST_PAYLOAD)).isNull();
- assertThat(manager.cancel(rpc)).isEqualTo(call);
-
- verify(mockOutput).send(cancel());
- }
-
- @Test
- public void startThenClear_sendsNothing() throws Exception {
- verifyNoMoreInteractions(mockOutput);
-
- assertThat(manager.start(rpc, call, REQUEST_PAYLOAD)).isNull();
- assertThat(manager.clear(rpc)).isEqualTo(call);
- }
-
- @Test
- public void clear_notPending_returnsNull() {
- assertThat(manager.clear(rpc)).isNull();
- }
-
- @Test
- public void open_sendingFails_rpcIsPending() throws Exception {
- doThrow(new ChannelOutputException()).when(mockOutput).send(any());
-
- assertThat(manager.open(rpc, call, REQUEST_PAYLOAD)).isNull();
-
- verify(mockOutput).send(REQUEST);
- assertThat(manager.getPending(rpc)).isSameInstanceAs(call);
- }
-
- @Test
- public void open_success_rpcIsPending() {
- assertThat(manager.open(rpc, call, REQUEST_PAYLOAD)).isNull();
-
- assertThat(manager.getPending(rpc)).isSameInstanceAs(call);
- }
-}
diff --git a/pw_rpc/java/test/dev/pigweed/pw_rpc/StreamObserverCallTest.java b/pw_rpc/java/test/dev/pigweed/pw_rpc/StreamObserverCallTest.java
index f4eba688b..76d86c162 100644
--- a/pw_rpc/java/test/dev/pigweed/pw_rpc/StreamObserverCallTest.java
+++ b/pw_rpc/java/test/dev/pigweed/pw_rpc/StreamObserverCallTest.java
@@ -15,22 +15,15 @@
package dev.pigweed.pw_rpc;
import static com.google.common.truth.Truth.assertThat;
-import static org.junit.Assert.assertThrows;
-import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
-import dev.pigweed.pw_rpc.StreamObserverCall.StreamResponseFuture;
-import dev.pigweed.pw_rpc.StreamObserverCall.UnaryResponseFuture;
+import com.google.common.collect.ImmutableList;
import dev.pigweed.pw_rpc.internal.Packet.PacketType;
import dev.pigweed.pw_rpc.internal.Packet.RpcPacket;
-import java.util.ArrayList;
-import java.util.List;
-import java.util.concurrent.ExecutionException;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
-import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnit;
import org.mockito.junit.MockitoRule;
@@ -49,9 +42,9 @@ public final class StreamObserverCallTest {
@Mock private StreamObserver<AnotherMessage> observer;
@Mock private Channel.Output mockOutput;
- private final RpcManager rpcManager = new RpcManager();
+ private final Channel channel = new Channel(CHANNEL_ID, packet -> mockOutput.send(packet));
+ private final Endpoint endpoint = new Endpoint(ImmutableList.of(channel));
private StreamObserverCall<SomeMessage, AnotherMessage> streamObserverCall;
- private PendingRpc rpc;
private static byte[] cancel() {
return packetBuilder()
@@ -70,9 +63,8 @@ public final class StreamObserverCallTest {
@Before
public void createCall() throws Exception {
- rpc = PendingRpc.create(new Channel(CHANNEL_ID, mockOutput), SERVICE, METHOD);
- streamObserverCall = StreamObserverCall.start(rpcManager, rpc, observer, null);
- rpcManager.start(rpc, streamObserverCall, SomeMessage.getDefaultInstance());
+ streamObserverCall =
+ endpoint.invokeRpc(CHANNEL_ID, METHOD, StreamObserverCall.getFactory(observer), null);
}
@Test
@@ -100,9 +92,23 @@ public final class StreamObserverCallTest {
}
@Test
+ public void abandon_doesNotSendCancelPacket() throws Exception {
+ streamObserverCall.abandon();
+
+ verify(mockOutput, never()).send(cancel());
+ }
+
+ @Test
+ public void abandon_deactivates() {
+ streamObserverCall.abandon();
+
+ assertThat(streamObserverCall.active()).isFalse();
+ }
+
+ @Test
public void send_sendsClientStreamPacket() throws Exception {
SomeMessage request = SomeMessage.newBuilder().setMagicNumber(123).build();
- streamObserverCall.send(request);
+ streamObserverCall.write(request);
verify(mockOutput)
.send(packetBuilder()
@@ -116,9 +122,7 @@ public final class StreamObserverCallTest {
public void send_raisesExceptionIfClosed() throws Exception {
streamObserverCall.cancel();
- RpcError thrown = assertThrows(
- RpcError.class, () -> streamObserverCall.send(SomeMessage.getDefaultInstance()));
- assertThat(thrown.status()).isSameInstanceAs(Status.CANCELLED);
+ assertThat(streamObserverCall.write(SomeMessage.getDefaultInstance())).isFalse();
}
@Test
@@ -131,29 +135,21 @@ public final class StreamObserverCallTest {
@Test
public void onNext_callsObserverIfActive() {
- streamObserverCall.onNext(AnotherMessage.getDefaultInstance().toByteString());
+ streamObserverCall.handleNext(AnotherMessage.getDefaultInstance().toByteString());
verify(observer).onNext(AnotherMessage.getDefaultInstance());
}
@Test
- public void onNext_ignoresIfNotActive() throws Exception {
- streamObserverCall.cancel();
- streamObserverCall.onNext(AnotherMessage.getDefaultInstance().toByteString());
-
- verify(observer, never()).onNext(any());
- }
-
- @Test
public void callDispatcher_onCompleted_callsObserver() {
- streamObserverCall.onCompleted(Status.ABORTED);
+ streamObserverCall.handleStreamCompleted(Status.ABORTED);
verify(observer).onCompleted(Status.ABORTED);
}
@Test
public void callDispatcher_onCompleted_setsActiveAndStatus() {
- streamObserverCall.onCompleted(Status.ABORTED);
+ streamObserverCall.handleStreamCompleted(Status.ABORTED);
verify(observer).onCompleted(Status.ABORTED);
assertThat(streamObserverCall.active()).isFalse();
@@ -162,109 +158,17 @@ public final class StreamObserverCallTest {
@Test
public void callDispatcher_onError_callsObserver() {
- streamObserverCall.onError(Status.NOT_FOUND);
+ streamObserverCall.handleError(Status.NOT_FOUND);
verify(observer).onError(Status.NOT_FOUND);
}
@Test
public void callDispatcher_onError_deactivates() {
- streamObserverCall.onError(Status.ABORTED);
+ streamObserverCall.handleError(Status.ABORTED);
verify(observer).onError(Status.ABORTED);
assertThat(streamObserverCall.active()).isFalse();
assertThat(streamObserverCall.status()).isNull();
}
-
- @Test
- public void unaryFuture_response_setsValue() throws Exception {
- UnaryResponseFuture<SomeMessage, AnotherMessage> call =
- new UnaryResponseFuture<>(rpcManager, rpc, SomeMessage.getDefaultInstance());
-
- AnotherMessage response = AnotherMessage.newBuilder().setResultValue(1138).build();
- call.onNext(response);
- assertThat(call.isDone()).isFalse();
- call.onCompleted(Status.CANCELLED);
-
- assertThat(call.isDone()).isTrue();
- assertThat(call.get()).isEqualTo(UnaryResult.create(response, Status.CANCELLED));
- }
-
- @Test
- public void unaryFuture_serverError_setsException() throws Exception {
- UnaryResponseFuture<SomeMessage, AnotherMessage> call =
- new UnaryResponseFuture<>(rpcManager, rpc, SomeMessage.getDefaultInstance());
-
- call.onError(Status.NOT_FOUND);
-
- assertThat(call.isDone()).isTrue();
- ExecutionException exception = assertThrows(ExecutionException.class, call::get);
- assertThat(exception).hasCauseThat().isInstanceOf(RpcError.class);
-
- RpcError error = (RpcError) exception.getCause();
- assertThat(error).isNotNull();
- assertThat(error.rpc()).isEqualTo(rpc);
- assertThat(error.status()).isEqualTo(Status.NOT_FOUND);
- }
-
- @Test
- public void unaryFuture_noMessage_setsException() throws Exception {
- UnaryResponseFuture<SomeMessage, AnotherMessage> call =
- new UnaryResponseFuture<>(rpcManager, rpc, SomeMessage.getDefaultInstance());
-
- call.onCompleted(Status.OK);
-
- assertThat(call.isDone()).isTrue();
- ExecutionException exception = assertThrows(ExecutionException.class, call::get);
- assertThat(exception).hasCauseThat().isInstanceOf(IllegalStateException.class);
- }
-
- @Test
- public void unaryFuture_multipleResponses_setsException() throws Exception {
- UnaryResponseFuture<SomeMessage, AnotherMessage> call =
- new UnaryResponseFuture<>(rpcManager, rpc, SomeMessage.getDefaultInstance());
-
- AnotherMessage response = AnotherMessage.newBuilder().setResultValue(1138).build();
- call.onNext(response);
- call.onNext(response);
- call.onCompleted(Status.OK);
-
- assertThat(call.isDone()).isTrue();
- ExecutionException exception = assertThrows(ExecutionException.class, call::get);
- assertThat(exception).hasCauseThat().isInstanceOf(IllegalStateException.class);
- }
-
- @Test
- public void bidirectionalStreamingfuture_responses_setsValue() throws Exception {
- List<AnotherMessage> responses = new ArrayList<>();
- StreamResponseFuture<SomeMessage, AnotherMessage> call =
- new StreamResponseFuture<>(rpcManager, rpc, responses::add, null);
-
- AnotherMessage message = AnotherMessage.newBuilder().setResultValue(1138).build();
- call.onNext(message);
- call.onNext(message);
- assertThat(call.isDone()).isFalse();
- call.onCompleted(Status.OK);
-
- assertThat(call.isDone()).isTrue();
- assertThat(call.get()).isEqualTo(Status.OK);
- assertThat(responses).containsExactly(message, message);
- }
-
- @Test
- public void bidirectionalStreamingfuture_serverError_setsException() throws Exception {
- StreamResponseFuture<SomeMessage, AnotherMessage> call =
- new StreamResponseFuture<>(rpcManager, rpc, (msg) -> {}, null);
-
- call.onError(Status.NOT_FOUND);
-
- assertThat(call.isDone()).isTrue();
- ExecutionException exception = assertThrows(ExecutionException.class, call::get);
- assertThat(exception).hasCauseThat().isInstanceOf(RpcError.class);
-
- RpcError error = (RpcError) exception.getCause();
- assertThat(error).isNotNull();
- assertThat(error.rpc()).isEqualTo(rpc);
- assertThat(error.status()).isEqualTo(Status.NOT_FOUND);
- }
}
diff --git a/pw_rpc/java/test/dev/pigweed/pw_rpc/StreamObserverMethodClientTest.java b/pw_rpc/java/test/dev/pigweed/pw_rpc/StreamObserverMethodClientTest.java
index 55fd88aa5..36102f447 100644
--- a/pw_rpc/java/test/dev/pigweed/pw_rpc/StreamObserverMethodClientTest.java
+++ b/pw_rpc/java/test/dev/pigweed/pw_rpc/StreamObserverMethodClientTest.java
@@ -16,14 +16,18 @@ package dev.pigweed.pw_rpc;
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertThrows;
-import static org.mockito.Mockito.mock;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
+import com.google.common.collect.ImmutableList;
import com.google.protobuf.MessageLite;
+import dev.pigweed.pw_rpc.internal.Packet.PacketType;
+import dev.pigweed.pw_rpc.internal.Packet.RpcPacket;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
-import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnit;
import org.mockito.junit.MockitoRule;
@@ -36,22 +40,24 @@ public final class StreamObserverMethodClientTest {
Service.bidirectionalStreamingMethod(
"SomeBidirectionalStreaming", SomeMessage.class, AnotherMessage.class));
- private static final Channel CHANNEL = new Channel(1, (bytes) -> {});
-
- private static final PendingRpc UNARY_RPC =
- PendingRpc.create(CHANNEL, SERVICE, SERVICE.method("SomeUnary"));
- private static final PendingRpc SERVER_STREAMING_RPC =
- PendingRpc.create(CHANNEL, SERVICE, SERVICE.method("SomeServerStreaming"));
- private static final PendingRpc CLIENT_STREAMING_RPC =
- PendingRpc.create(CHANNEL, SERVICE, SERVICE.method("SomeClientStreaming"));
- private static final PendingRpc BIDIRECTIONAL_STREAMING_RPC =
- PendingRpc.create(CHANNEL, SERVICE, SERVICE.method("SomeBidirectionalStreaming"));
-
@Rule public final MockitoRule mockito = MockitoJUnit.rule();
@Mock private StreamObserver<MessageLite> defaultObserver;
+ @Mock private StreamObserver<AnotherMessage> observer;
+ @Mock private Channel.Output channelOutput;
+
+ // Wrap Channel.Output since channelOutput will be null when the channel is initialized.
+ private final Channel channel = new Channel(1, bytes -> channelOutput.send(bytes));
- private final RpcManager rpcManager = new RpcManager();
+ private final PendingRpc unary_rpc = PendingRpc.create(channel, SERVICE.method("SomeUnary"));
+ private final PendingRpc server_streaming_rpc =
+ PendingRpc.create(channel, SERVICE.method("SomeServerStreaming"));
+ private final PendingRpc client_streaming_rpc =
+ PendingRpc.create(channel, SERVICE.method("SomeClientStreaming"));
+ private final PendingRpc bidirectional_streaming_rpc =
+ PendingRpc.create(channel, SERVICE.method("SomeBidirectionalStreaming"));
+
+ private final Client client = Client.create(ImmutableList.of(channel), ImmutableList.of(SERVICE));
private MethodClient unaryMethodClient;
private MethodClient serverStreamingMethodClient;
private MethodClient clientStreamingMethodClient;
@@ -59,33 +65,30 @@ public final class StreamObserverMethodClientTest {
@Before
public void createMethodClient() {
- unaryMethodClient = new MethodClient(rpcManager, UNARY_RPC, defaultObserver);
+ unaryMethodClient = new MethodClient(client, channel.id(), unary_rpc.method(), defaultObserver);
serverStreamingMethodClient =
- new MethodClient(rpcManager, SERVER_STREAMING_RPC, defaultObserver);
+ new MethodClient(client, channel.id(), server_streaming_rpc.method(), defaultObserver);
clientStreamingMethodClient =
- new MethodClient(rpcManager, CLIENT_STREAMING_RPC, defaultObserver);
- bidirectionalStreamingMethodClient =
- new MethodClient(rpcManager, BIDIRECTIONAL_STREAMING_RPC, defaultObserver);
+ new MethodClient(client, channel.id(), client_streaming_rpc.method(), defaultObserver);
+ bidirectionalStreamingMethodClient = new MethodClient(
+ client, channel.id(), bidirectional_streaming_rpc.method(), defaultObserver);
}
@Test
public void invokeWithNoObserver_usesDefaultObserver() throws Exception {
unaryMethodClient.invokeUnary(SomeMessage.getDefaultInstance());
AnotherMessage reply = AnotherMessage.newBuilder().setPayload("yo").build();
- rpcManager.getPending(UNARY_RPC).onNext(reply.toByteString());
+
+ assertThat(client.processPacket(responsePacket(unary_rpc, reply))).isTrue();
verify(defaultObserver).onNext(reply);
}
@Test
public void invoke_usesProvidedObserver() throws Exception {
- @SuppressWarnings("unchecked")
- StreamObserver<AnotherMessage> observer =
- (StreamObserver<AnotherMessage>) mock(StreamObserver.class);
-
unaryMethodClient.invokeUnary(SomeMessage.getDefaultInstance(), observer);
AnotherMessage reply = AnotherMessage.newBuilder().setPayload("yo").build();
- rpcManager.getPending(UNARY_RPC).onNext(reply.toByteString());
+ assertThat(client.processPacket(responsePacket(unary_rpc, reply))).isTrue();
verify(observer).onNext(reply);
}
@@ -93,75 +96,86 @@ public final class StreamObserverMethodClientTest {
@Test
public void invokeUnary_startsRpc() throws Exception {
Call call = unaryMethodClient.invokeUnary(SomeMessage.getDefaultInstance());
- assertThat(rpcManager.getPending(UNARY_RPC)).isSameInstanceAs(call);
+ assertThat(call.active()).isTrue();
+ verify(channelOutput, times(1)).send(any());
}
@Test
- public void openUnary_startsRpc() {
- Call call = unaryMethodClient.openUnary(SomeMessage.getDefaultInstance(), defaultObserver);
- assertThat(rpcManager.getPending(UNARY_RPC)).isSameInstanceAs(call);
+ public void openUnary_startsRpc() throws Exception {
+ Call call = unaryMethodClient.openUnary(defaultObserver);
+ assertThat(call.active()).isTrue();
+ verify(channelOutput, never()).send(any());
}
@Test
public void invokeServerStreaming_startsRpc() throws Exception {
Call call = serverStreamingMethodClient.invokeServerStreaming(SomeMessage.getDefaultInstance());
- assertThat(rpcManager.getPending(SERVER_STREAMING_RPC)).isSameInstanceAs(call);
+ assertThat(call.active()).isTrue();
+ verify(channelOutput, times(1)).send(any());
}
@Test
- public void openServerStreaming_startsRpc() {
- Call call = serverStreamingMethodClient.openServerStreaming(
- SomeMessage.getDefaultInstance(), defaultObserver);
- assertThat(rpcManager.getPending(SERVER_STREAMING_RPC)).isSameInstanceAs(call);
+ public void openServerStreaming_startsRpc() throws Exception {
+ Call call = serverStreamingMethodClient.openServerStreaming(defaultObserver);
+ assertThat(call.active()).isTrue();
+ verify(channelOutput, never()).send(any());
}
@Test
public void invokeClientStreaming_startsRpc() throws Exception {
Call call = clientStreamingMethodClient.invokeClientStreaming();
- assertThat(rpcManager.getPending(CLIENT_STREAMING_RPC)).isSameInstanceAs(call);
+ assertThat(call.active()).isTrue();
+ verify(channelOutput, times(1)).send(any());
}
@Test
- public void openClientStreaming_startsRpc() {
+ public void openClientStreaming_startsRpc() throws Exception {
Call call = clientStreamingMethodClient.openClientStreaming(defaultObserver);
- assertThat(rpcManager.getPending(CLIENT_STREAMING_RPC)).isSameInstanceAs(call);
+ assertThat(call.active()).isTrue();
+ verify(channelOutput, never()).send(any());
}
@Test
public void invokeBidirectionalStreaming_startsRpc() throws Exception {
Call call = bidirectionalStreamingMethodClient.invokeBidirectionalStreaming();
- assertThat(rpcManager.getPending(BIDIRECTIONAL_STREAMING_RPC)).isSameInstanceAs(call);
+ assertThat(call.active()).isTrue();
+ verify(channelOutput, times(1)).send(any());
}
@Test
- public void openBidirectionalStreaming_startsRpc() {
+ public void openBidirectionalStreaming_startsRpc() throws Exception {
Call call = bidirectionalStreamingMethodClient.openBidirectionalStreaming(defaultObserver);
- assertThat(rpcManager.getPending(BIDIRECTIONAL_STREAMING_RPC)).isSameInstanceAs(call);
+ assertThat(call.active()).isTrue();
+ verify(channelOutput, never()).send(any());
}
@Test
- public void invokeUnaryFuture_startsRpc() {
- unaryMethodClient.invokeUnaryFuture(SomeMessage.getDefaultInstance());
- assertThat(rpcManager.getPending(UNARY_RPC)).isNotNull();
+ public void invokeUnaryFuture_startsRpc() throws Exception {
+ Call call = unaryMethodClient.invokeUnaryFuture(SomeMessage.getDefaultInstance());
+ assertThat(call.active()).isTrue();
+ verify(channelOutput, times(1)).send(any());
}
@Test
- public void invokeServerStreamingFuture_startsRpc() {
- serverStreamingMethodClient.invokeServerStreamingFuture(
+ public void invokeServerStreamingFuture_startsRpc() throws Exception {
+ Call call = serverStreamingMethodClient.invokeServerStreamingFuture(
SomeMessage.getDefaultInstance(), (msg) -> {});
- assertThat(rpcManager.getPending(SERVER_STREAMING_RPC)).isNotNull();
+ assertThat(call.active()).isTrue();
+ verify(channelOutput, times(1)).send(any());
}
@Test
- public void invokeClientStreamingFuture_startsRpc() {
- clientStreamingMethodClient.invokeClientStreamingFuture();
- assertThat(rpcManager.getPending(CLIENT_STREAMING_RPC)).isNotNull();
+ public void invokeClientStreamingFuture_startsRpc() throws Exception {
+ Call call = clientStreamingMethodClient.invokeClientStreamingFuture();
+ assertThat(call.active()).isTrue();
+ verify(channelOutput, times(1)).send(any());
}
@Test
- public void invokeBidirectionalStreamingFuture_startsRpc() {
- bidirectionalStreamingMethodClient.invokeBidirectionalStreamingFuture((msg) -> {});
- assertThat(rpcManager.getPending(BIDIRECTIONAL_STREAMING_RPC)).isNotNull();
+ public void invokeBidirectionalStreamingFuture_startsRpc() throws Exception {
+ Call call = bidirectionalStreamingMethodClient.invokeBidirectionalStreamingFuture((msg) -> {});
+ assertThat(call.active()).isTrue();
+ verify(channelOutput, times(1)).send(any());
}
@Test
@@ -187,4 +201,50 @@ public final class StreamObserverMethodClientTest {
assertThrows(UnsupportedOperationException.class,
() -> clientStreamingMethodClient.invokeBidirectionalStreaming());
}
+
+ @Test
+ public void invalidChannel_throwsException() {
+ MethodClient methodClient =
+ new MethodClient(client, 999, client_streaming_rpc.method(), defaultObserver);
+ assertThrows(InvalidRpcChannelException.class, methodClient::invokeClientStreaming);
+ }
+
+ @Test
+ public void invalidService_throwsException() {
+ Service otherService = new Service("something.Else",
+ Service.clientStreamingMethod("ClientStream", SomeMessage.class, AnotherMessage.class));
+
+ MethodClient methodClient = new MethodClient(
+ client, channel.id(), otherService.method("ClientStream"), defaultObserver);
+ assertThrows(InvalidRpcServiceException.class, methodClient::invokeClientStreaming);
+ }
+
+ @Test
+ public void invalidMethod_throwsException() {
+ Service serviceWithDifferentUnaryMethod = new Service("pw.rpc.test1.TheTestService",
+ Service.unaryMethod("SomeUnary", AnotherMessage.class, AnotherMessage.class),
+ Service.serverStreamingMethod(
+ "SomeServerStreaming", SomeMessage.class, AnotherMessage.class),
+ Service.clientStreamingMethod(
+ "SomeClientStreaming", SomeMessage.class, AnotherMessage.class),
+ Service.bidirectionalStreamingMethod(
+ "SomeBidirectionalStreaming", SomeMessage.class, AnotherMessage.class));
+
+ MethodClient methodClient = new MethodClient(
+ client, 999, serviceWithDifferentUnaryMethod.method("SomeUnary"), defaultObserver);
+ assertThrows(InvalidRpcServiceMethodException.class,
+ () -> methodClient.invokeUnary(AnotherMessage.getDefaultInstance()));
+ }
+
+ private static byte[] responsePacket(PendingRpc rpc, MessageLite payload) {
+ return RpcPacket.newBuilder()
+ .setChannelId(1)
+ .setServiceId(rpc.service().id())
+ .setMethodId(rpc.method().id())
+ .setType(PacketType.RESPONSE)
+ .setStatus(Status.OK.code())
+ .setPayload(payload.toByteString())
+ .build()
+ .toByteArray();
+ }
}
diff --git a/pw_rpc/java/test/dev/pigweed/pw_rpc/TestClient.java b/pw_rpc/java/test/dev/pigweed/pw_rpc/TestClient.java
index 6688e2422..a58d51901 100644
--- a/pw_rpc/java/test/dev/pigweed/pw_rpc/TestClient.java
+++ b/pw_rpc/java/test/dev/pigweed/pw_rpc/TestClient.java
@@ -14,17 +14,19 @@
package dev.pigweed.pw_rpc;
-import static java.util.Arrays.stream;
-
import com.google.common.collect.ImmutableList;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.MessageLite;
+import com.google.protobuf.MessageLiteOrBuilder;
import dev.pigweed.pw_rpc.internal.Packet.PacketType;
import dev.pigweed.pw_rpc.internal.Packet.RpcPacket;
+import java.util.ArrayDeque;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.EnumMap;
import java.util.List;
import java.util.Map;
+import java.util.Queue;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
@@ -38,29 +40,35 @@ public class TestClient {
private final Client client;
private final List<RpcPacket> sentPackets = new ArrayList<>();
- private final List<RpcPacket> enqueuedPackets = new ArrayList<>();
- private int receiveEnqueuedPacketsAfter = 1;
- final Map<PacketType, Integer> sentPayloadIndices = new EnumMap<>(PacketType.class);
+ private final Queue<EnqueuedPackets> enqueuedPackets = new ArrayDeque<>();
+ private final Map<PacketType, Integer> sentPayloadIndices = new EnumMap<>(PacketType.class);
+
+ @Nullable private ChannelOutputException channelOutputException = null;
+
+ private static class EnqueuedPackets {
+ private int processAfterSentPackets;
+ private final List<RpcPacket> packets;
- @Nullable ChannelOutputException channelOutputException = null;
+ private EnqueuedPackets(int processAfterSentPackets, List<RpcPacket> packets) {
+ this.processAfterSentPackets = processAfterSentPackets;
+ this.packets = packets;
+ }
+
+ private boolean shouldProcessEnqueuedPackets() {
+ return processAfterSentPackets-- <= 1;
+ }
+ }
public TestClient(List<Service> services) {
Channel.Output channelOutput = packet -> {
- if (channelOutputException == null) {
- sentPackets.add(parsePacket(packet));
- } else {
+ if (channelOutputException != null) {
throw channelOutputException;
}
+ sentPackets.add(parsePacket(packet));
- // Process any enqueued packets.
- if (receiveEnqueuedPacketsAfter > 1) {
- receiveEnqueuedPacketsAfter -= 1;
- return;
- }
- if (!enqueuedPackets.isEmpty()) {
- List<RpcPacket> packetsToProcess = new ArrayList<>(enqueuedPackets);
- enqueuedPackets.clear();
- packetsToProcess.forEach(this::processPacket);
+ if (!enqueuedPackets.isEmpty() && enqueuedPackets.peek().shouldProcessEnqueuedPackets()) {
+ // Process any enqueued packets.
+ enqueuedPackets.remove().packets.forEach(this::processPacket);
}
};
client = Client.create(ImmutableList.of(new Channel(CHANNEL_ID, channelOutput)), services);
@@ -86,52 +94,34 @@ public class TestClient {
}
/** Simulates receiving SERVER_STREAM packets from the server. */
- public void receiveServerStream(String service, String method, MessageLite... payloads) {
+ public void receiveServerStream(String service, String method, MessageLiteOrBuilder... payloads) {
RpcPacket base = startPacket(service, method, PacketType.SERVER_STREAM).build();
- for (MessageLite payload : payloads) {
- processPacket(RpcPacket.newBuilder(base).setPayload(payload.toByteString()));
+ for (MessageLiteOrBuilder payload : payloads) {
+ processPacket(RpcPacket.newBuilder(base).setPayload(getMessage(payload).toByteString()));
}
}
- public void receiveServerStream(String service, String method, MessageLite.Builder... builders) {
- receiveServerStream(service,
- method,
- stream(builders).map(MessageLite.Builder::build).toArray(MessageLite[] ::new));
- }
-
/**
* Enqueues a SERVER_STREAM packet so that the client receives it after a packet is sent.
*
+ * This function may be called multiple times to create a queue of packets to process as different
+ * packets are sent.
+ *
* @param afterPackets Wait until this many packets have been sent before the client receives
- * these stream packets. The minimum value (and the default) is 1.
+ * these stream packets. The minimum value is 1. If multiple stream packets are queued,
+ * afterPackets is counted from the packet before it in the queue.
*/
public void enqueueServerStream(
- String service, String method, int afterPackets, MessageLite... payloads) {
+ String service, String method, int afterPackets, MessageLiteOrBuilder... payloads) {
if (afterPackets < 1) {
throw new IllegalArgumentException("afterPackets must be at least 1");
}
- if (afterPackets != 1 && receiveEnqueuedPacketsAfter != 1) {
- throw new AssertionError(
- "May only set afterPackets once before enqueued packets are processed");
- }
- receiveEnqueuedPacketsAfter = afterPackets;
RpcPacket base = startPacket(service, method, PacketType.SERVER_STREAM).build();
- for (MessageLite payload : payloads) {
- enqueuedPackets.add(RpcPacket.newBuilder(base).setPayload(payload.toByteString()).build());
- }
- }
-
- public void enqueueServerStream(String service, String method, MessageLite... payloads) {
- enqueueServerStream(service, method, 1, payloads);
- }
-
- public void enqueueServerStream(
- String service, String method, int afterPackets, MessageLite.Builder... builders) {
- enqueueServerStream(service,
- method,
- afterPackets,
- stream(builders).map(MessageLite.Builder::build).toArray(MessageLite[] ::new));
+ enqueuedPackets.add(new EnqueuedPackets(afterPackets,
+ Arrays.stream(payloads)
+ .map(m -> RpcPacket.newBuilder(base).setPayload(getMessage(m).toByteString()).build())
+ .collect(Collectors.toList())));
}
/** Simulates receiving a SERVER_ERROR packet from the server. */
@@ -192,4 +182,14 @@ public class TestClient {
throw new AssertionError("Decoding sent packet payload failed", e);
}
}
+
+ private MessageLite getMessage(MessageLiteOrBuilder messageOrBuilder) {
+ if (messageOrBuilder instanceof MessageLite.Builder) {
+ return ((MessageLite.Builder) messageOrBuilder).build();
+ }
+ if (messageOrBuilder instanceof MessageLite) {
+ return (MessageLite) messageOrBuilder;
+ }
+ throw new AssertionError("Unexpected MessageLiteOrBuilder class");
+ }
}