diff options
Diffstat (limited to 'python/src/com/jetbrains/python/psi/impl/PyFunctionImpl.java')
-rw-r--r-- | python/src/com/jetbrains/python/psi/impl/PyFunctionImpl.java | 127 |
1 files changed, 66 insertions, 61 deletions
diff --git a/python/src/com/jetbrains/python/psi/impl/PyFunctionImpl.java b/python/src/com/jetbrains/python/psi/impl/PyFunctionImpl.java index 2ffc8e6cba03..acc225b67f22 100644 --- a/python/src/com/jetbrains/python/psi/impl/PyFunctionImpl.java +++ b/python/src/com/jetbrains/python/psi/impl/PyFunctionImpl.java @@ -15,7 +15,6 @@ */ package com.jetbrains.python.psi.impl; -import com.google.common.collect.Maps; import com.intellij.lang.ASTNode; import com.intellij.openapi.extensions.Extensions; import com.intellij.openapi.util.Pair; @@ -176,42 +175,74 @@ public class PyFunctionImpl extends PyPresentableElementImpl<PyFunctionStub> imp @Nullable @Override - public PyType getReturnType(@NotNull TypeEvalContext context, @Nullable PyQualifiedExpression callSite) { - PyType type = getGenericReturnType(context, callSite); - if (callSite == null) { - return type; + public PyType getReturnType(@NotNull TypeEvalContext context, @NotNull TypeEvalContext.Key key) { + for (PyTypeProvider typeProvider : Extensions.getExtensions(PyTypeProvider.EP_NAME)) { + final PyType returnType = typeProvider.getReturnType(this, context); + if (returnType != null) { + returnType.assertValid(typeProvider.toString()); + return returnType; + } } - final PyTypeChecker.AnalyzeCallResults results = PyTypeChecker.analyzeCallSite(callSite, context); - if (PyTypeChecker.hasGenerics(type, context)) { - if (results != null) { - final Map<PyGenericType, PyType> substitutions = PyTypeChecker.unifyGenericCall(this, results.getReceiver(), results.getArguments(), - context); - type = substitutions != null ? PyTypeChecker.substitute(type, substitutions, context) : null; + if (context.maySwitchToAST(this) && LanguageLevel.forElement(this).isAtLeast(LanguageLevel.PYTHON30)) { + PyAnnotation anno = getAnnotation(); + if (anno != null) { + PyClass pyClass = anno.resolveToClass(); + if (pyClass != null) { + return new PyClassTypeImpl(pyClass, false); + } } - else { - type = null; + } + final PyType docStringType = getReturnTypeFromDocString(); + if (docStringType != null) { + docStringType.assertValid("from docstring"); + return docStringType; + } + if (context.allowReturnTypes(this)) { + final Ref<? extends PyType> yieldTypeRef = getYieldStatementType(context); + if (yieldTypeRef != null) { + return yieldTypeRef.get(); } + return getReturnStatementType(context); } - if (results != null) { - type = replaceSelf(type, results.getReceiver(), context); + return null; + } + + @Nullable + @Override + public PyType getCallType(@NotNull TypeEvalContext context, @NotNull PyQualifiedExpression callSite) { + PyType type = null; + for (PyTypeProvider typeProvider : Extensions.getExtensions(PyTypeProvider.EP_NAME)) { + type = typeProvider.getCallType(this, callSite, context); + if (type != null) { + type.assertValid(typeProvider.toString()); + break; + } } - if (results != null && isDynamicallyEvaluated(results.getArguments().values(), context)) { - return PyUnionType.createWeakType(type); + if (type == null) { + type = context.getReturnType(this); + } + final PyTypeChecker.AnalyzeCallResults results = PyTypeChecker.analyzeCallSite(callSite, context); + if (results != null) { + return analyzeCallType(type, results.getReceiver(), results.getArguments(), context); } return type; } @Nullable - /** - * Suits when there is no call site(e.g. implicit __iter__ call in statement for) - */ - public PyType getReturnTypeWithoutCallSite(@NotNull TypeEvalContext context, - @Nullable PyExpression receiver) { - PyType type = getGenericReturnType(context, null); + @Override + public PyType getCallType(@Nullable PyExpression receiver, + @NotNull Map<PyExpression, PyNamedParameter> parameters, + @NotNull TypeEvalContext context) { + return analyzeCallType(context.getReturnType(this), receiver, parameters, context); + } + + @Nullable + private PyType analyzeCallType(@Nullable PyType type, + @Nullable PyExpression receiver, + @NotNull Map<PyExpression, PyNamedParameter> parameters, + @NotNull TypeEvalContext context) { if (PyTypeChecker.hasGenerics(type, context)) { - final Map<PyGenericType, PyType> substitutions = PyTypeChecker.unifyGenericCall(this, receiver, - Maps.<PyExpression, PyNamedParameter>newHashMap(), - context); + final Map<PyGenericType, PyType> substitutions = PyTypeChecker.unifyGenericCall(receiver, parameters, context); if (substitutions != null) { type = PyTypeChecker.substitute(type, substitutions, context); } @@ -219,7 +250,13 @@ public class PyFunctionImpl extends PyPresentableElementImpl<PyFunctionStub> imp type = null; } } - return replaceSelf(type, receiver, context); + if (receiver != null) { + type = replaceSelf(type, receiver, context); + } + if (type != null && isDynamicallyEvaluated(parameters.values(), context)) { + type = PyUnionType.createWeakType(type); + } + return type; } @Nullable @@ -250,39 +287,6 @@ public class PyFunctionImpl extends PyPresentableElementImpl<PyFunctionStub> imp } @Nullable - private PyType getGenericReturnType(@NotNull TypeEvalContext typeEvalContext, @Nullable PyQualifiedExpression callSite) { - if (typeEvalContext.maySwitchToAST(this) && LanguageLevel.forElement(this).isAtLeast(LanguageLevel.PYTHON30)) { - PyAnnotation anno = getAnnotation(); - if (anno != null) { - PyClass pyClass = anno.resolveToClass(); - if (pyClass != null) { - return new PyClassTypeImpl(pyClass, false); - } - } - } - for (PyTypeProvider typeProvider : Extensions.getExtensions(PyTypeProvider.EP_NAME)) { - final PyType returnType = typeProvider.getReturnType(this, callSite, typeEvalContext); - if (returnType != null) { - returnType.assertValid(typeProvider.toString()); - return returnType; - } - } - final PyType docStringType = getReturnTypeFromDocString(); - if (docStringType != null) { - docStringType.assertValid("from docstring"); - return docStringType; - } - if (typeEvalContext.allowReturnTypes(this)) { - final Ref<? extends PyType> yieldTypeRef = getYieldStatementType(typeEvalContext); - if (yieldTypeRef != null) { - return yieldTypeRef.get(); - } - return getReturnStatementType(typeEvalContext); - } - return null; - } - - @Nullable private Ref<? extends PyType> getYieldStatementType(@NotNull final TypeEvalContext context) { Ref<PyType> elementType = null; final PyBuiltinCache cache = PyBuiltinCache.getInstance(this); @@ -386,8 +390,9 @@ public class PyFunctionImpl extends PyPresentableElementImpl<PyFunctionStub> imp return type; } } + final boolean hasCustomDecorators = PyUtil.hasCustomDecorators(this) && !PyUtil.isDecoratedAsAbstract(this) && getProperty() == null; final PyFunctionType type = new PyFunctionType(this); - if (getDecoratorList() != null) { + if (hasCustomDecorators) { return PyUnionType.createWeakType(type); } return type; |