summaryrefslogtreecommitdiff
path: root/python/src/com/jetbrains/python/psi/impl/PyFunctionImpl.java
diff options
context:
space:
mode:
Diffstat (limited to 'python/src/com/jetbrains/python/psi/impl/PyFunctionImpl.java')
-rw-r--r--python/src/com/jetbrains/python/psi/impl/PyFunctionImpl.java127
1 files changed, 66 insertions, 61 deletions
diff --git a/python/src/com/jetbrains/python/psi/impl/PyFunctionImpl.java b/python/src/com/jetbrains/python/psi/impl/PyFunctionImpl.java
index 2ffc8e6cba03..acc225b67f22 100644
--- a/python/src/com/jetbrains/python/psi/impl/PyFunctionImpl.java
+++ b/python/src/com/jetbrains/python/psi/impl/PyFunctionImpl.java
@@ -15,7 +15,6 @@
*/
package com.jetbrains.python.psi.impl;
-import com.google.common.collect.Maps;
import com.intellij.lang.ASTNode;
import com.intellij.openapi.extensions.Extensions;
import com.intellij.openapi.util.Pair;
@@ -176,42 +175,74 @@ public class PyFunctionImpl extends PyPresentableElementImpl<PyFunctionStub> imp
@Nullable
@Override
- public PyType getReturnType(@NotNull TypeEvalContext context, @Nullable PyQualifiedExpression callSite) {
- PyType type = getGenericReturnType(context, callSite);
- if (callSite == null) {
- return type;
+ public PyType getReturnType(@NotNull TypeEvalContext context, @NotNull TypeEvalContext.Key key) {
+ for (PyTypeProvider typeProvider : Extensions.getExtensions(PyTypeProvider.EP_NAME)) {
+ final PyType returnType = typeProvider.getReturnType(this, context);
+ if (returnType != null) {
+ returnType.assertValid(typeProvider.toString());
+ return returnType;
+ }
}
- final PyTypeChecker.AnalyzeCallResults results = PyTypeChecker.analyzeCallSite(callSite, context);
- if (PyTypeChecker.hasGenerics(type, context)) {
- if (results != null) {
- final Map<PyGenericType, PyType> substitutions = PyTypeChecker.unifyGenericCall(this, results.getReceiver(), results.getArguments(),
- context);
- type = substitutions != null ? PyTypeChecker.substitute(type, substitutions, context) : null;
+ if (context.maySwitchToAST(this) && LanguageLevel.forElement(this).isAtLeast(LanguageLevel.PYTHON30)) {
+ PyAnnotation anno = getAnnotation();
+ if (anno != null) {
+ PyClass pyClass = anno.resolveToClass();
+ if (pyClass != null) {
+ return new PyClassTypeImpl(pyClass, false);
+ }
}
- else {
- type = null;
+ }
+ final PyType docStringType = getReturnTypeFromDocString();
+ if (docStringType != null) {
+ docStringType.assertValid("from docstring");
+ return docStringType;
+ }
+ if (context.allowReturnTypes(this)) {
+ final Ref<? extends PyType> yieldTypeRef = getYieldStatementType(context);
+ if (yieldTypeRef != null) {
+ return yieldTypeRef.get();
}
+ return getReturnStatementType(context);
}
- if (results != null) {
- type = replaceSelf(type, results.getReceiver(), context);
+ return null;
+ }
+
+ @Nullable
+ @Override
+ public PyType getCallType(@NotNull TypeEvalContext context, @NotNull PyQualifiedExpression callSite) {
+ PyType type = null;
+ for (PyTypeProvider typeProvider : Extensions.getExtensions(PyTypeProvider.EP_NAME)) {
+ type = typeProvider.getCallType(this, callSite, context);
+ if (type != null) {
+ type.assertValid(typeProvider.toString());
+ break;
+ }
}
- if (results != null && isDynamicallyEvaluated(results.getArguments().values(), context)) {
- return PyUnionType.createWeakType(type);
+ if (type == null) {
+ type = context.getReturnType(this);
+ }
+ final PyTypeChecker.AnalyzeCallResults results = PyTypeChecker.analyzeCallSite(callSite, context);
+ if (results != null) {
+ return analyzeCallType(type, results.getReceiver(), results.getArguments(), context);
}
return type;
}
@Nullable
- /**
- * Suits when there is no call site(e.g. implicit __iter__ call in statement for)
- */
- public PyType getReturnTypeWithoutCallSite(@NotNull TypeEvalContext context,
- @Nullable PyExpression receiver) {
- PyType type = getGenericReturnType(context, null);
+ @Override
+ public PyType getCallType(@Nullable PyExpression receiver,
+ @NotNull Map<PyExpression, PyNamedParameter> parameters,
+ @NotNull TypeEvalContext context) {
+ return analyzeCallType(context.getReturnType(this), receiver, parameters, context);
+ }
+
+ @Nullable
+ private PyType analyzeCallType(@Nullable PyType type,
+ @Nullable PyExpression receiver,
+ @NotNull Map<PyExpression, PyNamedParameter> parameters,
+ @NotNull TypeEvalContext context) {
if (PyTypeChecker.hasGenerics(type, context)) {
- final Map<PyGenericType, PyType> substitutions = PyTypeChecker.unifyGenericCall(this, receiver,
- Maps.<PyExpression, PyNamedParameter>newHashMap(),
- context);
+ final Map<PyGenericType, PyType> substitutions = PyTypeChecker.unifyGenericCall(receiver, parameters, context);
if (substitutions != null) {
type = PyTypeChecker.substitute(type, substitutions, context);
}
@@ -219,7 +250,13 @@ public class PyFunctionImpl extends PyPresentableElementImpl<PyFunctionStub> imp
type = null;
}
}
- return replaceSelf(type, receiver, context);
+ if (receiver != null) {
+ type = replaceSelf(type, receiver, context);
+ }
+ if (type != null && isDynamicallyEvaluated(parameters.values(), context)) {
+ type = PyUnionType.createWeakType(type);
+ }
+ return type;
}
@Nullable
@@ -250,39 +287,6 @@ public class PyFunctionImpl extends PyPresentableElementImpl<PyFunctionStub> imp
}
@Nullable
- private PyType getGenericReturnType(@NotNull TypeEvalContext typeEvalContext, @Nullable PyQualifiedExpression callSite) {
- if (typeEvalContext.maySwitchToAST(this) && LanguageLevel.forElement(this).isAtLeast(LanguageLevel.PYTHON30)) {
- PyAnnotation anno = getAnnotation();
- if (anno != null) {
- PyClass pyClass = anno.resolveToClass();
- if (pyClass != null) {
- return new PyClassTypeImpl(pyClass, false);
- }
- }
- }
- for (PyTypeProvider typeProvider : Extensions.getExtensions(PyTypeProvider.EP_NAME)) {
- final PyType returnType = typeProvider.getReturnType(this, callSite, typeEvalContext);
- if (returnType != null) {
- returnType.assertValid(typeProvider.toString());
- return returnType;
- }
- }
- final PyType docStringType = getReturnTypeFromDocString();
- if (docStringType != null) {
- docStringType.assertValid("from docstring");
- return docStringType;
- }
- if (typeEvalContext.allowReturnTypes(this)) {
- final Ref<? extends PyType> yieldTypeRef = getYieldStatementType(typeEvalContext);
- if (yieldTypeRef != null) {
- return yieldTypeRef.get();
- }
- return getReturnStatementType(typeEvalContext);
- }
- return null;
- }
-
- @Nullable
private Ref<? extends PyType> getYieldStatementType(@NotNull final TypeEvalContext context) {
Ref<PyType> elementType = null;
final PyBuiltinCache cache = PyBuiltinCache.getInstance(this);
@@ -386,8 +390,9 @@ public class PyFunctionImpl extends PyPresentableElementImpl<PyFunctionStub> imp
return type;
}
}
+ final boolean hasCustomDecorators = PyUtil.hasCustomDecorators(this) && !PyUtil.isDecoratedAsAbstract(this) && getProperty() == null;
final PyFunctionType type = new PyFunctionType(this);
- if (getDecoratorList() != null) {
+ if (hasCustomDecorators) {
return PyUnionType.createWeakType(type);
}
return type;