aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorManu Sridharan <msridhar@gmail.com>2022-01-21 12:22:33 -0800
committerGitHub <noreply@github.com>2022-01-21 12:22:33 -0800
commit4ec519c4e324dfc1e468271ececc3c1b992b7a4f (patch)
tree69653eda398387a69eabef15b585dde9141e4952
parent26cfc02fdfdf37a67f6486fbc9267c3e43fdba0b (diff)
downloadnullaway-4ec519c4e324dfc1e468271ececc3c1b992b7a4f.tar.gz
Reason about iterating over a map's key set using an enhanced for loop (#554)
Fixes #535 Uses the strategy outlined in this comment: https://github.com/uber/NullAway/issues/535#issuecomment-1006811236
-rw-r--r--nullaway/src/main/java/com/uber/nullaway/dataflow/AccessPath.java80
-rw-r--r--nullaway/src/main/java/com/uber/nullaway/dataflow/AccessPathNullnessPropagation.java109
-rw-r--r--nullaway/src/main/java/com/uber/nullaway/dataflow/NullnessStore.java20
-rw-r--r--nullaway/src/test/java/com/uber/nullaway/NullAwayKeySetIteratorTests.java134
4 files changed, 339 insertions, 4 deletions
diff --git a/nullaway/src/main/java/com/uber/nullaway/dataflow/AccessPath.java b/nullaway/src/main/java/com/uber/nullaway/dataflow/AccessPath.java
index 7fc723f..b427b05 100644
--- a/nullaway/src/main/java/com/uber/nullaway/dataflow/AccessPath.java
+++ b/nullaway/src/main/java/com/uber/nullaway/dataflow/AccessPath.java
@@ -468,6 +468,37 @@ public final class AccessPath implements MapKey {
return result;
}
+ /**
+ * Creates an access path representing a Map get call, where the key is obtained by calling {@code
+ * next()} on some {@code Iterator}. Used to support reasoning about iteration over a map's key
+ * set using an enhanced-for loop.
+ *
+ * @param mapNode Node representing the map
+ * @param iterVar local variable holding the iterator
+ * @param apContext access path context
+ * @return access path representing the get call, or {@code null} if the map node cannot be
+ * represented with an access path
+ */
+ @Nullable
+ public static AccessPath mapWithIteratorContentsKey(
+ Node mapNode, LocalVariableNode iterVar, AccessPathContext apContext) {
+ List<AccessPathElement> elems = new ArrayList<>();
+ Root root = populateElementsRec(mapNode, elems, apContext);
+ if (root != null) {
+ return new AccessPath(
+ root, elems, new IteratorContentsKey((VariableElement) iterVar.getElement()));
+ }
+ return null;
+ }
+
+ /**
+ * Creates an access path identical to {@code accessPath} (which must represent a map get), but
+ * replacing its map {@code get()} argument with {@code mapKey}
+ */
+ public static AccessPath replaceMapKey(AccessPath accessPath, MapKey mapKey) {
+ return new AccessPath(accessPath.getRoot(), accessPath.getElements(), mapKey);
+ }
+
@Override
public boolean equals(Object o) {
if (this == o) {
@@ -506,6 +537,11 @@ public final class AccessPath implements MapKey {
return elements;
}
+ @Nullable
+ public MapKey getMapGetArg() {
+ return mapGetArg;
+ }
+
@Override
public String toString() {
return "AccessPath{" + "root=" + root + ", elements=" + elements + '}';
@@ -616,7 +652,7 @@ public final class AccessPath implements MapKey {
private static final class StringMapKey implements MapKey {
- private String key;
+ private final String key;
public StringMapKey(String key) {
this.key = key;
@@ -638,7 +674,7 @@ public final class AccessPath implements MapKey {
private static final class NumericMapKey implements MapKey {
- private long key;
+ private final long key;
public NumericMapKey(long key) {
this.key = key;
@@ -659,6 +695,46 @@ public final class AccessPath implements MapKey {
}
/**
+ * Represents all possible values that could be returned by calling {@code next()} on an {@code
+ * Iterator} variable
+ */
+ public static final class IteratorContentsKey implements MapKey {
+
+ /**
+ * Element for the local variable holding the {@code Iterator}. We only support locals for now,
+ * as this class is designed specifically for reasoning about iterating over map keys using an
+ * enhanced-for loop over a {@code keySet()}, and for such cases the iterator is always stored
+ * locally
+ */
+ private final VariableElement iteratorVarElement;
+
+ IteratorContentsKey(VariableElement iteratorVarElement) {
+ this.iteratorVarElement = iteratorVarElement;
+ }
+
+ public VariableElement getIteratorVarElement() {
+ return iteratorVarElement;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ IteratorContentsKey that = (IteratorContentsKey) o;
+ return iteratorVarElement.equals(that.iteratorVarElement);
+ }
+
+ @Override
+ public int hashCode() {
+ return iteratorVarElement.hashCode();
+ }
+ }
+
+ /**
* Represents a per-javac instance of an AccessPath context options.
*
* <p>This includes, for example, data on known structurally immutable types.
diff --git a/nullaway/src/main/java/com/uber/nullaway/dataflow/AccessPathNullnessPropagation.java b/nullaway/src/main/java/com/uber/nullaway/dataflow/AccessPathNullnessPropagation.java
index b4184ac..9e111fb 100644
--- a/nullaway/src/main/java/com/uber/nullaway/dataflow/AccessPathNullnessPropagation.java
+++ b/nullaway/src/main/java/com/uber/nullaway/dataflow/AccessPathNullnessPropagation.java
@@ -38,6 +38,7 @@ import java.util.List;
import java.util.Map;
import java.util.function.Predicate;
import javax.annotation.CheckReturnValue;
+import javax.annotation.Nullable;
import javax.lang.model.element.ElementKind;
import javax.lang.model.element.VariableElement;
import org.checkerframework.nullaway.dataflow.analysis.ConditionalTransferResult;
@@ -502,12 +503,15 @@ public class AccessPathNullnessPropagation
public TransferResult<Nullness, NullnessStore> visitAssignment(
AssignmentNode node, TransferInput<Nullness, NullnessStore> input) {
ReadableUpdates updates = new ReadableUpdates();
- Nullness value = values(input).valueOfSubNode(node.getExpression());
+ Node rhs = node.getExpression();
+ Nullness value = values(input).valueOfSubNode(rhs);
Node target = node.getTarget();
if (target instanceof LocalVariableNode
&& !ASTHelpers.getType(target.getTree()).isPrimitive()) {
- updates.set((LocalVariableNode) target, value);
+ LocalVariableNode localVariableNode = (LocalVariableNode) target;
+ updates.set(localVariableNode, value);
+ handleEnhancedForOverKeySet(localVariableNode, rhs, input, updates);
}
if (target instanceof ArrayAccessNode) {
@@ -530,6 +534,107 @@ public class AccessPathNullnessPropagation
return updateRegularStore(value, input, updates);
}
+ /**
+ * Propagates access paths to track iteration over a map's key set using an enhanced-for loop,
+ * i.e., code of the form {@code for (Object k: m.keySet()) ...}. For such code, we track access
+ * paths to enable reasoning that within the body of the loop, {@code m.get(k)} is non-null.
+ *
+ * <p>There are two relevant types of assignments in the Checker Framework CFG for such tracking:
+ *
+ * <ol>
+ * <li>{@code iter#numX = m.keySet().iterator()}, for getting the iterator over a key set for an
+ * enhanced-for loop. After such assignments, we track an access path indicating that {@code
+ * m.get(contentsOf(iter#numX)} is non-null.
+ * <li>{@code k = iter#numX.next()}, which gets the next key in the key set when {@code
+ * iter#numX} was assigned as in case 1. After such assignments, we track the desired {@code
+ * m.get(k)} access path.
+ * </ol>
+ */
+ private void handleEnhancedForOverKeySet(
+ LocalVariableNode lhs,
+ Node rhs,
+ TransferInput<Nullness, NullnessStore> input,
+ ReadableUpdates updates) {
+ if (isEnhancedForIteratorVariable(lhs)) {
+ // Based on the structure of Checker Framework CFGs, rhs must be a call of the form
+ // e.iterator(). We check if e is a call to keySet() on a Map, and if so, propagate
+ // NONNULL for an access path for e.get(iteratorContents(lhs))
+ MethodInvocationNode rhsInv = (MethodInvocationNode) rhs;
+ Node mapNode = getMapNodeForKeySetIteratorCall(rhsInv);
+ if (mapNode != null) {
+ AccessPath mapWithIteratorContentsKey =
+ AccessPath.mapWithIteratorContentsKey(mapNode, lhs, apContext);
+ if (mapWithIteratorContentsKey != null) {
+ // put sanity check here to minimize perf impact
+ if (!isCallToMethod(rhsInv, "java.util.Set", "iterator")) {
+ throw new RuntimeException(
+ "expected call to iterator(), instead saw " + rhsInv.getTarget().getMethod());
+ }
+ updates.set(mapWithIteratorContentsKey, NONNULL);
+ }
+ }
+ } else if (rhs instanceof MethodInvocationNode) {
+ // Check for an assignment lhs = iter#numX.next(). From the structure of Checker Framework
+ // CFGs, we know that if iter#numX is the receiver of a call on the rhs of an assignment, it
+ // must be a call to next().
+ MethodInvocationNode methodInv = (MethodInvocationNode) rhs;
+ Node receiver = methodInv.getTarget().getReceiver();
+ if (receiver instanceof LocalVariableNode
+ && isEnhancedForIteratorVariable((LocalVariableNode) receiver)) {
+ // See if we are tracking an access path e.get(iteratorContents(receiver)). If so, since
+ // lhs is being assigned from the iterator contents, propagate NONNULL for an access path
+ // e.get(lhs)
+ AccessPath mapGetPath =
+ input
+ .getRegularStore()
+ .getMapGetIteratorContentsAccessPath((LocalVariableNode) receiver);
+ if (mapGetPath != null) {
+ // put sanity check here to minimize perf impact
+ if (!isCallToMethod(methodInv, "java.util.Iterator", "next")) {
+ throw new RuntimeException(
+ "expected call to iterator(), instead saw " + methodInv.getTarget().getMethod());
+ }
+ updates.set(AccessPath.replaceMapKey(mapGetPath, AccessPath.fromLocal(lhs)), NONNULL);
+ }
+ }
+ }
+ }
+
+ /**
+ * {@code invocationNode} must represent a call of the form {@code e.iterator()}. If {@code e} is
+ * of the form {@code e'.keySet()}, returns the {@code Node} for {@code e'}. Otherwise, returns
+ * {@code null}.
+ */
+ @Nullable
+ private Node getMapNodeForKeySetIteratorCall(MethodInvocationNode invocationNode) {
+ Node receiver = invocationNode.getTarget().getReceiver();
+ if (receiver instanceof MethodInvocationNode) {
+ MethodInvocationNode baseInvocation = (MethodInvocationNode) receiver;
+ // Check for a call to java.util.Map.keySet()
+ if (isCallToMethod(baseInvocation, "java.util.Map", "keySet")) {
+ // receiver represents the map
+ return baseInvocation.getTarget().getReceiver();
+ }
+ }
+ return null;
+ }
+
+ private boolean isCallToMethod(
+ MethodInvocationNode invocationNode, String containingClassName, String methodName) {
+ Symbol.MethodSymbol symbol = ASTHelpers.getSymbol(invocationNode.getTree());
+ return symbol != null
+ && symbol.getSimpleName().contentEquals(methodName)
+ && symbol.owner.getQualifiedName().contentEquals(containingClassName);
+ }
+
+ /**
+ * Is {@code varNode} a temporary variable representing the {@code Iterator} for an enhanced for
+ * loop? Matched based on the naming scheme used by Checker dataflow.
+ */
+ private boolean isEnhancedForIteratorVariable(LocalVariableNode varNode) {
+ return varNode.getName().startsWith("iter#num");
+ }
+
private TransferResult<Nullness, NullnessStore> updateRegularStore(
Nullness value, TransferInput<Nullness, NullnessStore> input, ReadableUpdates updates) {
ResultingStore newStore = updateStore(input.getRegularStore(), updates);
diff --git a/nullaway/src/main/java/com/uber/nullaway/dataflow/NullnessStore.java b/nullaway/src/main/java/com/uber/nullaway/dataflow/NullnessStore.java
index 3c254a1..9068f89 100644
--- a/nullaway/src/main/java/com/uber/nullaway/dataflow/NullnessStore.java
+++ b/nullaway/src/main/java/com/uber/nullaway/dataflow/NullnessStore.java
@@ -22,12 +22,14 @@ import static com.google.common.collect.Sets.intersection;
import com.google.common.collect.ImmutableMap;
import com.sun.tools.javac.code.Types;
import com.uber.nullaway.Nullness;
+import com.uber.nullaway.dataflow.AccessPath.IteratorContentsKey;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;
import java.util.function.Predicate;
import java.util.stream.Collectors;
+import javax.annotation.Nullable;
import javax.lang.model.element.Element;
import org.checkerframework.nullaway.dataflow.analysis.Store;
import org.checkerframework.nullaway.dataflow.cfg.node.FieldAccessNode;
@@ -123,6 +125,24 @@ public class NullnessStore implements Store<NullnessStore> {
}
/**
+ * If this store maps an access path {@code p} whose map-get argument is an {@link
+ * IteratorContentsKey} whose variable is {@code iteratorVar}, returns {@code p}. Otherwise,
+ * returns {@code null}.
+ */
+ @Nullable
+ public AccessPath getMapGetIteratorContentsAccessPath(LocalVariableNode iteratorVar) {
+ for (AccessPath accessPath : contents.keySet()) {
+ MapKey mapGetArg = accessPath.getMapGetArg();
+ if (mapGetArg instanceof IteratorContentsKey) {
+ IteratorContentsKey iteratorContentsKey = (IteratorContentsKey) mapGetArg;
+ if (iteratorContentsKey.getIteratorVarElement().equals(iteratorVar.getElement())) {
+ return accessPath;
+ }
+ }
+ }
+ return null;
+ }
+ /**
* Gets the {@link Nullness} value of an access path.
*
* @param accessPath The access path.
diff --git a/nullaway/src/test/java/com/uber/nullaway/NullAwayKeySetIteratorTests.java b/nullaway/src/test/java/com/uber/nullaway/NullAwayKeySetIteratorTests.java
new file mode 100644
index 0000000..8ad60e9
--- /dev/null
+++ b/nullaway/src/test/java/com/uber/nullaway/NullAwayKeySetIteratorTests.java
@@ -0,0 +1,134 @@
+package com.uber.nullaway;
+
+import org.junit.Test;
+
+public class NullAwayKeySetIteratorTests extends NullAwayTestsBase {
+
+ @Test
+ public void mapKeySetIteratorBasic() {
+ defaultCompilationHelper
+ .addSourceLines(
+ "Test.java",
+ "package com.uber;",
+ "import java.util.Map;",
+ "public class Test {",
+ " public void keySetStuff(Map<Object, Object> m) {",
+ " for (Object k: m.keySet()) {",
+ " m.get(k).toString();",
+ " }",
+ " }",
+ "}")
+ .doTest();
+ }
+
+ @Test
+ public void mapKeySetIteratorShadowing() {
+ defaultCompilationHelper
+ .addSourceLines(
+ "Test.java",
+ "package com.uber;",
+ "import java.util.Map;",
+ "public class Test {",
+ " private final Object k = new Object();",
+ " public void keySetStuff(Map<Object, Object> m) {",
+ " for (Object k: m.keySet()) {",
+ " m.get(k).toString();",
+ " }",
+ " // BUG: Diagnostic contains: dereferenced expression m.get(k) is @Nullable",
+ " m.get(k).toString();",
+ " }",
+ "}")
+ .doTest();
+ }
+
+ @Test
+ public void mapKeySetIteratorDeeperAccessPath() {
+ defaultCompilationHelper
+ .addSourceLines(
+ "Test.java",
+ "package com.uber;",
+ "import java.util.Map;",
+ "import java.util.HashMap;",
+ "public class Test {",
+ " static class MapWrapper {",
+ " Map<Object,Object> mf = new HashMap<>();",
+ " public Map<Object,Object> getMF() { return mf; }",
+ " public Map<Object,Object> getMF2(int i) { return mf; }",
+ " }",
+ " public void keySetStuff(MapWrapper mw, int j) {",
+ " for (Object k: mw.mf.keySet()) {",
+ " mw.mf.get(k).toString();",
+ " }",
+ " for (Object k: mw.getMF().keySet()) {",
+ " mw.getMF().get(k).toString();",
+ " }",
+ " for (Object k: mw.getMF2(10).keySet()) {",
+ " mw.getMF2(10).get(k).toString();",
+ " }",
+ " for (Object k: mw.getMF2(j).keySet()) {",
+ " // Report error since we cannot represent mw.getMF2(j).get(k) with an access path",
+ " // BUG: Diagnostic contains: dereferenced expression mw.getMF2(j).get(k) is @Nullable",
+ " mw.getMF2(j).get(k).toString();",
+ " }",
+ " }",
+ "}")
+ .doTest();
+ }
+
+ @Test
+ public void unhandledCases() {
+ defaultCompilationHelper
+ .addSourceLines(
+ "Test.java",
+ "package com.uber;",
+ "import java.util.Iterator;",
+ "import java.util.Map;",
+ "public class Test {",
+ " public void keySetStuff(Map<Object, Object> m) {",
+ " // BUG: Diagnostic contains: dereferenced expression",
+ " m.get(m.keySet().iterator().next()).toString();",
+ " Iterator<Object> iter = m.keySet().iterator();",
+ " while (iter.hasNext()) {",
+ " // BUG: Diagnostic contains: dereferenced expression",
+ " m.get(iter.next()).toString();",
+ " }",
+ " }",
+ "}")
+ .doTest();
+ }
+
+ @Test
+ public void nestedLoops() {
+ defaultCompilationHelper
+ .addSourceLines(
+ "Test.java",
+ "package com.uber;",
+ "import java.util.Map;",
+ "public class Test {",
+ " public void keySetStuff(Map<Object, Object> m, Map<Object, Object> m2) {",
+ " for (Object k: m.keySet()) {",
+ " for (Object k2: m2.keySet()) {",
+ " m.get(k).toString();",
+ " // BUG: Diagnostic contains: dereferenced expression",
+ " m.get(k2).toString();",
+ " // BUG: Diagnostic contains: dereferenced expression",
+ " m2.get(k).toString();",
+ " m2.get(k2).toString();",
+ " }",
+ " }",
+ " // nested loop over the same map",
+ " for (Object k: m.keySet()) {",
+ " for (Object k2: m.keySet()) {",
+ " m.get(k).toString();",
+ " m.get(k2).toString();",
+ " // BUG: Diagnostic contains: dereferenced expression",
+ " m2.get(k).toString();",
+ " // BUG: Diagnostic contains: dereferenced expression",
+ " m2.get(k2).toString();",
+ " }",
+ " }",
+ " }",
+ "}")
+ .doTest();
+ }
+}