diff options
Diffstat (limited to 'pw_rpc/java/test')
10 files changed, 741 insertions, 399 deletions
diff --git a/pw_rpc/java/test/dev/pigweed/pw_rpc/BUILD.bazel b/pw_rpc/java/test/dev/pigweed/pw_rpc/BUILD.bazel index 3f63f7a7b..532255da9 100644 --- a/pw_rpc/java/test/dev/pigweed/pw_rpc/BUILD.bazel +++ b/pw_rpc/java/test/dev/pigweed/pw_rpc/BUILD.bazel @@ -20,7 +20,10 @@ java_library( name = "test_client", testonly = True, srcs = ["TestClient.java"], - visibility = ["__pkg__"], + visibility = [ + "__pkg__", + "//pw_transfer/java/test/dev/pigweed/pw_transfer:__pkg__", + ], deps = [ "//pw_rpc:packet_proto_java_lite", "//pw_rpc/java/main/dev/pigweed/pw_rpc:client", @@ -48,44 +51,60 @@ java_test( ) java_test( - name = "IdsTest", + name = "EndpointTest", size = "small", - srcs = ["IdsTest.java"], - test_class = "dev.pigweed.pw_rpc.IdsTest", + srcs = ["EndpointTest.java"], + test_class = "dev.pigweed.pw_rpc.EndpointTest", deps = [ + ":test_proto_java_proto_lite", + "//pw_rpc:packet_proto_java_lite", "//pw_rpc/java/main/dev/pigweed/pw_rpc:client", + "@com_google_protobuf//java/lite", "@maven//:com_google_flogger_flogger_system_backend", "@maven//:com_google_truth_truth", + "@maven//:org_mockito_mockito_core", ], ) java_test( - name = "PacketsTest", + name = "FutureCallTest", size = "small", - srcs = ["PacketsTest.java"], - test_class = "dev.pigweed.pw_rpc.PacketsTest", + srcs = ["FutureCallTest.java"], + test_class = "dev.pigweed.pw_rpc.FutureCallTest", deps = [ + ":test_proto_java_proto_lite", "//pw_rpc:packet_proto_java_lite", "//pw_rpc/java/main/dev/pigweed/pw_rpc:client", "@com_google_protobuf//java/lite", "@maven//:com_google_flogger_flogger_system_backend", "@maven//:com_google_truth_truth", + "@maven//:org_mockito_mockito_core", ], ) java_test( - name = "RpcManagerTest", + name = "IdsTest", size = "small", - srcs = ["RpcManagerTest.java"], - test_class = "dev.pigweed.pw_rpc.RpcManagerTest", + srcs = ["IdsTest.java"], + test_class = "dev.pigweed.pw_rpc.IdsTest", + deps = [ + "//pw_rpc/java/main/dev/pigweed/pw_rpc:client", + "@maven//:com_google_flogger_flogger_system_backend", + "@maven//:com_google_truth_truth", + ], +) + +java_test( + name = "PacketsTest", + size = "small", + srcs = ["PacketsTest.java"], + test_class = "dev.pigweed.pw_rpc.PacketsTest", deps = [ - ":test_proto_java_proto_lite", "//pw_rpc:packet_proto_java_lite", "//pw_rpc/java/main/dev/pigweed/pw_rpc:client", "@com_google_protobuf//java/lite", "@maven//:com_google_flogger_flogger_system_backend", "@maven//:com_google_truth_truth", - "@maven//:org_mockito_mockito_core", ], ) @@ -111,6 +130,7 @@ java_test( test_class = "dev.pigweed.pw_rpc.StreamObserverMethodClientTest", deps = [ ":test_proto_java_proto_lite", + "//pw_rpc:packet_proto_java_lite", "//pw_rpc/java/main/dev/pigweed/pw_rpc:client", "@com_google_protobuf//java/lite", "@maven//:com_google_flogger_flogger_system_backend", @@ -123,9 +143,10 @@ test_suite( name = "pw_rpc", tests = [ ":ClientTest", + ":EndpointTest", + ":FutureCallTest", ":IdsTest", ":PacketsTest", - ":RpcManagerTest", ":StreamObserverCallTest", ":StreamObserverMethodClientTest", ], diff --git a/pw_rpc/java/test/dev/pigweed/pw_rpc/ClientTest.java b/pw_rpc/java/test/dev/pigweed/pw_rpc/ClientTest.java index affffa598..b3e31984b 100644 --- a/pw_rpc/java/test/dev/pigweed/pw_rpc/ClientTest.java +++ b/pw_rpc/java/test/dev/pigweed/pw_rpc/ClientTest.java @@ -20,6 +20,7 @@ import static org.junit.Assert.assertThrows; import static org.junit.Assert.fail; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import com.google.common.collect.ImmutableList; @@ -33,7 +34,6 @@ import java.util.List; import org.junit.Before; import org.junit.Rule; import org.junit.Test; -import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.Mockito; import org.mockito.junit.MockitoJUnit; @@ -123,22 +123,40 @@ public final class ClientTest { } @Test - public void method_unknownMethod() { + public void method_invalidFormat() { assertThrows(IllegalArgumentException.class, () -> client.method(CHANNEL_ID, "")); assertThrows(IllegalArgumentException.class, () -> client.method(CHANNEL_ID, "one")); assertThrows(IllegalArgumentException.class, () -> client.method(CHANNEL_ID, "hello")); + } + + @Test + public void method_unknownService() { assertThrows( - IllegalArgumentException.class, () -> client.method(CHANNEL_ID, "abc.Service/Method")); - assertThrows(IllegalArgumentException.class, - () -> client.method(CHANNEL_ID, "pw.rpc.test1.TheTestService/NotAnRpc").method()); + InvalidRpcServiceException.class, () -> client.method(CHANNEL_ID, "abc.Service/Method")); + + Service service = new Service("throwaway.NotRealService", + Service.unaryMethod("NotAnRpc", SomeMessage.class, AnotherMessage.class)); + assertThrows(InvalidRpcServiceException.class, + () -> client.method(CHANNEL_ID, service.method("NotAnRpc"))); + } + + @Test + public void method_unknownMethodInKnownService() { + assertThrows(InvalidRpcServiceMethodException.class, + () -> client.method(CHANNEL_ID, "pw.rpc.test1.TheTestService/NotAnRpc")); + assertThrows(InvalidRpcServiceMethodException.class, + () -> client.method(CHANNEL_ID, "pw.rpc.test1.TheTestService", "NotAnRpc")); } @Test public void method_unknownChannel() { - assertThrows(IllegalArgumentException.class, - () -> client.method(0, "pw.rpc.test1.TheTestService/SomeUnary")); - assertThrows(IllegalArgumentException.class, - () -> client.method(999, "pw.rpc.test1.TheTestService/SomeUnary")); + MethodClient methodClient0 = client.method(0, "pw.rpc.test1.TheTestService/SomeUnary"); + assertThrows(InvalidRpcChannelException.class, + () -> methodClient0.invokeUnary(SomeMessage.getDefaultInstance())); + + MethodClient methodClient999 = client.method(999, "pw.rpc.test1.TheTestService/SomeUnary"); + assertThrows(InvalidRpcChannelException.class, + () -> methodClient999.invokeUnary(SomeMessage.getDefaultInstance())); } @Test @@ -178,13 +196,22 @@ public final class ClientTest { } @Test + public void method_accessFromMethodInstance() { + assertThat(client.method(CHANNEL_ID, UNARY_METHOD).method()).isSameInstanceAs(UNARY_METHOD); + assertThat(client.method(CHANNEL_ID, SERVER_STREAMING_METHOD).method()) + .isSameInstanceAs(SERVER_STREAMING_METHOD); + assertThat(client.method(CHANNEL_ID, CLIENT_STREAMING_METHOD).method()) + .isSameInstanceAs(CLIENT_STREAMING_METHOD); + } + + @Test public void processPacket_emptyPacket_isNotProcessed() { assertThat(client.processPacket(new byte[] {})).isFalse(); } @Test public void processPacket_invalidPacket_isNotProcessed() { - assertThat(client.processPacket("This is definitely not a packet!".getBytes(UTF_8))).isFalse(); + assertThat(client.processPacket("\uffff\uffff\uffffNot a packet!".getBytes(UTF_8))).isFalse(); } @Test @@ -323,7 +350,7 @@ public final class ClientTest { } @Test - @SuppressWarnings("unchecked") // No idea why, but this test causes "unchecked" warnings + @SuppressWarnings("unchecked") public void streamObserverClient_create_invokeMethod() throws Exception { Channel.Output mockChannelOutput = Mockito.mock(Channel.Output.class); Client client = Client.create(ImmutableList.of(new Channel(1, mockChannelOutput)), @@ -335,4 +362,59 @@ public final class ClientTest { verify(mockChannelOutput) .send(requestPacket("pw.rpc.test1.TheTestService", "SomeUnary", payload).toByteArray()); } + + @Test + public void closeChannel_abortsExisting() throws Exception { + MethodClient serverStreamMethod = + client.method(CHANNEL_ID, "pw.rpc.test1.TheTestService", "SomeServerStreaming"); + + Call call1 = serverStreamMethod.invokeServerStreaming(REQUEST_PAYLOAD, observer); + Call call2 = client.method(CHANNEL_ID, "pw.rpc.test1.TheTestService", "SomeClientStreaming") + .invokeClientStreaming(observer); + assertThat(call1.active()).isTrue(); + assertThat(call2.active()).isTrue(); + + assertThat(client.closeChannel(CHANNEL_ID)).isTrue(); + + assertThat(call1.active()).isFalse(); + assertThat(call2.active()).isFalse(); + + verify(observer, times(2)).onError(Status.ABORTED); + + assertThrows(InvalidRpcChannelException.class, + () -> serverStreamMethod.invokeServerStreaming(REQUEST_PAYLOAD, observer)); + } + + @Test + public void closeChannel_noCalls() { + assertThat(client.closeChannel(CHANNEL_ID)).isTrue(); + } + + @Test + public void closeChannel_knownChannel() { + assertThat(client.closeChannel(CHANNEL_ID + 100)).isFalse(); + } + + @Test + public void openChannel_uniqueChannel() throws Exception { + int newChannelId = CHANNEL_ID + 100; + Channel.Output channelOutput = Mockito.mock(Channel.Output.class); + client.openChannel(new Channel(newChannelId, channelOutput)); + + client.method(newChannelId, "pw.rpc.test1.TheTestService", "SomeUnary") + .invokeUnary(REQUEST_PAYLOAD, observer); + + verify(channelOutput) + .send(requestPacket("pw.rpc.test1.TheTestService", "SomeUnary", REQUEST_PAYLOAD) + .toBuilder() + .setChannelId(newChannelId) + .build() + .toByteArray()); + } + + @Test + public void openChannel_alreadyExists_throwsException() { + assertThrows(InvalidRpcChannelException.class, + () -> client.openChannel(new Channel(CHANNEL_ID, packet -> {}))); + } } diff --git a/pw_rpc/java/test/dev/pigweed/pw_rpc/EndpointTest.java b/pw_rpc/java/test/dev/pigweed/pw_rpc/EndpointTest.java new file mode 100644 index 000000000..91722ee70 --- /dev/null +++ b/pw_rpc/java/test/dev/pigweed/pw_rpc/EndpointTest.java @@ -0,0 +1,209 @@ +// Copyright 2021 The Pigweed Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. + +package dev.pigweed.pw_rpc; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.verifyNoMoreInteractions; + +import com.google.common.collect.ImmutableList; +import com.google.protobuf.MessageLite; +import dev.pigweed.pw_rpc.internal.Packet.PacketType; +import dev.pigweed.pw_rpc.internal.Packet.RpcPacket; +import org.junit.Rule; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +public final class EndpointTest { + @Rule public final MockitoRule mockito = MockitoJUnit.rule(); + + private static final Service SERVICE = new Service("pw.rpc.test1.TheTestService", + Service.unaryMethod("SomeUnary", SomeMessage.class, SomeMessage.class), + Service.serverStreamingMethod("SomeServerStreaming", SomeMessage.class, SomeMessage.class), + Service.clientStreamingMethod("SomeClientStreaming", SomeMessage.class, SomeMessage.class), + Service.bidirectionalStreamingMethod( + "SomeBidiStreaming", SomeMessage.class, SomeMessage.class)); + + private static final Method METHOD = SERVICE.method("SomeUnary"); + + private static final SomeMessage REQUEST_PAYLOAD = + SomeMessage.newBuilder().setMagicNumber(1337).build(); + private static final byte[] REQUEST = request(REQUEST_PAYLOAD); + private static final AnotherMessage RESPONSE_PAYLOAD = + AnotherMessage.newBuilder().setPayload("hello").build(); + private static final int CHANNEL_ID = 555; + + @Mock private Channel.Output mockOutput; + @Mock private StreamObserver<MessageLite> callEvents; + + private final Channel channel = new Channel(CHANNEL_ID, bytes -> mockOutput.send(bytes)); + private final Endpoint endpoint = new Endpoint(ImmutableList.of(channel)); + + private static byte[] request(MessageLite payload) { + return packetBuilder() + .setType(PacketType.REQUEST) + .setPayload(payload.toByteString()) + .build() + .toByteArray(); + } + + private static byte[] cancel() { + return packetBuilder() + .setType(PacketType.CLIENT_ERROR) + .setStatus(Status.CANCELLED.code()) + .build() + .toByteArray(); + } + + private static RpcPacket.Builder packetBuilder() { + return RpcPacket.newBuilder() + .setChannelId(CHANNEL_ID) + .setServiceId(SERVICE.id()) + .setMethodId(METHOD.id()); + } + + private AbstractCall<MessageLite, MessageLite> createCall(Endpoint endpoint, PendingRpc rpc) { + return StreamObserverCall.getFactory(callEvents).apply(endpoint, rpc); + } + + @Test + public void start_succeeds_rpcIsPending() throws Exception { + AbstractCall<MessageLite, MessageLite> call = + endpoint.invokeRpc(CHANNEL_ID, METHOD, this::createCall, REQUEST_PAYLOAD); + + verify(mockOutput).send(REQUEST); + assertThat(endpoint.abandon(call)).isTrue(); + } + + @Test + public void start_sendingFails_callsHandleError() throws Exception { + doThrow(new ChannelOutputException()).when(mockOutput).send(any()); + + assertThrows(ChannelOutputException.class, + () -> endpoint.invokeRpc(CHANNEL_ID, METHOD, this::createCall, REQUEST_PAYLOAD)); + + verify(mockOutput).send(REQUEST); + } + + @Test + public void abandon_rpcNoLongerPending() throws Exception { + AbstractCall<MessageLite, MessageLite> call = + endpoint.invokeRpc(CHANNEL_ID, METHOD, this::createCall, REQUEST_PAYLOAD); + assertThat(endpoint.abandon(call)).isTrue(); + + assertThat(endpoint.abandon(call)).isFalse(); + } + + @Test + public void abandon_sendsNoPackets() throws Exception { + AbstractCall<MessageLite, MessageLite> call = + endpoint.invokeRpc(CHANNEL_ID, METHOD, this::createCall, REQUEST_PAYLOAD); + verify(mockOutput).send(REQUEST); + verifyNoMoreInteractions(mockOutput); + + assertThat(endpoint.abandon(call)).isTrue(); + } + + @Test + public void cancel_rpcNoLongerPending() throws Exception { + AbstractCall<MessageLite, MessageLite> call = + endpoint.invokeRpc(CHANNEL_ID, METHOD, this::createCall, REQUEST_PAYLOAD); + assertThat(endpoint.cancel(call)).isTrue(); + + assertThat(endpoint.abandon(call)).isFalse(); + } + + @Test + public void cancel_sendsCancelPacket() throws Exception { + AbstractCall<MessageLite, MessageLite> call = + endpoint.invokeRpc(CHANNEL_ID, METHOD, this::createCall, REQUEST_PAYLOAD); + assertThat(endpoint.cancel(call)).isTrue(); + + verify(mockOutput).send(cancel()); + } + + @Test + public void open_sendsNoPacketsButRpcIsPending() { + AbstractCall<MessageLite, MessageLite> call = + endpoint.openRpc(CHANNEL_ID, METHOD, this::createCall); + + assertThat(call.active()).isTrue(); + assertThat(endpoint.abandon(call)).isTrue(); + verifyNoInteractions(mockOutput); + } + + @Test + public void ignoresActionsIfCallIsNotPending() throws Exception { + AbstractCall<MessageLite, MessageLite> call = + createCall(endpoint, PendingRpc.create(channel, METHOD)); + + assertThat(endpoint.cancel(call)).isFalse(); + assertThat(endpoint.abandon(call)).isFalse(); + assertThat(endpoint.clientStream(call, REQUEST_PAYLOAD)).isFalse(); + assertThat(endpoint.clientStreamEnd(call)).isFalse(); + } + + @Test + public void ignoresPacketsIfCallIsNotPending() throws Exception { + AbstractCall<MessageLite, MessageLite> call = + createCall(endpoint, PendingRpc.create(channel, METHOD)); + + assertThat(endpoint.cancel(call)).isFalse(); + assertThat(endpoint.abandon(call)).isFalse(); + + assertThat(endpoint.processClientPacket(call.rpc().method(), + packetBuilder() + .setType(PacketType.SERVER_STREAM) + .setPayload(RESPONSE_PAYLOAD.toByteString()) + .build())) + .isTrue(); + assertThat(endpoint.processClientPacket(call.rpc().method(), + packetBuilder() + .setType(PacketType.RESPONSE) + .setPayload(RESPONSE_PAYLOAD.toByteString()) + .build())) + .isTrue(); + assertThat(endpoint.processClientPacket(call.rpc().method(), + packetBuilder() + .setType(PacketType.SERVER_ERROR) + .setStatus(Status.ABORTED.code()) + .build())) + .isTrue(); + + assertThat(endpoint.processClientPacket(call.rpc().method(), + packetBuilder() + .setType(PacketType.CLIENT_STREAM) + .setPayload(REQUEST_PAYLOAD.toByteString()) + .build())) + .isTrue(); + assertThat(endpoint.processClientPacket(call.rpc().method(), + packetBuilder().setType(PacketType.CLIENT_STREAM_END).build())) + .isTrue(); + assertThat(endpoint.processClientPacket(call.rpc().method(), + packetBuilder() + .setType(PacketType.CLIENT_ERROR) + .setStatus(Status.ABORTED.code()) + .build())) + .isTrue(); + + verifyNoInteractions(callEvents); + } +} diff --git a/pw_rpc/java/test/dev/pigweed/pw_rpc/FutureCallTest.java b/pw_rpc/java/test/dev/pigweed/pw_rpc/FutureCallTest.java new file mode 100644 index 000000000..a1f11997a --- /dev/null +++ b/pw_rpc/java/test/dev/pigweed/pw_rpc/FutureCallTest.java @@ -0,0 +1,217 @@ +// Copyright 2021 The Pigweed Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. + +package dev.pigweed.pw_rpc; + +import static com.google.common.truth.Truth.assertThat; +import static com.google.common.util.concurrent.MoreExecutors.directExecutor; +import static org.junit.Assert.assertThrows; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import com.google.common.collect.ImmutableList; +import com.google.protobuf.ByteString; +import dev.pigweed.pw_rpc.Call.UnaryFuture; +import dev.pigweed.pw_rpc.FutureCall.StreamResponseFuture; +import dev.pigweed.pw_rpc.FutureCall.UnaryResponseFuture; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ExecutionException; +import org.junit.Rule; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +public final class FutureCallTest { + @Rule public final MockitoRule mockito = MockitoJUnit.rule(); + + private static final Service SERVICE = new Service("pw.rpc.test1.TheTestService", + Service.unaryMethod("SomeUnary", SomeMessage.class, AnotherMessage.class), + Service.clientStreamingMethod("SomeClient", SomeMessage.class, AnotherMessage.class), + Service.bidirectionalStreamingMethod( + "SomeBidirectional", SomeMessage.class, AnotherMessage.class)); + private static final Method METHOD = SERVICE.method("SomeUnary"); + private static final int CHANNEL_ID = 555; + + @Mock private Channel.Output mockOutput; + + private final Channel channel = new Channel(CHANNEL_ID, packet -> mockOutput.send(packet)); + private final Endpoint endpoint = new Endpoint(ImmutableList.of(channel)); + private final PendingRpc rpc = PendingRpc.create(channel, METHOD); + + @Test + public void unaryFuture_response_setsValue() throws Exception { + UnaryResponseFuture<SomeMessage, AnotherMessage> call = endpoint.invokeRpc( + CHANNEL_ID, METHOD, UnaryResponseFuture::new, SomeMessage.getDefaultInstance()); + + AnotherMessage response = AnotherMessage.newBuilder().setResultValue(1138).build(); + call.handleUnaryCompleted(response.toByteString(), Status.CANCELLED); + + assertThat(call.isDone()).isTrue(); + assertThat(call.get()).isEqualTo(UnaryResult.create(response, Status.CANCELLED)); + } + + @Test + public void unaryFuture_serverError_setsException() throws Exception { + UnaryResponseFuture<SomeMessage, AnotherMessage> call = endpoint.invokeRpc( + CHANNEL_ID, METHOD, UnaryResponseFuture::new, SomeMessage.getDefaultInstance()); + + call.handleError(Status.NOT_FOUND); + + assertThat(call.isDone()).isTrue(); + ExecutionException exception = assertThrows(ExecutionException.class, call::get); + assertThat(exception).hasCauseThat().isInstanceOf(RpcError.class); + + RpcError error = (RpcError) exception.getCause(); + assertThat(error).isNotNull(); + assertThat(error.rpc()).isEqualTo(rpc); + assertThat(error.status()).isEqualTo(Status.NOT_FOUND); + } + + @Test + public void unaryFuture_cancelOnCall_cancelsTheCallAndFuture() throws Exception { + UnaryFuture<SomeMessage> call = endpoint.invokeRpc( + CHANNEL_ID, METHOD, UnaryResponseFuture::new, SomeMessage.getDefaultInstance()); + assertThat(call.cancel()).isTrue(); + assertThat(call.isCancelled()).isTrue(); + + ExecutionException exception = assertThrows(ExecutionException.class, call::get); + assertThat(exception).hasCauseThat().isInstanceOf(RpcError.class); + + RpcError error = (RpcError) exception.getCause(); + assertThat(error).isNotNull(); + assertThat(error.rpc()).isEqualTo(rpc); + assertThat(error.status()).isEqualTo(Status.CANCELLED); + } + + @Test + public void unaryFuture_cancelOnFuture_cancelsTheCallAndFuture() throws Exception { + UnaryFuture<SomeMessage> call = endpoint.invokeRpc( + CHANNEL_ID, METHOD, UnaryResponseFuture::new, SomeMessage.getDefaultInstance()); + assertThat(call.cancel(true)).isTrue(); + assertThat(call.isCancelled()).isTrue(); + + ExecutionException exception = assertThrows(ExecutionException.class, call::get); + assertThat(exception).hasCauseThat().isInstanceOf(RpcError.class); + + RpcError error = (RpcError) exception.getCause(); + assertThat(error).isNotNull(); + assertThat(error.rpc()).isEqualTo(rpc); + assertThat(error.status()).isEqualTo(Status.CANCELLED); + } + + @Test + public void unaryFuture_cancelOnFutureSendFails_cancelsTheCallAndFuture() throws Exception { + UnaryFuture<SomeMessage> call = endpoint.invokeRpc( + CHANNEL_ID, METHOD, UnaryResponseFuture::new, SomeMessage.getDefaultInstance()); + + doThrow(new ChannelOutputException()).when(mockOutput).send(any()); + + assertThat(call.cancel(true)).isTrue(); + assertThat(call.isCancelled()).isTrue(); + + ExecutionException exception = assertThrows(ExecutionException.class, call::get); + assertThat(exception).hasCauseThat().isInstanceOf(RpcError.class); + + RpcError error = (RpcError) exception.getCause(); + assertThat(error).isNotNull(); + assertThat(error.rpc()).isEqualTo(rpc); + assertThat(error.status()).isEqualTo(Status.CANCELLED); + } + + @Test + public void unaryFuture_multipleResponses_setsException() throws Exception { + UnaryResponseFuture<SomeMessage, AnotherMessage> call = endpoint.invokeRpc( + CHANNEL_ID, METHOD, UnaryResponseFuture::new, SomeMessage.getDefaultInstance()); + + AnotherMessage response = AnotherMessage.newBuilder().setResultValue(1138).build(); + call.doHandleNext(response); + call.handleUnaryCompleted(ByteString.EMPTY, Status.OK); + + assertThat(call.isDone()).isTrue(); + ExecutionException exception = assertThrows(ExecutionException.class, call::get); + assertThat(exception).hasCauseThat().isInstanceOf(IllegalStateException.class); + } + + @Test + public void unaryFuture_addListener_calledOnCompletion() throws Exception { + UnaryResponseFuture<SomeMessage, AnotherMessage> call = endpoint.invokeRpc( + CHANNEL_ID, METHOD, UnaryResponseFuture::new, SomeMessage.getDefaultInstance()); + + Runnable listener = mock(Runnable.class); + call.addListener(listener, directExecutor()); + + AnotherMessage response = AnotherMessage.newBuilder().setResultValue(1138).build(); + call.handleUnaryCompleted(response.toByteString(), Status.OK); + + verify(listener, times(1)).run(); + } + + @Test + public void unaryFuture_exceptionDuringStart() throws Exception { + ChannelOutputException exceptionToThrow = new ChannelOutputException(); + doThrow(exceptionToThrow).when(mockOutput).send(any()); + + UnaryResponseFuture<SomeMessage, AnotherMessage> call = endpoint.invokeRpc( + CHANNEL_ID, METHOD, UnaryResponseFuture::new, SomeMessage.getDefaultInstance()); + + assertThat(call.error()).isEqualTo(Status.ABORTED); + ExecutionException exception = assertThrows(ExecutionException.class, call::get); + assertThat(exception).hasCauseThat().isInstanceOf(ChannelOutputException.class); + + assertThat(exception.getCause()).isSameInstanceAs(exceptionToThrow); + } + + @Test + public void bidirectionalStreamingFuture_responses_setsValue() throws Exception { + List<AnotherMessage> responses = new ArrayList<>(); + StreamResponseFuture<SomeMessage, AnotherMessage> call = endpoint.invokeRpc(CHANNEL_ID, + METHOD, + StreamResponseFuture.getFactory(responses::add), + SomeMessage.getDefaultInstance()); + + AnotherMessage message = AnotherMessage.newBuilder().setResultValue(1138).build(); + call.doHandleNext(message); + call.doHandleNext(message); + assertThat(call.isDone()).isFalse(); + call.handleStreamCompleted(Status.OK); + + assertThat(call.isDone()).isTrue(); + assertThat(call.get()).isEqualTo(Status.OK); + assertThat(responses).containsExactly(message, message); + } + + @Test + public void bidirectionalStreamingFuture_serverError_setsException() throws Exception { + StreamResponseFuture<SomeMessage, AnotherMessage> call = endpoint.invokeRpc(CHANNEL_ID, + METHOD, + StreamResponseFuture.getFactory(msg -> {}), + SomeMessage.getDefaultInstance()); + + call.handleError(Status.NOT_FOUND); + + assertThat(call.isDone()).isTrue(); + ExecutionException exception = assertThrows(ExecutionException.class, call::get); + assertThat(exception).hasCauseThat().isInstanceOf(RpcError.class); + + RpcError error = (RpcError) exception.getCause(); + assertThat(error).isNotNull(); + assertThat(error.rpc()).isEqualTo(rpc); + assertThat(error.status()).isEqualTo(Status.NOT_FOUND); + } +} diff --git a/pw_rpc/java/test/dev/pigweed/pw_rpc/IdsTest.java b/pw_rpc/java/test/dev/pigweed/pw_rpc/IdsTest.java index 1440c6dbf..f8f99762b 100644 --- a/pw_rpc/java/test/dev/pigweed/pw_rpc/IdsTest.java +++ b/pw_rpc/java/test/dev/pigweed/pw_rpc/IdsTest.java @@ -17,7 +17,6 @@ package dev.pigweed.pw_rpc; import static com.google.common.truth.Truth.assertThat; import org.junit.Test; -import org.junit.runner.RunWith; public final class IdsTest { @Test diff --git a/pw_rpc/java/test/dev/pigweed/pw_rpc/PacketsTest.java b/pw_rpc/java/test/dev/pigweed/pw_rpc/PacketsTest.java index 6a1e771f8..5f75aacf1 100644 --- a/pw_rpc/java/test/dev/pigweed/pw_rpc/PacketsTest.java +++ b/pw_rpc/java/test/dev/pigweed/pw_rpc/PacketsTest.java @@ -20,14 +20,13 @@ import com.google.protobuf.ExtensionRegistryLite; import dev.pigweed.pw_rpc.internal.Packet.PacketType; import dev.pigweed.pw_rpc.internal.Packet.RpcPacket; import org.junit.Test; -import org.junit.runner.RunWith; public final class PacketsTest { private static final Service SERVICE = new Service("Greetings", Service.unaryMethod("Hello", RpcPacket.class, RpcPacket.class)); private static final PendingRpc RPC = - PendingRpc.create(new Channel(123, null), SERVICE, SERVICE.method("Hello")); + PendingRpc.create(new Channel(123, null), SERVICE.method("Hello")); private static final RpcPacket PACKET = RpcPacket.newBuilder() .setChannelId(123) diff --git a/pw_rpc/java/test/dev/pigweed/pw_rpc/RpcManagerTest.java b/pw_rpc/java/test/dev/pigweed/pw_rpc/RpcManagerTest.java deleted file mode 100644 index f25c66283..000000000 --- a/pw_rpc/java/test/dev/pigweed/pw_rpc/RpcManagerTest.java +++ /dev/null @@ -1,149 +0,0 @@ -// Copyright 2021 The Pigweed Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); you may not -// use this file except in compliance with the License. You may obtain a copy of -// the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -// License for the specific language governing permissions and limitations under -// the License. - -package dev.pigweed.pw_rpc; - -import static com.google.common.truth.Truth.assertThat; -import static org.junit.Assert.assertThrows; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.doThrow; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyNoMoreInteractions; - -import com.google.protobuf.MessageLite; -import dev.pigweed.pw_rpc.internal.Packet.PacketType; -import dev.pigweed.pw_rpc.internal.Packet.RpcPacket; -import org.junit.Before; -import org.junit.Rule; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.mockito.Mock; -import org.mockito.junit.MockitoJUnit; -import org.mockito.junit.MockitoRule; - -public final class RpcManagerTest { - @Rule public final MockitoRule mockito = MockitoJUnit.rule(); - - private static final Service SERVICE = new Service("pw.rpc.test1.TheTestService", - Service.unaryMethod("SomeUnary", SomeMessage.class, SomeMessage.class), - Service.serverStreamingMethod("SomeServerStreaming", SomeMessage.class, SomeMessage.class), - Service.clientStreamingMethod("SomeClientStreaming", SomeMessage.class, SomeMessage.class), - Service.bidirectionalStreamingMethod( - "SomeBidiStreaming", SomeMessage.class, SomeMessage.class)); - - private static final Method METHOD = SERVICE.method("SomeUnary"); - - private static final SomeMessage REQUEST_PAYLOAD = - SomeMessage.newBuilder().setMagicNumber(1337).build(); - private static final byte[] REQUEST = request(REQUEST_PAYLOAD); - private static final int CHANNEL_ID = 555; - - @Mock private Channel.Output mockOutput; - @Mock private StreamObserverCall<MessageLite, MessageLite> call; - - private PendingRpc rpc; - private RpcManager manager; - - @Before - public void setup() { - rpc = PendingRpc.create(new Channel(CHANNEL_ID, mockOutput), SERVICE, METHOD); - manager = new RpcManager(); - } - - private static byte[] request(MessageLite payload) { - return packetBuilder() - .setType(PacketType.REQUEST) - .setPayload(payload.toByteString()) - .build() - .toByteArray(); - } - - private static byte[] cancel() { - return packetBuilder() - .setType(PacketType.CLIENT_ERROR) - .setStatus(Status.CANCELLED.code()) - .build() - .toByteArray(); - } - - private static RpcPacket.Builder packetBuilder() { - return RpcPacket.newBuilder() - .setChannelId(CHANNEL_ID) - .setServiceId(SERVICE.id()) - .setMethodId(METHOD.id()); - } - - @Test - public void start_sendingFails_rpcNotPending() throws Exception { - doThrow(new ChannelOutputException()).when(mockOutput).send(any()); - - assertThrows(ChannelOutputException.class, () -> manager.start(rpc, call, REQUEST_PAYLOAD)); - - verify(mockOutput).send(REQUEST); - assertThat(manager.getPending(rpc)).isNull(); - } - - @Test - public void start_succeeds_rpcIsPending() throws Exception { - assertThat(manager.start(rpc, call, REQUEST_PAYLOAD)).isNull(); - - assertThat(manager.getPending(rpc)).isSameInstanceAs(call); - } - - @Test - public void startThenCancel_rpcNotPending() throws Exception { - assertThat(manager.start(rpc, call, REQUEST_PAYLOAD)).isNull(); - assertThat(manager.cancel(rpc)).isSameInstanceAs(call); - - assertThat(manager.getPending(rpc)).isNull(); - } - - @Test - public void startThenCancel_sendsCancelPacket() throws Exception { - assertThat(manager.start(rpc, call, REQUEST_PAYLOAD)).isNull(); - assertThat(manager.cancel(rpc)).isEqualTo(call); - - verify(mockOutput).send(cancel()); - } - - @Test - public void startThenClear_sendsNothing() throws Exception { - verifyNoMoreInteractions(mockOutput); - - assertThat(manager.start(rpc, call, REQUEST_PAYLOAD)).isNull(); - assertThat(manager.clear(rpc)).isEqualTo(call); - } - - @Test - public void clear_notPending_returnsNull() { - assertThat(manager.clear(rpc)).isNull(); - } - - @Test - public void open_sendingFails_rpcIsPending() throws Exception { - doThrow(new ChannelOutputException()).when(mockOutput).send(any()); - - assertThat(manager.open(rpc, call, REQUEST_PAYLOAD)).isNull(); - - verify(mockOutput).send(REQUEST); - assertThat(manager.getPending(rpc)).isSameInstanceAs(call); - } - - @Test - public void open_success_rpcIsPending() { - assertThat(manager.open(rpc, call, REQUEST_PAYLOAD)).isNull(); - - assertThat(manager.getPending(rpc)).isSameInstanceAs(call); - } -} diff --git a/pw_rpc/java/test/dev/pigweed/pw_rpc/StreamObserverCallTest.java b/pw_rpc/java/test/dev/pigweed/pw_rpc/StreamObserverCallTest.java index f4eba688b..76d86c162 100644 --- a/pw_rpc/java/test/dev/pigweed/pw_rpc/StreamObserverCallTest.java +++ b/pw_rpc/java/test/dev/pigweed/pw_rpc/StreamObserverCallTest.java @@ -15,22 +15,15 @@ package dev.pigweed.pw_rpc; import static com.google.common.truth.Truth.assertThat; -import static org.junit.Assert.assertThrows; -import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; -import dev.pigweed.pw_rpc.StreamObserverCall.StreamResponseFuture; -import dev.pigweed.pw_rpc.StreamObserverCall.UnaryResponseFuture; +import com.google.common.collect.ImmutableList; import dev.pigweed.pw_rpc.internal.Packet.PacketType; import dev.pigweed.pw_rpc.internal.Packet.RpcPacket; -import java.util.ArrayList; -import java.util.List; -import java.util.concurrent.ExecutionException; import org.junit.Before; import org.junit.Rule; import org.junit.Test; -import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnit; import org.mockito.junit.MockitoRule; @@ -49,9 +42,9 @@ public final class StreamObserverCallTest { @Mock private StreamObserver<AnotherMessage> observer; @Mock private Channel.Output mockOutput; - private final RpcManager rpcManager = new RpcManager(); + private final Channel channel = new Channel(CHANNEL_ID, packet -> mockOutput.send(packet)); + private final Endpoint endpoint = new Endpoint(ImmutableList.of(channel)); private StreamObserverCall<SomeMessage, AnotherMessage> streamObserverCall; - private PendingRpc rpc; private static byte[] cancel() { return packetBuilder() @@ -70,9 +63,8 @@ public final class StreamObserverCallTest { @Before public void createCall() throws Exception { - rpc = PendingRpc.create(new Channel(CHANNEL_ID, mockOutput), SERVICE, METHOD); - streamObserverCall = StreamObserverCall.start(rpcManager, rpc, observer, null); - rpcManager.start(rpc, streamObserverCall, SomeMessage.getDefaultInstance()); + streamObserverCall = + endpoint.invokeRpc(CHANNEL_ID, METHOD, StreamObserverCall.getFactory(observer), null); } @Test @@ -100,9 +92,23 @@ public final class StreamObserverCallTest { } @Test + public void abandon_doesNotSendCancelPacket() throws Exception { + streamObserverCall.abandon(); + + verify(mockOutput, never()).send(cancel()); + } + + @Test + public void abandon_deactivates() { + streamObserverCall.abandon(); + + assertThat(streamObserverCall.active()).isFalse(); + } + + @Test public void send_sendsClientStreamPacket() throws Exception { SomeMessage request = SomeMessage.newBuilder().setMagicNumber(123).build(); - streamObserverCall.send(request); + streamObserverCall.write(request); verify(mockOutput) .send(packetBuilder() @@ -116,9 +122,7 @@ public final class StreamObserverCallTest { public void send_raisesExceptionIfClosed() throws Exception { streamObserverCall.cancel(); - RpcError thrown = assertThrows( - RpcError.class, () -> streamObserverCall.send(SomeMessage.getDefaultInstance())); - assertThat(thrown.status()).isSameInstanceAs(Status.CANCELLED); + assertThat(streamObserverCall.write(SomeMessage.getDefaultInstance())).isFalse(); } @Test @@ -131,29 +135,21 @@ public final class StreamObserverCallTest { @Test public void onNext_callsObserverIfActive() { - streamObserverCall.onNext(AnotherMessage.getDefaultInstance().toByteString()); + streamObserverCall.handleNext(AnotherMessage.getDefaultInstance().toByteString()); verify(observer).onNext(AnotherMessage.getDefaultInstance()); } @Test - public void onNext_ignoresIfNotActive() throws Exception { - streamObserverCall.cancel(); - streamObserverCall.onNext(AnotherMessage.getDefaultInstance().toByteString()); - - verify(observer, never()).onNext(any()); - } - - @Test public void callDispatcher_onCompleted_callsObserver() { - streamObserverCall.onCompleted(Status.ABORTED); + streamObserverCall.handleStreamCompleted(Status.ABORTED); verify(observer).onCompleted(Status.ABORTED); } @Test public void callDispatcher_onCompleted_setsActiveAndStatus() { - streamObserverCall.onCompleted(Status.ABORTED); + streamObserverCall.handleStreamCompleted(Status.ABORTED); verify(observer).onCompleted(Status.ABORTED); assertThat(streamObserverCall.active()).isFalse(); @@ -162,109 +158,17 @@ public final class StreamObserverCallTest { @Test public void callDispatcher_onError_callsObserver() { - streamObserverCall.onError(Status.NOT_FOUND); + streamObserverCall.handleError(Status.NOT_FOUND); verify(observer).onError(Status.NOT_FOUND); } @Test public void callDispatcher_onError_deactivates() { - streamObserverCall.onError(Status.ABORTED); + streamObserverCall.handleError(Status.ABORTED); verify(observer).onError(Status.ABORTED); assertThat(streamObserverCall.active()).isFalse(); assertThat(streamObserverCall.status()).isNull(); } - - @Test - public void unaryFuture_response_setsValue() throws Exception { - UnaryResponseFuture<SomeMessage, AnotherMessage> call = - new UnaryResponseFuture<>(rpcManager, rpc, SomeMessage.getDefaultInstance()); - - AnotherMessage response = AnotherMessage.newBuilder().setResultValue(1138).build(); - call.onNext(response); - assertThat(call.isDone()).isFalse(); - call.onCompleted(Status.CANCELLED); - - assertThat(call.isDone()).isTrue(); - assertThat(call.get()).isEqualTo(UnaryResult.create(response, Status.CANCELLED)); - } - - @Test - public void unaryFuture_serverError_setsException() throws Exception { - UnaryResponseFuture<SomeMessage, AnotherMessage> call = - new UnaryResponseFuture<>(rpcManager, rpc, SomeMessage.getDefaultInstance()); - - call.onError(Status.NOT_FOUND); - - assertThat(call.isDone()).isTrue(); - ExecutionException exception = assertThrows(ExecutionException.class, call::get); - assertThat(exception).hasCauseThat().isInstanceOf(RpcError.class); - - RpcError error = (RpcError) exception.getCause(); - assertThat(error).isNotNull(); - assertThat(error.rpc()).isEqualTo(rpc); - assertThat(error.status()).isEqualTo(Status.NOT_FOUND); - } - - @Test - public void unaryFuture_noMessage_setsException() throws Exception { - UnaryResponseFuture<SomeMessage, AnotherMessage> call = - new UnaryResponseFuture<>(rpcManager, rpc, SomeMessage.getDefaultInstance()); - - call.onCompleted(Status.OK); - - assertThat(call.isDone()).isTrue(); - ExecutionException exception = assertThrows(ExecutionException.class, call::get); - assertThat(exception).hasCauseThat().isInstanceOf(IllegalStateException.class); - } - - @Test - public void unaryFuture_multipleResponses_setsException() throws Exception { - UnaryResponseFuture<SomeMessage, AnotherMessage> call = - new UnaryResponseFuture<>(rpcManager, rpc, SomeMessage.getDefaultInstance()); - - AnotherMessage response = AnotherMessage.newBuilder().setResultValue(1138).build(); - call.onNext(response); - call.onNext(response); - call.onCompleted(Status.OK); - - assertThat(call.isDone()).isTrue(); - ExecutionException exception = assertThrows(ExecutionException.class, call::get); - assertThat(exception).hasCauseThat().isInstanceOf(IllegalStateException.class); - } - - @Test - public void bidirectionalStreamingfuture_responses_setsValue() throws Exception { - List<AnotherMessage> responses = new ArrayList<>(); - StreamResponseFuture<SomeMessage, AnotherMessage> call = - new StreamResponseFuture<>(rpcManager, rpc, responses::add, null); - - AnotherMessage message = AnotherMessage.newBuilder().setResultValue(1138).build(); - call.onNext(message); - call.onNext(message); - assertThat(call.isDone()).isFalse(); - call.onCompleted(Status.OK); - - assertThat(call.isDone()).isTrue(); - assertThat(call.get()).isEqualTo(Status.OK); - assertThat(responses).containsExactly(message, message); - } - - @Test - public void bidirectionalStreamingfuture_serverError_setsException() throws Exception { - StreamResponseFuture<SomeMessage, AnotherMessage> call = - new StreamResponseFuture<>(rpcManager, rpc, (msg) -> {}, null); - - call.onError(Status.NOT_FOUND); - - assertThat(call.isDone()).isTrue(); - ExecutionException exception = assertThrows(ExecutionException.class, call::get); - assertThat(exception).hasCauseThat().isInstanceOf(RpcError.class); - - RpcError error = (RpcError) exception.getCause(); - assertThat(error).isNotNull(); - assertThat(error.rpc()).isEqualTo(rpc); - assertThat(error.status()).isEqualTo(Status.NOT_FOUND); - } } diff --git a/pw_rpc/java/test/dev/pigweed/pw_rpc/StreamObserverMethodClientTest.java b/pw_rpc/java/test/dev/pigweed/pw_rpc/StreamObserverMethodClientTest.java index 55fd88aa5..36102f447 100644 --- a/pw_rpc/java/test/dev/pigweed/pw_rpc/StreamObserverMethodClientTest.java +++ b/pw_rpc/java/test/dev/pigweed/pw_rpc/StreamObserverMethodClientTest.java @@ -16,14 +16,18 @@ package dev.pigweed.pw_rpc; import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertThrows; -import static org.mockito.Mockito.mock; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import com.google.common.collect.ImmutableList; import com.google.protobuf.MessageLite; +import dev.pigweed.pw_rpc.internal.Packet.PacketType; +import dev.pigweed.pw_rpc.internal.Packet.RpcPacket; import org.junit.Before; import org.junit.Rule; import org.junit.Test; -import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnit; import org.mockito.junit.MockitoRule; @@ -36,22 +40,24 @@ public final class StreamObserverMethodClientTest { Service.bidirectionalStreamingMethod( "SomeBidirectionalStreaming", SomeMessage.class, AnotherMessage.class)); - private static final Channel CHANNEL = new Channel(1, (bytes) -> {}); - - private static final PendingRpc UNARY_RPC = - PendingRpc.create(CHANNEL, SERVICE, SERVICE.method("SomeUnary")); - private static final PendingRpc SERVER_STREAMING_RPC = - PendingRpc.create(CHANNEL, SERVICE, SERVICE.method("SomeServerStreaming")); - private static final PendingRpc CLIENT_STREAMING_RPC = - PendingRpc.create(CHANNEL, SERVICE, SERVICE.method("SomeClientStreaming")); - private static final PendingRpc BIDIRECTIONAL_STREAMING_RPC = - PendingRpc.create(CHANNEL, SERVICE, SERVICE.method("SomeBidirectionalStreaming")); - @Rule public final MockitoRule mockito = MockitoJUnit.rule(); @Mock private StreamObserver<MessageLite> defaultObserver; + @Mock private StreamObserver<AnotherMessage> observer; + @Mock private Channel.Output channelOutput; + + // Wrap Channel.Output since channelOutput will be null when the channel is initialized. + private final Channel channel = new Channel(1, bytes -> channelOutput.send(bytes)); - private final RpcManager rpcManager = new RpcManager(); + private final PendingRpc unary_rpc = PendingRpc.create(channel, SERVICE.method("SomeUnary")); + private final PendingRpc server_streaming_rpc = + PendingRpc.create(channel, SERVICE.method("SomeServerStreaming")); + private final PendingRpc client_streaming_rpc = + PendingRpc.create(channel, SERVICE.method("SomeClientStreaming")); + private final PendingRpc bidirectional_streaming_rpc = + PendingRpc.create(channel, SERVICE.method("SomeBidirectionalStreaming")); + + private final Client client = Client.create(ImmutableList.of(channel), ImmutableList.of(SERVICE)); private MethodClient unaryMethodClient; private MethodClient serverStreamingMethodClient; private MethodClient clientStreamingMethodClient; @@ -59,33 +65,30 @@ public final class StreamObserverMethodClientTest { @Before public void createMethodClient() { - unaryMethodClient = new MethodClient(rpcManager, UNARY_RPC, defaultObserver); + unaryMethodClient = new MethodClient(client, channel.id(), unary_rpc.method(), defaultObserver); serverStreamingMethodClient = - new MethodClient(rpcManager, SERVER_STREAMING_RPC, defaultObserver); + new MethodClient(client, channel.id(), server_streaming_rpc.method(), defaultObserver); clientStreamingMethodClient = - new MethodClient(rpcManager, CLIENT_STREAMING_RPC, defaultObserver); - bidirectionalStreamingMethodClient = - new MethodClient(rpcManager, BIDIRECTIONAL_STREAMING_RPC, defaultObserver); + new MethodClient(client, channel.id(), client_streaming_rpc.method(), defaultObserver); + bidirectionalStreamingMethodClient = new MethodClient( + client, channel.id(), bidirectional_streaming_rpc.method(), defaultObserver); } @Test public void invokeWithNoObserver_usesDefaultObserver() throws Exception { unaryMethodClient.invokeUnary(SomeMessage.getDefaultInstance()); AnotherMessage reply = AnotherMessage.newBuilder().setPayload("yo").build(); - rpcManager.getPending(UNARY_RPC).onNext(reply.toByteString()); + + assertThat(client.processPacket(responsePacket(unary_rpc, reply))).isTrue(); verify(defaultObserver).onNext(reply); } @Test public void invoke_usesProvidedObserver() throws Exception { - @SuppressWarnings("unchecked") - StreamObserver<AnotherMessage> observer = - (StreamObserver<AnotherMessage>) mock(StreamObserver.class); - unaryMethodClient.invokeUnary(SomeMessage.getDefaultInstance(), observer); AnotherMessage reply = AnotherMessage.newBuilder().setPayload("yo").build(); - rpcManager.getPending(UNARY_RPC).onNext(reply.toByteString()); + assertThat(client.processPacket(responsePacket(unary_rpc, reply))).isTrue(); verify(observer).onNext(reply); } @@ -93,75 +96,86 @@ public final class StreamObserverMethodClientTest { @Test public void invokeUnary_startsRpc() throws Exception { Call call = unaryMethodClient.invokeUnary(SomeMessage.getDefaultInstance()); - assertThat(rpcManager.getPending(UNARY_RPC)).isSameInstanceAs(call); + assertThat(call.active()).isTrue(); + verify(channelOutput, times(1)).send(any()); } @Test - public void openUnary_startsRpc() { - Call call = unaryMethodClient.openUnary(SomeMessage.getDefaultInstance(), defaultObserver); - assertThat(rpcManager.getPending(UNARY_RPC)).isSameInstanceAs(call); + public void openUnary_startsRpc() throws Exception { + Call call = unaryMethodClient.openUnary(defaultObserver); + assertThat(call.active()).isTrue(); + verify(channelOutput, never()).send(any()); } @Test public void invokeServerStreaming_startsRpc() throws Exception { Call call = serverStreamingMethodClient.invokeServerStreaming(SomeMessage.getDefaultInstance()); - assertThat(rpcManager.getPending(SERVER_STREAMING_RPC)).isSameInstanceAs(call); + assertThat(call.active()).isTrue(); + verify(channelOutput, times(1)).send(any()); } @Test - public void openServerStreaming_startsRpc() { - Call call = serverStreamingMethodClient.openServerStreaming( - SomeMessage.getDefaultInstance(), defaultObserver); - assertThat(rpcManager.getPending(SERVER_STREAMING_RPC)).isSameInstanceAs(call); + public void openServerStreaming_startsRpc() throws Exception { + Call call = serverStreamingMethodClient.openServerStreaming(defaultObserver); + assertThat(call.active()).isTrue(); + verify(channelOutput, never()).send(any()); } @Test public void invokeClientStreaming_startsRpc() throws Exception { Call call = clientStreamingMethodClient.invokeClientStreaming(); - assertThat(rpcManager.getPending(CLIENT_STREAMING_RPC)).isSameInstanceAs(call); + assertThat(call.active()).isTrue(); + verify(channelOutput, times(1)).send(any()); } @Test - public void openClientStreaming_startsRpc() { + public void openClientStreaming_startsRpc() throws Exception { Call call = clientStreamingMethodClient.openClientStreaming(defaultObserver); - assertThat(rpcManager.getPending(CLIENT_STREAMING_RPC)).isSameInstanceAs(call); + assertThat(call.active()).isTrue(); + verify(channelOutput, never()).send(any()); } @Test public void invokeBidirectionalStreaming_startsRpc() throws Exception { Call call = bidirectionalStreamingMethodClient.invokeBidirectionalStreaming(); - assertThat(rpcManager.getPending(BIDIRECTIONAL_STREAMING_RPC)).isSameInstanceAs(call); + assertThat(call.active()).isTrue(); + verify(channelOutput, times(1)).send(any()); } @Test - public void openBidirectionalStreaming_startsRpc() { + public void openBidirectionalStreaming_startsRpc() throws Exception { Call call = bidirectionalStreamingMethodClient.openBidirectionalStreaming(defaultObserver); - assertThat(rpcManager.getPending(BIDIRECTIONAL_STREAMING_RPC)).isSameInstanceAs(call); + assertThat(call.active()).isTrue(); + verify(channelOutput, never()).send(any()); } @Test - public void invokeUnaryFuture_startsRpc() { - unaryMethodClient.invokeUnaryFuture(SomeMessage.getDefaultInstance()); - assertThat(rpcManager.getPending(UNARY_RPC)).isNotNull(); + public void invokeUnaryFuture_startsRpc() throws Exception { + Call call = unaryMethodClient.invokeUnaryFuture(SomeMessage.getDefaultInstance()); + assertThat(call.active()).isTrue(); + verify(channelOutput, times(1)).send(any()); } @Test - public void invokeServerStreamingFuture_startsRpc() { - serverStreamingMethodClient.invokeServerStreamingFuture( + public void invokeServerStreamingFuture_startsRpc() throws Exception { + Call call = serverStreamingMethodClient.invokeServerStreamingFuture( SomeMessage.getDefaultInstance(), (msg) -> {}); - assertThat(rpcManager.getPending(SERVER_STREAMING_RPC)).isNotNull(); + assertThat(call.active()).isTrue(); + verify(channelOutput, times(1)).send(any()); } @Test - public void invokeClientStreamingFuture_startsRpc() { - clientStreamingMethodClient.invokeClientStreamingFuture(); - assertThat(rpcManager.getPending(CLIENT_STREAMING_RPC)).isNotNull(); + public void invokeClientStreamingFuture_startsRpc() throws Exception { + Call call = clientStreamingMethodClient.invokeClientStreamingFuture(); + assertThat(call.active()).isTrue(); + verify(channelOutput, times(1)).send(any()); } @Test - public void invokeBidirectionalStreamingFuture_startsRpc() { - bidirectionalStreamingMethodClient.invokeBidirectionalStreamingFuture((msg) -> {}); - assertThat(rpcManager.getPending(BIDIRECTIONAL_STREAMING_RPC)).isNotNull(); + public void invokeBidirectionalStreamingFuture_startsRpc() throws Exception { + Call call = bidirectionalStreamingMethodClient.invokeBidirectionalStreamingFuture((msg) -> {}); + assertThat(call.active()).isTrue(); + verify(channelOutput, times(1)).send(any()); } @Test @@ -187,4 +201,50 @@ public final class StreamObserverMethodClientTest { assertThrows(UnsupportedOperationException.class, () -> clientStreamingMethodClient.invokeBidirectionalStreaming()); } + + @Test + public void invalidChannel_throwsException() { + MethodClient methodClient = + new MethodClient(client, 999, client_streaming_rpc.method(), defaultObserver); + assertThrows(InvalidRpcChannelException.class, methodClient::invokeClientStreaming); + } + + @Test + public void invalidService_throwsException() { + Service otherService = new Service("something.Else", + Service.clientStreamingMethod("ClientStream", SomeMessage.class, AnotherMessage.class)); + + MethodClient methodClient = new MethodClient( + client, channel.id(), otherService.method("ClientStream"), defaultObserver); + assertThrows(InvalidRpcServiceException.class, methodClient::invokeClientStreaming); + } + + @Test + public void invalidMethod_throwsException() { + Service serviceWithDifferentUnaryMethod = new Service("pw.rpc.test1.TheTestService", + Service.unaryMethod("SomeUnary", AnotherMessage.class, AnotherMessage.class), + Service.serverStreamingMethod( + "SomeServerStreaming", SomeMessage.class, AnotherMessage.class), + Service.clientStreamingMethod( + "SomeClientStreaming", SomeMessage.class, AnotherMessage.class), + Service.bidirectionalStreamingMethod( + "SomeBidirectionalStreaming", SomeMessage.class, AnotherMessage.class)); + + MethodClient methodClient = new MethodClient( + client, 999, serviceWithDifferentUnaryMethod.method("SomeUnary"), defaultObserver); + assertThrows(InvalidRpcServiceMethodException.class, + () -> methodClient.invokeUnary(AnotherMessage.getDefaultInstance())); + } + + private static byte[] responsePacket(PendingRpc rpc, MessageLite payload) { + return RpcPacket.newBuilder() + .setChannelId(1) + .setServiceId(rpc.service().id()) + .setMethodId(rpc.method().id()) + .setType(PacketType.RESPONSE) + .setStatus(Status.OK.code()) + .setPayload(payload.toByteString()) + .build() + .toByteArray(); + } } diff --git a/pw_rpc/java/test/dev/pigweed/pw_rpc/TestClient.java b/pw_rpc/java/test/dev/pigweed/pw_rpc/TestClient.java index 6688e2422..a58d51901 100644 --- a/pw_rpc/java/test/dev/pigweed/pw_rpc/TestClient.java +++ b/pw_rpc/java/test/dev/pigweed/pw_rpc/TestClient.java @@ -14,17 +14,19 @@ package dev.pigweed.pw_rpc; -import static java.util.Arrays.stream; - import com.google.common.collect.ImmutableList; import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.MessageLite; +import com.google.protobuf.MessageLiteOrBuilder; import dev.pigweed.pw_rpc.internal.Packet.PacketType; import dev.pigweed.pw_rpc.internal.Packet.RpcPacket; +import java.util.ArrayDeque; import java.util.ArrayList; +import java.util.Arrays; import java.util.EnumMap; import java.util.List; import java.util.Map; +import java.util.Queue; import java.util.stream.Collectors; import javax.annotation.Nullable; @@ -38,29 +40,35 @@ public class TestClient { private final Client client; private final List<RpcPacket> sentPackets = new ArrayList<>(); - private final List<RpcPacket> enqueuedPackets = new ArrayList<>(); - private int receiveEnqueuedPacketsAfter = 1; - final Map<PacketType, Integer> sentPayloadIndices = new EnumMap<>(PacketType.class); + private final Queue<EnqueuedPackets> enqueuedPackets = new ArrayDeque<>(); + private final Map<PacketType, Integer> sentPayloadIndices = new EnumMap<>(PacketType.class); + + @Nullable private ChannelOutputException channelOutputException = null; + + private static class EnqueuedPackets { + private int processAfterSentPackets; + private final List<RpcPacket> packets; - @Nullable ChannelOutputException channelOutputException = null; + private EnqueuedPackets(int processAfterSentPackets, List<RpcPacket> packets) { + this.processAfterSentPackets = processAfterSentPackets; + this.packets = packets; + } + + private boolean shouldProcessEnqueuedPackets() { + return processAfterSentPackets-- <= 1; + } + } public TestClient(List<Service> services) { Channel.Output channelOutput = packet -> { - if (channelOutputException == null) { - sentPackets.add(parsePacket(packet)); - } else { + if (channelOutputException != null) { throw channelOutputException; } + sentPackets.add(parsePacket(packet)); - // Process any enqueued packets. - if (receiveEnqueuedPacketsAfter > 1) { - receiveEnqueuedPacketsAfter -= 1; - return; - } - if (!enqueuedPackets.isEmpty()) { - List<RpcPacket> packetsToProcess = new ArrayList<>(enqueuedPackets); - enqueuedPackets.clear(); - packetsToProcess.forEach(this::processPacket); + if (!enqueuedPackets.isEmpty() && enqueuedPackets.peek().shouldProcessEnqueuedPackets()) { + // Process any enqueued packets. + enqueuedPackets.remove().packets.forEach(this::processPacket); } }; client = Client.create(ImmutableList.of(new Channel(CHANNEL_ID, channelOutput)), services); @@ -86,52 +94,34 @@ public class TestClient { } /** Simulates receiving SERVER_STREAM packets from the server. */ - public void receiveServerStream(String service, String method, MessageLite... payloads) { + public void receiveServerStream(String service, String method, MessageLiteOrBuilder... payloads) { RpcPacket base = startPacket(service, method, PacketType.SERVER_STREAM).build(); - for (MessageLite payload : payloads) { - processPacket(RpcPacket.newBuilder(base).setPayload(payload.toByteString())); + for (MessageLiteOrBuilder payload : payloads) { + processPacket(RpcPacket.newBuilder(base).setPayload(getMessage(payload).toByteString())); } } - public void receiveServerStream(String service, String method, MessageLite.Builder... builders) { - receiveServerStream(service, - method, - stream(builders).map(MessageLite.Builder::build).toArray(MessageLite[] ::new)); - } - /** * Enqueues a SERVER_STREAM packet so that the client receives it after a packet is sent. * + * This function may be called multiple times to create a queue of packets to process as different + * packets are sent. + * * @param afterPackets Wait until this many packets have been sent before the client receives - * these stream packets. The minimum value (and the default) is 1. + * these stream packets. The minimum value is 1. If multiple stream packets are queued, + * afterPackets is counted from the packet before it in the queue. */ public void enqueueServerStream( - String service, String method, int afterPackets, MessageLite... payloads) { + String service, String method, int afterPackets, MessageLiteOrBuilder... payloads) { if (afterPackets < 1) { throw new IllegalArgumentException("afterPackets must be at least 1"); } - if (afterPackets != 1 && receiveEnqueuedPacketsAfter != 1) { - throw new AssertionError( - "May only set afterPackets once before enqueued packets are processed"); - } - receiveEnqueuedPacketsAfter = afterPackets; RpcPacket base = startPacket(service, method, PacketType.SERVER_STREAM).build(); - for (MessageLite payload : payloads) { - enqueuedPackets.add(RpcPacket.newBuilder(base).setPayload(payload.toByteString()).build()); - } - } - - public void enqueueServerStream(String service, String method, MessageLite... payloads) { - enqueueServerStream(service, method, 1, payloads); - } - - public void enqueueServerStream( - String service, String method, int afterPackets, MessageLite.Builder... builders) { - enqueueServerStream(service, - method, - afterPackets, - stream(builders).map(MessageLite.Builder::build).toArray(MessageLite[] ::new)); + enqueuedPackets.add(new EnqueuedPackets(afterPackets, + Arrays.stream(payloads) + .map(m -> RpcPacket.newBuilder(base).setPayload(getMessage(m).toByteString()).build()) + .collect(Collectors.toList()))); } /** Simulates receiving a SERVER_ERROR packet from the server. */ @@ -192,4 +182,14 @@ public class TestClient { throw new AssertionError("Decoding sent packet payload failed", e); } } + + private MessageLite getMessage(MessageLiteOrBuilder messageOrBuilder) { + if (messageOrBuilder instanceof MessageLite.Builder) { + return ((MessageLite.Builder) messageOrBuilder).build(); + } + if (messageOrBuilder instanceof MessageLite) { + return (MessageLite) messageOrBuilder; + } + throw new AssertionError("Unexpected MessageLiteOrBuilder class"); + } } |