diff options
Diffstat (limited to 'python/psi-api/src/com/jetbrains/python/psi/types/TypeEvalContext.java')
-rw-r--r-- | python/psi-api/src/com/jetbrains/python/psi/types/TypeEvalContext.java | 77 |
1 files changed, 57 insertions, 20 deletions
diff --git a/python/psi-api/src/com/jetbrains/python/psi/types/TypeEvalContext.java b/python/psi-api/src/com/jetbrains/python/psi/types/TypeEvalContext.java index bc53ffdeceaf..f3c688bc88ae 100644 --- a/python/psi-api/src/com/jetbrains/python/psi/types/TypeEvalContext.java +++ b/python/psi-api/src/com/jetbrains/python/psi/types/TypeEvalContext.java @@ -18,6 +18,7 @@ package com.jetbrains.python.psi.types; import com.intellij.openapi.util.text.StringUtil; import com.intellij.psi.PsiElement; import com.intellij.psi.PsiFile; +import com.jetbrains.python.psi.Callable; import com.jetbrains.python.psi.PyTypedElement; import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.Nullable; @@ -41,12 +42,19 @@ public class TypeEvalContext { @Nullable private final PsiFile myOrigin; private final Map<PyTypedElement, PyType> myEvaluated = new HashMap<PyTypedElement, PyType>(); + private final Map<Callable, PyType> myEvaluatedReturn = new HashMap<Callable, PyType>(); private final ThreadLocal<Set<PyTypedElement>> myEvaluating = new ThreadLocal<Set<PyTypedElement>>() { @Override protected Set<PyTypedElement> initialValue() { return new HashSet<PyTypedElement>(); } }; + private final ThreadLocal<Set<Callable>> myEvaluatingReturn = new ThreadLocal<Set<Callable>>() { + @Override + protected Set<Callable> initialValue() { + return new HashSet<Callable>(); + } + }; private TypeEvalContext(boolean allowDataFlow, boolean allowStubToAST, @Nullable PsiFile origin) { myAllowDataFlow = allowDataFlow; @@ -114,64 +122,93 @@ public class TypeEvalContext { } return this; } - + public void trace(String message, Object... args) { if (myTrace != null) { myTrace.add(myTraceIndent + String.format(message, args)); } } - + public void traceIndent() { if (myTrace != null) { myTraceIndent += " "; } } - + public void traceUnindent() { if (myTrace != null && myTraceIndent.length() >= 2) { myTraceIndent = myTraceIndent.substring(0, myTraceIndent.length()-2); } } - + public String printTrace() { return StringUtil.join(myTrace, "\n"); } - + public boolean tracing() { return myTrace != null; } @Nullable - public PyType getType(@NotNull PyTypedElement element) { - synchronized (myEvaluated) { - if (myEvaluated.containsKey(element)) { - final PyType pyType = myEvaluated.get(element); - if (pyType != null) { - pyType.assertValid(element.toString()); - } - return pyType; - } - } + public PyType getType(@NotNull final PyTypedElement element) { final Set<PyTypedElement> evaluating = myEvaluating.get(); if (evaluating.contains(element)) { return null; } evaluating.add(element); try { - PyType result = element.getType(this, Key.INSTANCE); - if (result != null) { - result.assertValid(element.toString()); + synchronized (myEvaluated) { + if (myEvaluated.containsKey(element)) { + final PyType type = myEvaluated.get(element); + assertValid(type, element); + return type; + } } + final PyType type = element.getType(this, Key.INSTANCE); + assertValid(type, element); synchronized (myEvaluated) { - myEvaluated.put(element, result); + myEvaluated.put(element, type); } - return result; + return type; } finally { evaluating.remove(element); } } + @Nullable + public PyType getReturnType(@NotNull final Callable callable) { + final Set<Callable> evaluating = myEvaluatingReturn.get(); + if (evaluating.contains(callable)) { + return null; + } + evaluating.add(callable); + try { + synchronized (myEvaluatedReturn) { + if (myEvaluatedReturn.containsKey(callable)) { + final PyType type = myEvaluatedReturn.get(callable); + assertValid(type, callable); + return type; + } + } + final PyType type = callable.getReturnType(this, Key.INSTANCE); + assertValid(type, callable); + synchronized (myEvaluatedReturn) { + myEvaluatedReturn.put(callable, type); + } + return type; + } + finally { + evaluating.remove(callable); + } + } + + private static void assertValid(@Nullable PyType result, @NotNull PyTypedElement element) { + if (result != null) { + result.assertValid(element.toString()); + } + } + public boolean maySwitchToAST(@NotNull PsiElement element) { return myAllowStubToAST || myOrigin == element.getContainingFile(); } |