aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLarry Safran <lsafran@google.com>2023-11-09 13:46:52 -0800
committerGitHub <noreply@github.com>2023-11-09 13:46:52 -0800
commitdfdd50bc7905aa904d19be38f7d300b98c60cf6e (patch)
tree3ba4e7fa0b84e5975015af98dddaac0c6cf48bd4
parent0346b40e4e19e081e082288fc8af49169dd848ad (diff)
downloadgrpc-grpc-java-dfdd50bc7905aa904d19be38f7d300b98c60cf6e.tar.gz
xds:Make Ring Hash LB a petiole policy (#10610)
* Update picker logic per A61 that it no longer pays attention to the first 2 elements, but rather takes the first ring element not in TF and uses that. --------- Pulled in by rebase: Eric Anderson (android: Remove unneeded proguard rule 44723b6) Terry Wilson (stub: Deprecate StreamObservers b5434e8)
-rw-r--r--core/src/main/java/io/grpc/internal/PickFirstLoadBalancer.java1
-rw-r--r--util/src/main/java/io/grpc/util/GracefulSwitchLoadBalancer.java4
-rw-r--r--util/src/main/java/io/grpc/util/MultiChildLoadBalancer.java178
-rw-r--r--util/src/main/java/io/grpc/util/RoundRobinLoadBalancer.java9
-rw-r--r--util/src/test/java/io/grpc/util/RoundRobinLoadBalancerTest.java14
-rw-r--r--util/src/testFixtures/java/io/grpc/util/AbstractTestHelper.java10
-rw-r--r--xds/src/main/java/io/grpc/xds/ClusterManagerLoadBalancer.java25
-rw-r--r--xds/src/main/java/io/grpc/xds/LeastRequestLoadBalancer.java4
-rw-r--r--xds/src/main/java/io/grpc/xds/RingHashLoadBalancer.java610
-rw-r--r--xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java34
-rw-r--r--xds/src/test/java/io/grpc/xds/LeastRequestLoadBalancerTest.java3
-rw-r--r--xds/src/test/java/io/grpc/xds/RingHashLoadBalancerTest.java853
-rw-r--r--xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java28
13 files changed, 902 insertions, 871 deletions
diff --git a/core/src/main/java/io/grpc/internal/PickFirstLoadBalancer.java b/core/src/main/java/io/grpc/internal/PickFirstLoadBalancer.java
index 13cbeed1d..acef79d3d 100644
--- a/core/src/main/java/io/grpc/internal/PickFirstLoadBalancer.java
+++ b/core/src/main/java/io/grpc/internal/PickFirstLoadBalancer.java
@@ -102,6 +102,7 @@ final class PickFirstLoadBalancer extends LoadBalancer {
subchannel.shutdown();
subchannel = null;
}
+
// NB(lukaszx0) Whether we should propagate the error unconditionally is arguable. It's fine
// for time being.
updateBalancingState(TRANSIENT_FAILURE, new Picker(PickResult.withError(error)));
diff --git a/util/src/main/java/io/grpc/util/GracefulSwitchLoadBalancer.java b/util/src/main/java/io/grpc/util/GracefulSwitchLoadBalancer.java
index eab77d7cd..a07428a30 100644
--- a/util/src/main/java/io/grpc/util/GracefulSwitchLoadBalancer.java
+++ b/util/src/main/java/io/grpc/util/GracefulSwitchLoadBalancer.java
@@ -181,4 +181,8 @@ public final class GracefulSwitchLoadBalancer extends ForwardingLoadBalancer {
pendingLb.shutdown();
currentLb.shutdown();
}
+
+ public String delegateType() {
+ return delegate().getClass().getSimpleName();
+ }
}
diff --git a/util/src/main/java/io/grpc/util/MultiChildLoadBalancer.java b/util/src/main/java/io/grpc/util/MultiChildLoadBalancer.java
index a196b3859..306901250 100644
--- a/util/src/main/java/io/grpc/util/MultiChildLoadBalancer.java
+++ b/util/src/main/java/io/grpc/util/MultiChildLoadBalancer.java
@@ -26,6 +26,7 @@ import static io.grpc.ConnectivityState.TRANSIENT_FAILURE;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
import io.grpc.ConnectivityState;
import io.grpc.EquivalentAddressGroup;
import io.grpc.Internal;
@@ -57,11 +58,9 @@ public abstract class MultiChildLoadBalancer extends LoadBalancer {
private final Map<Object, ChildLbState> childLbStates = new LinkedHashMap<>();
private final Helper helper;
// Set to true if currently in the process of handling resolved addresses.
- @VisibleForTesting
protected boolean resolvingAddresses;
- protected final PickFirstLoadBalancerProvider pickFirstLbProvider =
- new PickFirstLoadBalancerProvider();
+ protected final LoadBalancerProvider pickFirstLbProvider = new PickFirstLoadBalancerProvider();
protected ConnectivityState currentConnectivityState;
@@ -85,6 +84,10 @@ public abstract class MultiChildLoadBalancer extends LoadBalancer {
* Generally, the only reason to override this is to expose it to a test of a LB in a different
* package.
*/
+ protected ImmutableMap<Object, ChildLbState> getImmutableChildMap() {
+ return ImmutableMap.copyOf(childLbStates);
+ }
+
@VisibleForTesting
protected Collection<ChildLbState> getChildLbStates() {
return childLbStates.values();
@@ -93,8 +96,7 @@ public abstract class MultiChildLoadBalancer extends LoadBalancer {
/**
* Generally, the only reason to override this is to expose it to a test of a LB in a
* different package.
- */
-
+ */
protected ChildLbState getChildLbState(Object key) {
if (key == null) {
return null;
@@ -125,7 +127,8 @@ public abstract class MultiChildLoadBalancer extends LoadBalancer {
if (existingChildLbState != null) {
childLbMap.put(endpoint, existingChildLbState);
} else {
- childLbMap.put(endpoint, createChildLbState(endpoint, null, getInitialPicker()));
+ childLbMap.put(endpoint,
+ createChildLbState(endpoint, null, getInitialPicker(), resolvedAddresses));
}
}
return childLbMap;
@@ -135,7 +138,7 @@ public abstract class MultiChildLoadBalancer extends LoadBalancer {
* Override to create an instance of a subclass.
*/
protected ChildLbState createChildLbState(Object key, Object policyConfig,
- SubchannelPicker initialPicker) {
+ SubchannelPicker initialPicker, ResolvedAddresses resolvedAddresses) {
return new ChildLbState(key, pickFirstLbProvider, policyConfig, initialPicker);
}
@@ -146,7 +149,20 @@ public abstract class MultiChildLoadBalancer extends LoadBalancer {
public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) {
try {
resolvingAddresses = true;
- return acceptResolvedAddressesInternal(resolvedAddresses);
+
+ // process resolvedAddresses to update children
+ AcceptResolvedAddressRetVal acceptRetVal =
+ acceptResolvedAddressesInternal(resolvedAddresses);
+ if (!acceptRetVal.status.isOk()) {
+ return acceptRetVal.status;
+ }
+
+ // Update the picker and our connectivity state
+ updateOverallBalancingState();
+
+ // shutdown removed children
+ shutdownRemoved(acceptRetVal.removedChildren);
+ return acceptRetVal.status;
} finally {
resolvingAddresses = false;
}
@@ -161,15 +177,18 @@ public abstract class MultiChildLoadBalancer extends LoadBalancer {
*/
protected ResolvedAddresses getChildAddresses(Object key, ResolvedAddresses resolvedAddresses,
Object childConfig) {
+ Endpoint endpointKey;
if (key instanceof EquivalentAddressGroup) {
- key = new Endpoint((EquivalentAddressGroup) key);
+ endpointKey = new Endpoint((EquivalentAddressGroup) key);
+ } else {
+ checkArgument(key instanceof Endpoint, "key is wrong type");
+ endpointKey = (Endpoint) key;
}
- checkArgument(key instanceof Endpoint, "key is wrong type");
// Retrieve the non-stripped version
EquivalentAddressGroup eagToUse = null;
for (EquivalentAddressGroup currEag : resolvedAddresses.getAddresses()) {
- if (key.equals(new Endpoint(currEag))) {
+ if (endpointKey.equals(new Endpoint(currEag))) {
eagToUse = currEag;
break;
}
@@ -183,15 +202,21 @@ public abstract class MultiChildLoadBalancer extends LoadBalancer {
.build();
}
- private Status acceptResolvedAddressesInternal(ResolvedAddresses resolvedAddresses) {
+ /**
+ * This does the work to update the child map and calculate which children have been removed.
+ * You must call {@link #updateOverallBalancingState} to update the picker
+ * and call {@link #shutdownRemoved(List)} to shutdown the endpoints that have been removed.
+ */
+ protected AcceptResolvedAddressRetVal acceptResolvedAddressesInternal(
+ ResolvedAddresses resolvedAddresses) {
logger.log(Level.FINE, "Received resolution result: {0}", resolvedAddresses);
Map<Object, ChildLbState> newChildren = createChildLbMap(resolvedAddresses);
if (newChildren.isEmpty()) {
Status unavailableStatus = Status.UNAVAILABLE.withDescription(
- "NameResolver returned no usable address. " + resolvedAddresses);
+ "NameResolver returned no usable address. " + resolvedAddresses);
handleNameResolutionError(unavailableStatus);
- return unavailableStatus;
+ return new AcceptResolvedAddressRetVal(unavailableStatus, null);
}
// Do adds and updates
@@ -204,33 +229,44 @@ public abstract class MultiChildLoadBalancer extends LoadBalancer {
} else {
// Reuse the existing one
ChildLbState existingChildLbState = childLbStates.get(key);
- if (existingChildLbState.isDeactivated()) {
+ if (existingChildLbState.isDeactivated() && reactivateChildOnReuse()) {
existingChildLbState.reactivate(childPolicyProvider);
}
}
- LoadBalancer childLb = childLbStates.get(key).lb;
+ ChildLbState childLbState = childLbStates.get(key);
ResolvedAddresses childAddresses = getChildAddresses(key, resolvedAddresses, childConfig);
- childLbStates.get(key).setResolvedAddresses(childAddresses); // update child state
- childLb.handleResolvedAddresses(childAddresses); // update child LB
+ childLbStates.get(key).setResolvedAddresses(childAddresses); // update child
+ if (!childLbState.deactivated) {
+ childLbState.lb.handleResolvedAddresses(childAddresses); // update child LB
+ }
}
+ List<ChildLbState> removedChildren = new ArrayList<>();
// Do removals
for (Object key : ImmutableList.copyOf(childLbStates.keySet())) {
if (!newChildren.containsKey(key)) {
- childLbStates.get(key).deactivate();
+ ChildLbState childLbState = childLbStates.get(key);
+ childLbState.deactivate();
+ removedChildren.add(childLbState);
}
}
- // Must update channel picker before return so that new RPCs will not be routed to deleted
- // clusters and resolver can remove them in service config.
- updateOverallBalancingState();
- return Status.OK;
+
+ return new AcceptResolvedAddressRetVal(Status.OK, removedChildren);
+ }
+
+ protected void shutdownRemoved(List<ChildLbState> removedChildren) {
+ // Do shutdowns after updating picker to reduce the chance of failing an RPC by picking a
+ // subchannel that has been shutdown.
+ for (ChildLbState childLbState : removedChildren) {
+ childLbState.shutdown();
+ }
}
@Override
public void handleNameResolutionError(Status error) {
if (currentConnectivityState != READY) {
- updateHelperBalancingState(TRANSIENT_FAILURE, getErrorPicker(error));
+ helper.updateBalancingState(TRANSIENT_FAILURE, getErrorPicker(error));
}
}
@@ -240,12 +276,22 @@ public abstract class MultiChildLoadBalancer extends LoadBalancer {
/**
* If true, then when a subchannel state changes to idle, the corresponding child will
- * have requestConnection called on its LB.
+ * have requestConnection called on its LB. Also causes the PickFirstLB to be created when
+ * the child is created or reused.
*/
protected boolean reconnectOnIdle() {
return true;
}
+ /**
+ * If true, then when {@link #acceptResolvedAddresses} sees a key that was already part of the
+ * child map which is deactivated, it will call reactivate on the child.
+ * If false, it will leave it deactivated.
+ */
+ protected boolean reactivateChildOnReuse() {
+ return true;
+ }
+
@Override
public void shutdown() {
logger.log(Level.INFO, "Shutdown");
@@ -265,17 +311,13 @@ public abstract class MultiChildLoadBalancer extends LoadBalancer {
childPickers.put(childLbState.key, childLbState.currentPicker);
overallState = aggregateState(overallState, childLbState.currentState);
}
+
if (overallState != null) {
helper.updateBalancingState(overallState, getSubchannelPicker(childPickers));
currentConnectivityState = overallState;
}
}
- protected final void updateHelperBalancingState(ConnectivityState newState,
- SubchannelPicker newPicker) {
- helper.updateBalancingState(newState, newPicker);
- }
-
@Nullable
protected static ConnectivityState aggregateState(
@Nullable ConnectivityState overallState, ConnectivityState childState) {
@@ -332,20 +374,31 @@ public abstract class MultiChildLoadBalancer extends LoadBalancer {
private final Object key;
private ResolvedAddresses resolvedAddresses;
private final Object config;
+
private final GracefulSwitchLoadBalancer lb;
- private LoadBalancerProvider policyProvider;
- private ConnectivityState currentState = CONNECTING;
+ private final LoadBalancerProvider policyProvider;
+ private ConnectivityState currentState;
private SubchannelPicker currentPicker;
private boolean deactivated;
public ChildLbState(Object key, LoadBalancerProvider policyProvider, Object childConfig,
SubchannelPicker initialPicker) {
+ this(key, policyProvider, childConfig, initialPicker, null, false);
+ }
+
+ public ChildLbState(Object key, LoadBalancerProvider policyProvider, Object childConfig,
+ SubchannelPicker initialPicker, ResolvedAddresses resolvedAddrs, boolean deactivated) {
this.key = key;
this.policyProvider = policyProvider;
- lb = new GracefulSwitchLoadBalancer(new ChildLbStateHelper());
- lb.switchTo(policyProvider);
- currentPicker = initialPicker;
- config = childConfig;
+ this.deactivated = deactivated;
+ this.currentPicker = initialPicker;
+ this.config = childConfig;
+ this.lb = new GracefulSwitchLoadBalancer(new ChildLbStateHelper());
+ this.currentState = deactivated ? IDLE : CONNECTING;
+ this.resolvedAddresses = resolvedAddrs;
+ if (!deactivated) {
+ lb.switchTo(policyProvider);
+ }
}
@Override
@@ -365,6 +418,10 @@ public abstract class MultiChildLoadBalancer extends LoadBalancer {
return config;
}
+ protected GracefulSwitchLoadBalancer getLb() {
+ return lb;
+ }
+
public LoadBalancerProvider getPolicyProvider() {
return policyProvider;
}
@@ -399,34 +456,41 @@ public abstract class MultiChildLoadBalancer extends LoadBalancer {
deactivated = true;
}
+ protected void markReactivated() {
+ deactivated = false;
+ }
+
protected void setResolvedAddresses(ResolvedAddresses newAddresses) {
checkNotNull(newAddresses, "Missing address list for child");
resolvedAddresses = newAddresses;
}
+ /**
+ * The default implementation. This not only marks the lb policy as not active, it also removes
+ * this child from the map of children maintained by the petiole policy.
+ *
+ * <p>Note that this does not explicitly shutdown this child. That will generally be done by
+ * acceptResolvedAddresses on the LB, but can also be handled by an override such as is done
+ * in <a href=" https://github.com/grpc/grpc-java/blob/master/xds/src/main/java/io/grpc/xds/ClusterManagerLoadBalancer.java">ClusterManagerLoadBalancer</a>.
+ *
+ * <p>If you plan to reactivate, you will probably want to override this to not call
+ * childLbStates.remove() and handle that cleanup another way.
+ */
protected void deactivate() {
if (deactivated) {
return;
}
- shutdown();
- childLbStates.remove(key);
+ childLbStates.remove(key); // This means it can't be reactivated again
deactivated = true;
logger.log(Level.FINE, "Child balancer {0} deactivated", key);
}
+ /**
+ * This base implementation does nothing but reset the flag. If you really want to both
+ * deactivate and reactivate you should override them both.
+ */
protected void reactivate(LoadBalancerProvider policyProvider) {
- if (!this.policyProvider.getPolicyName().equals(policyProvider.getPolicyName())) {
- Object[] objects = {
- key, this.policyProvider.getPolicyName(),policyProvider.getPolicyName()};
- logger.log(Level.FINE, "Child balancer {0} switching policy from {1} to {2}", objects);
- lb.switchTo(policyProvider);
- this.policyProvider = policyProvider;
- } else {
- logger.log(Level.FINE, "Child balancer {0} reactivated", key);
- lb.acceptResolvedAddresses(resolvedAddresses);
- }
-
deactivated = false;
}
@@ -443,6 +507,10 @@ public abstract class MultiChildLoadBalancer extends LoadBalancer {
* <p>The ChildLbState updates happen during updateBalancingState. Otherwise, it is doing
* simple forwarding.
*/
+ protected ResolvedAddresses getResolvedAddresses() {
+ return resolvedAddresses;
+ }
+
private final class ChildLbStateHelper extends ForwardingLoadBalancerHelper {
@Override
@@ -482,7 +550,7 @@ public abstract class MultiChildLoadBalancer extends LoadBalancer {
final String[] addrs;
final int hashCode;
- Endpoint(EquivalentAddressGroup eag) {
+ public Endpoint(EquivalentAddressGroup eag) {
checkNotNull(eag, "eag");
addrs = new String[eag.getAddresses().size()];
@@ -525,4 +593,14 @@ public abstract class MultiChildLoadBalancer extends LoadBalancer {
return Arrays.toString(addrs);
}
}
+
+ protected static class AcceptResolvedAddressRetVal {
+ public final Status status;
+ public final List<ChildLbState> removedChildren;
+
+ public AcceptResolvedAddressRetVal(Status status, List<ChildLbState> removedChildren) {
+ this.status = status;
+ this.removedChildren = removedChildren;
+ }
+ }
}
diff --git a/util/src/main/java/io/grpc/util/RoundRobinLoadBalancer.java b/util/src/main/java/io/grpc/util/RoundRobinLoadBalancer.java
index 3bd83ffe1..f4db72849 100644
--- a/util/src/main/java/io/grpc/util/RoundRobinLoadBalancer.java
+++ b/util/src/main/java/io/grpc/util/RoundRobinLoadBalancer.java
@@ -16,6 +16,7 @@
package io.grpc.util;
+import static com.google.common.base.Preconditions.checkArgument;
import static io.grpc.ConnectivityState.CONNECTING;
import static io.grpc.ConnectivityState.IDLE;
import static io.grpc.ConnectivityState.READY;
@@ -25,9 +26,7 @@ import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.MoreObjects;
import com.google.common.base.Objects;
import com.google.common.base.Preconditions;
-import io.grpc.Attributes;
import io.grpc.ConnectivityState;
-import io.grpc.ConnectivityStateInfo;
import io.grpc.EquivalentAddressGroup;
import io.grpc.Internal;
import io.grpc.LoadBalancer;
@@ -48,10 +47,6 @@ import javax.annotation.Nonnull;
*/
@Internal
public class RoundRobinLoadBalancer extends MultiChildLoadBalancer {
- @VisibleForTesting
- static final Attributes.Key<Ref<ConnectivityStateInfo>> STATE_INFO =
- Attributes.Key.create("state-info");
-
private final Random random;
protected RoundRobinPicker currentPicker = new EmptyPicker(EMPTY_OK);
@@ -132,7 +127,7 @@ public class RoundRobinLoadBalancer extends MultiChildLoadBalancer {
private volatile int index;
public ReadyPicker(List<SubchannelPicker> list, int startIndex) {
- Preconditions.checkArgument(!list.isEmpty(), "empty list");
+ checkArgument(!list.isEmpty(), "empty list");
this.subchannelPickers = list;
this.index = startIndex - 1;
}
diff --git a/util/src/test/java/io/grpc/util/RoundRobinLoadBalancerTest.java b/util/src/test/java/io/grpc/util/RoundRobinLoadBalancerTest.java
index 7513e0da4..465bea907 100644
--- a/util/src/test/java/io/grpc/util/RoundRobinLoadBalancerTest.java
+++ b/util/src/test/java/io/grpc/util/RoundRobinLoadBalancerTest.java
@@ -74,6 +74,7 @@ import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.InOrder;
import org.mockito.Mock;
+import org.mockito.Mockito;
import org.mockito.junit.MockitoJUnit;
import org.mockito.junit.MockitoRule;
@@ -353,6 +354,19 @@ public class RoundRobinLoadBalancerTest {
}
@Test
+ public void removingAddressShutsdownSubchannel() {
+ acceptAddresses(servers, affinity);
+ final Subchannel subchannel2 = subchannels.get(Collections.singletonList(servers.get(2)));
+
+ InOrder inOrder = Mockito.inOrder(mockHelper, subchannel2);
+ // send LB only the first 2 addresses
+ List<EquivalentAddressGroup> svs2 = Arrays.asList(servers.get(0), servers.get(1));
+ acceptAddresses(svs2, affinity);
+ inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), any());
+ inOrder.verify(subchannel2).shutdown();
+ }
+
+ @Test
public void pickerRoundRobin() throws Exception {
Subchannel subchannel = mock(Subchannel.class);
Subchannel subchannel1 = mock(Subchannel.class);
diff --git a/util/src/testFixtures/java/io/grpc/util/AbstractTestHelper.java b/util/src/testFixtures/java/io/grpc/util/AbstractTestHelper.java
index 2afb13387..2b17558c9 100644
--- a/util/src/testFixtures/java/io/grpc/util/AbstractTestHelper.java
+++ b/util/src/testFixtures/java/io/grpc/util/AbstractTestHelper.java
@@ -52,6 +52,7 @@ import java.util.Map;
public abstract class AbstractTestHelper extends ForwardingLoadBalancerHelper {
private final Map<Subchannel, Subchannel> mockToRealSubChannelMap = new HashMap<>();
+ protected final Map<Subchannel, Subchannel> realToMockSubChannelMap = new HashMap<>();
private final Map<Subchannel, SubchannelStateListener> subchannelStateListeners =
Maps.newLinkedHashMap();
@@ -99,15 +100,20 @@ public abstract class AbstractTestHelper extends ForwardingLoadBalancerHelper {
public Subchannel createSubchannel(CreateSubchannelArgs args) {
Subchannel subchannel = getSubchannelMap().get(args.getAddresses());
if (subchannel == null) {
- TestSubchannel delegate = new TestSubchannel(args);
+ TestSubchannel delegate = createRealSubchannel(args);
subchannel = mock(Subchannel.class, delegatesTo(delegate));
getSubchannelMap().put(args.getAddresses(), subchannel);
getMockToRealSubChannelMap().put(subchannel, delegate);
+ realToMockSubChannelMap.put(delegate, subchannel);
}
return subchannel;
}
+ protected TestSubchannel createRealSubchannel(CreateSubchannelArgs args) {
+ return new TestSubchannel(args);
+ }
+
@Override
public void refreshNameResolution() {
// no-op
@@ -122,7 +128,7 @@ public abstract class AbstractTestHelper extends ForwardingLoadBalancerHelper {
return "Test Helper";
}
- private class TestSubchannel extends ForwardingSubchannel {
+ protected class TestSubchannel extends ForwardingSubchannel {
CreateSubchannelArgs args;
Channel channel;
diff --git a/xds/src/main/java/io/grpc/xds/ClusterManagerLoadBalancer.java b/xds/src/main/java/io/grpc/xds/ClusterManagerLoadBalancer.java
index 62fab0d12..a1dd479af 100644
--- a/xds/src/main/java/io/grpc/xds/ClusterManagerLoadBalancer.java
+++ b/xds/src/main/java/io/grpc/xds/ClusterManagerLoadBalancer.java
@@ -93,6 +93,31 @@ class ClusterManagerLoadBalancer extends MultiChildLoadBalancer {
return newChildPolicies;
}
+ /**
+ * This is like the parent except that it doesn't shutdown the removed children since we want that
+ * to be done by the timer.
+ */
+ @Override
+ public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) {
+ try {
+ resolvingAddresses = true;
+
+ // process resolvedAddresses to update children
+ AcceptResolvedAddressRetVal acceptRetVal =
+ acceptResolvedAddressesInternal(resolvedAddresses);
+ if (!acceptRetVal.status.isOk()) {
+ return acceptRetVal.status;
+ }
+
+ // Update the picker
+ updateOverallBalancingState();
+
+ return acceptRetVal.status;
+ } finally {
+ resolvingAddresses = false;
+ }
+ }
+
@Override
protected SubchannelPicker getSubchannelPicker(Map<Object, SubchannelPicker> childPickers) {
return new SubchannelPicker() {
diff --git a/xds/src/main/java/io/grpc/xds/LeastRequestLoadBalancer.java b/xds/src/main/java/io/grpc/xds/LeastRequestLoadBalancer.java
index fe52238e4..127acd9b3 100644
--- a/xds/src/main/java/io/grpc/xds/LeastRequestLoadBalancer.java
+++ b/xds/src/main/java/io/grpc/xds/LeastRequestLoadBalancer.java
@@ -145,13 +145,13 @@ final class LeastRequestLoadBalancer extends MultiChildLoadBalancer {
@Override
protected ChildLbState createChildLbState(Object key, Object policyConfig,
- SubchannelPicker initialPicker) {
+ SubchannelPicker initialPicker, ResolvedAddresses unused) {
return new LeastRequestLbState(key, pickFirstLbProvider, policyConfig, initialPicker);
}
private void updateBalancingState(ConnectivityState state, LeastRequestPicker picker) {
if (state != currentConnectivityState || !picker.isEquivalentTo(currentPicker)) {
- super.updateHelperBalancingState(state, picker);
+ getHelper().updateBalancingState(state, picker);
currentConnectivityState = state;
currentPicker = picker;
}
diff --git a/xds/src/main/java/io/grpc/xds/RingHashLoadBalancer.java b/xds/src/main/java/io/grpc/xds/RingHashLoadBalancer.java
index 901c79c06..54461385f 100644
--- a/xds/src/main/java/io/grpc/xds/RingHashLoadBalancer.java
+++ b/xds/src/main/java/io/grpc/xds/RingHashLoadBalancer.java
@@ -27,27 +27,28 @@ import static io.grpc.ConnectivityState.TRANSIENT_FAILURE;
import com.google.common.base.MoreObjects;
import com.google.common.collect.HashMultiset;
+import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Multiset;
-import com.google.common.collect.Sets;
import com.google.common.primitives.UnsignedInteger;
import io.grpc.Attributes;
import io.grpc.ConnectivityState;
-import io.grpc.ConnectivityStateInfo;
import io.grpc.EquivalentAddressGroup;
import io.grpc.InternalLogId;
import io.grpc.LoadBalancer;
+import io.grpc.LoadBalancerProvider;
import io.grpc.Status;
import io.grpc.SynchronizationContext;
+import io.grpc.util.GracefulSwitchLoadBalancer;
+import io.grpc.util.MultiChildLoadBalancer;
import io.grpc.xds.XdsLogger.XdsLogLevel;
import java.net.SocketAddress;
import java.util.ArrayList;
+import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
-import java.util.Iterator;
import java.util.List;
import java.util.Map;
-import java.util.Random;
import java.util.Set;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
@@ -60,9 +61,7 @@ import javax.annotation.Nullable;
* number of times proportional to its weight. With the ring partitioned appropriately, the
* addition or removal of one host from a set of N hosts will affect only 1/N requests.
*/
-final class RingHashLoadBalancer extends LoadBalancer {
- private static final Attributes.Key<Ref<ConnectivityStateInfo>> STATE_INFO =
- Attributes.Key.create("state-info");
+final class RingHashLoadBalancer extends MultiChildLoadBalancer {
private static final Status RPC_HASH_NOT_FOUND =
Status.INTERNAL.withDescription("RPC hash not found. Probably a bug because xds resolver"
+ " config selector always generates a hash.");
@@ -70,16 +69,10 @@ final class RingHashLoadBalancer extends LoadBalancer {
private final XdsLogger logger;
private final SynchronizationContext syncContext;
- private final Map<EquivalentAddressGroup, Subchannel> subchannels = new HashMap<>();
- private final Helper helper;
-
private List<RingEntry> ring;
- private ConnectivityState currentState;
- private Iterator<Subchannel> connectionAttemptIterator = subchannels.values().iterator();
- private final Random random = new Random();
RingHashLoadBalancer(Helper helper) {
- this.helper = checkNotNull(helper, "helper");
+ super(helper);
syncContext = checkNotNull(helper.getSynchronizationContext(), "syncContext");
logger = XdsLogger.withLogId(InternalLogId.allocate("ring_hash_lb", helper.getAuthority()));
logger.log(XdsLogLevel.INFO, "Created");
@@ -94,81 +87,157 @@ final class RingHashLoadBalancer extends LoadBalancer {
return addressValidityStatus;
}
- Map<EquivalentAddressGroup, EquivalentAddressGroup> latestAddrs = stripAttrs(addrList);
- Set<EquivalentAddressGroup> removedAddrs =
- Sets.newHashSet(Sets.difference(subchannels.keySet(), latestAddrs.keySet()));
-
- RingHashConfig config = (RingHashConfig) resolvedAddresses.getLoadBalancingPolicyConfig();
- Map<EquivalentAddressGroup, Long> serverWeights = new HashMap<>();
- long totalWeight = 0L;
- for (EquivalentAddressGroup eag : addrList) {
- Long weight = eag.getAttributes().get(InternalXdsAttributes.ATTR_SERVER_WEIGHT);
- // Support two ways of server weighing: either multiple instances of the same address
- // or each address contains a per-address weight attribute. If a weight is not provided,
- // each occurrence of the address will be counted a weight value of one.
- if (weight == null) {
- weight = 1L;
- }
- totalWeight += weight;
- EquivalentAddressGroup addrKey = stripAttrs(eag);
- if (serverWeights.containsKey(addrKey)) {
- serverWeights.put(addrKey, serverWeights.get(addrKey) + weight);
- } else {
- serverWeights.put(addrKey, weight);
+ AcceptResolvedAddressRetVal acceptRetVal;
+ try {
+ resolvingAddresses = true;
+ // Update the child list by creating-adding, updating addresses, and removing
+ acceptRetVal = super.acceptResolvedAddressesInternal(resolvedAddresses);
+ if (!acceptRetVal.status.isOk()) {
+ addressValidityStatus = Status.UNAVAILABLE.withDescription(
+ "Ring hash lb error: EDS resolution was successful, but was not accepted by base class"
+ + " (" + acceptRetVal.status + ")");
+ handleNameResolutionError(addressValidityStatus);
+ return addressValidityStatus;
}
- Subchannel existingSubchannel = subchannels.get(addrKey);
- if (existingSubchannel != null) {
- existingSubchannel.updateAddresses(Collections.singletonList(eag));
- continue;
+ // Now do the ringhash specific logic with weights and building the ring
+ RingHashConfig config = (RingHashConfig) resolvedAddresses.getLoadBalancingPolicyConfig();
+ if (config == null) {
+ throw new IllegalArgumentException("Missing RingHash configuration");
}
- Attributes attr = Attributes.newBuilder().set(
- STATE_INFO, new Ref<>(ConnectivityStateInfo.forNonError(IDLE))).build();
- final Subchannel subchannel = helper.createSubchannel(
- CreateSubchannelArgs.newBuilder().setAddresses(eag).setAttributes(attr).build());
- subchannel.start(new SubchannelStateListener() {
- @Override
- public void onSubchannelState(ConnectivityStateInfo newState) {
- processSubchannelState(subchannel, newState);
+ Map<EquivalentAddressGroup, Long> serverWeights = new HashMap<>();
+ long totalWeight = 0L;
+ for (EquivalentAddressGroup eag : addrList) {
+ Long weight = eag.getAttributes().get(InternalXdsAttributes.ATTR_SERVER_WEIGHT);
+ // Support two ways of server weighing: either multiple instances of the same address
+ // or each address contains a per-address weight attribute. If a weight is not provided,
+ // each occurrence of the address will be counted a weight value of one.
+ if (weight == null) {
+ weight = 1L;
+ }
+ totalWeight += weight;
+ EquivalentAddressGroup addrKey = stripAttrs(eag);
+ if (serverWeights.containsKey(addrKey)) {
+ serverWeights.put(addrKey, serverWeights.get(addrKey) + weight);
+ } else {
+ serverWeights.put(addrKey, weight);
}
- });
- subchannels.put(addrKey, subchannel);
+ }
+ // Calculate scale
+ long minWeight = Collections.min(serverWeights.values());
+ double normalizedMinWeight = (double) minWeight / totalWeight;
+ // Scale up the number of hashes per host such that the least-weighted host gets a whole
+ // number of hashes on the the ring. Other hosts might not end up with whole numbers, and
+ // that's fine (the ring-building algorithm can handle this). This preserves the original
+ // implementation's behavior: when weights aren't provided, all hosts should get an equal
+ // number of hashes. In the case where this number exceeds the max_ring_size, it's scaled
+ // back down to fit.
+ double scale = Math.min(
+ Math.ceil(normalizedMinWeight * config.minRingSize) / normalizedMinWeight,
+ (double) config.maxRingSize);
+
+ // Build the ring
+ ring = buildRing(serverWeights, totalWeight, scale);
+
+ // Must update channel picker before return so that new RPCs will not be routed to deleted
+ // clusters and resolver can remove them in service config.
+ updateOverallBalancingState();
+
+ shutdownRemoved(acceptRetVal.removedChildren);
+ } finally {
+ this.resolvingAddresses = false;
}
- long minWeight = Collections.min(serverWeights.values());
- double normalizedMinWeight = (double) minWeight / totalWeight;
- // Scale up the number of hashes per host such that the least-weighted host gets a whole
- // number of hashes on the the ring. Other hosts might not end up with whole numbers, and
- // that's fine (the ring-building algorithm can handle this). This preserves the original
- // implementation's behavior: when weights aren't provided, all hosts should get an equal
- // number of hashes. In the case where this number exceeds the max_ring_size, it's scaled
- // back down to fit.
- double scale = Math.min(
- Math.ceil(normalizedMinWeight * config.minRingSize) / normalizedMinWeight,
- (double) config.maxRingSize);
- ring = buildRing(serverWeights, totalWeight, scale);
-
- // Shut down subchannels for delisted addresses.
- List<Subchannel> removedSubchannels = new ArrayList<>();
- for (EquivalentAddressGroup addr : removedAddrs) {
- removedSubchannels.add(subchannels.remove(addr));
+
+ return Status.OK;
+ }
+
+ /**
+ * Updates the overall balancing state by aggregating the connectivity states of all subchannels.
+ *
+ * <p>Aggregation rules (in order of dominance):
+ * <ol>
+ * <li>If there is at least one subchannel in READY state, overall state is READY</li>
+ * <li>If there are <em>2 or more</em> subchannels in TRANSIENT_FAILURE, overall state is
+ * TRANSIENT_FAILURE (to allow timely failover to another policy)</li>
+ * <li>If there is at least one subchannel in CONNECTING state, overall state is
+ * CONNECTING</li>
+ * <li> If there is one subchannel in TRANSIENT_FAILURE state and there is
+ * more than one subchannel, report CONNECTING </li>
+ * <li>If there is at least one subchannel in IDLE state, overall state is IDLE</li>
+ * <li>Otherwise, overall state is TRANSIENT_FAILURE</li>
+ * </ol>
+ */
+ @Override
+ protected void updateOverallBalancingState() {
+ checkState(!getChildLbStates().isEmpty(), "no subchannel has been created");
+ if (this.currentConnectivityState == SHUTDOWN) {
+ // Ignore changes that happen after shutdown is called
+ logger.log(XdsLogLevel.DEBUG, "UpdateOverallBalancingState called after shutdown");
+ return;
}
- // If we need to proactively start connecting, iterate through all the subchannels, starting
- // at a random position.
- // Alternatively, we should better start at the same position.
- connectionAttemptIterator = subchannels.values().iterator();
- int randomAdvance = random.nextInt(subchannels.size());
- while (randomAdvance-- > 0) {
- connectionAttemptIterator.next();
+
+ // Calculate the current overall state to report
+ int numIdle = 0;
+ int numReady = 0;
+ int numConnecting = 0;
+ int numTF = 0;
+
+ forloop:
+ for (ChildLbState childLbState : getChildLbStates()) {
+ ConnectivityState state = childLbState.getCurrentState();
+ switch (state) {
+ case READY:
+ numReady++;
+ break forloop;
+ case CONNECTING:
+ numConnecting++;
+ break;
+ case IDLE:
+ numIdle++;
+ break;
+ case TRANSIENT_FAILURE:
+ numTF++;
+ break;
+ default:
+ // ignore it
+ }
}
- // Update the picker before shutting down the subchannels, to reduce the chance of race
- // between picking a subchannel and shutting it down.
- updateBalancingState();
- for (Subchannel subchann : removedSubchannels) {
- shutdownSubchannel(subchann);
+ ConnectivityState overallState;
+ if (numReady > 0) {
+ overallState = READY;
+ } else if (numTF >= 2) {
+ overallState = TRANSIENT_FAILURE;
+ } else if (numConnecting > 0) {
+ overallState = CONNECTING;
+ } else if (numTF == 1 && getChildLbStates().size() > 1) {
+ overallState = CONNECTING;
+ } else if (numIdle > 0) {
+ overallState = IDLE;
+ } else {
+ overallState = TRANSIENT_FAILURE;
}
- return Status.OK;
+ RingHashPicker picker = new RingHashPicker(syncContext, ring, getImmutableChildMap());
+ getHelper().updateBalancingState(overallState, picker);
+ this.currentConnectivityState = overallState;
+ }
+
+ @Override
+ protected boolean reconnectOnIdle() {
+ return false;
+ }
+
+ @Override
+ protected boolean reactivateChildOnReuse() {
+ return false;
+ }
+
+ @Override
+ protected ChildLbState createChildLbState(Object key, Object policyConfig,
+ SubchannelPicker initialPicker, ResolvedAddresses resolvedAddresses) {
+ return new RingHashChildLbState((Endpoint)key,
+ getChildAddresses(key, resolvedAddresses, null));
}
private Status validateAddrList(List<EquivalentAddressGroup> addrList) {
@@ -197,7 +266,7 @@ final class RingHashLoadBalancer extends LoadBalancer {
if (weight < 0) {
Status unavailableStatus = Status.UNAVAILABLE.withDescription(
- String.format("Ring hash lb error: EDS resolution was successful, but returned a "
+ String.format("Ring hash lb error: EDS resolution was successful, but returned a "
+ "negative weight for %s.", stripAttrs(eag)));
handleNameResolutionError(unavailableStatus);
return unavailableStatus;
@@ -252,10 +321,10 @@ final class RingHashLoadBalancer extends LoadBalancer {
double currentHashes = 0.0;
double targetHashes = 0.0;
for (Map.Entry<EquivalentAddressGroup, Long> entry : serverWeights.entrySet()) {
- EquivalentAddressGroup addrKey = entry.getKey();
+ Endpoint endpoint = new Endpoint(entry.getKey());
double normalizedWeight = (double) entry.getValue() / totalWeight;
- // TODO(chengyuanzhang): is using the list of socket address correct?
- StringBuilder sb = new StringBuilder(addrKey.getAddresses().toString());
+ // Per GRFC A61 use the first address for the hash
+ StringBuilder sb = new StringBuilder(entry.getKey().getAddresses().get(0).toString());
sb.append('_');
int lengthWithoutCounter = sb.length();
targetHashes += scale * normalizedWeight;
@@ -263,7 +332,7 @@ final class RingHashLoadBalancer extends LoadBalancer {
while (currentHashes < targetHashes) {
sb.append(i);
long hash = hashFunc.hashAsciiString(sb.toString());
- ring.add(new RingEntry(hash, addrKey));
+ ring.add(new RingEntry(hash, endpoint));
i++;
currentHashes++;
sb.setLength(lengthWithoutCounter);
@@ -273,159 +342,14 @@ final class RingHashLoadBalancer extends LoadBalancer {
return Collections.unmodifiableList(ring);
}
- @Override
- public void handleNameResolutionError(Status error) {
- if (currentState != READY) {
- helper.updateBalancingState(
- TRANSIENT_FAILURE, new FixedResultPicker(PickResult.withError(error)));
- }
- }
-
- @Override
- public void shutdown() {
- logger.log(XdsLogLevel.INFO, "Shutdown");
- for (Subchannel subchannel : subchannels.values()) {
- shutdownSubchannel(subchannel);
+ @SuppressWarnings("ReferenceEquality")
+ public static EquivalentAddressGroup stripAttrs(EquivalentAddressGroup eag) {
+ if (eag.getAttributes() == Attributes.EMPTY) {
+ return eag;
}
- subchannels.clear();
- }
-
- /**
- * Updates the overall balancing state by aggregating the connectivity states of all subchannels.
- *
- * <p>Aggregation rules (in order of dominance):
- * <ol>
- * <li>If there is at least one subchannel in READY state, overall state is READY</li>
- * <li>If there are <em>2 or more</em> subchannels in TRANSIENT_FAILURE, overall state is
- * TRANSIENT_FAILURE</li>
- * <li>If there is at least one subchannel in CONNECTING state, overall state is
- * CONNECTING</li>
- * <li> If there is one subchannel in TRANSIENT_FAILURE state and there is
- * more than one subchannel, report CONNECTING </li>
- * <li>If there is at least one subchannel in IDLE state, overall state is IDLE</li>
- * <li>Otherwise, overall state is TRANSIENT_FAILURE</li>
- * </ol>
- */
- private void updateBalancingState() {
- checkState(!subchannels.isEmpty(), "no subchannel has been created");
- boolean startConnectionAttempt = false;
- int numIdle = 0;
- int numReady = 0;
- int numConnecting = 0;
- int numTransientFailure = 0;
- for (Subchannel subchannel : subchannels.values()) {
- ConnectivityState state = getSubchannelStateInfoRef(subchannel).value.getState();
- if (state == READY) {
- numReady++;
- break;
- } else if (state == TRANSIENT_FAILURE) {
- numTransientFailure++;
- } else if (state == CONNECTING ) {
- numConnecting++;
- } else if (state == IDLE) {
- numIdle++;
- }
- }
- ConnectivityState overallState;
- if (numReady > 0) {
- overallState = READY;
- } else if (numTransientFailure >= 2) {
- overallState = TRANSIENT_FAILURE;
- startConnectionAttempt = (numConnecting == 0);
- } else if (numConnecting > 0) {
- overallState = CONNECTING;
- } else if (numTransientFailure == 1 && subchannels.size() > 1) {
- overallState = CONNECTING;
- startConnectionAttempt = true;
- } else if (numIdle > 0) {
- overallState = IDLE;
- } else {
- overallState = TRANSIENT_FAILURE;
- startConnectionAttempt = true;
- }
- RingHashPicker picker = new RingHashPicker(syncContext, ring, subchannels);
- // TODO(chengyuanzhang): avoid unnecessary reprocess caused by duplicated server addr updates
- helper.updateBalancingState(overallState, picker);
- currentState = overallState;
- // While the ring_hash policy is reporting TRANSIENT_FAILURE, it will
- // not be getting any pick requests from the priority policy.
- // However, because the ring_hash policy does not attempt to
- // reconnect to subchannels unless it is getting pick requests,
- // it will need special handling to ensure that it will eventually
- // recover from TRANSIENT_FAILURE state once the problem is resolved.
- // Specifically, it will make sure that it is attempting to connect to
- // at least one subchannel at any given time. After a given subchannel
- // fails a connection attempt, it will move on to the next subchannel
- // in the ring. It will keep doing this until one of the subchannels
- // successfully connects, at which point it will report READY and stop
- // proactively trying to connect. The policy will remain in
- // TRANSIENT_FAILURE until at least one subchannel becomes connected,
- // even if subchannels are in state CONNECTING during that time.
- //
- // Note that we do the same thing when the policy is in state
- // CONNECTING, just to ensure that we don't remain in CONNECTING state
- // indefinitely if there are no new picks coming in.
- if (startConnectionAttempt) {
- if (!connectionAttemptIterator.hasNext()) {
- connectionAttemptIterator = subchannels.values().iterator();
- }
- connectionAttemptIterator.next().requestConnection();
- }
- }
-
- private void processSubchannelState(Subchannel subchannel, ConnectivityStateInfo stateInfo) {
- if (subchannels.get(stripAttrs(subchannel.getAddresses())) != subchannel) {
- return;
- }
- if (stateInfo.getState() == TRANSIENT_FAILURE || stateInfo.getState() == IDLE) {
- helper.refreshNameResolution();
- }
- updateConnectivityState(subchannel, stateInfo);
- updateBalancingState();
- }
-
- private void updateConnectivityState(Subchannel subchannel, ConnectivityStateInfo stateInfo) {
- Ref<ConnectivityStateInfo> subchannelStateRef = getSubchannelStateInfoRef(subchannel);
- ConnectivityState previousConnectivityState = subchannelStateRef.value.getState();
- // Don't proactively reconnect if the subchannel enters IDLE, even if previously was connected.
- // If the subchannel was previously in TRANSIENT_FAILURE, it is considered to stay in
- // TRANSIENT_FAILURE until it becomes READY.
- if (previousConnectivityState == TRANSIENT_FAILURE) {
- if (stateInfo.getState() == CONNECTING || stateInfo.getState() == IDLE) {
- return;
- }
- }
- subchannelStateRef.value = stateInfo;
- }
-
- private static void shutdownSubchannel(Subchannel subchannel) {
- subchannel.shutdown();
- getSubchannelStateInfoRef(subchannel).value = ConnectivityStateInfo.forNonError(SHUTDOWN);
- }
-
- /**
- * Converts list of {@link EquivalentAddressGroup} to {@link EquivalentAddressGroup} set and
- * remove all attributes. The values are the original EAGs.
- */
- private static Map<EquivalentAddressGroup, EquivalentAddressGroup> stripAttrs(
- List<EquivalentAddressGroup> groupList) {
- Map<EquivalentAddressGroup, EquivalentAddressGroup> addrs =
- new HashMap<>(groupList.size() * 2);
- for (EquivalentAddressGroup group : groupList) {
- addrs.put(stripAttrs(group), group);
- }
- return addrs;
- }
-
- private static EquivalentAddressGroup stripAttrs(EquivalentAddressGroup eag) {
return new EquivalentAddressGroup(eag.getAddresses());
}
- private static Ref<ConnectivityStateInfo> getSubchannelStateInfoRef(
- Subchannel subchannel) {
- return checkNotNull(subchannel.getAttributes().get(STATE_INFO), "STATE_INFO");
- }
-
private static final class RingHashPicker extends SubchannelPicker {
private final SynchronizationContext syncContext;
private final List<RingEntry> ring;
@@ -433,38 +357,31 @@ final class RingHashLoadBalancer extends LoadBalancer {
// freeze picker's view of subchannel's connectivity state.
// TODO(chengyuanzhang): can be more performance-friendly with
// IdentityHashMap<Subchannel, ConnectivityStateInfo> and RingEntry contains Subchannel.
- private final Map<EquivalentAddressGroup, SubchannelView> pickableSubchannels; // read-only
+ private final Map<Endpoint, SubchannelView> pickableSubchannels; // read-only
private RingHashPicker(
SynchronizationContext syncContext, List<RingEntry> ring,
- Map<EquivalentAddressGroup, Subchannel> subchannels) {
+ ImmutableMap<Object, ChildLbState> subchannels) {
this.syncContext = syncContext;
this.ring = ring;
pickableSubchannels = new HashMap<>(subchannels.size());
- for (Map.Entry<EquivalentAddressGroup, Subchannel> entry : subchannels.entrySet()) {
- Subchannel subchannel = entry.getValue();
- ConnectivityStateInfo stateInfo = subchannel.getAttributes().get(STATE_INFO).value;
- pickableSubchannels.put(entry.getKey(), new SubchannelView(subchannel, stateInfo));
+ for (Map.Entry<Object, ChildLbState> entry : subchannels.entrySet()) {
+ RingHashChildLbState childLbState = (RingHashChildLbState) entry.getValue();
+ pickableSubchannels.put((Endpoint)entry.getKey(),
+ new SubchannelView(childLbState, childLbState.getCurrentState()));
}
}
- @Override
- public PickResult pickSubchannel(PickSubchannelArgs args) {
- Long requestHash = args.getCallOptions().getOption(XdsNameResolver.RPC_HASH_KEY);
- if (requestHash == null) {
- return PickResult.withError(RPC_HASH_NOT_FOUND);
+ // Find the ring entry with hash next to (clockwise) the RPC's hash (binary search).
+ private int getTargetIndex(Long requestHash) {
+ if (ring.size() <= 1) {
+ return 0;
}
- // Find the ring entry with hash next to (clockwise) the RPC's hash.
int low = 0;
- int high = ring.size();
- int mid;
- while (true) {
- mid = (low + high) / 2;
- if (mid == ring.size()) {
- mid = 0;
- break;
- }
+ int high = ring.size() - 1;
+ int mid = (low + high) / 2;
+ do {
long midVal = ring.get(mid).hash;
long midValL = mid == 0 ? 0 : ring.get(mid - 1).hash;
if (requestHash <= midVal && requestHash > midValL) {
@@ -475,79 +392,61 @@ final class RingHashLoadBalancer extends LoadBalancer {
} else {
high = mid - 1;
}
- if (low > high) {
- mid = 0;
- break;
- }
+ mid = (low + high) / 2;
+ } while (mid < ring.size() && low <= high);
+ return mid;
+ }
+
+ @Override
+ public PickResult pickSubchannel(PickSubchannelArgs args) {
+ Long requestHash = args.getCallOptions().getOption(XdsNameResolver.RPC_HASH_KEY);
+ if (requestHash == null) {
+ return PickResult.withError(RPC_HASH_NOT_FOUND);
}
- // Try finding a READY subchannel. Starting from the ring entry next to the RPC's hash.
- // If the one of the first two subchannels is not in TRANSIENT_FAILURE, return result
- // based on that subchannel. Otherwise, fail the pick unless a READY subchannel is found.
- // Meanwhile, trigger connection for the channel and status:
- // For the first subchannel that is in IDLE or TRANSIENT_FAILURE;
- // And for the second subchannel that is in IDLE or TRANSIENT_FAILURE;
- // And for each of the following subchannels that is in TRANSIENT_FAILURE or IDLE,
- // stop until we find the first subchannel that is in CONNECTING or IDLE status.
- boolean foundFirstNonFailed = false; // true if having subchannel(s) in CONNECTING or IDLE
- Subchannel firstSubchannel = null;
- Subchannel secondSubchannel = null;
+ int targetIndex = getTargetIndex(requestHash);
+
+ // Per gRFC A61, because of sticky-TF with PickFirst's auto reconnect on TF, we ignore
+ // all TF subchannels and find the first ring entry in READY, CONNECTING or IDLE. If
+ // CONNECTING or IDLE we return a pick with no results. Additionally, if that entry is in
+ // IDLE, we initiate a connection.
for (int i = 0; i < ring.size(); i++) {
- int index = (mid + i) % ring.size();
- EquivalentAddressGroup addrKey = ring.get(index).addrKey;
- SubchannelView subchannel = pickableSubchannels.get(addrKey);
- if (subchannel.stateInfo.getState() == READY) {
- return PickResult.withSubchannel(subchannel.subchannel);
+ int index = (targetIndex + i) % ring.size();
+ SubchannelView subchannelView = pickableSubchannels.get(ring.get(index).addrKey);
+ RingHashChildLbState childLbState = subchannelView.childLbState;
+
+ if (subchannelView.connectivityState == READY) {
+ return childLbState.getCurrentPicker().pickSubchannel(args);
}
- // RPCs can be buffered if any of the first two subchannels is pending. Otherwise, RPCs
+ // RPCs can be buffered if the next subchannel is pending (per A62). Otherwise, RPCs
// are failed unless there is a READY connection.
- if (firstSubchannel == null) {
- firstSubchannel = subchannel.subchannel;
- PickResult maybeBuffer = pickSubchannelsNonReady(subchannel);
- if (maybeBuffer != null) {
- return maybeBuffer;
- }
- } else if (subchannel.subchannel != firstSubchannel && secondSubchannel == null) {
- secondSubchannel = subchannel.subchannel;
- PickResult maybeBuffer = pickSubchannelsNonReady(subchannel);
- if (maybeBuffer != null) {
- return maybeBuffer;
- }
- } else if (subchannel.subchannel != firstSubchannel
- && subchannel.subchannel != secondSubchannel) {
- if (!foundFirstNonFailed) {
- pickSubchannelsNonReady(subchannel);
- if (subchannel.stateInfo.getState() != TRANSIENT_FAILURE) {
- foundFirstNonFailed = true;
- }
- }
+ if (subchannelView.connectivityState == CONNECTING) {
+ return PickResult.withNoResult();
}
- }
- // Fail the pick with error status of the original subchannel hit by hash.
- SubchannelView originalSubchannel = pickableSubchannels.get(ring.get(mid).addrKey);
- return PickResult.withError(originalSubchannel.stateInfo.getStatus());
- }
- @Nullable
- private PickResult pickSubchannelsNonReady(SubchannelView subchannel) {
- if (subchannel.stateInfo.getState() == TRANSIENT_FAILURE
- || subchannel.stateInfo.getState() == IDLE ) {
- final Subchannel finalSubchannel = subchannel.subchannel;
- syncContext.execute(new Runnable() {
- @Override
- public void run() {
- finalSubchannel.requestConnection();
+ if (subchannelView.connectivityState == IDLE || childLbState.isDeactivated()) {
+ if (childLbState.isDeactivated()) {
+ childLbState.activate();
+ } else {
+ syncContext.execute(() -> childLbState.getLb().requestConnection());
}
- });
- }
- if (subchannel.stateInfo.getState() == CONNECTING
- || subchannel.stateInfo.getState() == IDLE) {
- return PickResult.withNoResult();
- } else {
- return null;
+
+ return PickResult.withNoResult(); // Indicates that this should be retried after backoff
+ }
}
+
+ // return the pick from the original subchannel hit by hash, which is probably an error
+ RingHashChildLbState originalSubchannel =
+ pickableSubchannels.get(ring.get(targetIndex).addrKey).childLbState;
+ return originalSubchannel.getCurrentPicker().pickSubchannel(args);
}
+
+ }
+
+ @Override
+ protected SubchannelPicker getSubchannelPicker(Map<Object, SubchannelPicker> childPickers) {
+ throw new UnsupportedOperationException("Not used by RingHash");
}
/**
@@ -555,20 +454,20 @@ final class RingHashLoadBalancer extends LoadBalancer {
* state changes.
*/
private static final class SubchannelView {
- private final Subchannel subchannel;
- private final ConnectivityStateInfo stateInfo;
+ private final RingHashChildLbState childLbState;
+ private final ConnectivityState connectivityState;
- private SubchannelView(Subchannel subchannel, ConnectivityStateInfo stateInfo) {
- this.subchannel = subchannel;
- this.stateInfo = stateInfo;
+ private SubchannelView(RingHashChildLbState childLbState, ConnectivityState state) {
+ this.childLbState = childLbState;
+ this.connectivityState = state;
}
}
private static final class RingEntry implements Comparable<RingEntry> {
private final long hash;
- private final EquivalentAddressGroup addrKey;
+ private final Endpoint addrKey;
- private RingEntry(long hash, EquivalentAddressGroup addrKey) {
+ private RingEntry(long hash, Endpoint addrKey) {
this.hash = hash;
this.addrKey = addrKey;
}
@@ -580,17 +479,6 @@ final class RingHashLoadBalancer extends LoadBalancer {
}
/**
- * A lighter weight Reference than AtomicReference.
- */
- private static final class Ref<T> {
- T value;
-
- Ref(T value) {
- this.value = value;
- }
- }
-
- /**
* Configures the ring property. The larger the ring is (that is, the more hashes there are
* for each provided host) the better the request distribution will reflect the desired weights.
*/
@@ -614,4 +502,58 @@ final class RingHashLoadBalancer extends LoadBalancer {
.toString();
}
}
-}
+
+ static Set<EquivalentAddressGroup> getStrippedChildEags(Collection<ChildLbState> states) {
+ return states.stream()
+ .map(ChildLbState::getEag)
+ .map(RingHashLoadBalancer::stripAttrs)
+ .collect(Collectors.toSet());
+ }
+
+ @Override
+ protected Collection<ChildLbState> getChildLbStates() {
+ return super.getChildLbStates();
+ }
+
+ @Override
+ protected ChildLbState getChildLbStateEag(EquivalentAddressGroup eag) {
+ return super.getChildLbStateEag(eag);
+ }
+
+ class RingHashChildLbState extends MultiChildLoadBalancer.ChildLbState {
+
+ public RingHashChildLbState(Endpoint key, ResolvedAddresses resolvedAddresses) {
+ super(key, pickFirstLbProvider, null, EMPTY_PICKER, resolvedAddresses, true);
+ }
+
+ @Override
+ protected void reactivate(LoadBalancerProvider policyProvider) {
+ if (!isDeactivated()) {
+ return;
+ }
+
+ currentConnectivityState = CONNECTING;
+ getLb().switchTo(pickFirstLbProvider);
+ markReactivated();
+ getLb().acceptResolvedAddresses(this.getResolvedAddresses()); // Time to get a subchannel
+ logger.log(XdsLogLevel.DEBUG, "Child balancer {0} reactivated", getKey());
+ }
+
+ public void activate() {
+ reactivate(pickFirstLbProvider);
+ }
+
+ // Need to expose this to the LB class
+ @Override
+ protected void shutdown() {
+ super.shutdown();
+ }
+
+ // Need to expose this to the LB class
+ @Override
+ protected GracefulSwitchLoadBalancer getLb() {
+ return super.getLb();
+ }
+
+ }
+} \ No newline at end of file
diff --git a/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java b/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java
index a71c91deb..4e77e99a3 100644
--- a/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java
+++ b/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java
@@ -97,7 +97,7 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer {
@Override
protected ChildLbState createChildLbState(Object key, Object policyConfig,
- SubchannelPicker initialPicker) {
+ SubchannelPicker initialPicker, ResolvedAddresses unused) {
ChildLbState childLbState = new WeightedChildLbState(key, pickFirstLbProvider, policyConfig,
initialPicker);
return childLbState;
@@ -115,13 +115,31 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer {
}
config =
(WeightedRoundRobinLoadBalancerConfig) resolvedAddresses.getLoadBalancingPolicyConfig();
- Status addressAcceptanceStatus = super.acceptResolvedAddresses(resolvedAddresses);
- if (weightUpdateTimer != null && weightUpdateTimer.isPending()) {
- weightUpdateTimer.cancel();
+ AcceptResolvedAddressRetVal acceptRetVal;
+ try {
+ resolvingAddresses = true;
+ acceptRetVal = acceptResolvedAddressesInternal(resolvedAddresses);
+ if (!acceptRetVal.status.isOk()) {
+ return acceptRetVal.status;
+ }
+
+ if (weightUpdateTimer != null && weightUpdateTimer.isPending()) {
+ weightUpdateTimer.cancel();
+ }
+ updateWeightTask.run();
+
+ createAndApplyOrcaListeners();
+
+ // Must update channel picker before return so that new RPCs will not be routed to deleted
+ // clusters and resolver can remove them in service config.
+ updateOverallBalancingState();
+
+ shutdownRemoved(acceptRetVal.removedChildren);
+ } finally {
+ resolvingAddresses = false;
}
- updateWeightTask.run();
- afterAcceptAddresses();
- return addressAcceptanceStatus;
+
+ return acceptRetVal.status;
}
@Override
@@ -228,7 +246,7 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer {
}
}
- private void afterAcceptAddresses() {
+ private void createAndApplyOrcaListeners() {
for (ChildLbState child : getChildLbStates()) {
WeightedChildLbState wChild = (WeightedChildLbState) child;
for (WrrSubchannel weightedSubchannel : wChild.subchannels) {
diff --git a/xds/src/test/java/io/grpc/xds/LeastRequestLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/LeastRequestLoadBalancerTest.java
index 5f5e0df4d..a2dc7936b 100644
--- a/xds/src/test/java/io/grpc/xds/LeastRequestLoadBalancerTest.java
+++ b/xds/src/test/java/io/grpc/xds/LeastRequestLoadBalancerTest.java
@@ -106,7 +106,7 @@ public class LeastRequestLoadBalancerTest {
@Captor
private ArgumentCaptor<CreateSubchannelArgs> createArgsCaptor;
private final TestHelper testHelperInstance = new TestHelper();
- private Helper helper = mock(Helper.class, delegatesTo(testHelperInstance));
+ private final Helper helper = mock(Helper.class, delegatesTo(testHelperInstance));
@Mock
private ThreadSafeRandom mockRandom;
@@ -522,7 +522,6 @@ public class LeastRequestLoadBalancerTest {
loadBalancer.handleNameResolutionError(error);
loadBalancer.setResolvingAddresses(false);
verify(helper).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture());
-
LoadBalancer.PickResult pickResult = pickerCaptor.getValue().pickSubchannel(mockArgs);
assertNull(pickResult.getSubchannel());
assertEquals(error, pickResult.getStatus());
diff --git a/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerTest.java
index f2b38abda..337737265 100644
--- a/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerTest.java
+++ b/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerTest.java
@@ -22,18 +22,21 @@ import static io.grpc.ConnectivityState.IDLE;
import static io.grpc.ConnectivityState.READY;
import static io.grpc.ConnectivityState.SHUTDOWN;
import static io.grpc.ConnectivityState.TRANSIENT_FAILURE;
+import static io.grpc.xds.RingHashLoadBalancerTest.InitializationFlags.DO_NOT_RESET_HELPER;
+import static io.grpc.xds.RingHashLoadBalancerTest.InitializationFlags.DO_NOT_VERIFY;
+import static io.grpc.xds.RingHashLoadBalancerTest.InitializationFlags.RESET_SUBCHANNEL_MOCKS;
+import static io.grpc.xds.RingHashLoadBalancerTest.InitializationFlags.STAY_IN_CONNECTING;
+import static org.mockito.AdditionalAnswers.delegatesTo;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.clearInvocations;
-import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
-import static org.mockito.Mockito.when;
import com.google.common.collect.Iterables;
import com.google.common.primitives.UnsignedInteger;
@@ -48,13 +51,15 @@ import io.grpc.LoadBalancer.PickSubchannelArgs;
import io.grpc.LoadBalancer.ResolvedAddresses;
import io.grpc.LoadBalancer.Subchannel;
import io.grpc.LoadBalancer.SubchannelPicker;
-import io.grpc.LoadBalancer.SubchannelStateListener;
import io.grpc.Metadata;
import io.grpc.Status;
import io.grpc.Status.Code;
import io.grpc.SynchronizationContext;
import io.grpc.internal.PickSubchannelArgsImpl;
import io.grpc.testing.TestMethodDescriptors;
+import io.grpc.util.AbstractTestHelper;
+import io.grpc.util.MultiChildLoadBalancer.ChildLbState;
+import io.grpc.xds.RingHashLoadBalancer.RingHashChildLbState;
import io.grpc.xds.RingHashLoadBalancer.RingHashConfig;
import java.lang.Thread.UncaughtExceptionHandler;
import java.net.SocketAddress;
@@ -74,12 +79,9 @@ import org.junit.runners.JUnit4;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.InOrder;
-import org.mockito.Mock;
import org.mockito.Mockito;
-import org.mockito.invocation.InvocationOnMock;
import org.mockito.junit.MockitoJUnit;
import org.mockito.junit.MockitoRule;
-import org.mockito.stubbing.Answer;
/** Unit test for {@link io.grpc.LoadBalancer}. */
@RunWith(JUnit4.class)
@@ -97,49 +99,18 @@ public class RingHashLoadBalancerTest {
}
});
private final Map<List<EquivalentAddressGroup>, Subchannel> subchannels = new HashMap<>();
- private final Map<Subchannel, SubchannelStateListener> subchannelStateListeners =
- new HashMap<>();
private final Deque<Subchannel> connectionRequestedQueue = new ArrayDeque<>();
private final XxHash64 hashFunc = XxHash64.INSTANCE;
- @Mock
- private Helper helper;
+ private final TestHelper testHelperInst = new TestHelper();
+ private final Helper helper = mock(Helper.class, delegatesTo(testHelperInst));
@Captor
private ArgumentCaptor<SubchannelPicker> pickerCaptor;
private RingHashLoadBalancer loadBalancer;
@Before
public void setUp() {
- when(helper.getAuthority()).thenReturn(AUTHORITY);
- when(helper.getSynchronizationContext()).thenReturn(syncContext);
- when(helper.createSubchannel(any(CreateSubchannelArgs.class))).thenAnswer(
- new Answer<Subchannel>() {
- @Override
- public Subchannel answer(InvocationOnMock invocation) throws Throwable {
- CreateSubchannelArgs args = (CreateSubchannelArgs) invocation.getArguments()[0];
- final Subchannel subchannel = mock(Subchannel.class);
- when(subchannel.getAllAddresses()).thenReturn(args.getAddresses());
- when(subchannel.getAttributes()).thenReturn(args.getAttributes());
- subchannels.put(args.getAddresses(), subchannel);
- doAnswer(new Answer<Void>() {
- @Override
- public Void answer(InvocationOnMock invocation) throws Throwable {
- subchannelStateListeners.put(
- subchannel, (SubchannelStateListener) invocation.getArguments()[0]);
- return null;
- }
- }).when(subchannel).start(any(SubchannelStateListener.class));
- doAnswer(new Answer<Void>() {
- @Override
- public Void answer(InvocationOnMock invocation) throws Throwable {
- connectionRequestedQueue.offer(subchannel);
- return null;
- }
- }).when(subchannel).requestConnection();
- return subchannel;
- }
- });
loadBalancer = new RingHashLoadBalancer(helper);
- // Skip uninterested interactions.
+ // Consume calls not relevant for tests that would otherwise fail verifyNoMoreInteractions
verify(helper).getAuthority();
verify(helper).getSynchronizationContext();
}
@@ -161,21 +132,20 @@ public class RingHashLoadBalancerTest {
ResolvedAddresses.newBuilder()
.setAddresses(servers).setLoadBalancingPolicyConfig(config).build());
assertThat(addressesAcceptanceStatus.isOk()).isTrue();
- verify(helper).createSubchannel(any(CreateSubchannelArgs.class));
- Subchannel subchannel = Iterables.getOnlyElement(subchannels.values());
- verify(subchannel, never()).requestConnection();
verify(helper).updateBalancingState(eq(IDLE), pickerCaptor.capture());
+ assertThat(subchannels.size()).isEqualTo(0);
// Picking subchannel triggers connection.
- PickSubchannelArgs args = new PickSubchannelArgsImpl(
- TestMethodDescriptors.voidMethod(), new Metadata(),
- CallOptions.DEFAULT.withOption(XdsNameResolver.RPC_HASH_KEY, hashFunc.hashVoid()));
+ PickSubchannelArgs args = getDefaultPickSubchannelArgs(hashFunc.hashVoid());
PickResult result = pickerCaptor.getValue().pickSubchannel(args);
assertThat(result.getStatus().isOk()).isTrue();
assertThat(result.getSubchannel()).isNull();
+ Subchannel subchannel = Iterables.getOnlyElement(subchannels.values());
verify(subchannel).requestConnection();
- deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(CONNECTING));
verify(helper).updateBalancingState(eq(CONNECTING), any(SubchannelPicker.class));
+ verify(helper).createSubchannel(any(CreateSubchannelArgs.class));
+ deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(CONNECTING));
+ verify(helper, times(2)).updateBalancingState(eq(CONNECTING), any(SubchannelPicker.class));
// Subchannel becomes ready, triggers pick again.
deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY));
@@ -193,16 +163,19 @@ public class RingHashLoadBalancerTest {
ResolvedAddresses.newBuilder()
.setAddresses(servers).setLoadBalancingPolicyConfig(config).build());
assertThat(addressesAcceptanceStatus.isOk()).isTrue();
- Subchannel subchannel = Iterables.getOnlyElement(subchannels.values());
- InOrder inOrder = Mockito.inOrder(helper, subchannel);
- inOrder.verify(helper).updateBalancingState(eq(IDLE), pickerCaptor.capture());
- inOrder.verify(subchannel, never()).requestConnection();
+ verify(helper).updateBalancingState(eq(IDLE), pickerCaptor.capture());
+
+ RingHashChildLbState childLbState =
+ (RingHashChildLbState) loadBalancer.getChildLbStates().iterator().next();
+ assertThat(childLbState.isDeactivated()).isTrue();
// Picking subchannel triggers connection.
- PickSubchannelArgs args = new PickSubchannelArgsImpl(
- TestMethodDescriptors.voidMethod(), new Metadata(),
- CallOptions.DEFAULT.withOption(XdsNameResolver.RPC_HASH_KEY, hashFunc.hashVoid()));
+ PickSubchannelArgs args = getDefaultPickSubchannelArgs(hashFunc.hashVoid());
pickerCaptor.getValue().pickSubchannel(args);
+ assertThat(childLbState.isDeactivated()).isFalse();
+ assertThat(childLbState.getLb().delegateType()).isEqualTo("PickFirstLoadBalancer");
+ Subchannel subchannel = subchannels.get(Collections.singletonList(childLbState.getEag()));
+ InOrder inOrder = Mockito.inOrder(helper, subchannel);
inOrder.verify(subchannel).requestConnection();
deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY));
inOrder.verify(helper).updateBalancingState(eq(READY), any(SubchannelPicker.class));
@@ -220,12 +193,8 @@ public class RingHashLoadBalancerTest {
RingHashConfig config = new RingHashConfig(10, 100);
List<EquivalentAddressGroup> servers = createWeightedServerAddrs(1, 1);
InOrder inOrder = Mockito.inOrder(helper);
- Status addressesAcceptanceStatus = loadBalancer.acceptResolvedAddresses(
- ResolvedAddresses.newBuilder()
- .setAddresses(servers).setLoadBalancingPolicyConfig(config).build());
- assertThat(addressesAcceptanceStatus.isOk()).isTrue();
- inOrder.verify(helper, times(2)).createSubchannel(any(CreateSubchannelArgs.class));
- inOrder.verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class));
+
+ initializeLbSubchannels(config, servers);
// one in CONNECTING, one in IDLE
deliverSubchannelState(
@@ -263,7 +232,7 @@ public class RingHashLoadBalancerTest {
ConnectivityStateInfo.forNonError(IDLE));
inOrder.verify(helper).refreshNameResolution();
inOrder.verify(helper).updateBalancingState(eq(CONNECTING), any(SubchannelPicker.class));
- verifyConnection(1);
+ verifyConnection(0);
verifyNoMoreInteractions(helper);
}
@@ -278,164 +247,57 @@ public class RingHashLoadBalancerTest {
}
@Test
- public void aggregateSubchannelStates_twoOrMoreSubchannelsInTransientFailure() {
+ public void aggregateSubchannelStates_allSubchannelsInTransientFailure() {
RingHashConfig config = new RingHashConfig(10, 100);
List<EquivalentAddressGroup> servers = createWeightedServerAddrs(1, 1, 1, 1);
- InOrder inOrder = Mockito.inOrder(helper);
- Status addressesAcceptanceStatus = loadBalancer.acceptResolvedAddresses(
- ResolvedAddresses.newBuilder()
- .setAddresses(servers).setLoadBalancingPolicyConfig(config).build());
- assertThat(addressesAcceptanceStatus.isOk()).isTrue();
- inOrder.verify(helper, times(4)).createSubchannel(any(CreateSubchannelArgs.class));
- inOrder.verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class));
- // one in TRANSIENT_FAILURE, three in IDLE
- deliverSubchannelState(
- subchannels.get(Collections.singletonList(servers.get(0))),
- ConnectivityStateInfo.forTransientFailure(
- Status.UNAVAILABLE.withDescription("not found")));
+ List<Subchannel> subChannelList = initializeLbSubchannels(config, servers, STAY_IN_CONNECTING);
+
+ // reset inOrder to include all the childLBs now that they have been created
+ clearInvocations(helper);
+ InOrder inOrder = Mockito.inOrder(helper,
+ subChannelList.get(0), subChannelList.get(1), subChannelList.get(2), subChannelList.get(3));
+
+ // one in TRANSIENT_FAILURE, three in CONNECTING
+ deliverNotFound(subChannelList, 0);
inOrder.verify(helper).refreshNameResolution();
inOrder.verify(helper).updateBalancingState(eq(CONNECTING), any(SubchannelPicker.class));
- verifyConnection(1);
- // two in TRANSIENT_FAILURE, two in IDLE
- deliverSubchannelState(
- subchannels.get(Collections.singletonList(servers.get(1))),
- ConnectivityStateInfo.forTransientFailure(
- Status.UNAVAILABLE.withDescription("also not found")));
+ // two in TRANSIENT_FAILURE, two in CONNECTING
+ deliverNotFound(subChannelList, 1);
inOrder.verify(helper).refreshNameResolution();
inOrder.verify(helper)
.updateBalancingState(eq(TRANSIENT_FAILURE), any(SubchannelPicker.class));
- verifyConnection(1);
- // two in TRANSIENT_FAILURE, one in CONNECTING, one in IDLE
- // The overall state is dominated by the two in TRANSIENT_FAILURE.
- deliverSubchannelState(
- subchannels.get(Collections.singletonList(servers.get(2))),
- ConnectivityStateInfo.forNonError(CONNECTING));
+ // All 4 in TF switch to TF
+ deliverNotFound(subChannelList, 2);
+ inOrder.verify(helper).refreshNameResolution();
inOrder.verify(helper)
.updateBalancingState(eq(TRANSIENT_FAILURE), any(SubchannelPicker.class));
- verifyConnection(0);
-
- // three in TRANSIENT_FAILURE, one in CONNECTING
- deliverSubchannelState(
- subchannels.get(Collections.singletonList(servers.get(3))),
- ConnectivityStateInfo.forTransientFailure(
- Status.UNAVAILABLE.withDescription("connection lost")));
+ deliverNotFound(subChannelList, 3);
inOrder.verify(helper).refreshNameResolution();
inOrder.verify(helper)
.updateBalancingState(eq(TRANSIENT_FAILURE), any(SubchannelPicker.class));
- verifyConnection(0);
+
+ // reset subchannel to CONNECTING - shouldn't change anything since PF hides the state change
+ deliverSubchannelState(subChannelList.get(2), ConnectivityStateInfo.forNonError(CONNECTING));
+ inOrder.verify(helper, never())
+ .updateBalancingState(eq(TRANSIENT_FAILURE), any(SubchannelPicker.class));
+ inOrder.verify(subChannelList.get(2), never()).requestConnection();
// three in TRANSIENT_FAILURE, one in READY
- deliverSubchannelState(
- subchannels.get(Collections.singletonList(servers.get(2))),
- ConnectivityStateInfo.forNonError(READY));
+ deliverSubchannelState(subChannelList.get(2), ConnectivityStateInfo.forNonError(READY));
inOrder.verify(helper).updateBalancingState(eq(READY), any(SubchannelPicker.class));
- verifyConnection(0);
+ inOrder.verify(subChannelList.get(2), never()).requestConnection();
verifyNoMoreInteractions(helper);
}
@Test
- public void subchannelStayInTransientFailureUntilBecomeReady() {
- RingHashConfig config = new RingHashConfig(10, 100);
- List<EquivalentAddressGroup> servers = createWeightedServerAddrs(1, 1, 1);
- Status addressesAcceptanceStatus = loadBalancer.acceptResolvedAddresses(
- ResolvedAddresses.newBuilder()
- .setAddresses(servers).setLoadBalancingPolicyConfig(config).build());
- assertThat(addressesAcceptanceStatus.isOk()).isTrue();
- verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class));
- verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class));
- reset(helper);
-
- // Simulate picks have taken place and subchannels have requested connection.
- for (Subchannel subchannel : subchannels.values()) {
- deliverSubchannelState(subchannel, ConnectivityStateInfo.forTransientFailure(
- Status.UNAUTHENTICATED.withDescription("Permission denied")));
- }
- verify(helper, times(3)).refreshNameResolution();
-
- // Stays in IDLE when until there are two or more subchannels in TRANSIENT_FAILURE.
- verify(helper).updateBalancingState(eq(CONNECTING), any(SubchannelPicker.class));
- verify(helper, times(2))
- .updateBalancingState(eq(TRANSIENT_FAILURE), any(SubchannelPicker.class));
- verifyConnection(3);
-
- verifyNoMoreInteractions(helper);
- reset(helper);
- // Simulate underlying subchannel auto reconnect after backoff.
- for (Subchannel subchannel : subchannels.values()) {
- deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(CONNECTING));
- }
- verify(helper, times(3))
- .updateBalancingState(eq(TRANSIENT_FAILURE), any(SubchannelPicker.class));
- verifyConnection(3);
- verifyNoMoreInteractions(helper);
-
- // Simulate one subchannel enters READY.
- deliverSubchannelState(
- subchannels.values().iterator().next(), ConnectivityStateInfo.forNonError(READY));
- verify(helper).updateBalancingState(eq(READY), any(SubchannelPicker.class));
- }
-
- @Test
- public void updateConnectionIterator() {
- RingHashConfig config = new RingHashConfig(10, 100);
- List<EquivalentAddressGroup> servers = createWeightedServerAddrs(1, 1, 1);
- InOrder inOrder = Mockito.inOrder(helper);
- Status addressesAcceptanceStatus = loadBalancer.acceptResolvedAddresses(
- ResolvedAddresses.newBuilder()
- .setAddresses(servers).setLoadBalancingPolicyConfig(config).build());
- assertThat(addressesAcceptanceStatus.isOk()).isTrue();
- verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class));
- verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class));
-
- deliverSubchannelState(
- subchannels.get(Collections.singletonList(servers.get(0))),
- ConnectivityStateInfo.forTransientFailure(
- Status.UNAVAILABLE.withDescription("connection lost")));
- inOrder.verify(helper).refreshNameResolution();
- inOrder.verify(helper)
- .updateBalancingState(eq(CONNECTING), any(SubchannelPicker.class));
- verifyConnection(1);
-
- servers = createWeightedServerAddrs(1,1);
- addressesAcceptanceStatus = loadBalancer.acceptResolvedAddresses(
- ResolvedAddresses.newBuilder()
- .setAddresses(servers).setLoadBalancingPolicyConfig(config).build());
- assertThat(addressesAcceptanceStatus.isOk()).isTrue();
- inOrder.verify(helper)
- .updateBalancingState(eq(CONNECTING), any(SubchannelPicker.class));
- verifyConnection(1);
-
- deliverSubchannelState(
- subchannels.get(Collections.singletonList(servers.get(1))),
- ConnectivityStateInfo.forTransientFailure(
- Status.UNAVAILABLE.withDescription("connection lost")));
- inOrder.verify(helper).refreshNameResolution();
- inOrder.verify(helper)
- .updateBalancingState(eq(TRANSIENT_FAILURE), any(SubchannelPicker.class));
- verifyConnection(1);
-
- deliverSubchannelState(
- subchannels.get(Collections.singletonList(servers.get(0))),
- ConnectivityStateInfo.forNonError(CONNECTING));
- inOrder.verify(helper)
- .updateBalancingState(eq(TRANSIENT_FAILURE), any(SubchannelPicker.class));
- verifyConnection(1);
- }
-
- @Test
public void ignoreShutdownSubchannelStateChange() {
RingHashConfig config = new RingHashConfig(10, 100);
List<EquivalentAddressGroup> servers = createWeightedServerAddrs(1, 1, 1);
- Status addressesAcceptanceStatus = loadBalancer.acceptResolvedAddresses(
- ResolvedAddresses.newBuilder()
- .setAddresses(servers).setLoadBalancingPolicyConfig(config).build());
- assertThat(addressesAcceptanceStatus.isOk()).isTrue();
- verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class));
- verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class));
+ initializeLbSubchannels(config, servers);
loadBalancer.shutdown();
for (Subchannel sc : subchannels.values()) {
@@ -451,13 +313,8 @@ public class RingHashLoadBalancerTest {
public void deterministicPickWithHostsPartiallyRemoved() {
RingHashConfig config = new RingHashConfig(10, 100);
List<EquivalentAddressGroup> servers = createWeightedServerAddrs(1, 1, 1, 1, 1);
- Status addressesAcceptanceStatus = loadBalancer.acceptResolvedAddresses(
- ResolvedAddresses.newBuilder()
- .setAddresses(servers).setLoadBalancingPolicyConfig(config).build());
- assertThat(addressesAcceptanceStatus.isOk()).isTrue();
+ initializeLbSubchannels(config, servers);
InOrder inOrder = Mockito.inOrder(helper);
- inOrder.verify(helper, times(5)).createSubchannel(any(CreateSubchannelArgs.class));
- inOrder.verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class));
// Bring all subchannels to READY so that next pick always succeeds.
for (Subchannel subchannel : subchannels.values()) {
@@ -466,10 +323,8 @@ public class RingHashLoadBalancerTest {
}
// Simulate rpc hash hits one ring entry exactly for server1.
- long rpcHash = hashFunc.hashAsciiString("[FakeSocketAddress-server1]_0");
- PickSubchannelArgs args = new PickSubchannelArgsImpl(
- TestMethodDescriptors.voidMethod(), new Metadata(),
- CallOptions.DEFAULT.withOption(XdsNameResolver.RPC_HASH_KEY, rpcHash));
+ long rpcHash = hashFunc.hashAsciiString("FakeSocketAddress-server1_0");
+ PickSubchannelArgs args = getDefaultPickSubchannelArgs(rpcHash);
pickerCaptor.getValue().pickSubchannel(args);
PickResult result = pickerCaptor.getValue().pickSubchannel(args);
Subchannel subchannel = result.getSubchannel();
@@ -480,14 +335,14 @@ public class RingHashLoadBalancerTest {
Attributes attr = addr.getAttributes().toBuilder().set(CUSTOM_KEY, "custom value").build();
updatedServers.add(new EquivalentAddressGroup(addr.getAddresses(), attr));
}
- addressesAcceptanceStatus = loadBalancer.acceptResolvedAddresses(
+ Subchannel subchannel0_old = subchannels.get(Collections.singletonList(servers.get(0)));
+ Subchannel subchannel1_old = subchannels.get(Collections.singletonList(servers.get(1)));
+ Status addressesAcceptanceStatus = loadBalancer.acceptResolvedAddresses(
ResolvedAddresses.newBuilder()
.setAddresses(updatedServers).setLoadBalancingPolicyConfig(config).build());
assertThat(addressesAcceptanceStatus.isOk()).isTrue();
- verify(subchannels.get(Collections.singletonList(servers.get(0))))
- .updateAddresses(Collections.singletonList(updatedServers.get(0)));
- verify(subchannels.get(Collections.singletonList(servers.get(1))))
- .updateAddresses(Collections.singletonList(updatedServers.get(1)));
+ verify(subchannel0_old).updateAddresses(Collections.singletonList(updatedServers.get(0)));
+ verify(subchannel1_old).updateAddresses(Collections.singletonList(updatedServers.get(1)));
inOrder.verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture());
assertThat(pickerCaptor.getValue().pickSubchannel(args).getSubchannel())
.isSameInstanceAs(subchannel);
@@ -498,13 +353,9 @@ public class RingHashLoadBalancerTest {
public void deterministicPickWithNewHostsAdded() {
RingHashConfig config = new RingHashConfig(10, 100);
List<EquivalentAddressGroup> servers = createWeightedServerAddrs(1, 1); // server0 and server1
- Status addressesAcceptanceStatus = loadBalancer.acceptResolvedAddresses(
- ResolvedAddresses.newBuilder()
- .setAddresses(servers).setLoadBalancingPolicyConfig(config).build());
- assertThat(addressesAcceptanceStatus.isOk()).isTrue();
+ initializeLbSubchannels(config, servers, DO_NOT_VERIFY, DO_NOT_RESET_HELPER);
+
InOrder inOrder = Mockito.inOrder(helper);
- inOrder.verify(helper, times(2)).createSubchannel(any(CreateSubchannelArgs.class));
- inOrder.verify(helper).updateBalancingState(eq(IDLE), pickerCaptor.capture());
// Bring all subchannels to READY so that next pick always succeeds.
for (Subchannel subchannel : subchannels.values()) {
@@ -513,66 +364,68 @@ public class RingHashLoadBalancerTest {
}
// Simulate rpc hash hits one ring entry exactly for server1.
- long rpcHash = hashFunc.hashAsciiString("[FakeSocketAddress-server1]_0");
- PickSubchannelArgs args = new PickSubchannelArgsImpl(
- TestMethodDescriptors.voidMethod(), new Metadata(),
- CallOptions.DEFAULT.withOption(XdsNameResolver.RPC_HASH_KEY, rpcHash));
+ long rpcHash = hashFunc.hashAsciiString("FakeSocketAddress-server1_0");
+ PickSubchannelArgs args = getDefaultPickSubchannelArgs(rpcHash);
pickerCaptor.getValue().pickSubchannel(args);
PickResult result = pickerCaptor.getValue().pickSubchannel(args);
Subchannel subchannel = result.getSubchannel();
assertThat(subchannel.getAddresses()).isEqualTo(servers.get(1));
servers = createWeightedServerAddrs(1, 1, 1, 1, 1); // server2, server3, server4 added
- addressesAcceptanceStatus = loadBalancer.acceptResolvedAddresses(
+ Status addressesAcceptanceStatus = loadBalancer.acceptResolvedAddresses(
ResolvedAddresses.newBuilder()
.setAddresses(servers).setLoadBalancingPolicyConfig(config).build());
assertThat(addressesAcceptanceStatus.isOk()).isTrue();
- inOrder.verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class));
+ assertThat(loadBalancer.getChildLbStates().size()).isEqualTo(5);
inOrder.verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture());
assertThat(pickerCaptor.getValue().pickSubchannel(args).getSubchannel())
.isSameInstanceAs(subchannel);
- verifyNoMoreInteractions(helper);
+ inOrder.verifyNoMoreInteractions();
+ }
+
+ private Subchannel getSubChannel(EquivalentAddressGroup eag) {
+ return subchannels.get(Collections.singletonList(eag));
}
@Test
- public void skipFailingHosts_pickNextNonFailingHostInFirstTwoHosts() {
+ public void skipFailingHosts_pickNextNonFailingHost() {
// Map each server address to exactly one ring entry.
RingHashConfig config = new RingHashConfig(3, 3);
List<EquivalentAddressGroup> servers = createWeightedServerAddrs(1, 1, 1);
- Status addressesAcceptanceStatus = loadBalancer.acceptResolvedAddresses(
- ResolvedAddresses.newBuilder()
- .setAddresses(servers).setLoadBalancingPolicyConfig(config).build());
+ Status addressesAcceptanceStatus =
+ loadBalancer.acceptResolvedAddresses(
+ ResolvedAddresses.newBuilder()
+ .setAddresses(servers).setLoadBalancingPolicyConfig(config).build());
assertThat(addressesAcceptanceStatus.isOk()).isTrue();
- verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class));
- verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class)); // initial IDLE
+
+ // Create subchannel for the first address
+ ((RingHashChildLbState)loadBalancer.getChildLbStateEag(servers.get(0))).activate();
+ verifyConnection(1);
+
reset(helper);
// ring:
- // "[FakeSocketAddress-server1]_0"
- // "[FakeSocketAddress-server0]_0"
- // "[FakeSocketAddress-server2]_0"
+ // "FakeSocketAddress-server0_0"
+ // "FakeSocketAddress-server1_0"
+ // "FakeSocketAddress-server2_0"
- long rpcHash = hashFunc.hashAsciiString("[FakeSocketAddress-server0]_0");
+ long rpcHash = hashFunc.hashAsciiString("FakeSocketAddress-server0_0");
PickSubchannelArgs args = getDefaultPickSubchannelArgs(rpcHash);
// Bring down server0 to force trying server2.
deliverSubchannelState(
- subchannels.get(Collections.singletonList(servers.get(0))),
+ getSubChannel(servers.get(0)),
ConnectivityStateInfo.forTransientFailure(
Status.UNAVAILABLE.withDescription("unreachable")));
verify(helper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture());
- verifyConnection(1);
PickResult result = pickerCaptor.getValue().pickSubchannel(args);
assertThat(result.getStatus().isOk()).isTrue();
assertThat(result.getSubchannel()).isNull(); // buffer request
- verify(subchannels.get(Collections.singletonList(servers.get(2))))
- .requestConnection(); // kick off connection to server2
- verify(subchannels.get(Collections.singletonList(servers.get(1))), never())
- .requestConnection(); // no excessive connection
+ verify(getSubChannel(servers.get(1))).requestConnection(); // kicked off connection to server2
+ assertThat(subchannels.size()).isEqualTo(2); // no excessive connection
reset(helper);
- deliverSubchannelState(
- subchannels.get(Collections.singletonList(servers.get(2))),
+ deliverSubchannelState(getSubChannel(servers.get(1)),
ConnectivityStateInfo.forNonError(CONNECTING));
verify(helper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture());
@@ -580,14 +433,12 @@ public class RingHashLoadBalancerTest {
assertThat(result.getStatus().isOk()).isTrue();
assertThat(result.getSubchannel()).isNull(); // buffer request
- deliverSubchannelState(
- subchannels.get(Collections.singletonList(servers.get(2))),
- ConnectivityStateInfo.forNonError(READY));
+ deliverSubchannelState(getSubChannel(servers.get(1)), ConnectivityStateInfo.forNonError(READY));
verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture());
result = pickerCaptor.getValue().pickSubchannel(args);
assertThat(result.getStatus().isOk()).isTrue();
- assertThat(result.getSubchannel().getAddresses()).isEqualTo(servers.get(2));
+ assertThat(result.getSubchannel().getAddresses()).isEqualTo(servers.get(1));
}
private PickSubchannelArgsImpl getDefaultPickSubchannelArgs(long rpcHash) {
@@ -596,55 +447,59 @@ public class RingHashLoadBalancerTest {
CallOptions.DEFAULT.withOption(XdsNameResolver.RPC_HASH_KEY, rpcHash));
}
+ private PickSubchannelArgsImpl getDefaultPickSubchannelArgsForServer(int serverid) {
+ long rpcHash = hashFunc.hashAsciiString("FakeSocketAddress-server" + serverid + "_0");
+ return getDefaultPickSubchannelArgs(rpcHash);
+ }
+
@Test
public void skipFailingHosts_firstTwoHostsFailed_pickNextFirstReady() {
// Map each server address to exactly one ring entry.
RingHashConfig config = new RingHashConfig(3, 3);
List<EquivalentAddressGroup> servers = createWeightedServerAddrs(1, 1, 1);
- Status addressesAcceptanceStatus = loadBalancer.acceptResolvedAddresses(
- ResolvedAddresses.newBuilder()
- .setAddresses(servers).setLoadBalancingPolicyConfig(config).build());
- assertThat(addressesAcceptanceStatus.isOk()).isTrue();
- verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class));
- verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class)); // initial IDLE
- reset(helper);
+
+ initializeLbSubchannels(config, servers);
+
// ring:
- // "[FakeSocketAddress-server1]_0"
- // "[FakeSocketAddress-server0]_0"
- // "[FakeSocketAddress-server2]_0"
+ // "FakeSocketAddress-server0_0"
+ // "FakeSocketAddress-server1_0"
+ // "FakeSocketAddress-server2_0"
- long rpcHash = hashFunc.hashAsciiString("[FakeSocketAddress-server0]_0");
- PickSubchannelArgs args = new PickSubchannelArgsImpl(
- TestMethodDescriptors.voidMethod(), new Metadata(),
- CallOptions.DEFAULT.withOption(XdsNameResolver.RPC_HASH_KEY, rpcHash));
+ long rpcHash = hashFunc.hashAsciiString("FakeSocketAddress-server1_0");
+ PickSubchannelArgs args = getDefaultPickSubchannelArgs(rpcHash);
// Bring down server0 and server2 to force trying server1.
deliverSubchannelState(
- subchannels.get(Collections.singletonList(servers.get(0))),
+ subchannels.get(Collections.singletonList(servers.get(1))),
ConnectivityStateInfo.forTransientFailure(
Status.UNAVAILABLE.withDescription("unreachable")));
deliverSubchannelState(
subchannels.get(Collections.singletonList(servers.get(2))),
ConnectivityStateInfo.forTransientFailure(
Status.PERMISSION_DENIED.withDescription("permission denied")));
- verify(helper).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture());
- verifyConnection(2); // LB attempts to recover by itself
+ verify(helper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture());
+ verifyConnection(0);
+ PickResult result = pickerCaptor.getValue().pickSubchannel(args); // activate last subchannel
+ assertThat(result.getStatus().isOk()).isTrue();
+ verifyConnection(1);
- PickResult result = pickerCaptor.getValue().pickSubchannel(args);
+ deliverSubchannelState(
+ subchannels.get(Collections.singletonList(servers.get(0))),
+ ConnectivityStateInfo.forTransientFailure(
+ Status.PERMISSION_DENIED.withDescription("permission denied again")));
+ verify(helper, times(2)).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture());
+ result = pickerCaptor.getValue().pickSubchannel(args);
assertThat(result.getStatus().isOk()).isFalse(); // fail the RPC
assertThat(result.getStatus().getCode())
.isEqualTo(Code.UNAVAILABLE); // with error status for the original server hit by hash
assertThat(result.getStatus().getDescription()).isEqualTo("unreachable");
- verify(subchannels.get(Collections.singletonList(servers.get(1))))
- .requestConnection(); // kickoff connection to server3 (next first non-failing)
- verify(subchannels.get(Collections.singletonList(servers.get(0)))).requestConnection();
- verify(subchannels.get(Collections.singletonList(servers.get(2)))).requestConnection();
// Now connecting to server1.
deliverSubchannelState(
subchannels.get(Collections.singletonList(servers.get(1))),
ConnectivityStateInfo.forNonError(CONNECTING));
- verify(helper, times(2)).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture());
+
+ reset(helper);
result = pickerCaptor.getValue().pickSubchannel(args);
assertThat(result.getStatus().isOk()).isFalse(); // fail the RPC
@@ -658,9 +513,29 @@ public class RingHashLoadBalancerTest {
ConnectivityStateInfo.forNonError(READY));
verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture());
- result = pickerCaptor.getValue().pickSubchannel(args);
+ SubchannelPicker picker = pickerCaptor.getValue();
+ result = picker.pickSubchannel(args);
assertThat(result.getStatus().isOk()).isTrue(); // succeed
assertThat(result.getSubchannel().getAddresses()).isEqualTo(servers.get(1)); // with server1
+ assertThat(picker.pickSubchannel(getDefaultPickSubchannelArgsForServer(0))).isEqualTo(result);
+ assertThat(picker.pickSubchannel(getDefaultPickSubchannelArgsForServer(2))).isEqualTo(result);
+ }
+
+ @Test
+ public void removingAddressShutdownSubchannel() {
+ // Map each server address to exactly one ring entry.
+ RingHashConfig config = new RingHashConfig(3, 3);
+ List<EquivalentAddressGroup> svs1 = createWeightedServerAddrs(1, 1, 1);
+ List<Subchannel> subchannels1 = initializeLbSubchannels(config, svs1, STAY_IN_CONNECTING);
+
+ List<EquivalentAddressGroup> svs2 = createWeightedServerAddrs(1, 1);
+ InOrder inOrder = Mockito.inOrder(helper, subchannels1.get(2));
+ // send LB the missing address
+ loadBalancer.acceptResolvedAddresses(
+ ResolvedAddresses.newBuilder()
+ .setAddresses(svs2).setLoadBalancingPolicyConfig(config).build());
+ inOrder.verify(helper).updateBalancingState(eq(CONNECTING), any());
+ inOrder.verify(subchannels1.get(2)).shutdown();
}
@Test
@@ -668,12 +543,7 @@ public class RingHashLoadBalancerTest {
// Map each server address to exactly one ring entry.
RingHashConfig config = new RingHashConfig(3, 3);
List<EquivalentAddressGroup> servers = createWeightedServerAddrs(1, 1, 1);
- Status addressesAcceptanceStatus = loadBalancer.acceptResolvedAddresses(
- ResolvedAddresses.newBuilder()
- .setAddresses(servers).setLoadBalancingPolicyConfig(config).build());
- assertThat(addressesAcceptanceStatus.isOk()).isTrue();
- verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class));
- verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class));
+ initializeLbSubchannels(config, servers);
// Bring all subchannels to TRANSIENT_FAILURE.
for (Subchannel subchannel : subchannels.values()) {
@@ -683,23 +553,16 @@ public class RingHashLoadBalancerTest {
}
verify(helper, atLeastOnce())
.updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture());
- verifyConnection(3);
+ verifyConnection(0);
// Picking subchannel triggers connection. RPC hash hits server0.
- PickSubchannelArgs args = new PickSubchannelArgsImpl(
- TestMethodDescriptors.voidMethod(), new Metadata(),
- CallOptions.DEFAULT.withOption(XdsNameResolver.RPC_HASH_KEY, hashFunc.hashVoid()));
+ PickSubchannelArgs args = getDefaultPickSubchannelArgsForServer(0);
PickResult result = pickerCaptor.getValue().pickSubchannel(args);
assertThat(result.getStatus().isOk()).isFalse();
assertThat(result.getStatus().getCode()).isEqualTo(Code.UNAVAILABLE);
assertThat(result.getStatus().getDescription())
.isEqualTo("[FakeSocketAddress-server0] unreachable");
- verify(subchannels.get(Collections.singletonList(servers.get(0))))
- .requestConnection();
- verify(subchannels.get(Collections.singletonList(servers.get(1))))
- .requestConnection();
- verify(subchannels.get(Collections.singletonList(servers.get(2))))
- .requestConnection();
+ verifyConnection(0); // TF has already started taking care of this, pick doesn't need to
}
@Test
@@ -707,31 +570,20 @@ public class RingHashLoadBalancerTest {
// Map each server address to exactly one ring entry.
RingHashConfig config = new RingHashConfig(3, 3);
List<EquivalentAddressGroup> servers = createWeightedServerAddrs(1, 1, 1);
- Status addressesAcceptanceStatus = loadBalancer.acceptResolvedAddresses(
- ResolvedAddresses.newBuilder()
- .setAddresses(servers).setLoadBalancingPolicyConfig(config).build());
- assertThat(addressesAcceptanceStatus.isOk()).isTrue();
- verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class));
- verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class));
+ initializeLbSubchannels(config, servers);
+ // Go to TF does nothing, though PF will try to reconnect after backoff
deliverSubchannelState(subchannels.get(Collections.singletonList(servers.get(1))),
ConnectivityStateInfo.forTransientFailure(
Status.UNAVAILABLE.withDescription("unreachable")));
verify(helper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture());
- verifyConnection(1);
+ verifyConnection(0);
// Picking subchannel triggers connection. RPC hash hits server0.
- PickSubchannelArgs args = new PickSubchannelArgsImpl(
- TestMethodDescriptors.voidMethod(), new Metadata(),
- CallOptions.DEFAULT.withOption(XdsNameResolver.RPC_HASH_KEY, hashFunc.hashVoid()));
+ PickSubchannelArgs args = getDefaultPickSubchannelArgs(hashFunc.hashVoid());
PickResult result = pickerCaptor.getValue().pickSubchannel(args);
assertThat(result.getStatus().isOk()).isTrue();
- verify(subchannels.get(Collections.singletonList(servers.get(0))))
- .requestConnection();
- verify(subchannels.get(Collections.singletonList(servers.get(1))), never())
- .requestConnection();
- verify(subchannels.get(Collections.singletonList(servers.get(2))), never())
- .requestConnection();
+ verifyConnection(1);
}
@Test
@@ -739,12 +591,7 @@ public class RingHashLoadBalancerTest {
// Map each server address to exactly one ring entry.
RingHashConfig config = new RingHashConfig(3, 3);
List<EquivalentAddressGroup> servers = createWeightedServerAddrs(1, 1, 1);
- Status addressesAcceptanceStatus = loadBalancer.acceptResolvedAddresses(
- ResolvedAddresses.newBuilder()
- .setAddresses(servers).setLoadBalancingPolicyConfig(config).build());
- assertThat(addressesAcceptanceStatus.isOk()).isTrue();
- verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class));
- verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class));
+ initializeLbSubchannels(config, servers);
deliverSubchannelState(subchannels.get(Collections.singletonList(servers.get(0))),
ConnectivityStateInfo.forNonError(CONNECTING));
@@ -753,9 +600,7 @@ public class RingHashLoadBalancerTest {
verify(helper, times(2)).updateBalancingState(eq(CONNECTING), pickerCaptor.capture());
// Picking subchannel triggers connection.
- PickSubchannelArgs args = new PickSubchannelArgsImpl(
- TestMethodDescriptors.voidMethod(), new Metadata(),
- CallOptions.DEFAULT.withOption(XdsNameResolver.RPC_HASH_KEY, hashFunc.hashVoid()));
+ PickSubchannelArgs args = getDefaultPickSubchannelArgs(hashFunc.hashVoid());
PickResult result = pickerCaptor.getValue().pickSubchannel(args);
assertThat(result.getStatus().isOk()).isTrue();
verify(subchannels.get(Collections.singletonList(servers.get(0))), never())
@@ -771,35 +616,30 @@ public class RingHashLoadBalancerTest {
// Map each server address to exactly one ring entry.
RingHashConfig config = new RingHashConfig(3, 3);
List<EquivalentAddressGroup> servers = createWeightedServerAddrs(1, 1, 1);
- Status addressesAcceptanceStatus = loadBalancer.acceptResolvedAddresses(
- ResolvedAddresses.newBuilder()
- .setAddresses(servers).setLoadBalancingPolicyConfig(config).build());
- assertThat(addressesAcceptanceStatus.isOk()).isTrue();
- verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class));
- verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class));
+
+ List<Subchannel> subchannelList =
+ initializeLbSubchannels(config, servers, RESET_SUBCHANNEL_MOCKS);
+
// ring:
- // "[FakeSocketAddress-server1]_0"
- // "[FakeSocketAddress-server0]_0"
- // "[FakeSocketAddress-server2]_0"
+ // "FakeSocketAddress-server1_0"
+ // "FakeSocketAddress-server0_0"
+ // "FakeSocketAddress-server2_0"
- deliverSubchannelState(subchannels.get(Collections.singletonList(servers.get(0))),
+ deliverSubchannelState(subchannelList.get(0),
ConnectivityStateInfo.forTransientFailure(
Status.UNAVAILABLE.withDescription("unreachable")));
verify(helper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture());
- verifyConnection(1);
+ verifyConnection(0);
- // Picking subchannel triggers connection.
- PickSubchannelArgs args = new PickSubchannelArgsImpl(
- TestMethodDescriptors.voidMethod(), new Metadata(),
- CallOptions.DEFAULT.withOption(XdsNameResolver.RPC_HASH_KEY, hashFunc.hashVoid()));
- PickResult result = pickerCaptor.getValue().pickSubchannel(args);
+ // Per GRFC A61 Picking subchannel should no longer request connections that were failing
+ PickSubchannelArgs args = getDefaultPickSubchannelArgs(hashFunc.hashVoid());
+ SubchannelPicker picker1 = pickerCaptor.getValue();
+ PickResult result = picker1.pickSubchannel(args);
assertThat(result.getStatus().isOk()).isTrue();
- verify(subchannels.get(Collections.singletonList(servers.get(0))))
- .requestConnection();
- verify(subchannels.get(Collections.singletonList(servers.get(2))))
- .requestConnection();
- verify(subchannels.get(Collections.singletonList(servers.get(1))), never())
- .requestConnection();
+ assertThat(result.getSubchannel()).isNull();
+ verify(subchannelList.get(0), never()).requestConnection(); // In TF
+ verify(subchannelList.get(1)).requestConnection();
+ verify(subchannelList.get(2), never()).requestConnection(); // Not one of the first 2
}
@Test
@@ -807,38 +647,31 @@ public class RingHashLoadBalancerTest {
// Map each server address to exactly one ring entry.
RingHashConfig config = new RingHashConfig(3, 3);
List<EquivalentAddressGroup> servers = createWeightedServerAddrs(1, 1, 1);
- Status addressesAcceptanceStatus = loadBalancer.acceptResolvedAddresses(
- ResolvedAddresses.newBuilder()
- .setAddresses(servers).setLoadBalancingPolicyConfig(config).build());
- assertThat(addressesAcceptanceStatus.isOk()).isTrue();
- verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class));
- verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class));
+
+ initializeLbSubchannels(config, servers);
+
// ring:
- // "[FakeSocketAddress-server1]_0"
- // "[FakeSocketAddress-server0]_0"
- // "[FakeSocketAddress-server2]_0"
+ // "FakeSocketAddress-server1_0"
+ // "FakeSocketAddress-server0_0"
+ // "FakeSocketAddress-server2_0"
Subchannel firstSubchannel = subchannels.get(Collections.singletonList(servers.get(0)));
- deliverSubchannelState(firstSubchannel,
- ConnectivityStateInfo.forTransientFailure(Status.UNAVAILABLE.withDescription(
- firstSubchannel.getAddresses().getAddresses() + "unreachable")));
+ deliverSubchannelUnreachable(firstSubchannel);
+ verifyConnection(0);
+
deliverSubchannelState(subchannels.get(Collections.singletonList(servers.get(2))),
ConnectivityStateInfo.forNonError(CONNECTING));
verify(helper, times(2)).updateBalancingState(eq(CONNECTING), pickerCaptor.capture());
- verifyConnection(1);
+ verifyConnection(0);
- // Picking subchannel triggers connection.
- PickSubchannelArgs args = new PickSubchannelArgsImpl(
- TestMethodDescriptors.voidMethod(), new Metadata(),
- CallOptions.DEFAULT.withOption(XdsNameResolver.RPC_HASH_KEY, hashFunc.hashVoid()));
+ // Picking subchannel when idle triggers connection.
+ deliverSubchannelState(subchannels.get(Collections.singletonList(servers.get(2))),
+ ConnectivityStateInfo.forNonError(IDLE));
+ verifyConnection(0);
+ PickSubchannelArgs args = getDefaultPickSubchannelArgs(hashFunc.hashVoid());
PickResult result = pickerCaptor.getValue().pickSubchannel(args);
assertThat(result.getStatus().isOk()).isTrue();
- verify(subchannels.get(Collections.singletonList(servers.get(0))))
- .requestConnection();
- verify(subchannels.get(Collections.singletonList(servers.get(2))), never())
- .requestConnection();
- verify(subchannels.get(Collections.singletonList(servers.get(1))), never())
- .requestConnection();
+ verifyConnection(1);
}
@Test
@@ -846,42 +679,26 @@ public class RingHashLoadBalancerTest {
// Map each server address to exactly one ring entry.
RingHashConfig config = new RingHashConfig(3, 3);
List<EquivalentAddressGroup> servers = createWeightedServerAddrs(1, 1, 1);
- Status addressesAcceptanceStatus = loadBalancer.acceptResolvedAddresses(
- ResolvedAddresses.newBuilder()
- .setAddresses(servers).setLoadBalancingPolicyConfig(config).build());
- assertThat(addressesAcceptanceStatus.isOk()).isTrue();
- verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class));
- verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class));
+
+ initializeLbSubchannels(config, servers);
+
// ring:
- // "[FakeSocketAddress-server1]_0"
- // "[FakeSocketAddress-server0]_0"
- // "[FakeSocketAddress-server2]_0"
+ // "FakeSocketAddress-server1_0"
+ // "FakeSocketAddress-server0_0"
+ // "FakeSocketAddress-server2_0"
Subchannel firstSubchannel = subchannels.get(Collections.singletonList(servers.get(0)));
- deliverSubchannelState(firstSubchannel,
- ConnectivityStateInfo.forTransientFailure(Status.UNAVAILABLE.withDescription(
- firstSubchannel.getAddresses().getAddresses() + " unreachable")));
- deliverSubchannelState(subchannels.get(Collections.singletonList(servers.get(2))),
- ConnectivityStateInfo.forTransientFailure(
- Status.UNAVAILABLE.withDescription("unreachable")));
+ deliverSubchannelUnreachable(firstSubchannel);
+ deliverSubchannelUnreachable(subchannels.get(Collections.singletonList(servers.get(2))));
verify(helper).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture());
- verifyConnection(2);
+ verifyConnection(0);
// Picking subchannel triggers connection.
- PickSubchannelArgs args = new PickSubchannelArgsImpl(
- TestMethodDescriptors.voidMethod(), new Metadata(),
- CallOptions.DEFAULT.withOption(XdsNameResolver.RPC_HASH_KEY, hashFunc.hashVoid()));
+ PickSubchannelArgs args = getDefaultPickSubchannelArgs(hashFunc.hashVoid());
PickResult result = pickerCaptor.getValue().pickSubchannel(args);
- assertThat(result.getStatus().isOk()).isFalse();
- assertThat(result.getStatus().getCode()).isEqualTo(Code.UNAVAILABLE);
- assertThat(result.getStatus().getDescription())
- .isEqualTo("[FakeSocketAddress-server0] unreachable");
- verify(subchannels.get(Collections.singletonList(servers.get(0))))
- .requestConnection();
- verify(subchannels.get(Collections.singletonList(servers.get(2))))
- .requestConnection();
- verify(subchannels.get(Collections.singletonList(servers.get(1))))
- .requestConnection();
+ assertThat(result.getStatus().isOk()).isTrue();
+ verify(subchannels.get(Collections.singletonList(servers.get(1)))).requestConnection();
+ verifyConnection(1);
}
@Test
@@ -889,44 +706,29 @@ public class RingHashLoadBalancerTest {
// Map each server address to exactly one ring entry.
RingHashConfig config = new RingHashConfig(3, 3);
List<EquivalentAddressGroup> servers = createWeightedServerAddrs(1, 1, 1);
- Status addressesAcceptanceStatus = loadBalancer.acceptResolvedAddresses(
- ResolvedAddresses.newBuilder()
- .setAddresses(servers).setLoadBalancingPolicyConfig(config).build());
- assertThat(addressesAcceptanceStatus.isOk()).isTrue();
- verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class));
- verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class));
+
+ initializeLbSubchannels(config, servers);
+
// ring:
- // "[FakeSocketAddress-server1]_0"
- // "[FakeSocketAddress-server0]_0"
- // "[FakeSocketAddress-server2]_0"
+ // "FakeSocketAddress-server1_0"
+ // "FakeSocketAddress-server0_0"
+ // "FakeSocketAddress-server2_0"
Subchannel firstSubchannel = subchannels.get(Collections.singletonList(servers.get(0)));
- deliverSubchannelState(firstSubchannel,
- ConnectivityStateInfo.forTransientFailure(Status.UNAVAILABLE.withDescription(
- firstSubchannel.getAddresses().getAddresses() + " unreachable")));
- deliverSubchannelState(subchannels.get(Collections.singletonList(servers.get(2))),
- ConnectivityStateInfo.forTransientFailure(
- Status.UNAVAILABLE.withDescription("unreachable")));
+
+ deliverSubchannelUnreachable(firstSubchannel);
+ deliverSubchannelUnreachable(subchannels.get(Collections.singletonList(servers.get(2))));
deliverSubchannelState(subchannels.get(Collections.singletonList(servers.get(1))),
ConnectivityStateInfo.forNonError(CONNECTING));
- verify(helper, times(2)).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture());
- verifyConnection(2);
+ verify(helper, atLeastOnce())
+ .updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture());
+ verifyConnection(0);
- // Picking subchannel triggers connection.
- PickSubchannelArgs args = new PickSubchannelArgsImpl(
- TestMethodDescriptors.voidMethod(), new Metadata(),
- CallOptions.DEFAULT.withOption(XdsNameResolver.RPC_HASH_KEY, hashFunc.hashVoid()));
+ // Picking subchannel should not trigger connection per gRFC A61.
+ PickSubchannelArgs args = getDefaultPickSubchannelArgs(hashFunc.hashVoid());
PickResult result = pickerCaptor.getValue().pickSubchannel(args);
- assertThat(result.getStatus().isOk()).isFalse();
- assertThat(result.getStatus().getCode()).isEqualTo(Code.UNAVAILABLE);
- assertThat(result.getStatus().getDescription())
- .isEqualTo("[FakeSocketAddress-server0] unreachable");
- verify(subchannels.get(Collections.singletonList(servers.get(0))))
- .requestConnection();
- verify(subchannels.get(Collections.singletonList(servers.get(2))))
- .requestConnection();
- verify(subchannels.get(Collections.singletonList(servers.get(1))), never())
- .requestConnection();
+ assertThat(result.getStatus().isOk()).isTrue();
+ verifyConnection(0);
}
@Test
@@ -934,35 +736,29 @@ public class RingHashLoadBalancerTest {
// Map each server address to exactly one ring entry.
RingHashConfig config = new RingHashConfig(3, 3);
List<EquivalentAddressGroup> servers = createWeightedServerAddrs(1, 1, 1);
- Status addressesAcceptanceStatus = loadBalancer.acceptResolvedAddresses(
- ResolvedAddresses.newBuilder()
- .setAddresses(servers).setLoadBalancingPolicyConfig(config).build());
- assertThat(addressesAcceptanceStatus.isOk()).isTrue();
- verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class));
- verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class));
+
+ initializeLbSubchannels(config, servers);
// Bring one subchannel to TRANSIENT_FAILURE.
Subchannel firstSubchannel = subchannels.get(Collections.singletonList(servers.get(0)));
- deliverSubchannelState(firstSubchannel,
- ConnectivityStateInfo.forTransientFailure(
- Status.UNAVAILABLE.withDescription(
- firstSubchannel.getAddresses().getAddresses() + " unreachable")));
+ deliverSubchannelUnreachable(firstSubchannel);
- verify(helper).updateBalancingState(eq(CONNECTING), any());
- verifyConnection(1);
+ verify(helper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture());
+ verifyConnection(0);
+
+ reset(helper);
deliverSubchannelState(firstSubchannel, ConnectivityStateInfo.forNonError(IDLE));
- verify(helper, times(2)).updateBalancingState(eq(CONNECTING), pickerCaptor.capture());
+ // Should not have called updateBalancingState on the helper again because PickFirst is
+ // shielding the higher level from the state change.
+ verify(helper, never()).updateBalancingState(any(), any());
verifyConnection(1);
- // Picking subchannel triggers connection. RPC hash hits server0.
- PickSubchannelArgs args = new PickSubchannelArgsImpl(
- TestMethodDescriptors.voidMethod(), new Metadata(),
- CallOptions.DEFAULT.withOption(XdsNameResolver.RPC_HASH_KEY, hashFunc.hashVoid()));
+ // Picking subchannel triggers connection on second address. RPC hash hits server0.
+ PickSubchannelArgs args = getDefaultPickSubchannelArgs(hashFunc.hashVoid());
PickResult result = pickerCaptor.getValue().pickSubchannel(args);
assertThat(result.getStatus().isOk()).isTrue();
- verify(subchannels.get(Collections.singletonList(servers.get(0)))).requestConnection();
- verify(subchannels.get(Collections.singletonList(servers.get(2)))).requestConnection();
- verify(subchannels.get(Collections.singletonList(servers.get(1))), never())
+ verify(subchannels.get(Collections.singletonList(servers.get(1)))).requestConnection();
+ verify(subchannels.get(Collections.singletonList(servers.get(2))), never())
.requestConnection();
}
@@ -971,14 +767,12 @@ public class RingHashLoadBalancerTest {
RingHashConfig config = new RingHashConfig(10000, 100000); // large ring
List<EquivalentAddressGroup> servers =
createWeightedServerAddrs(Integer.MAX_VALUE, 10, 100); // MAX:10:100
- Status addressesAcceptanceStatus = loadBalancer.acceptResolvedAddresses(
- ResolvedAddresses.newBuilder()
- .setAddresses(servers).setLoadBalancingPolicyConfig(config).build());
- assertThat(addressesAcceptanceStatus.isOk()).isTrue();
+
+ initializeLbSubchannels(config, servers);
// Try value between max signed and max unsigned int
servers = createWeightedServerAddrs(Integer.MAX_VALUE + 100L, 100); // (MAX+100):100
- addressesAcceptanceStatus = loadBalancer.acceptResolvedAddresses(
+ Status addressesAcceptanceStatus = loadBalancer.acceptResolvedAddresses(
ResolvedAddresses.newBuilder()
.setAddresses(servers).setLoadBalancingPolicyConfig(config).build());
assertThat(addressesAcceptanceStatus.isOk()).isTrue();
@@ -1010,12 +804,8 @@ public class RingHashLoadBalancerTest {
public void hostSelectionProportionalToWeights() {
RingHashConfig config = new RingHashConfig(10000, 100000); // large ring
List<EquivalentAddressGroup> servers = createWeightedServerAddrs(1, 10, 100); // 1:10:100
- Status addressesAcceptanceStatus = loadBalancer.acceptResolvedAddresses(
- ResolvedAddresses.newBuilder()
- .setAddresses(servers).setLoadBalancingPolicyConfig(config).build());
- assertThat(addressesAcceptanceStatus.isOk()).isTrue();
- verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class));
- verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class));
+
+ initializeLbSubchannels(config, servers);
// Bring all subchannels to READY.
Map<EquivalentAddressGroup, Integer> pickCounts = new HashMap<>();
@@ -1028,9 +818,7 @@ public class RingHashLoadBalancerTest {
for (int i = 0; i < 10000; i++) {
long hash = hashFunc.hashInt(i);
- PickSubchannelArgs args = new PickSubchannelArgsImpl(
- TestMethodDescriptors.voidMethod(), new Metadata(),
- CallOptions.DEFAULT.withOption(XdsNameResolver.RPC_HASH_KEY, hash));
+ PickSubchannelArgs args = getDefaultPickSubchannelArgs(hash);
Subchannel pickedSubchannel = picker.pickSubchannel(args).getSubchannel();
EquivalentAddressGroup addr = pickedSubchannel.getAddresses();
pickCounts.put(addr, pickCounts.get(addr) + 1);
@@ -1059,21 +847,19 @@ public class RingHashLoadBalancerTest {
public void nameResolutionErrorWithActiveSubchannels() {
RingHashConfig config = new RingHashConfig(10, 100);
List<EquivalentAddressGroup> servers = createWeightedServerAddrs(1);
- Status addressesAcceptanceStatus = loadBalancer.acceptResolvedAddresses(
- ResolvedAddresses.newBuilder()
- .setAddresses(servers).setLoadBalancingPolicyConfig(config).build());
- assertThat(addressesAcceptanceStatus.isOk()).isTrue();
+
+ initializeLbSubchannels(config, servers, DO_NOT_VERIFY, DO_NOT_RESET_HELPER);
verify(helper).createSubchannel(any(CreateSubchannelArgs.class));
- verify(helper).updateBalancingState(eq(IDLE), pickerCaptor.capture());
+ verify(helper, times(2)).updateBalancingState(eq(IDLE), pickerCaptor.capture());
// Picking subchannel triggers subchannel creation and connection.
- PickSubchannelArgs args = new PickSubchannelArgsImpl(
- TestMethodDescriptors.voidMethod(), new Metadata(),
- CallOptions.DEFAULT.withOption(XdsNameResolver.RPC_HASH_KEY, hashFunc.hashVoid()));
+ PickSubchannelArgs args = getDefaultPickSubchannelArgs(hashFunc.hashVoid());
pickerCaptor.getValue().pickSubchannel(args);
+ verify(helper, never()).updateBalancingState(eq(READY), any(SubchannelPicker.class));
deliverSubchannelState(
Iterables.getOnlyElement(subchannels.values()), ConnectivityStateInfo.forNonError(READY));
verify(helper).updateBalancingState(eq(READY), any(SubchannelPicker.class));
+ reset(helper);
loadBalancer.handleNameResolutionError(Status.NOT_FOUND.withDescription("target not found"));
verifyNoMoreInteractions(helper);
@@ -1083,15 +869,12 @@ public class RingHashLoadBalancerTest {
public void duplicateAddresses() {
RingHashConfig config = new RingHashConfig(10, 100);
List<EquivalentAddressGroup> servers = createRepeatedServerAddrs(1, 2, 3);
- Status addressesAcceptanceStatus = loadBalancer.acceptResolvedAddresses(
- ResolvedAddresses.newBuilder()
- .setAddresses(servers).setLoadBalancingPolicyConfig(config).build());
- assertThat(addressesAcceptanceStatus.isOk()).isFalse();
+
+ initializeLbSubchannels(config, servers, DO_NOT_VERIFY);
+
verify(helper).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture());
- PickSubchannelArgs args = new PickSubchannelArgsImpl(
- TestMethodDescriptors.voidMethod(), new Metadata(),
- CallOptions.DEFAULT.withOption(XdsNameResolver.RPC_HASH_KEY, hashFunc.hashVoid()));
+ PickSubchannelArgs args = getDefaultPickSubchannelArgs(hashFunc.hashVoid());
PickResult result = pickerCaptor.getValue().pickSubchannel(args);
assertThat(result.getStatus().isOk()).isFalse(); // fail the RPC
assertThat(result.getStatus().getCode())
@@ -1103,8 +886,105 @@ public class RingHashLoadBalancerTest {
assertThat(description).contains("Address: FakeSocketAddress-server2, count: 3");
}
+ private List<Subchannel> initializeLbSubchannels(RingHashConfig config,
+ List<EquivalentAddressGroup> servers, InitializationFlags... initFlags) {
+
+ boolean doVerifies = true;
+ boolean resetSubchannels = false;
+ boolean returnToIdle = true;
+ boolean resetHelper = true;
+ for (InitializationFlags flag : initFlags) {
+ switch (flag) {
+ case DO_NOT_VERIFY:
+ doVerifies = false;
+ break;
+ case RESET_SUBCHANNEL_MOCKS:
+ resetSubchannels = true;
+ break;
+ case STAY_IN_CONNECTING:
+ returnToIdle = false;
+ break;
+ case DO_NOT_RESET_HELPER:
+ resetHelper = false;
+ break;
+ default:
+ throw new IllegalArgumentException("Unrecognized flag: " + flag);
+ }
+ }
+
+ Status addressesAcceptanceStatus =
+ loadBalancer.acceptResolvedAddresses(
+ ResolvedAddresses.newBuilder()
+ .setAddresses(servers).setLoadBalancingPolicyConfig(config).build());
+
+ if (doVerifies) {
+ assertThat(addressesAcceptanceStatus.isOk()).isTrue();
+ verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class));
+ }
+
+ if (!addressesAcceptanceStatus.isOk()) {
+ return new ArrayList<>();
+ }
+
+ // Activate them all to create the child LB and subchannel
+ for (ChildLbState childLbState : loadBalancer.getChildLbStates()) {
+ ((RingHashChildLbState)childLbState).activate();
+ }
+
+ if (doVerifies) {
+ verify(helper, times(servers.size())).createSubchannel(any(CreateSubchannelArgs.class));
+ verify(helper, times(servers.size()))
+ .updateBalancingState(eq(CONNECTING), any(SubchannelPicker.class));
+ verifyConnection(servers.size());
+ }
+
+ if (returnToIdle) {
+ for (Subchannel subchannel : subchannels.values()) {
+ deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(IDLE));
+ }
+ if (doVerifies) {
+ verify(helper, times(2 * servers.size() - 1))
+ .updateBalancingState(eq(CONNECTING), any(SubchannelPicker.class));
+ verify(helper, times(2)).updateBalancingState(eq(IDLE), any(SubchannelPicker.class));
+ }
+ }
+
+
+ // Get a list of subchannels in the same order as servers
+ List<Subchannel> subchannelList = new ArrayList<>();
+ for (EquivalentAddressGroup server : servers) {
+ List<EquivalentAddressGroup> singletonList = Collections.singletonList(server);
+ Subchannel subchannel = subchannels.get(singletonList);
+ subchannelList.add(subchannel);
+ if (resetSubchannels) {
+ reset(subchannel);
+ }
+ }
+
+ if (resetHelper) {
+ reset(helper);
+ }
+
+ return subchannelList;
+ }
+
private void deliverSubchannelState(Subchannel subchannel, ConnectivityStateInfo state) {
- subchannelStateListeners.get(subchannel).onSubchannelState(state);
+ testHelperInst.deliverSubchannelState(subchannel, state);
+ }
+
+ private void deliverNotFound(List<Subchannel> subChannelList, int index) {
+ deliverSubchannelState(
+ subChannelList.get(index),
+ ConnectivityStateInfo.forTransientFailure(
+ Status.UNAVAILABLE.withDescription("also not found")));
+ }
+
+ protected void deliverSubchannelUnreachable(Subchannel subchannel) {
+ deliverSubchannelState(subchannel,
+ ConnectivityStateInfo.forTransientFailure(
+ Status.UNAVAILABLE.withDescription(
+ subchannel.getAddresses().getAddresses() + "unreachable")));
+
}
private static List<EquivalentAddressGroup> createWeightedServerAddrs(long... weights) {
@@ -1156,4 +1036,51 @@ public class RingHashLoadBalancerTest {
return "FakeSocketAddress-" + name;
}
}
+
+ private class TestHelper extends AbstractTestHelper {
+
+ @Override
+ public Map<List<EquivalentAddressGroup>, Subchannel> getSubchannelMap() {
+ return subchannels;
+ }
+
+ @Override
+ public String getAuthority() {
+ return AUTHORITY;
+ }
+
+ @Override
+ public SynchronizationContext getSynchronizationContext() {
+ return syncContext;
+ }
+
+ private Subchannel getMockSubchannel(Subchannel realSubchannel) {
+ return realToMockSubChannelMap.get(realSubchannel);
+ }
+
+ @Override
+ protected AbstractTestHelper.TestSubchannel createRealSubchannel(CreateSubchannelArgs args) {
+ return new RingHashTestSubchannel(args);
+ }
+
+ private class RingHashTestSubchannel extends AbstractTestHelper.TestSubchannel {
+
+ RingHashTestSubchannel(CreateSubchannelArgs args) {
+ super(args);
+ }
+
+ @Override
+ public void requestConnection() {
+ connectionRequestedQueue.offer(getMockSubchannel(this));
+ }
+
+ }
+ }
+
+ enum InitializationFlags {
+ DO_NOT_VERIFY,
+ RESET_SUBCHANNEL_MOCKS,
+ STAY_IN_CONNECTING,
+ DO_NOT_RESET_HELPER
+ }
}
diff --git a/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java
index b6ea8bc04..dfbbf9b08 100644
--- a/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java
+++ b/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java
@@ -17,6 +17,7 @@
package io.grpc.xds;
import static com.google.common.truth.Truth.assertThat;
+import static io.grpc.ConnectivityState.CONNECTING;
import static org.mockito.AdditionalAnswers.delegatesTo;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.eq;
@@ -59,6 +60,7 @@ import io.grpc.xds.WeightedRoundRobinLoadBalancer.WeightedRoundRobinLoadBalancer
import io.grpc.xds.WeightedRoundRobinLoadBalancer.WeightedRoundRobinPicker;
import java.net.SocketAddress;
import java.util.Arrays;
+import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
@@ -78,7 +80,9 @@ import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
+import org.mockito.InOrder;
import org.mockito.Mock;
+import org.mockito.Mockito;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.junit.MockitoJUnit;
import org.mockito.junit.MockitoRule;
@@ -176,7 +180,7 @@ public class WeightedRoundRobinLoadBalancerTest {
.forNonError(ConnectivityState.READY));
Subchannel connectingSubchannel = it.next();
getSubchannelStateListener(connectingSubchannel).onSubchannelState(ConnectivityStateInfo
- .forNonError(ConnectivityState.CONNECTING));
+ .forNonError(CONNECTING));
verify(helper, times(2)).updateBalancingState(
eq(ConnectivityState.READY), pickerCaptor.capture());
assertThat(pickerCaptor.getAllValues().size()).isEqualTo(2);
@@ -477,7 +481,7 @@ public class WeightedRoundRobinLoadBalancerTest {
.setAttributes(affinity).build()));
verify(helper, times(6)).createSubchannel(
any(CreateSubchannelArgs.class));
- verify(helper).updateBalancingState(eq(ConnectivityState.CONNECTING), pickerCaptor.capture());
+ verify(helper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture());
assertThat(pickerCaptor.getValue().getClass().getName())
.isEqualTo("io.grpc.util.RoundRobinLoadBalancer$EmptyPicker");
assertThat(fakeClock.forwardTime(11, TimeUnit.SECONDS)).isEqualTo(1);
@@ -554,7 +558,7 @@ public class WeightedRoundRobinLoadBalancerTest {
.forNonError(ConnectivityState.READY));
Subchannel connectingSubchannel = it.next();
getSubchannelStateListener(connectingSubchannel).onSubchannelState(ConnectivityStateInfo
- .forNonError(ConnectivityState.CONNECTING));
+ .forNonError(CONNECTING));
verify(helper, times(2)).updateBalancingState(
eq(ConnectivityState.READY), pickerCaptor.capture());
assertThat(pickerCaptor.getAllValues().size()).isEqualTo(2);
@@ -1063,6 +1067,24 @@ public class WeightedRoundRobinLoadBalancerTest {
assertThat(sequence.get()).isEqualTo(9);
}
+ @Test
+ public void removingAddressShutsdownSubchannel() {
+ syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder()
+ .setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig)
+ .setAttributes(affinity).build()));
+ final Subchannel subchannel2 = subchannels.get(Collections.singletonList(servers.get(2)));
+
+ InOrder inOrder = Mockito.inOrder(helper, subchannel2);
+ // send LB only the first 2 addresses
+ List<EquivalentAddressGroup> svs2 = Arrays.asList(servers.get(0), servers.get(1));
+ syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder()
+ .setAddresses(svs2).setLoadBalancingPolicyConfig(weightedConfig)
+ .setAttributes(affinity).build()));
+ inOrder.verify(helper).updateBalancingState(eq(CONNECTING), any());
+ inOrder.verify(subchannel2).shutdown();
+ }
+
+
private static final class VerifyingScheduler {
private final StaticStrideScheduler delegate;
private final int max;