diff options
author | Manu Sridharan <msridhar@gmail.com> | 2022-01-21 12:22:33 -0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-01-21 12:22:33 -0800 |
commit | 4ec519c4e324dfc1e468271ececc3c1b992b7a4f (patch) | |
tree | 69653eda398387a69eabef15b585dde9141e4952 | |
parent | 26cfc02fdfdf37a67f6486fbc9267c3e43fdba0b (diff) | |
download | nullaway-4ec519c4e324dfc1e468271ececc3c1b992b7a4f.tar.gz |
Reason about iterating over a map's key set using an enhanced for loop (#554)
Fixes #535
Uses the strategy outlined in this comment: https://github.com/uber/NullAway/issues/535#issuecomment-1006811236
4 files changed, 339 insertions, 4 deletions
diff --git a/nullaway/src/main/java/com/uber/nullaway/dataflow/AccessPath.java b/nullaway/src/main/java/com/uber/nullaway/dataflow/AccessPath.java index 7fc723f..b427b05 100644 --- a/nullaway/src/main/java/com/uber/nullaway/dataflow/AccessPath.java +++ b/nullaway/src/main/java/com/uber/nullaway/dataflow/AccessPath.java @@ -468,6 +468,37 @@ public final class AccessPath implements MapKey { return result; } + /** + * Creates an access path representing a Map get call, where the key is obtained by calling {@code + * next()} on some {@code Iterator}. Used to support reasoning about iteration over a map's key + * set using an enhanced-for loop. + * + * @param mapNode Node representing the map + * @param iterVar local variable holding the iterator + * @param apContext access path context + * @return access path representing the get call, or {@code null} if the map node cannot be + * represented with an access path + */ + @Nullable + public static AccessPath mapWithIteratorContentsKey( + Node mapNode, LocalVariableNode iterVar, AccessPathContext apContext) { + List<AccessPathElement> elems = new ArrayList<>(); + Root root = populateElementsRec(mapNode, elems, apContext); + if (root != null) { + return new AccessPath( + root, elems, new IteratorContentsKey((VariableElement) iterVar.getElement())); + } + return null; + } + + /** + * Creates an access path identical to {@code accessPath} (which must represent a map get), but + * replacing its map {@code get()} argument with {@code mapKey} + */ + public static AccessPath replaceMapKey(AccessPath accessPath, MapKey mapKey) { + return new AccessPath(accessPath.getRoot(), accessPath.getElements(), mapKey); + } + @Override public boolean equals(Object o) { if (this == o) { @@ -506,6 +537,11 @@ public final class AccessPath implements MapKey { return elements; } + @Nullable + public MapKey getMapGetArg() { + return mapGetArg; + } + @Override public String toString() { return "AccessPath{" + "root=" + root + ", elements=" + elements + '}'; @@ -616,7 +652,7 @@ public final class AccessPath implements MapKey { private static final class StringMapKey implements MapKey { - private String key; + private final String key; public StringMapKey(String key) { this.key = key; @@ -638,7 +674,7 @@ public final class AccessPath implements MapKey { private static final class NumericMapKey implements MapKey { - private long key; + private final long key; public NumericMapKey(long key) { this.key = key; @@ -659,6 +695,46 @@ public final class AccessPath implements MapKey { } /** + * Represents all possible values that could be returned by calling {@code next()} on an {@code + * Iterator} variable + */ + public static final class IteratorContentsKey implements MapKey { + + /** + * Element for the local variable holding the {@code Iterator}. We only support locals for now, + * as this class is designed specifically for reasoning about iterating over map keys using an + * enhanced-for loop over a {@code keySet()}, and for such cases the iterator is always stored + * locally + */ + private final VariableElement iteratorVarElement; + + IteratorContentsKey(VariableElement iteratorVarElement) { + this.iteratorVarElement = iteratorVarElement; + } + + public VariableElement getIteratorVarElement() { + return iteratorVarElement; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + IteratorContentsKey that = (IteratorContentsKey) o; + return iteratorVarElement.equals(that.iteratorVarElement); + } + + @Override + public int hashCode() { + return iteratorVarElement.hashCode(); + } + } + + /** * Represents a per-javac instance of an AccessPath context options. * * <p>This includes, for example, data on known structurally immutable types. diff --git a/nullaway/src/main/java/com/uber/nullaway/dataflow/AccessPathNullnessPropagation.java b/nullaway/src/main/java/com/uber/nullaway/dataflow/AccessPathNullnessPropagation.java index b4184ac..9e111fb 100644 --- a/nullaway/src/main/java/com/uber/nullaway/dataflow/AccessPathNullnessPropagation.java +++ b/nullaway/src/main/java/com/uber/nullaway/dataflow/AccessPathNullnessPropagation.java @@ -38,6 +38,7 @@ import java.util.List; import java.util.Map; import java.util.function.Predicate; import javax.annotation.CheckReturnValue; +import javax.annotation.Nullable; import javax.lang.model.element.ElementKind; import javax.lang.model.element.VariableElement; import org.checkerframework.nullaway.dataflow.analysis.ConditionalTransferResult; @@ -502,12 +503,15 @@ public class AccessPathNullnessPropagation public TransferResult<Nullness, NullnessStore> visitAssignment( AssignmentNode node, TransferInput<Nullness, NullnessStore> input) { ReadableUpdates updates = new ReadableUpdates(); - Nullness value = values(input).valueOfSubNode(node.getExpression()); + Node rhs = node.getExpression(); + Nullness value = values(input).valueOfSubNode(rhs); Node target = node.getTarget(); if (target instanceof LocalVariableNode && !ASTHelpers.getType(target.getTree()).isPrimitive()) { - updates.set((LocalVariableNode) target, value); + LocalVariableNode localVariableNode = (LocalVariableNode) target; + updates.set(localVariableNode, value); + handleEnhancedForOverKeySet(localVariableNode, rhs, input, updates); } if (target instanceof ArrayAccessNode) { @@ -530,6 +534,107 @@ public class AccessPathNullnessPropagation return updateRegularStore(value, input, updates); } + /** + * Propagates access paths to track iteration over a map's key set using an enhanced-for loop, + * i.e., code of the form {@code for (Object k: m.keySet()) ...}. For such code, we track access + * paths to enable reasoning that within the body of the loop, {@code m.get(k)} is non-null. + * + * <p>There are two relevant types of assignments in the Checker Framework CFG for such tracking: + * + * <ol> + * <li>{@code iter#numX = m.keySet().iterator()}, for getting the iterator over a key set for an + * enhanced-for loop. After such assignments, we track an access path indicating that {@code + * m.get(contentsOf(iter#numX)} is non-null. + * <li>{@code k = iter#numX.next()}, which gets the next key in the key set when {@code + * iter#numX} was assigned as in case 1. After such assignments, we track the desired {@code + * m.get(k)} access path. + * </ol> + */ + private void handleEnhancedForOverKeySet( + LocalVariableNode lhs, + Node rhs, + TransferInput<Nullness, NullnessStore> input, + ReadableUpdates updates) { + if (isEnhancedForIteratorVariable(lhs)) { + // Based on the structure of Checker Framework CFGs, rhs must be a call of the form + // e.iterator(). We check if e is a call to keySet() on a Map, and if so, propagate + // NONNULL for an access path for e.get(iteratorContents(lhs)) + MethodInvocationNode rhsInv = (MethodInvocationNode) rhs; + Node mapNode = getMapNodeForKeySetIteratorCall(rhsInv); + if (mapNode != null) { + AccessPath mapWithIteratorContentsKey = + AccessPath.mapWithIteratorContentsKey(mapNode, lhs, apContext); + if (mapWithIteratorContentsKey != null) { + // put sanity check here to minimize perf impact + if (!isCallToMethod(rhsInv, "java.util.Set", "iterator")) { + throw new RuntimeException( + "expected call to iterator(), instead saw " + rhsInv.getTarget().getMethod()); + } + updates.set(mapWithIteratorContentsKey, NONNULL); + } + } + } else if (rhs instanceof MethodInvocationNode) { + // Check for an assignment lhs = iter#numX.next(). From the structure of Checker Framework + // CFGs, we know that if iter#numX is the receiver of a call on the rhs of an assignment, it + // must be a call to next(). + MethodInvocationNode methodInv = (MethodInvocationNode) rhs; + Node receiver = methodInv.getTarget().getReceiver(); + if (receiver instanceof LocalVariableNode + && isEnhancedForIteratorVariable((LocalVariableNode) receiver)) { + // See if we are tracking an access path e.get(iteratorContents(receiver)). If so, since + // lhs is being assigned from the iterator contents, propagate NONNULL for an access path + // e.get(lhs) + AccessPath mapGetPath = + input + .getRegularStore() + .getMapGetIteratorContentsAccessPath((LocalVariableNode) receiver); + if (mapGetPath != null) { + // put sanity check here to minimize perf impact + if (!isCallToMethod(methodInv, "java.util.Iterator", "next")) { + throw new RuntimeException( + "expected call to iterator(), instead saw " + methodInv.getTarget().getMethod()); + } + updates.set(AccessPath.replaceMapKey(mapGetPath, AccessPath.fromLocal(lhs)), NONNULL); + } + } + } + } + + /** + * {@code invocationNode} must represent a call of the form {@code e.iterator()}. If {@code e} is + * of the form {@code e'.keySet()}, returns the {@code Node} for {@code e'}. Otherwise, returns + * {@code null}. + */ + @Nullable + private Node getMapNodeForKeySetIteratorCall(MethodInvocationNode invocationNode) { + Node receiver = invocationNode.getTarget().getReceiver(); + if (receiver instanceof MethodInvocationNode) { + MethodInvocationNode baseInvocation = (MethodInvocationNode) receiver; + // Check for a call to java.util.Map.keySet() + if (isCallToMethod(baseInvocation, "java.util.Map", "keySet")) { + // receiver represents the map + return baseInvocation.getTarget().getReceiver(); + } + } + return null; + } + + private boolean isCallToMethod( + MethodInvocationNode invocationNode, String containingClassName, String methodName) { + Symbol.MethodSymbol symbol = ASTHelpers.getSymbol(invocationNode.getTree()); + return symbol != null + && symbol.getSimpleName().contentEquals(methodName) + && symbol.owner.getQualifiedName().contentEquals(containingClassName); + } + + /** + * Is {@code varNode} a temporary variable representing the {@code Iterator} for an enhanced for + * loop? Matched based on the naming scheme used by Checker dataflow. + */ + private boolean isEnhancedForIteratorVariable(LocalVariableNode varNode) { + return varNode.getName().startsWith("iter#num"); + } + private TransferResult<Nullness, NullnessStore> updateRegularStore( Nullness value, TransferInput<Nullness, NullnessStore> input, ReadableUpdates updates) { ResultingStore newStore = updateStore(input.getRegularStore(), updates); diff --git a/nullaway/src/main/java/com/uber/nullaway/dataflow/NullnessStore.java b/nullaway/src/main/java/com/uber/nullaway/dataflow/NullnessStore.java index 3c254a1..9068f89 100644 --- a/nullaway/src/main/java/com/uber/nullaway/dataflow/NullnessStore.java +++ b/nullaway/src/main/java/com/uber/nullaway/dataflow/NullnessStore.java @@ -22,12 +22,14 @@ import static com.google.common.collect.Sets.intersection; import com.google.common.collect.ImmutableMap; import com.sun.tools.javac.code.Types; import com.uber.nullaway.Nullness; +import com.uber.nullaway.dataflow.AccessPath.IteratorContentsKey; import java.util.HashMap; import java.util.LinkedHashSet; import java.util.Map; import java.util.Set; import java.util.function.Predicate; import java.util.stream.Collectors; +import javax.annotation.Nullable; import javax.lang.model.element.Element; import org.checkerframework.nullaway.dataflow.analysis.Store; import org.checkerframework.nullaway.dataflow.cfg.node.FieldAccessNode; @@ -123,6 +125,24 @@ public class NullnessStore implements Store<NullnessStore> { } /** + * If this store maps an access path {@code p} whose map-get argument is an {@link + * IteratorContentsKey} whose variable is {@code iteratorVar}, returns {@code p}. Otherwise, + * returns {@code null}. + */ + @Nullable + public AccessPath getMapGetIteratorContentsAccessPath(LocalVariableNode iteratorVar) { + for (AccessPath accessPath : contents.keySet()) { + MapKey mapGetArg = accessPath.getMapGetArg(); + if (mapGetArg instanceof IteratorContentsKey) { + IteratorContentsKey iteratorContentsKey = (IteratorContentsKey) mapGetArg; + if (iteratorContentsKey.getIteratorVarElement().equals(iteratorVar.getElement())) { + return accessPath; + } + } + } + return null; + } + /** * Gets the {@link Nullness} value of an access path. * * @param accessPath The access path. diff --git a/nullaway/src/test/java/com/uber/nullaway/NullAwayKeySetIteratorTests.java b/nullaway/src/test/java/com/uber/nullaway/NullAwayKeySetIteratorTests.java new file mode 100644 index 0000000..8ad60e9 --- /dev/null +++ b/nullaway/src/test/java/com/uber/nullaway/NullAwayKeySetIteratorTests.java @@ -0,0 +1,134 @@ +package com.uber.nullaway; + +import org.junit.Test; + +public class NullAwayKeySetIteratorTests extends NullAwayTestsBase { + + @Test + public void mapKeySetIteratorBasic() { + defaultCompilationHelper + .addSourceLines( + "Test.java", + "package com.uber;", + "import java.util.Map;", + "public class Test {", + " public void keySetStuff(Map<Object, Object> m) {", + " for (Object k: m.keySet()) {", + " m.get(k).toString();", + " }", + " }", + "}") + .doTest(); + } + + @Test + public void mapKeySetIteratorShadowing() { + defaultCompilationHelper + .addSourceLines( + "Test.java", + "package com.uber;", + "import java.util.Map;", + "public class Test {", + " private final Object k = new Object();", + " public void keySetStuff(Map<Object, Object> m) {", + " for (Object k: m.keySet()) {", + " m.get(k).toString();", + " }", + " // BUG: Diagnostic contains: dereferenced expression m.get(k) is @Nullable", + " m.get(k).toString();", + " }", + "}") + .doTest(); + } + + @Test + public void mapKeySetIteratorDeeperAccessPath() { + defaultCompilationHelper + .addSourceLines( + "Test.java", + "package com.uber;", + "import java.util.Map;", + "import java.util.HashMap;", + "public class Test {", + " static class MapWrapper {", + " Map<Object,Object> mf = new HashMap<>();", + " public Map<Object,Object> getMF() { return mf; }", + " public Map<Object,Object> getMF2(int i) { return mf; }", + " }", + " public void keySetStuff(MapWrapper mw, int j) {", + " for (Object k: mw.mf.keySet()) {", + " mw.mf.get(k).toString();", + " }", + " for (Object k: mw.getMF().keySet()) {", + " mw.getMF().get(k).toString();", + " }", + " for (Object k: mw.getMF2(10).keySet()) {", + " mw.getMF2(10).get(k).toString();", + " }", + " for (Object k: mw.getMF2(j).keySet()) {", + " // Report error since we cannot represent mw.getMF2(j).get(k) with an access path", + " // BUG: Diagnostic contains: dereferenced expression mw.getMF2(j).get(k) is @Nullable", + " mw.getMF2(j).get(k).toString();", + " }", + " }", + "}") + .doTest(); + } + + @Test + public void unhandledCases() { + defaultCompilationHelper + .addSourceLines( + "Test.java", + "package com.uber;", + "import java.util.Iterator;", + "import java.util.Map;", + "public class Test {", + " public void keySetStuff(Map<Object, Object> m) {", + " // BUG: Diagnostic contains: dereferenced expression", + " m.get(m.keySet().iterator().next()).toString();", + " Iterator<Object> iter = m.keySet().iterator();", + " while (iter.hasNext()) {", + " // BUG: Diagnostic contains: dereferenced expression", + " m.get(iter.next()).toString();", + " }", + " }", + "}") + .doTest(); + } + + @Test + public void nestedLoops() { + defaultCompilationHelper + .addSourceLines( + "Test.java", + "package com.uber;", + "import java.util.Map;", + "public class Test {", + " public void keySetStuff(Map<Object, Object> m, Map<Object, Object> m2) {", + " for (Object k: m.keySet()) {", + " for (Object k2: m2.keySet()) {", + " m.get(k).toString();", + " // BUG: Diagnostic contains: dereferenced expression", + " m.get(k2).toString();", + " // BUG: Diagnostic contains: dereferenced expression", + " m2.get(k).toString();", + " m2.get(k2).toString();", + " }", + " }", + " // nested loop over the same map", + " for (Object k: m.keySet()) {", + " for (Object k2: m.keySet()) {", + " m.get(k).toString();", + " m.get(k2).toString();", + " // BUG: Diagnostic contains: dereferenced expression", + " m2.get(k).toString();", + " // BUG: Diagnostic contains: dereferenced expression", + " m2.get(k2).toString();", + " }", + " }", + " }", + "}") + .doTest(); + } +} |