aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--pw_rpc/java/main/dev/pigweed/pw_rpc/Endpoint.java117
1 files changed, 80 insertions, 37 deletions
diff --git a/pw_rpc/java/main/dev/pigweed/pw_rpc/Endpoint.java b/pw_rpc/java/main/dev/pigweed/pw_rpc/Endpoint.java
index 810f1fc90..974a13204 100644
--- a/pw_rpc/java/main/dev/pigweed/pw_rpc/Endpoint.java
+++ b/pw_rpc/java/main/dev/pigweed/pw_rpc/Endpoint.java
@@ -21,6 +21,8 @@ import dev.pigweed.pw_rpc.internal.Packet.RpcPacket;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.LinkedBlockingQueue;
import java.util.function.BiFunction;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
@@ -30,12 +32,19 @@ import javax.annotation.Nullable;
*
* The RPC endpoint handles all RPC-related events and actions. It synchronizes interactions between
* the endpoint and any threads interacting with RPC call objects.
+ *
+ * The Endpoint's intrinsic lock is held when updating the channels or pending calls lists. Call
+ * objects only make updates to their own state through function calls made from the Endpoint, which
+ * ensures their states are also guarded by the Endpoint's lock. Updates to call objects are
+ * enqueued while the lock is held and processed after releasing the lock. This ensures updates
+ * occur in order without needing to hold the Endpoint's lock while possibly executing user code.
*/
class Endpoint {
private static final Logger logger = Logger.forClass(Endpoint.class);
private final Map<Integer, Channel> channels;
private final Map<PendingRpc, AbstractCall<?, ?>> pending = new HashMap<>();
+ private final BlockingQueue<Runnable> callUpdates = new LinkedBlockingQueue<>();
public Endpoint(List<Channel> channels) {
this.channels = channels.stream().collect(Collectors.toMap(Channel::id, c -> c));
@@ -102,24 +111,52 @@ class Endpoint {
pending.put(call.rpc(), call);
}
+ /** Enqueues call object updates to make after release the Endpoint's lock. */
+ private void enqueueCallUpdate(Runnable callUpdate) {
+ while (!callUpdates.add(callUpdate)) {
+ // Retry until added successfully
+ }
+ }
+
+ /** Processes all enqueued call updates; the lock must NOT be held when this is called. */
+ private void processCallUpdates() {
+ while (true) {
+ Runnable callUpdate = callUpdates.poll();
+ if (callUpdate == null) {
+ break;
+ }
+ callUpdate.run();
+ }
+ }
+
/** Cancels an ongoing RPC */
- public synchronized boolean cancel(AbstractCall<?, ?> call) throws ChannelOutputException {
- if (pending.remove(call.rpc()) == null) {
- return false;
+ public boolean cancel(AbstractCall<?, ?> call) throws ChannelOutputException {
+ try {
+ synchronized (this) {
+ if (pending.remove(call.rpc()) == null) {
+ return false;
+ }
+
+ enqueueCallUpdate(() -> call.handleError(Status.CANCELLED));
+ call.sendPacket(Packets.cancel(call.rpc()));
+ }
+ } finally {
+ logger.atFiner().log("Cancelling %s", call);
+ processCallUpdates();
}
- logger.atFiner().log("Cancelling %s", call);
- call.handleError(Status.CANCELLED);
- call.sendPacket(Packets.cancel(call.rpc()));
return true;
}
/** Cancels an ongoing RPC without sending a cancellation packet. */
- public synchronized boolean abandon(AbstractCall<?, ?> call) {
- if (pending.remove(call.rpc()) == null) {
- return false;
+ public boolean abandon(AbstractCall<?, ?> call) {
+ synchronized (this) {
+ if (pending.remove(call.rpc()) == null) {
+ return false;
+ }
+ enqueueCallUpdate(() -> call.handleError(Status.CANCELLED));
}
logger.atFiner().log("Abandoning %s", call);
- call.handleError(Status.CANCELLED);
+ processCallUpdates();
return true;
}
@@ -148,14 +185,16 @@ class Endpoint {
}
}
- public synchronized boolean closeChannel(int id) {
- if (channels.remove(id) == null) {
- return false;
+ public boolean closeChannel(int id) {
+ synchronized (this) {
+ if (channels.remove(id) == null) {
+ return false;
+ }
+ pending.values().stream().filter(call -> call.getChannelId() == id).forEach(call -> {
+ enqueueCallUpdate(() -> call.handleError(Status.ABORTED));
+ });
}
- pending.values()
- .stream()
- .filter(call -> call.getChannelId() == id)
- .forEach(call -> call.handleError(Status.ABORTED));
+ processCallUpdates();
return true;
}
@@ -164,8 +203,8 @@ class Endpoint {
if (call == null) {
return false;
}
- call.handleNext(payload);
logger.atFiner().log("%s received server stream with %d B payload", call, payload.size());
+ enqueueCallUpdate(() -> call.handleNext(payload));
return true;
}
@@ -174,9 +213,9 @@ class Endpoint {
if (call == null) {
return false;
}
- call.handleUnaryCompleted(payload, status);
logger.atFiner().log(
"%s completed with status %s and %d B payload", call, status, payload.size());
+ enqueueCallUpdate(() -> call.handleUnaryCompleted(payload, status));
return true;
}
@@ -185,8 +224,8 @@ class Endpoint {
if (call == null) {
return false;
}
- call.handleStreamCompleted(status);
logger.atFiner().log("%s completed with status %s", call, status);
+ enqueueCallUpdate(() -> call.handleStreamCompleted(status));
return true;
}
@@ -195,31 +234,35 @@ class Endpoint {
if (call == null) {
return false;
}
- call.handleError(status);
logger.atFiner().log("%s failed with error %s", call, status);
+ enqueueCallUpdate(() -> call.handleError(status));
return true;
}
- public synchronized boolean processClientPacket(@Nullable Method method, RpcPacket packet) {
- Channel channel = channels.get(packet.getChannelId());
- if (channel == null) {
- logger.atWarning().log("Received packet for unrecognized channel %d", packet.getChannelId());
- return false;
- }
+ public boolean processClientPacket(@Nullable Method method, RpcPacket packet) {
+ synchronized (this) {
+ Channel channel = channels.get(packet.getChannelId());
+ if (channel == null) {
+ logger.atWarning().log(
+ "Received packet for unrecognized channel %d", packet.getChannelId());
+ return false;
+ }
- if (method == null) {
- logger.atFine().log("Ignoring packet for unknown service method");
- sendError(channel, packet, Status.NOT_FOUND);
- return true; // true since the packet was handled, even though it was invalid.
- }
+ if (method == null) {
+ logger.atFine().log("Ignoring packet for unknown service method");
+ sendError(channel, packet, Status.NOT_FOUND);
+ return true; // true since the packet was handled, even though it was invalid.
+ }
- PendingRpc rpc = PendingRpc.create(channel, method);
- if (!updateCall(packet, rpc)) {
- logger.atFine().log("Ignoring packet for %s, which isn't pending", rpc);
- sendError(channel, packet, Status.FAILED_PRECONDITION);
- return true;
+ PendingRpc rpc = PendingRpc.create(channel, method);
+ if (!updateCall(packet, rpc)) {
+ logger.atFine().log("Ignoring packet for %s, which isn't pending", rpc);
+ sendError(channel, packet, Status.FAILED_PRECONDITION);
+ return true;
+ }
}
+ processCallUpdates();
return true;
}