diff options
author | Larry Safran <lsafran@google.com> | 2023-11-09 13:46:52 -0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-11-09 13:46:52 -0800 |
commit | dfdd50bc7905aa904d19be38f7d300b98c60cf6e (patch) | |
tree | 3ba4e7fa0b84e5975015af98dddaac0c6cf48bd4 | |
parent | 0346b40e4e19e081e082288fc8af49169dd848ad (diff) | |
download | grpc-grpc-java-dfdd50bc7905aa904d19be38f7d300b98c60cf6e.tar.gz |
xds:Make Ring Hash LB a petiole policy (#10610)
* Update picker logic per A61 that it no longer pays attention to the first 2 elements, but rather takes the first ring element not in TF and uses that.
---------
Pulled in by rebase:
Eric Anderson (android: Remove unneeded proguard rule 44723b6)
Terry Wilson (stub: Deprecate StreamObservers b5434e8)
13 files changed, 902 insertions, 871 deletions
diff --git a/core/src/main/java/io/grpc/internal/PickFirstLoadBalancer.java b/core/src/main/java/io/grpc/internal/PickFirstLoadBalancer.java index 13cbeed1d..acef79d3d 100644 --- a/core/src/main/java/io/grpc/internal/PickFirstLoadBalancer.java +++ b/core/src/main/java/io/grpc/internal/PickFirstLoadBalancer.java @@ -102,6 +102,7 @@ final class PickFirstLoadBalancer extends LoadBalancer { subchannel.shutdown(); subchannel = null; } + // NB(lukaszx0) Whether we should propagate the error unconditionally is arguable. It's fine // for time being. updateBalancingState(TRANSIENT_FAILURE, new Picker(PickResult.withError(error))); diff --git a/util/src/main/java/io/grpc/util/GracefulSwitchLoadBalancer.java b/util/src/main/java/io/grpc/util/GracefulSwitchLoadBalancer.java index eab77d7cd..a07428a30 100644 --- a/util/src/main/java/io/grpc/util/GracefulSwitchLoadBalancer.java +++ b/util/src/main/java/io/grpc/util/GracefulSwitchLoadBalancer.java @@ -181,4 +181,8 @@ public final class GracefulSwitchLoadBalancer extends ForwardingLoadBalancer { pendingLb.shutdown(); currentLb.shutdown(); } + + public String delegateType() { + return delegate().getClass().getSimpleName(); + } } diff --git a/util/src/main/java/io/grpc/util/MultiChildLoadBalancer.java b/util/src/main/java/io/grpc/util/MultiChildLoadBalancer.java index a196b3859..306901250 100644 --- a/util/src/main/java/io/grpc/util/MultiChildLoadBalancer.java +++ b/util/src/main/java/io/grpc/util/MultiChildLoadBalancer.java @@ -26,6 +26,7 @@ import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import io.grpc.ConnectivityState; import io.grpc.EquivalentAddressGroup; import io.grpc.Internal; @@ -57,11 +58,9 @@ public abstract class MultiChildLoadBalancer extends LoadBalancer { private final Map<Object, ChildLbState> childLbStates = new LinkedHashMap<>(); private final Helper helper; // Set to true if currently in the process of handling resolved addresses. - @VisibleForTesting protected boolean resolvingAddresses; - protected final PickFirstLoadBalancerProvider pickFirstLbProvider = - new PickFirstLoadBalancerProvider(); + protected final LoadBalancerProvider pickFirstLbProvider = new PickFirstLoadBalancerProvider(); protected ConnectivityState currentConnectivityState; @@ -85,6 +84,10 @@ public abstract class MultiChildLoadBalancer extends LoadBalancer { * Generally, the only reason to override this is to expose it to a test of a LB in a different * package. */ + protected ImmutableMap<Object, ChildLbState> getImmutableChildMap() { + return ImmutableMap.copyOf(childLbStates); + } + @VisibleForTesting protected Collection<ChildLbState> getChildLbStates() { return childLbStates.values(); @@ -93,8 +96,7 @@ public abstract class MultiChildLoadBalancer extends LoadBalancer { /** * Generally, the only reason to override this is to expose it to a test of a LB in a * different package. - */ - + */ protected ChildLbState getChildLbState(Object key) { if (key == null) { return null; @@ -125,7 +127,8 @@ public abstract class MultiChildLoadBalancer extends LoadBalancer { if (existingChildLbState != null) { childLbMap.put(endpoint, existingChildLbState); } else { - childLbMap.put(endpoint, createChildLbState(endpoint, null, getInitialPicker())); + childLbMap.put(endpoint, + createChildLbState(endpoint, null, getInitialPicker(), resolvedAddresses)); } } return childLbMap; @@ -135,7 +138,7 @@ public abstract class MultiChildLoadBalancer extends LoadBalancer { * Override to create an instance of a subclass. */ protected ChildLbState createChildLbState(Object key, Object policyConfig, - SubchannelPicker initialPicker) { + SubchannelPicker initialPicker, ResolvedAddresses resolvedAddresses) { return new ChildLbState(key, pickFirstLbProvider, policyConfig, initialPicker); } @@ -146,7 +149,20 @@ public abstract class MultiChildLoadBalancer extends LoadBalancer { public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { try { resolvingAddresses = true; - return acceptResolvedAddressesInternal(resolvedAddresses); + + // process resolvedAddresses to update children + AcceptResolvedAddressRetVal acceptRetVal = + acceptResolvedAddressesInternal(resolvedAddresses); + if (!acceptRetVal.status.isOk()) { + return acceptRetVal.status; + } + + // Update the picker and our connectivity state + updateOverallBalancingState(); + + // shutdown removed children + shutdownRemoved(acceptRetVal.removedChildren); + return acceptRetVal.status; } finally { resolvingAddresses = false; } @@ -161,15 +177,18 @@ public abstract class MultiChildLoadBalancer extends LoadBalancer { */ protected ResolvedAddresses getChildAddresses(Object key, ResolvedAddresses resolvedAddresses, Object childConfig) { + Endpoint endpointKey; if (key instanceof EquivalentAddressGroup) { - key = new Endpoint((EquivalentAddressGroup) key); + endpointKey = new Endpoint((EquivalentAddressGroup) key); + } else { + checkArgument(key instanceof Endpoint, "key is wrong type"); + endpointKey = (Endpoint) key; } - checkArgument(key instanceof Endpoint, "key is wrong type"); // Retrieve the non-stripped version EquivalentAddressGroup eagToUse = null; for (EquivalentAddressGroup currEag : resolvedAddresses.getAddresses()) { - if (key.equals(new Endpoint(currEag))) { + if (endpointKey.equals(new Endpoint(currEag))) { eagToUse = currEag; break; } @@ -183,15 +202,21 @@ public abstract class MultiChildLoadBalancer extends LoadBalancer { .build(); } - private Status acceptResolvedAddressesInternal(ResolvedAddresses resolvedAddresses) { + /** + * This does the work to update the child map and calculate which children have been removed. + * You must call {@link #updateOverallBalancingState} to update the picker + * and call {@link #shutdownRemoved(List)} to shutdown the endpoints that have been removed. + */ + protected AcceptResolvedAddressRetVal acceptResolvedAddressesInternal( + ResolvedAddresses resolvedAddresses) { logger.log(Level.FINE, "Received resolution result: {0}", resolvedAddresses); Map<Object, ChildLbState> newChildren = createChildLbMap(resolvedAddresses); if (newChildren.isEmpty()) { Status unavailableStatus = Status.UNAVAILABLE.withDescription( - "NameResolver returned no usable address. " + resolvedAddresses); + "NameResolver returned no usable address. " + resolvedAddresses); handleNameResolutionError(unavailableStatus); - return unavailableStatus; + return new AcceptResolvedAddressRetVal(unavailableStatus, null); } // Do adds and updates @@ -204,33 +229,44 @@ public abstract class MultiChildLoadBalancer extends LoadBalancer { } else { // Reuse the existing one ChildLbState existingChildLbState = childLbStates.get(key); - if (existingChildLbState.isDeactivated()) { + if (existingChildLbState.isDeactivated() && reactivateChildOnReuse()) { existingChildLbState.reactivate(childPolicyProvider); } } - LoadBalancer childLb = childLbStates.get(key).lb; + ChildLbState childLbState = childLbStates.get(key); ResolvedAddresses childAddresses = getChildAddresses(key, resolvedAddresses, childConfig); - childLbStates.get(key).setResolvedAddresses(childAddresses); // update child state - childLb.handleResolvedAddresses(childAddresses); // update child LB + childLbStates.get(key).setResolvedAddresses(childAddresses); // update child + if (!childLbState.deactivated) { + childLbState.lb.handleResolvedAddresses(childAddresses); // update child LB + } } + List<ChildLbState> removedChildren = new ArrayList<>(); // Do removals for (Object key : ImmutableList.copyOf(childLbStates.keySet())) { if (!newChildren.containsKey(key)) { - childLbStates.get(key).deactivate(); + ChildLbState childLbState = childLbStates.get(key); + childLbState.deactivate(); + removedChildren.add(childLbState); } } - // Must update channel picker before return so that new RPCs will not be routed to deleted - // clusters and resolver can remove them in service config. - updateOverallBalancingState(); - return Status.OK; + + return new AcceptResolvedAddressRetVal(Status.OK, removedChildren); + } + + protected void shutdownRemoved(List<ChildLbState> removedChildren) { + // Do shutdowns after updating picker to reduce the chance of failing an RPC by picking a + // subchannel that has been shutdown. + for (ChildLbState childLbState : removedChildren) { + childLbState.shutdown(); + } } @Override public void handleNameResolutionError(Status error) { if (currentConnectivityState != READY) { - updateHelperBalancingState(TRANSIENT_FAILURE, getErrorPicker(error)); + helper.updateBalancingState(TRANSIENT_FAILURE, getErrorPicker(error)); } } @@ -240,12 +276,22 @@ public abstract class MultiChildLoadBalancer extends LoadBalancer { /** * If true, then when a subchannel state changes to idle, the corresponding child will - * have requestConnection called on its LB. + * have requestConnection called on its LB. Also causes the PickFirstLB to be created when + * the child is created or reused. */ protected boolean reconnectOnIdle() { return true; } + /** + * If true, then when {@link #acceptResolvedAddresses} sees a key that was already part of the + * child map which is deactivated, it will call reactivate on the child. + * If false, it will leave it deactivated. + */ + protected boolean reactivateChildOnReuse() { + return true; + } + @Override public void shutdown() { logger.log(Level.INFO, "Shutdown"); @@ -265,17 +311,13 @@ public abstract class MultiChildLoadBalancer extends LoadBalancer { childPickers.put(childLbState.key, childLbState.currentPicker); overallState = aggregateState(overallState, childLbState.currentState); } + if (overallState != null) { helper.updateBalancingState(overallState, getSubchannelPicker(childPickers)); currentConnectivityState = overallState; } } - protected final void updateHelperBalancingState(ConnectivityState newState, - SubchannelPicker newPicker) { - helper.updateBalancingState(newState, newPicker); - } - @Nullable protected static ConnectivityState aggregateState( @Nullable ConnectivityState overallState, ConnectivityState childState) { @@ -332,20 +374,31 @@ public abstract class MultiChildLoadBalancer extends LoadBalancer { private final Object key; private ResolvedAddresses resolvedAddresses; private final Object config; + private final GracefulSwitchLoadBalancer lb; - private LoadBalancerProvider policyProvider; - private ConnectivityState currentState = CONNECTING; + private final LoadBalancerProvider policyProvider; + private ConnectivityState currentState; private SubchannelPicker currentPicker; private boolean deactivated; public ChildLbState(Object key, LoadBalancerProvider policyProvider, Object childConfig, SubchannelPicker initialPicker) { + this(key, policyProvider, childConfig, initialPicker, null, false); + } + + public ChildLbState(Object key, LoadBalancerProvider policyProvider, Object childConfig, + SubchannelPicker initialPicker, ResolvedAddresses resolvedAddrs, boolean deactivated) { this.key = key; this.policyProvider = policyProvider; - lb = new GracefulSwitchLoadBalancer(new ChildLbStateHelper()); - lb.switchTo(policyProvider); - currentPicker = initialPicker; - config = childConfig; + this.deactivated = deactivated; + this.currentPicker = initialPicker; + this.config = childConfig; + this.lb = new GracefulSwitchLoadBalancer(new ChildLbStateHelper()); + this.currentState = deactivated ? IDLE : CONNECTING; + this.resolvedAddresses = resolvedAddrs; + if (!deactivated) { + lb.switchTo(policyProvider); + } } @Override @@ -365,6 +418,10 @@ public abstract class MultiChildLoadBalancer extends LoadBalancer { return config; } + protected GracefulSwitchLoadBalancer getLb() { + return lb; + } + public LoadBalancerProvider getPolicyProvider() { return policyProvider; } @@ -399,34 +456,41 @@ public abstract class MultiChildLoadBalancer extends LoadBalancer { deactivated = true; } + protected void markReactivated() { + deactivated = false; + } + protected void setResolvedAddresses(ResolvedAddresses newAddresses) { checkNotNull(newAddresses, "Missing address list for child"); resolvedAddresses = newAddresses; } + /** + * The default implementation. This not only marks the lb policy as not active, it also removes + * this child from the map of children maintained by the petiole policy. + * + * <p>Note that this does not explicitly shutdown this child. That will generally be done by + * acceptResolvedAddresses on the LB, but can also be handled by an override such as is done + * in <a href=" https://github.com/grpc/grpc-java/blob/master/xds/src/main/java/io/grpc/xds/ClusterManagerLoadBalancer.java">ClusterManagerLoadBalancer</a>. + * + * <p>If you plan to reactivate, you will probably want to override this to not call + * childLbStates.remove() and handle that cleanup another way. + */ protected void deactivate() { if (deactivated) { return; } - shutdown(); - childLbStates.remove(key); + childLbStates.remove(key); // This means it can't be reactivated again deactivated = true; logger.log(Level.FINE, "Child balancer {0} deactivated", key); } + /** + * This base implementation does nothing but reset the flag. If you really want to both + * deactivate and reactivate you should override them both. + */ protected void reactivate(LoadBalancerProvider policyProvider) { - if (!this.policyProvider.getPolicyName().equals(policyProvider.getPolicyName())) { - Object[] objects = { - key, this.policyProvider.getPolicyName(),policyProvider.getPolicyName()}; - logger.log(Level.FINE, "Child balancer {0} switching policy from {1} to {2}", objects); - lb.switchTo(policyProvider); - this.policyProvider = policyProvider; - } else { - logger.log(Level.FINE, "Child balancer {0} reactivated", key); - lb.acceptResolvedAddresses(resolvedAddresses); - } - deactivated = false; } @@ -443,6 +507,10 @@ public abstract class MultiChildLoadBalancer extends LoadBalancer { * <p>The ChildLbState updates happen during updateBalancingState. Otherwise, it is doing * simple forwarding. */ + protected ResolvedAddresses getResolvedAddresses() { + return resolvedAddresses; + } + private final class ChildLbStateHelper extends ForwardingLoadBalancerHelper { @Override @@ -482,7 +550,7 @@ public abstract class MultiChildLoadBalancer extends LoadBalancer { final String[] addrs; final int hashCode; - Endpoint(EquivalentAddressGroup eag) { + public Endpoint(EquivalentAddressGroup eag) { checkNotNull(eag, "eag"); addrs = new String[eag.getAddresses().size()]; @@ -525,4 +593,14 @@ public abstract class MultiChildLoadBalancer extends LoadBalancer { return Arrays.toString(addrs); } } + + protected static class AcceptResolvedAddressRetVal { + public final Status status; + public final List<ChildLbState> removedChildren; + + public AcceptResolvedAddressRetVal(Status status, List<ChildLbState> removedChildren) { + this.status = status; + this.removedChildren = removedChildren; + } + } } diff --git a/util/src/main/java/io/grpc/util/RoundRobinLoadBalancer.java b/util/src/main/java/io/grpc/util/RoundRobinLoadBalancer.java index 3bd83ffe1..f4db72849 100644 --- a/util/src/main/java/io/grpc/util/RoundRobinLoadBalancer.java +++ b/util/src/main/java/io/grpc/util/RoundRobinLoadBalancer.java @@ -16,6 +16,7 @@ package io.grpc.util; +import static com.google.common.base.Preconditions.checkArgument; import static io.grpc.ConnectivityState.CONNECTING; import static io.grpc.ConnectivityState.IDLE; import static io.grpc.ConnectivityState.READY; @@ -25,9 +26,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.MoreObjects; import com.google.common.base.Objects; import com.google.common.base.Preconditions; -import io.grpc.Attributes; import io.grpc.ConnectivityState; -import io.grpc.ConnectivityStateInfo; import io.grpc.EquivalentAddressGroup; import io.grpc.Internal; import io.grpc.LoadBalancer; @@ -48,10 +47,6 @@ import javax.annotation.Nonnull; */ @Internal public class RoundRobinLoadBalancer extends MultiChildLoadBalancer { - @VisibleForTesting - static final Attributes.Key<Ref<ConnectivityStateInfo>> STATE_INFO = - Attributes.Key.create("state-info"); - private final Random random; protected RoundRobinPicker currentPicker = new EmptyPicker(EMPTY_OK); @@ -132,7 +127,7 @@ public class RoundRobinLoadBalancer extends MultiChildLoadBalancer { private volatile int index; public ReadyPicker(List<SubchannelPicker> list, int startIndex) { - Preconditions.checkArgument(!list.isEmpty(), "empty list"); + checkArgument(!list.isEmpty(), "empty list"); this.subchannelPickers = list; this.index = startIndex - 1; } diff --git a/util/src/test/java/io/grpc/util/RoundRobinLoadBalancerTest.java b/util/src/test/java/io/grpc/util/RoundRobinLoadBalancerTest.java index 7513e0da4..465bea907 100644 --- a/util/src/test/java/io/grpc/util/RoundRobinLoadBalancerTest.java +++ b/util/src/test/java/io/grpc/util/RoundRobinLoadBalancerTest.java @@ -74,6 +74,7 @@ import org.mockito.ArgumentCaptor; import org.mockito.Captor; import org.mockito.InOrder; import org.mockito.Mock; +import org.mockito.Mockito; import org.mockito.junit.MockitoJUnit; import org.mockito.junit.MockitoRule; @@ -353,6 +354,19 @@ public class RoundRobinLoadBalancerTest { } @Test + public void removingAddressShutsdownSubchannel() { + acceptAddresses(servers, affinity); + final Subchannel subchannel2 = subchannels.get(Collections.singletonList(servers.get(2))); + + InOrder inOrder = Mockito.inOrder(mockHelper, subchannel2); + // send LB only the first 2 addresses + List<EquivalentAddressGroup> svs2 = Arrays.asList(servers.get(0), servers.get(1)); + acceptAddresses(svs2, affinity); + inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), any()); + inOrder.verify(subchannel2).shutdown(); + } + + @Test public void pickerRoundRobin() throws Exception { Subchannel subchannel = mock(Subchannel.class); Subchannel subchannel1 = mock(Subchannel.class); diff --git a/util/src/testFixtures/java/io/grpc/util/AbstractTestHelper.java b/util/src/testFixtures/java/io/grpc/util/AbstractTestHelper.java index 2afb13387..2b17558c9 100644 --- a/util/src/testFixtures/java/io/grpc/util/AbstractTestHelper.java +++ b/util/src/testFixtures/java/io/grpc/util/AbstractTestHelper.java @@ -52,6 +52,7 @@ import java.util.Map; public abstract class AbstractTestHelper extends ForwardingLoadBalancerHelper { private final Map<Subchannel, Subchannel> mockToRealSubChannelMap = new HashMap<>(); + protected final Map<Subchannel, Subchannel> realToMockSubChannelMap = new HashMap<>(); private final Map<Subchannel, SubchannelStateListener> subchannelStateListeners = Maps.newLinkedHashMap(); @@ -99,15 +100,20 @@ public abstract class AbstractTestHelper extends ForwardingLoadBalancerHelper { public Subchannel createSubchannel(CreateSubchannelArgs args) { Subchannel subchannel = getSubchannelMap().get(args.getAddresses()); if (subchannel == null) { - TestSubchannel delegate = new TestSubchannel(args); + TestSubchannel delegate = createRealSubchannel(args); subchannel = mock(Subchannel.class, delegatesTo(delegate)); getSubchannelMap().put(args.getAddresses(), subchannel); getMockToRealSubChannelMap().put(subchannel, delegate); + realToMockSubChannelMap.put(delegate, subchannel); } return subchannel; } + protected TestSubchannel createRealSubchannel(CreateSubchannelArgs args) { + return new TestSubchannel(args); + } + @Override public void refreshNameResolution() { // no-op @@ -122,7 +128,7 @@ public abstract class AbstractTestHelper extends ForwardingLoadBalancerHelper { return "Test Helper"; } - private class TestSubchannel extends ForwardingSubchannel { + protected class TestSubchannel extends ForwardingSubchannel { CreateSubchannelArgs args; Channel channel; diff --git a/xds/src/main/java/io/grpc/xds/ClusterManagerLoadBalancer.java b/xds/src/main/java/io/grpc/xds/ClusterManagerLoadBalancer.java index 62fab0d12..a1dd479af 100644 --- a/xds/src/main/java/io/grpc/xds/ClusterManagerLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/ClusterManagerLoadBalancer.java @@ -93,6 +93,31 @@ class ClusterManagerLoadBalancer extends MultiChildLoadBalancer { return newChildPolicies; } + /** + * This is like the parent except that it doesn't shutdown the removed children since we want that + * to be done by the timer. + */ + @Override + public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { + try { + resolvingAddresses = true; + + // process resolvedAddresses to update children + AcceptResolvedAddressRetVal acceptRetVal = + acceptResolvedAddressesInternal(resolvedAddresses); + if (!acceptRetVal.status.isOk()) { + return acceptRetVal.status; + } + + // Update the picker + updateOverallBalancingState(); + + return acceptRetVal.status; + } finally { + resolvingAddresses = false; + } + } + @Override protected SubchannelPicker getSubchannelPicker(Map<Object, SubchannelPicker> childPickers) { return new SubchannelPicker() { diff --git a/xds/src/main/java/io/grpc/xds/LeastRequestLoadBalancer.java b/xds/src/main/java/io/grpc/xds/LeastRequestLoadBalancer.java index fe52238e4..127acd9b3 100644 --- a/xds/src/main/java/io/grpc/xds/LeastRequestLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/LeastRequestLoadBalancer.java @@ -145,13 +145,13 @@ final class LeastRequestLoadBalancer extends MultiChildLoadBalancer { @Override protected ChildLbState createChildLbState(Object key, Object policyConfig, - SubchannelPicker initialPicker) { + SubchannelPicker initialPicker, ResolvedAddresses unused) { return new LeastRequestLbState(key, pickFirstLbProvider, policyConfig, initialPicker); } private void updateBalancingState(ConnectivityState state, LeastRequestPicker picker) { if (state != currentConnectivityState || !picker.isEquivalentTo(currentPicker)) { - super.updateHelperBalancingState(state, picker); + getHelper().updateBalancingState(state, picker); currentConnectivityState = state; currentPicker = picker; } diff --git a/xds/src/main/java/io/grpc/xds/RingHashLoadBalancer.java b/xds/src/main/java/io/grpc/xds/RingHashLoadBalancer.java index 901c79c06..54461385f 100644 --- a/xds/src/main/java/io/grpc/xds/RingHashLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/RingHashLoadBalancer.java @@ -27,27 +27,28 @@ import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; import com.google.common.base.MoreObjects; import com.google.common.collect.HashMultiset; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.Multiset; -import com.google.common.collect.Sets; import com.google.common.primitives.UnsignedInteger; import io.grpc.Attributes; import io.grpc.ConnectivityState; -import io.grpc.ConnectivityStateInfo; import io.grpc.EquivalentAddressGroup; import io.grpc.InternalLogId; import io.grpc.LoadBalancer; +import io.grpc.LoadBalancerProvider; import io.grpc.Status; import io.grpc.SynchronizationContext; +import io.grpc.util.GracefulSwitchLoadBalancer; +import io.grpc.util.MultiChildLoadBalancer; import io.grpc.xds.XdsLogger.XdsLogLevel; import java.net.SocketAddress; import java.util.ArrayList; +import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; -import java.util.Iterator; import java.util.List; import java.util.Map; -import java.util.Random; import java.util.Set; import java.util.stream.Collectors; import javax.annotation.Nullable; @@ -60,9 +61,7 @@ import javax.annotation.Nullable; * number of times proportional to its weight. With the ring partitioned appropriately, the * addition or removal of one host from a set of N hosts will affect only 1/N requests. */ -final class RingHashLoadBalancer extends LoadBalancer { - private static final Attributes.Key<Ref<ConnectivityStateInfo>> STATE_INFO = - Attributes.Key.create("state-info"); +final class RingHashLoadBalancer extends MultiChildLoadBalancer { private static final Status RPC_HASH_NOT_FOUND = Status.INTERNAL.withDescription("RPC hash not found. Probably a bug because xds resolver" + " config selector always generates a hash."); @@ -70,16 +69,10 @@ final class RingHashLoadBalancer extends LoadBalancer { private final XdsLogger logger; private final SynchronizationContext syncContext; - private final Map<EquivalentAddressGroup, Subchannel> subchannels = new HashMap<>(); - private final Helper helper; - private List<RingEntry> ring; - private ConnectivityState currentState; - private Iterator<Subchannel> connectionAttemptIterator = subchannels.values().iterator(); - private final Random random = new Random(); RingHashLoadBalancer(Helper helper) { - this.helper = checkNotNull(helper, "helper"); + super(helper); syncContext = checkNotNull(helper.getSynchronizationContext(), "syncContext"); logger = XdsLogger.withLogId(InternalLogId.allocate("ring_hash_lb", helper.getAuthority())); logger.log(XdsLogLevel.INFO, "Created"); @@ -94,81 +87,157 @@ final class RingHashLoadBalancer extends LoadBalancer { return addressValidityStatus; } - Map<EquivalentAddressGroup, EquivalentAddressGroup> latestAddrs = stripAttrs(addrList); - Set<EquivalentAddressGroup> removedAddrs = - Sets.newHashSet(Sets.difference(subchannels.keySet(), latestAddrs.keySet())); - - RingHashConfig config = (RingHashConfig) resolvedAddresses.getLoadBalancingPolicyConfig(); - Map<EquivalentAddressGroup, Long> serverWeights = new HashMap<>(); - long totalWeight = 0L; - for (EquivalentAddressGroup eag : addrList) { - Long weight = eag.getAttributes().get(InternalXdsAttributes.ATTR_SERVER_WEIGHT); - // Support two ways of server weighing: either multiple instances of the same address - // or each address contains a per-address weight attribute. If a weight is not provided, - // each occurrence of the address will be counted a weight value of one. - if (weight == null) { - weight = 1L; - } - totalWeight += weight; - EquivalentAddressGroup addrKey = stripAttrs(eag); - if (serverWeights.containsKey(addrKey)) { - serverWeights.put(addrKey, serverWeights.get(addrKey) + weight); - } else { - serverWeights.put(addrKey, weight); + AcceptResolvedAddressRetVal acceptRetVal; + try { + resolvingAddresses = true; + // Update the child list by creating-adding, updating addresses, and removing + acceptRetVal = super.acceptResolvedAddressesInternal(resolvedAddresses); + if (!acceptRetVal.status.isOk()) { + addressValidityStatus = Status.UNAVAILABLE.withDescription( + "Ring hash lb error: EDS resolution was successful, but was not accepted by base class" + + " (" + acceptRetVal.status + ")"); + handleNameResolutionError(addressValidityStatus); + return addressValidityStatus; } - Subchannel existingSubchannel = subchannels.get(addrKey); - if (existingSubchannel != null) { - existingSubchannel.updateAddresses(Collections.singletonList(eag)); - continue; + // Now do the ringhash specific logic with weights and building the ring + RingHashConfig config = (RingHashConfig) resolvedAddresses.getLoadBalancingPolicyConfig(); + if (config == null) { + throw new IllegalArgumentException("Missing RingHash configuration"); } - Attributes attr = Attributes.newBuilder().set( - STATE_INFO, new Ref<>(ConnectivityStateInfo.forNonError(IDLE))).build(); - final Subchannel subchannel = helper.createSubchannel( - CreateSubchannelArgs.newBuilder().setAddresses(eag).setAttributes(attr).build()); - subchannel.start(new SubchannelStateListener() { - @Override - public void onSubchannelState(ConnectivityStateInfo newState) { - processSubchannelState(subchannel, newState); + Map<EquivalentAddressGroup, Long> serverWeights = new HashMap<>(); + long totalWeight = 0L; + for (EquivalentAddressGroup eag : addrList) { + Long weight = eag.getAttributes().get(InternalXdsAttributes.ATTR_SERVER_WEIGHT); + // Support two ways of server weighing: either multiple instances of the same address + // or each address contains a per-address weight attribute. If a weight is not provided, + // each occurrence of the address will be counted a weight value of one. + if (weight == null) { + weight = 1L; + } + totalWeight += weight; + EquivalentAddressGroup addrKey = stripAttrs(eag); + if (serverWeights.containsKey(addrKey)) { + serverWeights.put(addrKey, serverWeights.get(addrKey) + weight); + } else { + serverWeights.put(addrKey, weight); } - }); - subchannels.put(addrKey, subchannel); + } + // Calculate scale + long minWeight = Collections.min(serverWeights.values()); + double normalizedMinWeight = (double) minWeight / totalWeight; + // Scale up the number of hashes per host such that the least-weighted host gets a whole + // number of hashes on the the ring. Other hosts might not end up with whole numbers, and + // that's fine (the ring-building algorithm can handle this). This preserves the original + // implementation's behavior: when weights aren't provided, all hosts should get an equal + // number of hashes. In the case where this number exceeds the max_ring_size, it's scaled + // back down to fit. + double scale = Math.min( + Math.ceil(normalizedMinWeight * config.minRingSize) / normalizedMinWeight, + (double) config.maxRingSize); + + // Build the ring + ring = buildRing(serverWeights, totalWeight, scale); + + // Must update channel picker before return so that new RPCs will not be routed to deleted + // clusters and resolver can remove them in service config. + updateOverallBalancingState(); + + shutdownRemoved(acceptRetVal.removedChildren); + } finally { + this.resolvingAddresses = false; } - long minWeight = Collections.min(serverWeights.values()); - double normalizedMinWeight = (double) minWeight / totalWeight; - // Scale up the number of hashes per host such that the least-weighted host gets a whole - // number of hashes on the the ring. Other hosts might not end up with whole numbers, and - // that's fine (the ring-building algorithm can handle this). This preserves the original - // implementation's behavior: when weights aren't provided, all hosts should get an equal - // number of hashes. In the case where this number exceeds the max_ring_size, it's scaled - // back down to fit. - double scale = Math.min( - Math.ceil(normalizedMinWeight * config.minRingSize) / normalizedMinWeight, - (double) config.maxRingSize); - ring = buildRing(serverWeights, totalWeight, scale); - - // Shut down subchannels for delisted addresses. - List<Subchannel> removedSubchannels = new ArrayList<>(); - for (EquivalentAddressGroup addr : removedAddrs) { - removedSubchannels.add(subchannels.remove(addr)); + + return Status.OK; + } + + /** + * Updates the overall balancing state by aggregating the connectivity states of all subchannels. + * + * <p>Aggregation rules (in order of dominance): + * <ol> + * <li>If there is at least one subchannel in READY state, overall state is READY</li> + * <li>If there are <em>2 or more</em> subchannels in TRANSIENT_FAILURE, overall state is + * TRANSIENT_FAILURE (to allow timely failover to another policy)</li> + * <li>If there is at least one subchannel in CONNECTING state, overall state is + * CONNECTING</li> + * <li> If there is one subchannel in TRANSIENT_FAILURE state and there is + * more than one subchannel, report CONNECTING </li> + * <li>If there is at least one subchannel in IDLE state, overall state is IDLE</li> + * <li>Otherwise, overall state is TRANSIENT_FAILURE</li> + * </ol> + */ + @Override + protected void updateOverallBalancingState() { + checkState(!getChildLbStates().isEmpty(), "no subchannel has been created"); + if (this.currentConnectivityState == SHUTDOWN) { + // Ignore changes that happen after shutdown is called + logger.log(XdsLogLevel.DEBUG, "UpdateOverallBalancingState called after shutdown"); + return; } - // If we need to proactively start connecting, iterate through all the subchannels, starting - // at a random position. - // Alternatively, we should better start at the same position. - connectionAttemptIterator = subchannels.values().iterator(); - int randomAdvance = random.nextInt(subchannels.size()); - while (randomAdvance-- > 0) { - connectionAttemptIterator.next(); + + // Calculate the current overall state to report + int numIdle = 0; + int numReady = 0; + int numConnecting = 0; + int numTF = 0; + + forloop: + for (ChildLbState childLbState : getChildLbStates()) { + ConnectivityState state = childLbState.getCurrentState(); + switch (state) { + case READY: + numReady++; + break forloop; + case CONNECTING: + numConnecting++; + break; + case IDLE: + numIdle++; + break; + case TRANSIENT_FAILURE: + numTF++; + break; + default: + // ignore it + } } - // Update the picker before shutting down the subchannels, to reduce the chance of race - // between picking a subchannel and shutting it down. - updateBalancingState(); - for (Subchannel subchann : removedSubchannels) { - shutdownSubchannel(subchann); + ConnectivityState overallState; + if (numReady > 0) { + overallState = READY; + } else if (numTF >= 2) { + overallState = TRANSIENT_FAILURE; + } else if (numConnecting > 0) { + overallState = CONNECTING; + } else if (numTF == 1 && getChildLbStates().size() > 1) { + overallState = CONNECTING; + } else if (numIdle > 0) { + overallState = IDLE; + } else { + overallState = TRANSIENT_FAILURE; } - return Status.OK; + RingHashPicker picker = new RingHashPicker(syncContext, ring, getImmutableChildMap()); + getHelper().updateBalancingState(overallState, picker); + this.currentConnectivityState = overallState; + } + + @Override + protected boolean reconnectOnIdle() { + return false; + } + + @Override + protected boolean reactivateChildOnReuse() { + return false; + } + + @Override + protected ChildLbState createChildLbState(Object key, Object policyConfig, + SubchannelPicker initialPicker, ResolvedAddresses resolvedAddresses) { + return new RingHashChildLbState((Endpoint)key, + getChildAddresses(key, resolvedAddresses, null)); } private Status validateAddrList(List<EquivalentAddressGroup> addrList) { @@ -197,7 +266,7 @@ final class RingHashLoadBalancer extends LoadBalancer { if (weight < 0) { Status unavailableStatus = Status.UNAVAILABLE.withDescription( - String.format("Ring hash lb error: EDS resolution was successful, but returned a " + String.format("Ring hash lb error: EDS resolution was successful, but returned a " + "negative weight for %s.", stripAttrs(eag))); handleNameResolutionError(unavailableStatus); return unavailableStatus; @@ -252,10 +321,10 @@ final class RingHashLoadBalancer extends LoadBalancer { double currentHashes = 0.0; double targetHashes = 0.0; for (Map.Entry<EquivalentAddressGroup, Long> entry : serverWeights.entrySet()) { - EquivalentAddressGroup addrKey = entry.getKey(); + Endpoint endpoint = new Endpoint(entry.getKey()); double normalizedWeight = (double) entry.getValue() / totalWeight; - // TODO(chengyuanzhang): is using the list of socket address correct? - StringBuilder sb = new StringBuilder(addrKey.getAddresses().toString()); + // Per GRFC A61 use the first address for the hash + StringBuilder sb = new StringBuilder(entry.getKey().getAddresses().get(0).toString()); sb.append('_'); int lengthWithoutCounter = sb.length(); targetHashes += scale * normalizedWeight; @@ -263,7 +332,7 @@ final class RingHashLoadBalancer extends LoadBalancer { while (currentHashes < targetHashes) { sb.append(i); long hash = hashFunc.hashAsciiString(sb.toString()); - ring.add(new RingEntry(hash, addrKey)); + ring.add(new RingEntry(hash, endpoint)); i++; currentHashes++; sb.setLength(lengthWithoutCounter); @@ -273,159 +342,14 @@ final class RingHashLoadBalancer extends LoadBalancer { return Collections.unmodifiableList(ring); } - @Override - public void handleNameResolutionError(Status error) { - if (currentState != READY) { - helper.updateBalancingState( - TRANSIENT_FAILURE, new FixedResultPicker(PickResult.withError(error))); - } - } - - @Override - public void shutdown() { - logger.log(XdsLogLevel.INFO, "Shutdown"); - for (Subchannel subchannel : subchannels.values()) { - shutdownSubchannel(subchannel); + @SuppressWarnings("ReferenceEquality") + public static EquivalentAddressGroup stripAttrs(EquivalentAddressGroup eag) { + if (eag.getAttributes() == Attributes.EMPTY) { + return eag; } - subchannels.clear(); - } - - /** - * Updates the overall balancing state by aggregating the connectivity states of all subchannels. - * - * <p>Aggregation rules (in order of dominance): - * <ol> - * <li>If there is at least one subchannel in READY state, overall state is READY</li> - * <li>If there are <em>2 or more</em> subchannels in TRANSIENT_FAILURE, overall state is - * TRANSIENT_FAILURE</li> - * <li>If there is at least one subchannel in CONNECTING state, overall state is - * CONNECTING</li> - * <li> If there is one subchannel in TRANSIENT_FAILURE state and there is - * more than one subchannel, report CONNECTING </li> - * <li>If there is at least one subchannel in IDLE state, overall state is IDLE</li> - * <li>Otherwise, overall state is TRANSIENT_FAILURE</li> - * </ol> - */ - private void updateBalancingState() { - checkState(!subchannels.isEmpty(), "no subchannel has been created"); - boolean startConnectionAttempt = false; - int numIdle = 0; - int numReady = 0; - int numConnecting = 0; - int numTransientFailure = 0; - for (Subchannel subchannel : subchannels.values()) { - ConnectivityState state = getSubchannelStateInfoRef(subchannel).value.getState(); - if (state == READY) { - numReady++; - break; - } else if (state == TRANSIENT_FAILURE) { - numTransientFailure++; - } else if (state == CONNECTING ) { - numConnecting++; - } else if (state == IDLE) { - numIdle++; - } - } - ConnectivityState overallState; - if (numReady > 0) { - overallState = READY; - } else if (numTransientFailure >= 2) { - overallState = TRANSIENT_FAILURE; - startConnectionAttempt = (numConnecting == 0); - } else if (numConnecting > 0) { - overallState = CONNECTING; - } else if (numTransientFailure == 1 && subchannels.size() > 1) { - overallState = CONNECTING; - startConnectionAttempt = true; - } else if (numIdle > 0) { - overallState = IDLE; - } else { - overallState = TRANSIENT_FAILURE; - startConnectionAttempt = true; - } - RingHashPicker picker = new RingHashPicker(syncContext, ring, subchannels); - // TODO(chengyuanzhang): avoid unnecessary reprocess caused by duplicated server addr updates - helper.updateBalancingState(overallState, picker); - currentState = overallState; - // While the ring_hash policy is reporting TRANSIENT_FAILURE, it will - // not be getting any pick requests from the priority policy. - // However, because the ring_hash policy does not attempt to - // reconnect to subchannels unless it is getting pick requests, - // it will need special handling to ensure that it will eventually - // recover from TRANSIENT_FAILURE state once the problem is resolved. - // Specifically, it will make sure that it is attempting to connect to - // at least one subchannel at any given time. After a given subchannel - // fails a connection attempt, it will move on to the next subchannel - // in the ring. It will keep doing this until one of the subchannels - // successfully connects, at which point it will report READY and stop - // proactively trying to connect. The policy will remain in - // TRANSIENT_FAILURE until at least one subchannel becomes connected, - // even if subchannels are in state CONNECTING during that time. - // - // Note that we do the same thing when the policy is in state - // CONNECTING, just to ensure that we don't remain in CONNECTING state - // indefinitely if there are no new picks coming in. - if (startConnectionAttempt) { - if (!connectionAttemptIterator.hasNext()) { - connectionAttemptIterator = subchannels.values().iterator(); - } - connectionAttemptIterator.next().requestConnection(); - } - } - - private void processSubchannelState(Subchannel subchannel, ConnectivityStateInfo stateInfo) { - if (subchannels.get(stripAttrs(subchannel.getAddresses())) != subchannel) { - return; - } - if (stateInfo.getState() == TRANSIENT_FAILURE || stateInfo.getState() == IDLE) { - helper.refreshNameResolution(); - } - updateConnectivityState(subchannel, stateInfo); - updateBalancingState(); - } - - private void updateConnectivityState(Subchannel subchannel, ConnectivityStateInfo stateInfo) { - Ref<ConnectivityStateInfo> subchannelStateRef = getSubchannelStateInfoRef(subchannel); - ConnectivityState previousConnectivityState = subchannelStateRef.value.getState(); - // Don't proactively reconnect if the subchannel enters IDLE, even if previously was connected. - // If the subchannel was previously in TRANSIENT_FAILURE, it is considered to stay in - // TRANSIENT_FAILURE until it becomes READY. - if (previousConnectivityState == TRANSIENT_FAILURE) { - if (stateInfo.getState() == CONNECTING || stateInfo.getState() == IDLE) { - return; - } - } - subchannelStateRef.value = stateInfo; - } - - private static void shutdownSubchannel(Subchannel subchannel) { - subchannel.shutdown(); - getSubchannelStateInfoRef(subchannel).value = ConnectivityStateInfo.forNonError(SHUTDOWN); - } - - /** - * Converts list of {@link EquivalentAddressGroup} to {@link EquivalentAddressGroup} set and - * remove all attributes. The values are the original EAGs. - */ - private static Map<EquivalentAddressGroup, EquivalentAddressGroup> stripAttrs( - List<EquivalentAddressGroup> groupList) { - Map<EquivalentAddressGroup, EquivalentAddressGroup> addrs = - new HashMap<>(groupList.size() * 2); - for (EquivalentAddressGroup group : groupList) { - addrs.put(stripAttrs(group), group); - } - return addrs; - } - - private static EquivalentAddressGroup stripAttrs(EquivalentAddressGroup eag) { return new EquivalentAddressGroup(eag.getAddresses()); } - private static Ref<ConnectivityStateInfo> getSubchannelStateInfoRef( - Subchannel subchannel) { - return checkNotNull(subchannel.getAttributes().get(STATE_INFO), "STATE_INFO"); - } - private static final class RingHashPicker extends SubchannelPicker { private final SynchronizationContext syncContext; private final List<RingEntry> ring; @@ -433,38 +357,31 @@ final class RingHashLoadBalancer extends LoadBalancer { // freeze picker's view of subchannel's connectivity state. // TODO(chengyuanzhang): can be more performance-friendly with // IdentityHashMap<Subchannel, ConnectivityStateInfo> and RingEntry contains Subchannel. - private final Map<EquivalentAddressGroup, SubchannelView> pickableSubchannels; // read-only + private final Map<Endpoint, SubchannelView> pickableSubchannels; // read-only private RingHashPicker( SynchronizationContext syncContext, List<RingEntry> ring, - Map<EquivalentAddressGroup, Subchannel> subchannels) { + ImmutableMap<Object, ChildLbState> subchannels) { this.syncContext = syncContext; this.ring = ring; pickableSubchannels = new HashMap<>(subchannels.size()); - for (Map.Entry<EquivalentAddressGroup, Subchannel> entry : subchannels.entrySet()) { - Subchannel subchannel = entry.getValue(); - ConnectivityStateInfo stateInfo = subchannel.getAttributes().get(STATE_INFO).value; - pickableSubchannels.put(entry.getKey(), new SubchannelView(subchannel, stateInfo)); + for (Map.Entry<Object, ChildLbState> entry : subchannels.entrySet()) { + RingHashChildLbState childLbState = (RingHashChildLbState) entry.getValue(); + pickableSubchannels.put((Endpoint)entry.getKey(), + new SubchannelView(childLbState, childLbState.getCurrentState())); } } - @Override - public PickResult pickSubchannel(PickSubchannelArgs args) { - Long requestHash = args.getCallOptions().getOption(XdsNameResolver.RPC_HASH_KEY); - if (requestHash == null) { - return PickResult.withError(RPC_HASH_NOT_FOUND); + // Find the ring entry with hash next to (clockwise) the RPC's hash (binary search). + private int getTargetIndex(Long requestHash) { + if (ring.size() <= 1) { + return 0; } - // Find the ring entry with hash next to (clockwise) the RPC's hash. int low = 0; - int high = ring.size(); - int mid; - while (true) { - mid = (low + high) / 2; - if (mid == ring.size()) { - mid = 0; - break; - } + int high = ring.size() - 1; + int mid = (low + high) / 2; + do { long midVal = ring.get(mid).hash; long midValL = mid == 0 ? 0 : ring.get(mid - 1).hash; if (requestHash <= midVal && requestHash > midValL) { @@ -475,79 +392,61 @@ final class RingHashLoadBalancer extends LoadBalancer { } else { high = mid - 1; } - if (low > high) { - mid = 0; - break; - } + mid = (low + high) / 2; + } while (mid < ring.size() && low <= high); + return mid; + } + + @Override + public PickResult pickSubchannel(PickSubchannelArgs args) { + Long requestHash = args.getCallOptions().getOption(XdsNameResolver.RPC_HASH_KEY); + if (requestHash == null) { + return PickResult.withError(RPC_HASH_NOT_FOUND); } - // Try finding a READY subchannel. Starting from the ring entry next to the RPC's hash. - // If the one of the first two subchannels is not in TRANSIENT_FAILURE, return result - // based on that subchannel. Otherwise, fail the pick unless a READY subchannel is found. - // Meanwhile, trigger connection for the channel and status: - // For the first subchannel that is in IDLE or TRANSIENT_FAILURE; - // And for the second subchannel that is in IDLE or TRANSIENT_FAILURE; - // And for each of the following subchannels that is in TRANSIENT_FAILURE or IDLE, - // stop until we find the first subchannel that is in CONNECTING or IDLE status. - boolean foundFirstNonFailed = false; // true if having subchannel(s) in CONNECTING or IDLE - Subchannel firstSubchannel = null; - Subchannel secondSubchannel = null; + int targetIndex = getTargetIndex(requestHash); + + // Per gRFC A61, because of sticky-TF with PickFirst's auto reconnect on TF, we ignore + // all TF subchannels and find the first ring entry in READY, CONNECTING or IDLE. If + // CONNECTING or IDLE we return a pick with no results. Additionally, if that entry is in + // IDLE, we initiate a connection. for (int i = 0; i < ring.size(); i++) { - int index = (mid + i) % ring.size(); - EquivalentAddressGroup addrKey = ring.get(index).addrKey; - SubchannelView subchannel = pickableSubchannels.get(addrKey); - if (subchannel.stateInfo.getState() == READY) { - return PickResult.withSubchannel(subchannel.subchannel); + int index = (targetIndex + i) % ring.size(); + SubchannelView subchannelView = pickableSubchannels.get(ring.get(index).addrKey); + RingHashChildLbState childLbState = subchannelView.childLbState; + + if (subchannelView.connectivityState == READY) { + return childLbState.getCurrentPicker().pickSubchannel(args); } - // RPCs can be buffered if any of the first two subchannels is pending. Otherwise, RPCs + // RPCs can be buffered if the next subchannel is pending (per A62). Otherwise, RPCs // are failed unless there is a READY connection. - if (firstSubchannel == null) { - firstSubchannel = subchannel.subchannel; - PickResult maybeBuffer = pickSubchannelsNonReady(subchannel); - if (maybeBuffer != null) { - return maybeBuffer; - } - } else if (subchannel.subchannel != firstSubchannel && secondSubchannel == null) { - secondSubchannel = subchannel.subchannel; - PickResult maybeBuffer = pickSubchannelsNonReady(subchannel); - if (maybeBuffer != null) { - return maybeBuffer; - } - } else if (subchannel.subchannel != firstSubchannel - && subchannel.subchannel != secondSubchannel) { - if (!foundFirstNonFailed) { - pickSubchannelsNonReady(subchannel); - if (subchannel.stateInfo.getState() != TRANSIENT_FAILURE) { - foundFirstNonFailed = true; - } - } + if (subchannelView.connectivityState == CONNECTING) { + return PickResult.withNoResult(); } - } - // Fail the pick with error status of the original subchannel hit by hash. - SubchannelView originalSubchannel = pickableSubchannels.get(ring.get(mid).addrKey); - return PickResult.withError(originalSubchannel.stateInfo.getStatus()); - } - @Nullable - private PickResult pickSubchannelsNonReady(SubchannelView subchannel) { - if (subchannel.stateInfo.getState() == TRANSIENT_FAILURE - || subchannel.stateInfo.getState() == IDLE ) { - final Subchannel finalSubchannel = subchannel.subchannel; - syncContext.execute(new Runnable() { - @Override - public void run() { - finalSubchannel.requestConnection(); + if (subchannelView.connectivityState == IDLE || childLbState.isDeactivated()) { + if (childLbState.isDeactivated()) { + childLbState.activate(); + } else { + syncContext.execute(() -> childLbState.getLb().requestConnection()); } - }); - } - if (subchannel.stateInfo.getState() == CONNECTING - || subchannel.stateInfo.getState() == IDLE) { - return PickResult.withNoResult(); - } else { - return null; + + return PickResult.withNoResult(); // Indicates that this should be retried after backoff + } } + + // return the pick from the original subchannel hit by hash, which is probably an error + RingHashChildLbState originalSubchannel = + pickableSubchannels.get(ring.get(targetIndex).addrKey).childLbState; + return originalSubchannel.getCurrentPicker().pickSubchannel(args); } + + } + + @Override + protected SubchannelPicker getSubchannelPicker(Map<Object, SubchannelPicker> childPickers) { + throw new UnsupportedOperationException("Not used by RingHash"); } /** @@ -555,20 +454,20 @@ final class RingHashLoadBalancer extends LoadBalancer { * state changes. */ private static final class SubchannelView { - private final Subchannel subchannel; - private final ConnectivityStateInfo stateInfo; + private final RingHashChildLbState childLbState; + private final ConnectivityState connectivityState; - private SubchannelView(Subchannel subchannel, ConnectivityStateInfo stateInfo) { - this.subchannel = subchannel; - this.stateInfo = stateInfo; + private SubchannelView(RingHashChildLbState childLbState, ConnectivityState state) { + this.childLbState = childLbState; + this.connectivityState = state; } } private static final class RingEntry implements Comparable<RingEntry> { private final long hash; - private final EquivalentAddressGroup addrKey; + private final Endpoint addrKey; - private RingEntry(long hash, EquivalentAddressGroup addrKey) { + private RingEntry(long hash, Endpoint addrKey) { this.hash = hash; this.addrKey = addrKey; } @@ -580,17 +479,6 @@ final class RingHashLoadBalancer extends LoadBalancer { } /** - * A lighter weight Reference than AtomicReference. - */ - private static final class Ref<T> { - T value; - - Ref(T value) { - this.value = value; - } - } - - /** * Configures the ring property. The larger the ring is (that is, the more hashes there are * for each provided host) the better the request distribution will reflect the desired weights. */ @@ -614,4 +502,58 @@ final class RingHashLoadBalancer extends LoadBalancer { .toString(); } } -} + + static Set<EquivalentAddressGroup> getStrippedChildEags(Collection<ChildLbState> states) { + return states.stream() + .map(ChildLbState::getEag) + .map(RingHashLoadBalancer::stripAttrs) + .collect(Collectors.toSet()); + } + + @Override + protected Collection<ChildLbState> getChildLbStates() { + return super.getChildLbStates(); + } + + @Override + protected ChildLbState getChildLbStateEag(EquivalentAddressGroup eag) { + return super.getChildLbStateEag(eag); + } + + class RingHashChildLbState extends MultiChildLoadBalancer.ChildLbState { + + public RingHashChildLbState(Endpoint key, ResolvedAddresses resolvedAddresses) { + super(key, pickFirstLbProvider, null, EMPTY_PICKER, resolvedAddresses, true); + } + + @Override + protected void reactivate(LoadBalancerProvider policyProvider) { + if (!isDeactivated()) { + return; + } + + currentConnectivityState = CONNECTING; + getLb().switchTo(pickFirstLbProvider); + markReactivated(); + getLb().acceptResolvedAddresses(this.getResolvedAddresses()); // Time to get a subchannel + logger.log(XdsLogLevel.DEBUG, "Child balancer {0} reactivated", getKey()); + } + + public void activate() { + reactivate(pickFirstLbProvider); + } + + // Need to expose this to the LB class + @Override + protected void shutdown() { + super.shutdown(); + } + + // Need to expose this to the LB class + @Override + protected GracefulSwitchLoadBalancer getLb() { + return super.getLb(); + } + + } +}
\ No newline at end of file diff --git a/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java b/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java index a71c91deb..4e77e99a3 100644 --- a/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java @@ -97,7 +97,7 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer { @Override protected ChildLbState createChildLbState(Object key, Object policyConfig, - SubchannelPicker initialPicker) { + SubchannelPicker initialPicker, ResolvedAddresses unused) { ChildLbState childLbState = new WeightedChildLbState(key, pickFirstLbProvider, policyConfig, initialPicker); return childLbState; @@ -115,13 +115,31 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer { } config = (WeightedRoundRobinLoadBalancerConfig) resolvedAddresses.getLoadBalancingPolicyConfig(); - Status addressAcceptanceStatus = super.acceptResolvedAddresses(resolvedAddresses); - if (weightUpdateTimer != null && weightUpdateTimer.isPending()) { - weightUpdateTimer.cancel(); + AcceptResolvedAddressRetVal acceptRetVal; + try { + resolvingAddresses = true; + acceptRetVal = acceptResolvedAddressesInternal(resolvedAddresses); + if (!acceptRetVal.status.isOk()) { + return acceptRetVal.status; + } + + if (weightUpdateTimer != null && weightUpdateTimer.isPending()) { + weightUpdateTimer.cancel(); + } + updateWeightTask.run(); + + createAndApplyOrcaListeners(); + + // Must update channel picker before return so that new RPCs will not be routed to deleted + // clusters and resolver can remove them in service config. + updateOverallBalancingState(); + + shutdownRemoved(acceptRetVal.removedChildren); + } finally { + resolvingAddresses = false; } - updateWeightTask.run(); - afterAcceptAddresses(); - return addressAcceptanceStatus; + + return acceptRetVal.status; } @Override @@ -228,7 +246,7 @@ final class WeightedRoundRobinLoadBalancer extends RoundRobinLoadBalancer { } } - private void afterAcceptAddresses() { + private void createAndApplyOrcaListeners() { for (ChildLbState child : getChildLbStates()) { WeightedChildLbState wChild = (WeightedChildLbState) child; for (WrrSubchannel weightedSubchannel : wChild.subchannels) { diff --git a/xds/src/test/java/io/grpc/xds/LeastRequestLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/LeastRequestLoadBalancerTest.java index 5f5e0df4d..a2dc7936b 100644 --- a/xds/src/test/java/io/grpc/xds/LeastRequestLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/LeastRequestLoadBalancerTest.java @@ -106,7 +106,7 @@ public class LeastRequestLoadBalancerTest { @Captor private ArgumentCaptor<CreateSubchannelArgs> createArgsCaptor; private final TestHelper testHelperInstance = new TestHelper(); - private Helper helper = mock(Helper.class, delegatesTo(testHelperInstance)); + private final Helper helper = mock(Helper.class, delegatesTo(testHelperInstance)); @Mock private ThreadSafeRandom mockRandom; @@ -522,7 +522,6 @@ public class LeastRequestLoadBalancerTest { loadBalancer.handleNameResolutionError(error); loadBalancer.setResolvingAddresses(false); verify(helper).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture()); - LoadBalancer.PickResult pickResult = pickerCaptor.getValue().pickSubchannel(mockArgs); assertNull(pickResult.getSubchannel()); assertEquals(error, pickResult.getStatus()); diff --git a/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerTest.java index f2b38abda..337737265 100644 --- a/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerTest.java @@ -22,18 +22,21 @@ import static io.grpc.ConnectivityState.IDLE; import static io.grpc.ConnectivityState.READY; import static io.grpc.ConnectivityState.SHUTDOWN; import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; +import static io.grpc.xds.RingHashLoadBalancerTest.InitializationFlags.DO_NOT_RESET_HELPER; +import static io.grpc.xds.RingHashLoadBalancerTest.InitializationFlags.DO_NOT_VERIFY; +import static io.grpc.xds.RingHashLoadBalancerTest.InitializationFlags.RESET_SUBCHANNEL_MOCKS; +import static io.grpc.xds.RingHashLoadBalancerTest.InitializationFlags.STAY_IN_CONNECTING; +import static org.mockito.AdditionalAnswers.delegatesTo; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.clearInvocations; -import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.reset; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; -import static org.mockito.Mockito.when; import com.google.common.collect.Iterables; import com.google.common.primitives.UnsignedInteger; @@ -48,13 +51,15 @@ import io.grpc.LoadBalancer.PickSubchannelArgs; import io.grpc.LoadBalancer.ResolvedAddresses; import io.grpc.LoadBalancer.Subchannel; import io.grpc.LoadBalancer.SubchannelPicker; -import io.grpc.LoadBalancer.SubchannelStateListener; import io.grpc.Metadata; import io.grpc.Status; import io.grpc.Status.Code; import io.grpc.SynchronizationContext; import io.grpc.internal.PickSubchannelArgsImpl; import io.grpc.testing.TestMethodDescriptors; +import io.grpc.util.AbstractTestHelper; +import io.grpc.util.MultiChildLoadBalancer.ChildLbState; +import io.grpc.xds.RingHashLoadBalancer.RingHashChildLbState; import io.grpc.xds.RingHashLoadBalancer.RingHashConfig; import java.lang.Thread.UncaughtExceptionHandler; import java.net.SocketAddress; @@ -74,12 +79,9 @@ import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; import org.mockito.Captor; import org.mockito.InOrder; -import org.mockito.Mock; import org.mockito.Mockito; -import org.mockito.invocation.InvocationOnMock; import org.mockito.junit.MockitoJUnit; import org.mockito.junit.MockitoRule; -import org.mockito.stubbing.Answer; /** Unit test for {@link io.grpc.LoadBalancer}. */ @RunWith(JUnit4.class) @@ -97,49 +99,18 @@ public class RingHashLoadBalancerTest { } }); private final Map<List<EquivalentAddressGroup>, Subchannel> subchannels = new HashMap<>(); - private final Map<Subchannel, SubchannelStateListener> subchannelStateListeners = - new HashMap<>(); private final Deque<Subchannel> connectionRequestedQueue = new ArrayDeque<>(); private final XxHash64 hashFunc = XxHash64.INSTANCE; - @Mock - private Helper helper; + private final TestHelper testHelperInst = new TestHelper(); + private final Helper helper = mock(Helper.class, delegatesTo(testHelperInst)); @Captor private ArgumentCaptor<SubchannelPicker> pickerCaptor; private RingHashLoadBalancer loadBalancer; @Before public void setUp() { - when(helper.getAuthority()).thenReturn(AUTHORITY); - when(helper.getSynchronizationContext()).thenReturn(syncContext); - when(helper.createSubchannel(any(CreateSubchannelArgs.class))).thenAnswer( - new Answer<Subchannel>() { - @Override - public Subchannel answer(InvocationOnMock invocation) throws Throwable { - CreateSubchannelArgs args = (CreateSubchannelArgs) invocation.getArguments()[0]; - final Subchannel subchannel = mock(Subchannel.class); - when(subchannel.getAllAddresses()).thenReturn(args.getAddresses()); - when(subchannel.getAttributes()).thenReturn(args.getAttributes()); - subchannels.put(args.getAddresses(), subchannel); - doAnswer(new Answer<Void>() { - @Override - public Void answer(InvocationOnMock invocation) throws Throwable { - subchannelStateListeners.put( - subchannel, (SubchannelStateListener) invocation.getArguments()[0]); - return null; - } - }).when(subchannel).start(any(SubchannelStateListener.class)); - doAnswer(new Answer<Void>() { - @Override - public Void answer(InvocationOnMock invocation) throws Throwable { - connectionRequestedQueue.offer(subchannel); - return null; - } - }).when(subchannel).requestConnection(); - return subchannel; - } - }); loadBalancer = new RingHashLoadBalancer(helper); - // Skip uninterested interactions. + // Consume calls not relevant for tests that would otherwise fail verifyNoMoreInteractions verify(helper).getAuthority(); verify(helper).getSynchronizationContext(); } @@ -161,21 +132,20 @@ public class RingHashLoadBalancerTest { ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); assertThat(addressesAcceptanceStatus.isOk()).isTrue(); - verify(helper).createSubchannel(any(CreateSubchannelArgs.class)); - Subchannel subchannel = Iterables.getOnlyElement(subchannels.values()); - verify(subchannel, never()).requestConnection(); verify(helper).updateBalancingState(eq(IDLE), pickerCaptor.capture()); + assertThat(subchannels.size()).isEqualTo(0); // Picking subchannel triggers connection. - PickSubchannelArgs args = new PickSubchannelArgsImpl( - TestMethodDescriptors.voidMethod(), new Metadata(), - CallOptions.DEFAULT.withOption(XdsNameResolver.RPC_HASH_KEY, hashFunc.hashVoid())); + PickSubchannelArgs args = getDefaultPickSubchannelArgs(hashFunc.hashVoid()); PickResult result = pickerCaptor.getValue().pickSubchannel(args); assertThat(result.getStatus().isOk()).isTrue(); assertThat(result.getSubchannel()).isNull(); + Subchannel subchannel = Iterables.getOnlyElement(subchannels.values()); verify(subchannel).requestConnection(); - deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(CONNECTING)); verify(helper).updateBalancingState(eq(CONNECTING), any(SubchannelPicker.class)); + verify(helper).createSubchannel(any(CreateSubchannelArgs.class)); + deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(CONNECTING)); + verify(helper, times(2)).updateBalancingState(eq(CONNECTING), any(SubchannelPicker.class)); // Subchannel becomes ready, triggers pick again. deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); @@ -193,16 +163,19 @@ public class RingHashLoadBalancerTest { ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); assertThat(addressesAcceptanceStatus.isOk()).isTrue(); - Subchannel subchannel = Iterables.getOnlyElement(subchannels.values()); - InOrder inOrder = Mockito.inOrder(helper, subchannel); - inOrder.verify(helper).updateBalancingState(eq(IDLE), pickerCaptor.capture()); - inOrder.verify(subchannel, never()).requestConnection(); + verify(helper).updateBalancingState(eq(IDLE), pickerCaptor.capture()); + + RingHashChildLbState childLbState = + (RingHashChildLbState) loadBalancer.getChildLbStates().iterator().next(); + assertThat(childLbState.isDeactivated()).isTrue(); // Picking subchannel triggers connection. - PickSubchannelArgs args = new PickSubchannelArgsImpl( - TestMethodDescriptors.voidMethod(), new Metadata(), - CallOptions.DEFAULT.withOption(XdsNameResolver.RPC_HASH_KEY, hashFunc.hashVoid())); + PickSubchannelArgs args = getDefaultPickSubchannelArgs(hashFunc.hashVoid()); pickerCaptor.getValue().pickSubchannel(args); + assertThat(childLbState.isDeactivated()).isFalse(); + assertThat(childLbState.getLb().delegateType()).isEqualTo("PickFirstLoadBalancer"); + Subchannel subchannel = subchannels.get(Collections.singletonList(childLbState.getEag())); + InOrder inOrder = Mockito.inOrder(helper, subchannel); inOrder.verify(subchannel).requestConnection(); deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); inOrder.verify(helper).updateBalancingState(eq(READY), any(SubchannelPicker.class)); @@ -220,12 +193,8 @@ public class RingHashLoadBalancerTest { RingHashConfig config = new RingHashConfig(10, 100); List<EquivalentAddressGroup> servers = createWeightedServerAddrs(1, 1); InOrder inOrder = Mockito.inOrder(helper); - Status addressesAcceptanceStatus = loadBalancer.acceptResolvedAddresses( - ResolvedAddresses.newBuilder() - .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); - assertThat(addressesAcceptanceStatus.isOk()).isTrue(); - inOrder.verify(helper, times(2)).createSubchannel(any(CreateSubchannelArgs.class)); - inOrder.verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class)); + + initializeLbSubchannels(config, servers); // one in CONNECTING, one in IDLE deliverSubchannelState( @@ -263,7 +232,7 @@ public class RingHashLoadBalancerTest { ConnectivityStateInfo.forNonError(IDLE)); inOrder.verify(helper).refreshNameResolution(); inOrder.verify(helper).updateBalancingState(eq(CONNECTING), any(SubchannelPicker.class)); - verifyConnection(1); + verifyConnection(0); verifyNoMoreInteractions(helper); } @@ -278,164 +247,57 @@ public class RingHashLoadBalancerTest { } @Test - public void aggregateSubchannelStates_twoOrMoreSubchannelsInTransientFailure() { + public void aggregateSubchannelStates_allSubchannelsInTransientFailure() { RingHashConfig config = new RingHashConfig(10, 100); List<EquivalentAddressGroup> servers = createWeightedServerAddrs(1, 1, 1, 1); - InOrder inOrder = Mockito.inOrder(helper); - Status addressesAcceptanceStatus = loadBalancer.acceptResolvedAddresses( - ResolvedAddresses.newBuilder() - .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); - assertThat(addressesAcceptanceStatus.isOk()).isTrue(); - inOrder.verify(helper, times(4)).createSubchannel(any(CreateSubchannelArgs.class)); - inOrder.verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class)); - // one in TRANSIENT_FAILURE, three in IDLE - deliverSubchannelState( - subchannels.get(Collections.singletonList(servers.get(0))), - ConnectivityStateInfo.forTransientFailure( - Status.UNAVAILABLE.withDescription("not found"))); + List<Subchannel> subChannelList = initializeLbSubchannels(config, servers, STAY_IN_CONNECTING); + + // reset inOrder to include all the childLBs now that they have been created + clearInvocations(helper); + InOrder inOrder = Mockito.inOrder(helper, + subChannelList.get(0), subChannelList.get(1), subChannelList.get(2), subChannelList.get(3)); + + // one in TRANSIENT_FAILURE, three in CONNECTING + deliverNotFound(subChannelList, 0); inOrder.verify(helper).refreshNameResolution(); inOrder.verify(helper).updateBalancingState(eq(CONNECTING), any(SubchannelPicker.class)); - verifyConnection(1); - // two in TRANSIENT_FAILURE, two in IDLE - deliverSubchannelState( - subchannels.get(Collections.singletonList(servers.get(1))), - ConnectivityStateInfo.forTransientFailure( - Status.UNAVAILABLE.withDescription("also not found"))); + // two in TRANSIENT_FAILURE, two in CONNECTING + deliverNotFound(subChannelList, 1); inOrder.verify(helper).refreshNameResolution(); inOrder.verify(helper) .updateBalancingState(eq(TRANSIENT_FAILURE), any(SubchannelPicker.class)); - verifyConnection(1); - // two in TRANSIENT_FAILURE, one in CONNECTING, one in IDLE - // The overall state is dominated by the two in TRANSIENT_FAILURE. - deliverSubchannelState( - subchannels.get(Collections.singletonList(servers.get(2))), - ConnectivityStateInfo.forNonError(CONNECTING)); + // All 4 in TF switch to TF + deliverNotFound(subChannelList, 2); + inOrder.verify(helper).refreshNameResolution(); inOrder.verify(helper) .updateBalancingState(eq(TRANSIENT_FAILURE), any(SubchannelPicker.class)); - verifyConnection(0); - - // three in TRANSIENT_FAILURE, one in CONNECTING - deliverSubchannelState( - subchannels.get(Collections.singletonList(servers.get(3))), - ConnectivityStateInfo.forTransientFailure( - Status.UNAVAILABLE.withDescription("connection lost"))); + deliverNotFound(subChannelList, 3); inOrder.verify(helper).refreshNameResolution(); inOrder.verify(helper) .updateBalancingState(eq(TRANSIENT_FAILURE), any(SubchannelPicker.class)); - verifyConnection(0); + + // reset subchannel to CONNECTING - shouldn't change anything since PF hides the state change + deliverSubchannelState(subChannelList.get(2), ConnectivityStateInfo.forNonError(CONNECTING)); + inOrder.verify(helper, never()) + .updateBalancingState(eq(TRANSIENT_FAILURE), any(SubchannelPicker.class)); + inOrder.verify(subChannelList.get(2), never()).requestConnection(); // three in TRANSIENT_FAILURE, one in READY - deliverSubchannelState( - subchannels.get(Collections.singletonList(servers.get(2))), - ConnectivityStateInfo.forNonError(READY)); + deliverSubchannelState(subChannelList.get(2), ConnectivityStateInfo.forNonError(READY)); inOrder.verify(helper).updateBalancingState(eq(READY), any(SubchannelPicker.class)); - verifyConnection(0); + inOrder.verify(subChannelList.get(2), never()).requestConnection(); verifyNoMoreInteractions(helper); } @Test - public void subchannelStayInTransientFailureUntilBecomeReady() { - RingHashConfig config = new RingHashConfig(10, 100); - List<EquivalentAddressGroup> servers = createWeightedServerAddrs(1, 1, 1); - Status addressesAcceptanceStatus = loadBalancer.acceptResolvedAddresses( - ResolvedAddresses.newBuilder() - .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); - assertThat(addressesAcceptanceStatus.isOk()).isTrue(); - verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); - verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class)); - reset(helper); - - // Simulate picks have taken place and subchannels have requested connection. - for (Subchannel subchannel : subchannels.values()) { - deliverSubchannelState(subchannel, ConnectivityStateInfo.forTransientFailure( - Status.UNAUTHENTICATED.withDescription("Permission denied"))); - } - verify(helper, times(3)).refreshNameResolution(); - - // Stays in IDLE when until there are two or more subchannels in TRANSIENT_FAILURE. - verify(helper).updateBalancingState(eq(CONNECTING), any(SubchannelPicker.class)); - verify(helper, times(2)) - .updateBalancingState(eq(TRANSIENT_FAILURE), any(SubchannelPicker.class)); - verifyConnection(3); - - verifyNoMoreInteractions(helper); - reset(helper); - // Simulate underlying subchannel auto reconnect after backoff. - for (Subchannel subchannel : subchannels.values()) { - deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(CONNECTING)); - } - verify(helper, times(3)) - .updateBalancingState(eq(TRANSIENT_FAILURE), any(SubchannelPicker.class)); - verifyConnection(3); - verifyNoMoreInteractions(helper); - - // Simulate one subchannel enters READY. - deliverSubchannelState( - subchannels.values().iterator().next(), ConnectivityStateInfo.forNonError(READY)); - verify(helper).updateBalancingState(eq(READY), any(SubchannelPicker.class)); - } - - @Test - public void updateConnectionIterator() { - RingHashConfig config = new RingHashConfig(10, 100); - List<EquivalentAddressGroup> servers = createWeightedServerAddrs(1, 1, 1); - InOrder inOrder = Mockito.inOrder(helper); - Status addressesAcceptanceStatus = loadBalancer.acceptResolvedAddresses( - ResolvedAddresses.newBuilder() - .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); - assertThat(addressesAcceptanceStatus.isOk()).isTrue(); - verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); - verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class)); - - deliverSubchannelState( - subchannels.get(Collections.singletonList(servers.get(0))), - ConnectivityStateInfo.forTransientFailure( - Status.UNAVAILABLE.withDescription("connection lost"))); - inOrder.verify(helper).refreshNameResolution(); - inOrder.verify(helper) - .updateBalancingState(eq(CONNECTING), any(SubchannelPicker.class)); - verifyConnection(1); - - servers = createWeightedServerAddrs(1,1); - addressesAcceptanceStatus = loadBalancer.acceptResolvedAddresses( - ResolvedAddresses.newBuilder() - .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); - assertThat(addressesAcceptanceStatus.isOk()).isTrue(); - inOrder.verify(helper) - .updateBalancingState(eq(CONNECTING), any(SubchannelPicker.class)); - verifyConnection(1); - - deliverSubchannelState( - subchannels.get(Collections.singletonList(servers.get(1))), - ConnectivityStateInfo.forTransientFailure( - Status.UNAVAILABLE.withDescription("connection lost"))); - inOrder.verify(helper).refreshNameResolution(); - inOrder.verify(helper) - .updateBalancingState(eq(TRANSIENT_FAILURE), any(SubchannelPicker.class)); - verifyConnection(1); - - deliverSubchannelState( - subchannels.get(Collections.singletonList(servers.get(0))), - ConnectivityStateInfo.forNonError(CONNECTING)); - inOrder.verify(helper) - .updateBalancingState(eq(TRANSIENT_FAILURE), any(SubchannelPicker.class)); - verifyConnection(1); - } - - @Test public void ignoreShutdownSubchannelStateChange() { RingHashConfig config = new RingHashConfig(10, 100); List<EquivalentAddressGroup> servers = createWeightedServerAddrs(1, 1, 1); - Status addressesAcceptanceStatus = loadBalancer.acceptResolvedAddresses( - ResolvedAddresses.newBuilder() - .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); - assertThat(addressesAcceptanceStatus.isOk()).isTrue(); - verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); - verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class)); + initializeLbSubchannels(config, servers); loadBalancer.shutdown(); for (Subchannel sc : subchannels.values()) { @@ -451,13 +313,8 @@ public class RingHashLoadBalancerTest { public void deterministicPickWithHostsPartiallyRemoved() { RingHashConfig config = new RingHashConfig(10, 100); List<EquivalentAddressGroup> servers = createWeightedServerAddrs(1, 1, 1, 1, 1); - Status addressesAcceptanceStatus = loadBalancer.acceptResolvedAddresses( - ResolvedAddresses.newBuilder() - .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); - assertThat(addressesAcceptanceStatus.isOk()).isTrue(); + initializeLbSubchannels(config, servers); InOrder inOrder = Mockito.inOrder(helper); - inOrder.verify(helper, times(5)).createSubchannel(any(CreateSubchannelArgs.class)); - inOrder.verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class)); // Bring all subchannels to READY so that next pick always succeeds. for (Subchannel subchannel : subchannels.values()) { @@ -466,10 +323,8 @@ public class RingHashLoadBalancerTest { } // Simulate rpc hash hits one ring entry exactly for server1. - long rpcHash = hashFunc.hashAsciiString("[FakeSocketAddress-server1]_0"); - PickSubchannelArgs args = new PickSubchannelArgsImpl( - TestMethodDescriptors.voidMethod(), new Metadata(), - CallOptions.DEFAULT.withOption(XdsNameResolver.RPC_HASH_KEY, rpcHash)); + long rpcHash = hashFunc.hashAsciiString("FakeSocketAddress-server1_0"); + PickSubchannelArgs args = getDefaultPickSubchannelArgs(rpcHash); pickerCaptor.getValue().pickSubchannel(args); PickResult result = pickerCaptor.getValue().pickSubchannel(args); Subchannel subchannel = result.getSubchannel(); @@ -480,14 +335,14 @@ public class RingHashLoadBalancerTest { Attributes attr = addr.getAttributes().toBuilder().set(CUSTOM_KEY, "custom value").build(); updatedServers.add(new EquivalentAddressGroup(addr.getAddresses(), attr)); } - addressesAcceptanceStatus = loadBalancer.acceptResolvedAddresses( + Subchannel subchannel0_old = subchannels.get(Collections.singletonList(servers.get(0))); + Subchannel subchannel1_old = subchannels.get(Collections.singletonList(servers.get(1))); + Status addressesAcceptanceStatus = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(updatedServers).setLoadBalancingPolicyConfig(config).build()); assertThat(addressesAcceptanceStatus.isOk()).isTrue(); - verify(subchannels.get(Collections.singletonList(servers.get(0)))) - .updateAddresses(Collections.singletonList(updatedServers.get(0))); - verify(subchannels.get(Collections.singletonList(servers.get(1)))) - .updateAddresses(Collections.singletonList(updatedServers.get(1))); + verify(subchannel0_old).updateAddresses(Collections.singletonList(updatedServers.get(0))); + verify(subchannel1_old).updateAddresses(Collections.singletonList(updatedServers.get(1))); inOrder.verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture()); assertThat(pickerCaptor.getValue().pickSubchannel(args).getSubchannel()) .isSameInstanceAs(subchannel); @@ -498,13 +353,9 @@ public class RingHashLoadBalancerTest { public void deterministicPickWithNewHostsAdded() { RingHashConfig config = new RingHashConfig(10, 100); List<EquivalentAddressGroup> servers = createWeightedServerAddrs(1, 1); // server0 and server1 - Status addressesAcceptanceStatus = loadBalancer.acceptResolvedAddresses( - ResolvedAddresses.newBuilder() - .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); - assertThat(addressesAcceptanceStatus.isOk()).isTrue(); + initializeLbSubchannels(config, servers, DO_NOT_VERIFY, DO_NOT_RESET_HELPER); + InOrder inOrder = Mockito.inOrder(helper); - inOrder.verify(helper, times(2)).createSubchannel(any(CreateSubchannelArgs.class)); - inOrder.verify(helper).updateBalancingState(eq(IDLE), pickerCaptor.capture()); // Bring all subchannels to READY so that next pick always succeeds. for (Subchannel subchannel : subchannels.values()) { @@ -513,66 +364,68 @@ public class RingHashLoadBalancerTest { } // Simulate rpc hash hits one ring entry exactly for server1. - long rpcHash = hashFunc.hashAsciiString("[FakeSocketAddress-server1]_0"); - PickSubchannelArgs args = new PickSubchannelArgsImpl( - TestMethodDescriptors.voidMethod(), new Metadata(), - CallOptions.DEFAULT.withOption(XdsNameResolver.RPC_HASH_KEY, rpcHash)); + long rpcHash = hashFunc.hashAsciiString("FakeSocketAddress-server1_0"); + PickSubchannelArgs args = getDefaultPickSubchannelArgs(rpcHash); pickerCaptor.getValue().pickSubchannel(args); PickResult result = pickerCaptor.getValue().pickSubchannel(args); Subchannel subchannel = result.getSubchannel(); assertThat(subchannel.getAddresses()).isEqualTo(servers.get(1)); servers = createWeightedServerAddrs(1, 1, 1, 1, 1); // server2, server3, server4 added - addressesAcceptanceStatus = loadBalancer.acceptResolvedAddresses( + Status addressesAcceptanceStatus = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); assertThat(addressesAcceptanceStatus.isOk()).isTrue(); - inOrder.verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); + assertThat(loadBalancer.getChildLbStates().size()).isEqualTo(5); inOrder.verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture()); assertThat(pickerCaptor.getValue().pickSubchannel(args).getSubchannel()) .isSameInstanceAs(subchannel); - verifyNoMoreInteractions(helper); + inOrder.verifyNoMoreInteractions(); + } + + private Subchannel getSubChannel(EquivalentAddressGroup eag) { + return subchannels.get(Collections.singletonList(eag)); } @Test - public void skipFailingHosts_pickNextNonFailingHostInFirstTwoHosts() { + public void skipFailingHosts_pickNextNonFailingHost() { // Map each server address to exactly one ring entry. RingHashConfig config = new RingHashConfig(3, 3); List<EquivalentAddressGroup> servers = createWeightedServerAddrs(1, 1, 1); - Status addressesAcceptanceStatus = loadBalancer.acceptResolvedAddresses( - ResolvedAddresses.newBuilder() - .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); + Status addressesAcceptanceStatus = + loadBalancer.acceptResolvedAddresses( + ResolvedAddresses.newBuilder() + .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); assertThat(addressesAcceptanceStatus.isOk()).isTrue(); - verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); - verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class)); // initial IDLE + + // Create subchannel for the first address + ((RingHashChildLbState)loadBalancer.getChildLbStateEag(servers.get(0))).activate(); + verifyConnection(1); + reset(helper); // ring: - // "[FakeSocketAddress-server1]_0" - // "[FakeSocketAddress-server0]_0" - // "[FakeSocketAddress-server2]_0" + // "FakeSocketAddress-server0_0" + // "FakeSocketAddress-server1_0" + // "FakeSocketAddress-server2_0" - long rpcHash = hashFunc.hashAsciiString("[FakeSocketAddress-server0]_0"); + long rpcHash = hashFunc.hashAsciiString("FakeSocketAddress-server0_0"); PickSubchannelArgs args = getDefaultPickSubchannelArgs(rpcHash); // Bring down server0 to force trying server2. deliverSubchannelState( - subchannels.get(Collections.singletonList(servers.get(0))), + getSubChannel(servers.get(0)), ConnectivityStateInfo.forTransientFailure( Status.UNAVAILABLE.withDescription("unreachable"))); verify(helper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); - verifyConnection(1); PickResult result = pickerCaptor.getValue().pickSubchannel(args); assertThat(result.getStatus().isOk()).isTrue(); assertThat(result.getSubchannel()).isNull(); // buffer request - verify(subchannels.get(Collections.singletonList(servers.get(2)))) - .requestConnection(); // kick off connection to server2 - verify(subchannels.get(Collections.singletonList(servers.get(1))), never()) - .requestConnection(); // no excessive connection + verify(getSubChannel(servers.get(1))).requestConnection(); // kicked off connection to server2 + assertThat(subchannels.size()).isEqualTo(2); // no excessive connection reset(helper); - deliverSubchannelState( - subchannels.get(Collections.singletonList(servers.get(2))), + deliverSubchannelState(getSubChannel(servers.get(1)), ConnectivityStateInfo.forNonError(CONNECTING)); verify(helper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); @@ -580,14 +433,12 @@ public class RingHashLoadBalancerTest { assertThat(result.getStatus().isOk()).isTrue(); assertThat(result.getSubchannel()).isNull(); // buffer request - deliverSubchannelState( - subchannels.get(Collections.singletonList(servers.get(2))), - ConnectivityStateInfo.forNonError(READY)); + deliverSubchannelState(getSubChannel(servers.get(1)), ConnectivityStateInfo.forNonError(READY)); verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture()); result = pickerCaptor.getValue().pickSubchannel(args); assertThat(result.getStatus().isOk()).isTrue(); - assertThat(result.getSubchannel().getAddresses()).isEqualTo(servers.get(2)); + assertThat(result.getSubchannel().getAddresses()).isEqualTo(servers.get(1)); } private PickSubchannelArgsImpl getDefaultPickSubchannelArgs(long rpcHash) { @@ -596,55 +447,59 @@ public class RingHashLoadBalancerTest { CallOptions.DEFAULT.withOption(XdsNameResolver.RPC_HASH_KEY, rpcHash)); } + private PickSubchannelArgsImpl getDefaultPickSubchannelArgsForServer(int serverid) { + long rpcHash = hashFunc.hashAsciiString("FakeSocketAddress-server" + serverid + "_0"); + return getDefaultPickSubchannelArgs(rpcHash); + } + @Test public void skipFailingHosts_firstTwoHostsFailed_pickNextFirstReady() { // Map each server address to exactly one ring entry. RingHashConfig config = new RingHashConfig(3, 3); List<EquivalentAddressGroup> servers = createWeightedServerAddrs(1, 1, 1); - Status addressesAcceptanceStatus = loadBalancer.acceptResolvedAddresses( - ResolvedAddresses.newBuilder() - .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); - assertThat(addressesAcceptanceStatus.isOk()).isTrue(); - verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); - verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class)); // initial IDLE - reset(helper); + + initializeLbSubchannels(config, servers); + // ring: - // "[FakeSocketAddress-server1]_0" - // "[FakeSocketAddress-server0]_0" - // "[FakeSocketAddress-server2]_0" + // "FakeSocketAddress-server0_0" + // "FakeSocketAddress-server1_0" + // "FakeSocketAddress-server2_0" - long rpcHash = hashFunc.hashAsciiString("[FakeSocketAddress-server0]_0"); - PickSubchannelArgs args = new PickSubchannelArgsImpl( - TestMethodDescriptors.voidMethod(), new Metadata(), - CallOptions.DEFAULT.withOption(XdsNameResolver.RPC_HASH_KEY, rpcHash)); + long rpcHash = hashFunc.hashAsciiString("FakeSocketAddress-server1_0"); + PickSubchannelArgs args = getDefaultPickSubchannelArgs(rpcHash); // Bring down server0 and server2 to force trying server1. deliverSubchannelState( - subchannels.get(Collections.singletonList(servers.get(0))), + subchannels.get(Collections.singletonList(servers.get(1))), ConnectivityStateInfo.forTransientFailure( Status.UNAVAILABLE.withDescription("unreachable"))); deliverSubchannelState( subchannels.get(Collections.singletonList(servers.get(2))), ConnectivityStateInfo.forTransientFailure( Status.PERMISSION_DENIED.withDescription("permission denied"))); - verify(helper).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture()); - verifyConnection(2); // LB attempts to recover by itself + verify(helper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); + verifyConnection(0); + PickResult result = pickerCaptor.getValue().pickSubchannel(args); // activate last subchannel + assertThat(result.getStatus().isOk()).isTrue(); + verifyConnection(1); - PickResult result = pickerCaptor.getValue().pickSubchannel(args); + deliverSubchannelState( + subchannels.get(Collections.singletonList(servers.get(0))), + ConnectivityStateInfo.forTransientFailure( + Status.PERMISSION_DENIED.withDescription("permission denied again"))); + verify(helper, times(2)).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture()); + result = pickerCaptor.getValue().pickSubchannel(args); assertThat(result.getStatus().isOk()).isFalse(); // fail the RPC assertThat(result.getStatus().getCode()) .isEqualTo(Code.UNAVAILABLE); // with error status for the original server hit by hash assertThat(result.getStatus().getDescription()).isEqualTo("unreachable"); - verify(subchannels.get(Collections.singletonList(servers.get(1)))) - .requestConnection(); // kickoff connection to server3 (next first non-failing) - verify(subchannels.get(Collections.singletonList(servers.get(0)))).requestConnection(); - verify(subchannels.get(Collections.singletonList(servers.get(2)))).requestConnection(); // Now connecting to server1. deliverSubchannelState( subchannels.get(Collections.singletonList(servers.get(1))), ConnectivityStateInfo.forNonError(CONNECTING)); - verify(helper, times(2)).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture()); + + reset(helper); result = pickerCaptor.getValue().pickSubchannel(args); assertThat(result.getStatus().isOk()).isFalse(); // fail the RPC @@ -658,9 +513,29 @@ public class RingHashLoadBalancerTest { ConnectivityStateInfo.forNonError(READY)); verify(helper).updateBalancingState(eq(READY), pickerCaptor.capture()); - result = pickerCaptor.getValue().pickSubchannel(args); + SubchannelPicker picker = pickerCaptor.getValue(); + result = picker.pickSubchannel(args); assertThat(result.getStatus().isOk()).isTrue(); // succeed assertThat(result.getSubchannel().getAddresses()).isEqualTo(servers.get(1)); // with server1 + assertThat(picker.pickSubchannel(getDefaultPickSubchannelArgsForServer(0))).isEqualTo(result); + assertThat(picker.pickSubchannel(getDefaultPickSubchannelArgsForServer(2))).isEqualTo(result); + } + + @Test + public void removingAddressShutdownSubchannel() { + // Map each server address to exactly one ring entry. + RingHashConfig config = new RingHashConfig(3, 3); + List<EquivalentAddressGroup> svs1 = createWeightedServerAddrs(1, 1, 1); + List<Subchannel> subchannels1 = initializeLbSubchannels(config, svs1, STAY_IN_CONNECTING); + + List<EquivalentAddressGroup> svs2 = createWeightedServerAddrs(1, 1); + InOrder inOrder = Mockito.inOrder(helper, subchannels1.get(2)); + // send LB the missing address + loadBalancer.acceptResolvedAddresses( + ResolvedAddresses.newBuilder() + .setAddresses(svs2).setLoadBalancingPolicyConfig(config).build()); + inOrder.verify(helper).updateBalancingState(eq(CONNECTING), any()); + inOrder.verify(subchannels1.get(2)).shutdown(); } @Test @@ -668,12 +543,7 @@ public class RingHashLoadBalancerTest { // Map each server address to exactly one ring entry. RingHashConfig config = new RingHashConfig(3, 3); List<EquivalentAddressGroup> servers = createWeightedServerAddrs(1, 1, 1); - Status addressesAcceptanceStatus = loadBalancer.acceptResolvedAddresses( - ResolvedAddresses.newBuilder() - .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); - assertThat(addressesAcceptanceStatus.isOk()).isTrue(); - verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); - verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class)); + initializeLbSubchannels(config, servers); // Bring all subchannels to TRANSIENT_FAILURE. for (Subchannel subchannel : subchannels.values()) { @@ -683,23 +553,16 @@ public class RingHashLoadBalancerTest { } verify(helper, atLeastOnce()) .updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture()); - verifyConnection(3); + verifyConnection(0); // Picking subchannel triggers connection. RPC hash hits server0. - PickSubchannelArgs args = new PickSubchannelArgsImpl( - TestMethodDescriptors.voidMethod(), new Metadata(), - CallOptions.DEFAULT.withOption(XdsNameResolver.RPC_HASH_KEY, hashFunc.hashVoid())); + PickSubchannelArgs args = getDefaultPickSubchannelArgsForServer(0); PickResult result = pickerCaptor.getValue().pickSubchannel(args); assertThat(result.getStatus().isOk()).isFalse(); assertThat(result.getStatus().getCode()).isEqualTo(Code.UNAVAILABLE); assertThat(result.getStatus().getDescription()) .isEqualTo("[FakeSocketAddress-server0] unreachable"); - verify(subchannels.get(Collections.singletonList(servers.get(0)))) - .requestConnection(); - verify(subchannels.get(Collections.singletonList(servers.get(1)))) - .requestConnection(); - verify(subchannels.get(Collections.singletonList(servers.get(2)))) - .requestConnection(); + verifyConnection(0); // TF has already started taking care of this, pick doesn't need to } @Test @@ -707,31 +570,20 @@ public class RingHashLoadBalancerTest { // Map each server address to exactly one ring entry. RingHashConfig config = new RingHashConfig(3, 3); List<EquivalentAddressGroup> servers = createWeightedServerAddrs(1, 1, 1); - Status addressesAcceptanceStatus = loadBalancer.acceptResolvedAddresses( - ResolvedAddresses.newBuilder() - .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); - assertThat(addressesAcceptanceStatus.isOk()).isTrue(); - verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); - verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class)); + initializeLbSubchannels(config, servers); + // Go to TF does nothing, though PF will try to reconnect after backoff deliverSubchannelState(subchannels.get(Collections.singletonList(servers.get(1))), ConnectivityStateInfo.forTransientFailure( Status.UNAVAILABLE.withDescription("unreachable"))); verify(helper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); - verifyConnection(1); + verifyConnection(0); // Picking subchannel triggers connection. RPC hash hits server0. - PickSubchannelArgs args = new PickSubchannelArgsImpl( - TestMethodDescriptors.voidMethod(), new Metadata(), - CallOptions.DEFAULT.withOption(XdsNameResolver.RPC_HASH_KEY, hashFunc.hashVoid())); + PickSubchannelArgs args = getDefaultPickSubchannelArgs(hashFunc.hashVoid()); PickResult result = pickerCaptor.getValue().pickSubchannel(args); assertThat(result.getStatus().isOk()).isTrue(); - verify(subchannels.get(Collections.singletonList(servers.get(0)))) - .requestConnection(); - verify(subchannels.get(Collections.singletonList(servers.get(1))), never()) - .requestConnection(); - verify(subchannels.get(Collections.singletonList(servers.get(2))), never()) - .requestConnection(); + verifyConnection(1); } @Test @@ -739,12 +591,7 @@ public class RingHashLoadBalancerTest { // Map each server address to exactly one ring entry. RingHashConfig config = new RingHashConfig(3, 3); List<EquivalentAddressGroup> servers = createWeightedServerAddrs(1, 1, 1); - Status addressesAcceptanceStatus = loadBalancer.acceptResolvedAddresses( - ResolvedAddresses.newBuilder() - .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); - assertThat(addressesAcceptanceStatus.isOk()).isTrue(); - verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); - verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class)); + initializeLbSubchannels(config, servers); deliverSubchannelState(subchannels.get(Collections.singletonList(servers.get(0))), ConnectivityStateInfo.forNonError(CONNECTING)); @@ -753,9 +600,7 @@ public class RingHashLoadBalancerTest { verify(helper, times(2)).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); // Picking subchannel triggers connection. - PickSubchannelArgs args = new PickSubchannelArgsImpl( - TestMethodDescriptors.voidMethod(), new Metadata(), - CallOptions.DEFAULT.withOption(XdsNameResolver.RPC_HASH_KEY, hashFunc.hashVoid())); + PickSubchannelArgs args = getDefaultPickSubchannelArgs(hashFunc.hashVoid()); PickResult result = pickerCaptor.getValue().pickSubchannel(args); assertThat(result.getStatus().isOk()).isTrue(); verify(subchannels.get(Collections.singletonList(servers.get(0))), never()) @@ -771,35 +616,30 @@ public class RingHashLoadBalancerTest { // Map each server address to exactly one ring entry. RingHashConfig config = new RingHashConfig(3, 3); List<EquivalentAddressGroup> servers = createWeightedServerAddrs(1, 1, 1); - Status addressesAcceptanceStatus = loadBalancer.acceptResolvedAddresses( - ResolvedAddresses.newBuilder() - .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); - assertThat(addressesAcceptanceStatus.isOk()).isTrue(); - verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); - verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class)); + + List<Subchannel> subchannelList = + initializeLbSubchannels(config, servers, RESET_SUBCHANNEL_MOCKS); + // ring: - // "[FakeSocketAddress-server1]_0" - // "[FakeSocketAddress-server0]_0" - // "[FakeSocketAddress-server2]_0" + // "FakeSocketAddress-server1_0" + // "FakeSocketAddress-server0_0" + // "FakeSocketAddress-server2_0" - deliverSubchannelState(subchannels.get(Collections.singletonList(servers.get(0))), + deliverSubchannelState(subchannelList.get(0), ConnectivityStateInfo.forTransientFailure( Status.UNAVAILABLE.withDescription("unreachable"))); verify(helper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); - verifyConnection(1); + verifyConnection(0); - // Picking subchannel triggers connection. - PickSubchannelArgs args = new PickSubchannelArgsImpl( - TestMethodDescriptors.voidMethod(), new Metadata(), - CallOptions.DEFAULT.withOption(XdsNameResolver.RPC_HASH_KEY, hashFunc.hashVoid())); - PickResult result = pickerCaptor.getValue().pickSubchannel(args); + // Per GRFC A61 Picking subchannel should no longer request connections that were failing + PickSubchannelArgs args = getDefaultPickSubchannelArgs(hashFunc.hashVoid()); + SubchannelPicker picker1 = pickerCaptor.getValue(); + PickResult result = picker1.pickSubchannel(args); assertThat(result.getStatus().isOk()).isTrue(); - verify(subchannels.get(Collections.singletonList(servers.get(0)))) - .requestConnection(); - verify(subchannels.get(Collections.singletonList(servers.get(2)))) - .requestConnection(); - verify(subchannels.get(Collections.singletonList(servers.get(1))), never()) - .requestConnection(); + assertThat(result.getSubchannel()).isNull(); + verify(subchannelList.get(0), never()).requestConnection(); // In TF + verify(subchannelList.get(1)).requestConnection(); + verify(subchannelList.get(2), never()).requestConnection(); // Not one of the first 2 } @Test @@ -807,38 +647,31 @@ public class RingHashLoadBalancerTest { // Map each server address to exactly one ring entry. RingHashConfig config = new RingHashConfig(3, 3); List<EquivalentAddressGroup> servers = createWeightedServerAddrs(1, 1, 1); - Status addressesAcceptanceStatus = loadBalancer.acceptResolvedAddresses( - ResolvedAddresses.newBuilder() - .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); - assertThat(addressesAcceptanceStatus.isOk()).isTrue(); - verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); - verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class)); + + initializeLbSubchannels(config, servers); + // ring: - // "[FakeSocketAddress-server1]_0" - // "[FakeSocketAddress-server0]_0" - // "[FakeSocketAddress-server2]_0" + // "FakeSocketAddress-server1_0" + // "FakeSocketAddress-server0_0" + // "FakeSocketAddress-server2_0" Subchannel firstSubchannel = subchannels.get(Collections.singletonList(servers.get(0))); - deliverSubchannelState(firstSubchannel, - ConnectivityStateInfo.forTransientFailure(Status.UNAVAILABLE.withDescription( - firstSubchannel.getAddresses().getAddresses() + "unreachable"))); + deliverSubchannelUnreachable(firstSubchannel); + verifyConnection(0); + deliverSubchannelState(subchannels.get(Collections.singletonList(servers.get(2))), ConnectivityStateInfo.forNonError(CONNECTING)); verify(helper, times(2)).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); - verifyConnection(1); + verifyConnection(0); - // Picking subchannel triggers connection. - PickSubchannelArgs args = new PickSubchannelArgsImpl( - TestMethodDescriptors.voidMethod(), new Metadata(), - CallOptions.DEFAULT.withOption(XdsNameResolver.RPC_HASH_KEY, hashFunc.hashVoid())); + // Picking subchannel when idle triggers connection. + deliverSubchannelState(subchannels.get(Collections.singletonList(servers.get(2))), + ConnectivityStateInfo.forNonError(IDLE)); + verifyConnection(0); + PickSubchannelArgs args = getDefaultPickSubchannelArgs(hashFunc.hashVoid()); PickResult result = pickerCaptor.getValue().pickSubchannel(args); assertThat(result.getStatus().isOk()).isTrue(); - verify(subchannels.get(Collections.singletonList(servers.get(0)))) - .requestConnection(); - verify(subchannels.get(Collections.singletonList(servers.get(2))), never()) - .requestConnection(); - verify(subchannels.get(Collections.singletonList(servers.get(1))), never()) - .requestConnection(); + verifyConnection(1); } @Test @@ -846,42 +679,26 @@ public class RingHashLoadBalancerTest { // Map each server address to exactly one ring entry. RingHashConfig config = new RingHashConfig(3, 3); List<EquivalentAddressGroup> servers = createWeightedServerAddrs(1, 1, 1); - Status addressesAcceptanceStatus = loadBalancer.acceptResolvedAddresses( - ResolvedAddresses.newBuilder() - .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); - assertThat(addressesAcceptanceStatus.isOk()).isTrue(); - verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); - verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class)); + + initializeLbSubchannels(config, servers); + // ring: - // "[FakeSocketAddress-server1]_0" - // "[FakeSocketAddress-server0]_0" - // "[FakeSocketAddress-server2]_0" + // "FakeSocketAddress-server1_0" + // "FakeSocketAddress-server0_0" + // "FakeSocketAddress-server2_0" Subchannel firstSubchannel = subchannels.get(Collections.singletonList(servers.get(0))); - deliverSubchannelState(firstSubchannel, - ConnectivityStateInfo.forTransientFailure(Status.UNAVAILABLE.withDescription( - firstSubchannel.getAddresses().getAddresses() + " unreachable"))); - deliverSubchannelState(subchannels.get(Collections.singletonList(servers.get(2))), - ConnectivityStateInfo.forTransientFailure( - Status.UNAVAILABLE.withDescription("unreachable"))); + deliverSubchannelUnreachable(firstSubchannel); + deliverSubchannelUnreachable(subchannels.get(Collections.singletonList(servers.get(2)))); verify(helper).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture()); - verifyConnection(2); + verifyConnection(0); // Picking subchannel triggers connection. - PickSubchannelArgs args = new PickSubchannelArgsImpl( - TestMethodDescriptors.voidMethod(), new Metadata(), - CallOptions.DEFAULT.withOption(XdsNameResolver.RPC_HASH_KEY, hashFunc.hashVoid())); + PickSubchannelArgs args = getDefaultPickSubchannelArgs(hashFunc.hashVoid()); PickResult result = pickerCaptor.getValue().pickSubchannel(args); - assertThat(result.getStatus().isOk()).isFalse(); - assertThat(result.getStatus().getCode()).isEqualTo(Code.UNAVAILABLE); - assertThat(result.getStatus().getDescription()) - .isEqualTo("[FakeSocketAddress-server0] unreachable"); - verify(subchannels.get(Collections.singletonList(servers.get(0)))) - .requestConnection(); - verify(subchannels.get(Collections.singletonList(servers.get(2)))) - .requestConnection(); - verify(subchannels.get(Collections.singletonList(servers.get(1)))) - .requestConnection(); + assertThat(result.getStatus().isOk()).isTrue(); + verify(subchannels.get(Collections.singletonList(servers.get(1)))).requestConnection(); + verifyConnection(1); } @Test @@ -889,44 +706,29 @@ public class RingHashLoadBalancerTest { // Map each server address to exactly one ring entry. RingHashConfig config = new RingHashConfig(3, 3); List<EquivalentAddressGroup> servers = createWeightedServerAddrs(1, 1, 1); - Status addressesAcceptanceStatus = loadBalancer.acceptResolvedAddresses( - ResolvedAddresses.newBuilder() - .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); - assertThat(addressesAcceptanceStatus.isOk()).isTrue(); - verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); - verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class)); + + initializeLbSubchannels(config, servers); + // ring: - // "[FakeSocketAddress-server1]_0" - // "[FakeSocketAddress-server0]_0" - // "[FakeSocketAddress-server2]_0" + // "FakeSocketAddress-server1_0" + // "FakeSocketAddress-server0_0" + // "FakeSocketAddress-server2_0" Subchannel firstSubchannel = subchannels.get(Collections.singletonList(servers.get(0))); - deliverSubchannelState(firstSubchannel, - ConnectivityStateInfo.forTransientFailure(Status.UNAVAILABLE.withDescription( - firstSubchannel.getAddresses().getAddresses() + " unreachable"))); - deliverSubchannelState(subchannels.get(Collections.singletonList(servers.get(2))), - ConnectivityStateInfo.forTransientFailure( - Status.UNAVAILABLE.withDescription("unreachable"))); + + deliverSubchannelUnreachable(firstSubchannel); + deliverSubchannelUnreachable(subchannels.get(Collections.singletonList(servers.get(2)))); deliverSubchannelState(subchannels.get(Collections.singletonList(servers.get(1))), ConnectivityStateInfo.forNonError(CONNECTING)); - verify(helper, times(2)).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture()); - verifyConnection(2); + verify(helper, atLeastOnce()) + .updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture()); + verifyConnection(0); - // Picking subchannel triggers connection. - PickSubchannelArgs args = new PickSubchannelArgsImpl( - TestMethodDescriptors.voidMethod(), new Metadata(), - CallOptions.DEFAULT.withOption(XdsNameResolver.RPC_HASH_KEY, hashFunc.hashVoid())); + // Picking subchannel should not trigger connection per gRFC A61. + PickSubchannelArgs args = getDefaultPickSubchannelArgs(hashFunc.hashVoid()); PickResult result = pickerCaptor.getValue().pickSubchannel(args); - assertThat(result.getStatus().isOk()).isFalse(); - assertThat(result.getStatus().getCode()).isEqualTo(Code.UNAVAILABLE); - assertThat(result.getStatus().getDescription()) - .isEqualTo("[FakeSocketAddress-server0] unreachable"); - verify(subchannels.get(Collections.singletonList(servers.get(0)))) - .requestConnection(); - verify(subchannels.get(Collections.singletonList(servers.get(2)))) - .requestConnection(); - verify(subchannels.get(Collections.singletonList(servers.get(1))), never()) - .requestConnection(); + assertThat(result.getStatus().isOk()).isTrue(); + verifyConnection(0); } @Test @@ -934,35 +736,29 @@ public class RingHashLoadBalancerTest { // Map each server address to exactly one ring entry. RingHashConfig config = new RingHashConfig(3, 3); List<EquivalentAddressGroup> servers = createWeightedServerAddrs(1, 1, 1); - Status addressesAcceptanceStatus = loadBalancer.acceptResolvedAddresses( - ResolvedAddresses.newBuilder() - .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); - assertThat(addressesAcceptanceStatus.isOk()).isTrue(); - verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); - verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class)); + + initializeLbSubchannels(config, servers); // Bring one subchannel to TRANSIENT_FAILURE. Subchannel firstSubchannel = subchannels.get(Collections.singletonList(servers.get(0))); - deliverSubchannelState(firstSubchannel, - ConnectivityStateInfo.forTransientFailure( - Status.UNAVAILABLE.withDescription( - firstSubchannel.getAddresses().getAddresses() + " unreachable"))); + deliverSubchannelUnreachable(firstSubchannel); - verify(helper).updateBalancingState(eq(CONNECTING), any()); - verifyConnection(1); + verify(helper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); + verifyConnection(0); + + reset(helper); deliverSubchannelState(firstSubchannel, ConnectivityStateInfo.forNonError(IDLE)); - verify(helper, times(2)).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); + // Should not have called updateBalancingState on the helper again because PickFirst is + // shielding the higher level from the state change. + verify(helper, never()).updateBalancingState(any(), any()); verifyConnection(1); - // Picking subchannel triggers connection. RPC hash hits server0. - PickSubchannelArgs args = new PickSubchannelArgsImpl( - TestMethodDescriptors.voidMethod(), new Metadata(), - CallOptions.DEFAULT.withOption(XdsNameResolver.RPC_HASH_KEY, hashFunc.hashVoid())); + // Picking subchannel triggers connection on second address. RPC hash hits server0. + PickSubchannelArgs args = getDefaultPickSubchannelArgs(hashFunc.hashVoid()); PickResult result = pickerCaptor.getValue().pickSubchannel(args); assertThat(result.getStatus().isOk()).isTrue(); - verify(subchannels.get(Collections.singletonList(servers.get(0)))).requestConnection(); - verify(subchannels.get(Collections.singletonList(servers.get(2)))).requestConnection(); - verify(subchannels.get(Collections.singletonList(servers.get(1))), never()) + verify(subchannels.get(Collections.singletonList(servers.get(1)))).requestConnection(); + verify(subchannels.get(Collections.singletonList(servers.get(2))), never()) .requestConnection(); } @@ -971,14 +767,12 @@ public class RingHashLoadBalancerTest { RingHashConfig config = new RingHashConfig(10000, 100000); // large ring List<EquivalentAddressGroup> servers = createWeightedServerAddrs(Integer.MAX_VALUE, 10, 100); // MAX:10:100 - Status addressesAcceptanceStatus = loadBalancer.acceptResolvedAddresses( - ResolvedAddresses.newBuilder() - .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); - assertThat(addressesAcceptanceStatus.isOk()).isTrue(); + + initializeLbSubchannels(config, servers); // Try value between max signed and max unsigned int servers = createWeightedServerAddrs(Integer.MAX_VALUE + 100L, 100); // (MAX+100):100 - addressesAcceptanceStatus = loadBalancer.acceptResolvedAddresses( + Status addressesAcceptanceStatus = loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); assertThat(addressesAcceptanceStatus.isOk()).isTrue(); @@ -1010,12 +804,8 @@ public class RingHashLoadBalancerTest { public void hostSelectionProportionalToWeights() { RingHashConfig config = new RingHashConfig(10000, 100000); // large ring List<EquivalentAddressGroup> servers = createWeightedServerAddrs(1, 10, 100); // 1:10:100 - Status addressesAcceptanceStatus = loadBalancer.acceptResolvedAddresses( - ResolvedAddresses.newBuilder() - .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); - assertThat(addressesAcceptanceStatus.isOk()).isTrue(); - verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); - verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class)); + + initializeLbSubchannels(config, servers); // Bring all subchannels to READY. Map<EquivalentAddressGroup, Integer> pickCounts = new HashMap<>(); @@ -1028,9 +818,7 @@ public class RingHashLoadBalancerTest { for (int i = 0; i < 10000; i++) { long hash = hashFunc.hashInt(i); - PickSubchannelArgs args = new PickSubchannelArgsImpl( - TestMethodDescriptors.voidMethod(), new Metadata(), - CallOptions.DEFAULT.withOption(XdsNameResolver.RPC_HASH_KEY, hash)); + PickSubchannelArgs args = getDefaultPickSubchannelArgs(hash); Subchannel pickedSubchannel = picker.pickSubchannel(args).getSubchannel(); EquivalentAddressGroup addr = pickedSubchannel.getAddresses(); pickCounts.put(addr, pickCounts.get(addr) + 1); @@ -1059,21 +847,19 @@ public class RingHashLoadBalancerTest { public void nameResolutionErrorWithActiveSubchannels() { RingHashConfig config = new RingHashConfig(10, 100); List<EquivalentAddressGroup> servers = createWeightedServerAddrs(1); - Status addressesAcceptanceStatus = loadBalancer.acceptResolvedAddresses( - ResolvedAddresses.newBuilder() - .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); - assertThat(addressesAcceptanceStatus.isOk()).isTrue(); + + initializeLbSubchannels(config, servers, DO_NOT_VERIFY, DO_NOT_RESET_HELPER); verify(helper).createSubchannel(any(CreateSubchannelArgs.class)); - verify(helper).updateBalancingState(eq(IDLE), pickerCaptor.capture()); + verify(helper, times(2)).updateBalancingState(eq(IDLE), pickerCaptor.capture()); // Picking subchannel triggers subchannel creation and connection. - PickSubchannelArgs args = new PickSubchannelArgsImpl( - TestMethodDescriptors.voidMethod(), new Metadata(), - CallOptions.DEFAULT.withOption(XdsNameResolver.RPC_HASH_KEY, hashFunc.hashVoid())); + PickSubchannelArgs args = getDefaultPickSubchannelArgs(hashFunc.hashVoid()); pickerCaptor.getValue().pickSubchannel(args); + verify(helper, never()).updateBalancingState(eq(READY), any(SubchannelPicker.class)); deliverSubchannelState( Iterables.getOnlyElement(subchannels.values()), ConnectivityStateInfo.forNonError(READY)); verify(helper).updateBalancingState(eq(READY), any(SubchannelPicker.class)); + reset(helper); loadBalancer.handleNameResolutionError(Status.NOT_FOUND.withDescription("target not found")); verifyNoMoreInteractions(helper); @@ -1083,15 +869,12 @@ public class RingHashLoadBalancerTest { public void duplicateAddresses() { RingHashConfig config = new RingHashConfig(10, 100); List<EquivalentAddressGroup> servers = createRepeatedServerAddrs(1, 2, 3); - Status addressesAcceptanceStatus = loadBalancer.acceptResolvedAddresses( - ResolvedAddresses.newBuilder() - .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); - assertThat(addressesAcceptanceStatus.isOk()).isFalse(); + + initializeLbSubchannels(config, servers, DO_NOT_VERIFY); + verify(helper).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture()); - PickSubchannelArgs args = new PickSubchannelArgsImpl( - TestMethodDescriptors.voidMethod(), new Metadata(), - CallOptions.DEFAULT.withOption(XdsNameResolver.RPC_HASH_KEY, hashFunc.hashVoid())); + PickSubchannelArgs args = getDefaultPickSubchannelArgs(hashFunc.hashVoid()); PickResult result = pickerCaptor.getValue().pickSubchannel(args); assertThat(result.getStatus().isOk()).isFalse(); // fail the RPC assertThat(result.getStatus().getCode()) @@ -1103,8 +886,105 @@ public class RingHashLoadBalancerTest { assertThat(description).contains("Address: FakeSocketAddress-server2, count: 3"); } + private List<Subchannel> initializeLbSubchannels(RingHashConfig config, + List<EquivalentAddressGroup> servers, InitializationFlags... initFlags) { + + boolean doVerifies = true; + boolean resetSubchannels = false; + boolean returnToIdle = true; + boolean resetHelper = true; + for (InitializationFlags flag : initFlags) { + switch (flag) { + case DO_NOT_VERIFY: + doVerifies = false; + break; + case RESET_SUBCHANNEL_MOCKS: + resetSubchannels = true; + break; + case STAY_IN_CONNECTING: + returnToIdle = false; + break; + case DO_NOT_RESET_HELPER: + resetHelper = false; + break; + default: + throw new IllegalArgumentException("Unrecognized flag: " + flag); + } + } + + Status addressesAcceptanceStatus = + loadBalancer.acceptResolvedAddresses( + ResolvedAddresses.newBuilder() + .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); + + if (doVerifies) { + assertThat(addressesAcceptanceStatus.isOk()).isTrue(); + verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class)); + } + + if (!addressesAcceptanceStatus.isOk()) { + return new ArrayList<>(); + } + + // Activate them all to create the child LB and subchannel + for (ChildLbState childLbState : loadBalancer.getChildLbStates()) { + ((RingHashChildLbState)childLbState).activate(); + } + + if (doVerifies) { + verify(helper, times(servers.size())).createSubchannel(any(CreateSubchannelArgs.class)); + verify(helper, times(servers.size())) + .updateBalancingState(eq(CONNECTING), any(SubchannelPicker.class)); + verifyConnection(servers.size()); + } + + if (returnToIdle) { + for (Subchannel subchannel : subchannels.values()) { + deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(IDLE)); + } + if (doVerifies) { + verify(helper, times(2 * servers.size() - 1)) + .updateBalancingState(eq(CONNECTING), any(SubchannelPicker.class)); + verify(helper, times(2)).updateBalancingState(eq(IDLE), any(SubchannelPicker.class)); + } + } + + + // Get a list of subchannels in the same order as servers + List<Subchannel> subchannelList = new ArrayList<>(); + for (EquivalentAddressGroup server : servers) { + List<EquivalentAddressGroup> singletonList = Collections.singletonList(server); + Subchannel subchannel = subchannels.get(singletonList); + subchannelList.add(subchannel); + if (resetSubchannels) { + reset(subchannel); + } + } + + if (resetHelper) { + reset(helper); + } + + return subchannelList; + } + private void deliverSubchannelState(Subchannel subchannel, ConnectivityStateInfo state) { - subchannelStateListeners.get(subchannel).onSubchannelState(state); + testHelperInst.deliverSubchannelState(subchannel, state); + } + + private void deliverNotFound(List<Subchannel> subChannelList, int index) { + deliverSubchannelState( + subChannelList.get(index), + ConnectivityStateInfo.forTransientFailure( + Status.UNAVAILABLE.withDescription("also not found"))); + } + + protected void deliverSubchannelUnreachable(Subchannel subchannel) { + deliverSubchannelState(subchannel, + ConnectivityStateInfo.forTransientFailure( + Status.UNAVAILABLE.withDescription( + subchannel.getAddresses().getAddresses() + "unreachable"))); + } private static List<EquivalentAddressGroup> createWeightedServerAddrs(long... weights) { @@ -1156,4 +1036,51 @@ public class RingHashLoadBalancerTest { return "FakeSocketAddress-" + name; } } + + private class TestHelper extends AbstractTestHelper { + + @Override + public Map<List<EquivalentAddressGroup>, Subchannel> getSubchannelMap() { + return subchannels; + } + + @Override + public String getAuthority() { + return AUTHORITY; + } + + @Override + public SynchronizationContext getSynchronizationContext() { + return syncContext; + } + + private Subchannel getMockSubchannel(Subchannel realSubchannel) { + return realToMockSubChannelMap.get(realSubchannel); + } + + @Override + protected AbstractTestHelper.TestSubchannel createRealSubchannel(CreateSubchannelArgs args) { + return new RingHashTestSubchannel(args); + } + + private class RingHashTestSubchannel extends AbstractTestHelper.TestSubchannel { + + RingHashTestSubchannel(CreateSubchannelArgs args) { + super(args); + } + + @Override + public void requestConnection() { + connectionRequestedQueue.offer(getMockSubchannel(this)); + } + + } + } + + enum InitializationFlags { + DO_NOT_VERIFY, + RESET_SUBCHANNEL_MOCKS, + STAY_IN_CONNECTING, + DO_NOT_RESET_HELPER + } } diff --git a/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java index b6ea8bc04..dfbbf9b08 100644 --- a/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java @@ -17,6 +17,7 @@ package io.grpc.xds; import static com.google.common.truth.Truth.assertThat; +import static io.grpc.ConnectivityState.CONNECTING; import static org.mockito.AdditionalAnswers.delegatesTo; import static org.mockito.Mockito.any; import static org.mockito.Mockito.eq; @@ -59,6 +60,7 @@ import io.grpc.xds.WeightedRoundRobinLoadBalancer.WeightedRoundRobinLoadBalancer import io.grpc.xds.WeightedRoundRobinLoadBalancer.WeightedRoundRobinPicker; import java.net.SocketAddress; import java.util.Arrays; +import java.util.Collections; import java.util.HashMap; import java.util.Iterator; import java.util.List; @@ -78,7 +80,9 @@ import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; import org.mockito.Captor; +import org.mockito.InOrder; import org.mockito.Mock; +import org.mockito.Mockito; import org.mockito.invocation.InvocationOnMock; import org.mockito.junit.MockitoJUnit; import org.mockito.junit.MockitoRule; @@ -176,7 +180,7 @@ public class WeightedRoundRobinLoadBalancerTest { .forNonError(ConnectivityState.READY)); Subchannel connectingSubchannel = it.next(); getSubchannelStateListener(connectingSubchannel).onSubchannelState(ConnectivityStateInfo - .forNonError(ConnectivityState.CONNECTING)); + .forNonError(CONNECTING)); verify(helper, times(2)).updateBalancingState( eq(ConnectivityState.READY), pickerCaptor.capture()); assertThat(pickerCaptor.getAllValues().size()).isEqualTo(2); @@ -477,7 +481,7 @@ public class WeightedRoundRobinLoadBalancerTest { .setAttributes(affinity).build())); verify(helper, times(6)).createSubchannel( any(CreateSubchannelArgs.class)); - verify(helper).updateBalancingState(eq(ConnectivityState.CONNECTING), pickerCaptor.capture()); + verify(helper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); assertThat(pickerCaptor.getValue().getClass().getName()) .isEqualTo("io.grpc.util.RoundRobinLoadBalancer$EmptyPicker"); assertThat(fakeClock.forwardTime(11, TimeUnit.SECONDS)).isEqualTo(1); @@ -554,7 +558,7 @@ public class WeightedRoundRobinLoadBalancerTest { .forNonError(ConnectivityState.READY)); Subchannel connectingSubchannel = it.next(); getSubchannelStateListener(connectingSubchannel).onSubchannelState(ConnectivityStateInfo - .forNonError(ConnectivityState.CONNECTING)); + .forNonError(CONNECTING)); verify(helper, times(2)).updateBalancingState( eq(ConnectivityState.READY), pickerCaptor.capture()); assertThat(pickerCaptor.getAllValues().size()).isEqualTo(2); @@ -1063,6 +1067,24 @@ public class WeightedRoundRobinLoadBalancerTest { assertThat(sequence.get()).isEqualTo(9); } + @Test + public void removingAddressShutsdownSubchannel() { + syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder() + .setAddresses(servers).setLoadBalancingPolicyConfig(weightedConfig) + .setAttributes(affinity).build())); + final Subchannel subchannel2 = subchannels.get(Collections.singletonList(servers.get(2))); + + InOrder inOrder = Mockito.inOrder(helper, subchannel2); + // send LB only the first 2 addresses + List<EquivalentAddressGroup> svs2 = Arrays.asList(servers.get(0), servers.get(1)); + syncContext.execute(() -> wrr.acceptResolvedAddresses(ResolvedAddresses.newBuilder() + .setAddresses(svs2).setLoadBalancingPolicyConfig(weightedConfig) + .setAttributes(affinity).build())); + inOrder.verify(helper).updateBalancingState(eq(CONNECTING), any()); + inOrder.verify(subchannel2).shutdown(); + } + + private static final class VerifyingScheduler { private final StaticStrideScheduler delegate; private final int max; |