summaryrefslogtreecommitdiff
path: root/python/psi-api/src/com/jetbrains/python/psi/types/TypeEvalContext.java
diff options
context:
space:
mode:
Diffstat (limited to 'python/psi-api/src/com/jetbrains/python/psi/types/TypeEvalContext.java')
-rw-r--r--python/psi-api/src/com/jetbrains/python/psi/types/TypeEvalContext.java77
1 files changed, 57 insertions, 20 deletions
diff --git a/python/psi-api/src/com/jetbrains/python/psi/types/TypeEvalContext.java b/python/psi-api/src/com/jetbrains/python/psi/types/TypeEvalContext.java
index bc53ffdeceaf..f3c688bc88ae 100644
--- a/python/psi-api/src/com/jetbrains/python/psi/types/TypeEvalContext.java
+++ b/python/psi-api/src/com/jetbrains/python/psi/types/TypeEvalContext.java
@@ -18,6 +18,7 @@ package com.jetbrains.python.psi.types;
import com.intellij.openapi.util.text.StringUtil;
import com.intellij.psi.PsiElement;
import com.intellij.psi.PsiFile;
+import com.jetbrains.python.psi.Callable;
import com.jetbrains.python.psi.PyTypedElement;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
@@ -41,12 +42,19 @@ public class TypeEvalContext {
@Nullable private final PsiFile myOrigin;
private final Map<PyTypedElement, PyType> myEvaluated = new HashMap<PyTypedElement, PyType>();
+ private final Map<Callable, PyType> myEvaluatedReturn = new HashMap<Callable, PyType>();
private final ThreadLocal<Set<PyTypedElement>> myEvaluating = new ThreadLocal<Set<PyTypedElement>>() {
@Override
protected Set<PyTypedElement> initialValue() {
return new HashSet<PyTypedElement>();
}
};
+ private final ThreadLocal<Set<Callable>> myEvaluatingReturn = new ThreadLocal<Set<Callable>>() {
+ @Override
+ protected Set<Callable> initialValue() {
+ return new HashSet<Callable>();
+ }
+ };
private TypeEvalContext(boolean allowDataFlow, boolean allowStubToAST, @Nullable PsiFile origin) {
myAllowDataFlow = allowDataFlow;
@@ -114,64 +122,93 @@ public class TypeEvalContext {
}
return this;
}
-
+
public void trace(String message, Object... args) {
if (myTrace != null) {
myTrace.add(myTraceIndent + String.format(message, args));
}
}
-
+
public void traceIndent() {
if (myTrace != null) {
myTraceIndent += " ";
}
}
-
+
public void traceUnindent() {
if (myTrace != null && myTraceIndent.length() >= 2) {
myTraceIndent = myTraceIndent.substring(0, myTraceIndent.length()-2);
}
}
-
+
public String printTrace() {
return StringUtil.join(myTrace, "\n");
}
-
+
public boolean tracing() {
return myTrace != null;
}
@Nullable
- public PyType getType(@NotNull PyTypedElement element) {
- synchronized (myEvaluated) {
- if (myEvaluated.containsKey(element)) {
- final PyType pyType = myEvaluated.get(element);
- if (pyType != null) {
- pyType.assertValid(element.toString());
- }
- return pyType;
- }
- }
+ public PyType getType(@NotNull final PyTypedElement element) {
final Set<PyTypedElement> evaluating = myEvaluating.get();
if (evaluating.contains(element)) {
return null;
}
evaluating.add(element);
try {
- PyType result = element.getType(this, Key.INSTANCE);
- if (result != null) {
- result.assertValid(element.toString());
+ synchronized (myEvaluated) {
+ if (myEvaluated.containsKey(element)) {
+ final PyType type = myEvaluated.get(element);
+ assertValid(type, element);
+ return type;
+ }
}
+ final PyType type = element.getType(this, Key.INSTANCE);
+ assertValid(type, element);
synchronized (myEvaluated) {
- myEvaluated.put(element, result);
+ myEvaluated.put(element, type);
}
- return result;
+ return type;
}
finally {
evaluating.remove(element);
}
}
+ @Nullable
+ public PyType getReturnType(@NotNull final Callable callable) {
+ final Set<Callable> evaluating = myEvaluatingReturn.get();
+ if (evaluating.contains(callable)) {
+ return null;
+ }
+ evaluating.add(callable);
+ try {
+ synchronized (myEvaluatedReturn) {
+ if (myEvaluatedReturn.containsKey(callable)) {
+ final PyType type = myEvaluatedReturn.get(callable);
+ assertValid(type, callable);
+ return type;
+ }
+ }
+ final PyType type = callable.getReturnType(this, Key.INSTANCE);
+ assertValid(type, callable);
+ synchronized (myEvaluatedReturn) {
+ myEvaluatedReturn.put(callable, type);
+ }
+ return type;
+ }
+ finally {
+ evaluating.remove(callable);
+ }
+ }
+
+ private static void assertValid(@Nullable PyType result, @NotNull PyTypedElement element) {
+ if (result != null) {
+ result.assertValid(element.toString());
+ }
+ }
+
public boolean maySwitchToAST(@NotNull PsiElement element) {
return myAllowStubToAST || myOrigin == element.getContainingFile();
}