diff options
Diffstat (limited to 'python/src/com/jetbrains/python/codeInsight/PyTypingTypeProvider.java')
-rw-r--r-- | python/src/com/jetbrains/python/codeInsight/PyTypingTypeProvider.java | 375 |
1 files changed, 375 insertions, 0 deletions
diff --git a/python/src/com/jetbrains/python/codeInsight/PyTypingTypeProvider.java b/python/src/com/jetbrains/python/codeInsight/PyTypingTypeProvider.java new file mode 100644 index 000000000000..f03befac6c14 --- /dev/null +++ b/python/src/com/jetbrains/python/codeInsight/PyTypingTypeProvider.java @@ -0,0 +1,375 @@ +/* + * Copyright 2000-2014 JetBrains s.r.o. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.jetbrains.python.codeInsight; + +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.intellij.openapi.project.Project; +import com.intellij.psi.PsiElement; +import com.intellij.psi.PsiPolyVariantReference; +import com.intellij.psi.util.QualifiedName; +import com.jetbrains.python.PyNames; +import com.jetbrains.python.psi.*; +import com.jetbrains.python.psi.impl.PyExpressionCodeFragmentImpl; +import com.jetbrains.python.psi.impl.PyPsiUtils; +import com.jetbrains.python.psi.resolve.PyResolveContext; +import com.jetbrains.python.psi.types.*; +import org.jetbrains.annotations.NotNull; +import org.jetbrains.annotations.Nullable; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +/** + * @author vlan + */ +public class PyTypingTypeProvider extends PyTypeProviderBase { + private static ImmutableMap<String, String> BUILTIN_COLLECTIONS = ImmutableMap.<String, String>builder() + .put("typing.List", "list") + .put("typing.Dict", "dict") + .put("typing.Set", PyNames.SET) + .put("typing.Tuple", PyNames.TUPLE) + .build(); + + private static ImmutableSet<String> GENERIC_CLASSES = ImmutableSet.<String>builder() + .add("typing.Generic") + .add("typing.AbstractGeneric") + .add("typing.Protocol") + .build(); + + public PyType getParameterType(@NotNull PyNamedParameter param, @NotNull PyFunction func, @NotNull TypeEvalContext context) { + final PyAnnotation annotation = param.getAnnotation(); + if (annotation != null) { + // XXX: Requires switching from stub to AST + final PyExpression value = annotation.getValue(); + if (value != null) { + return getTypingType(value, context); + } + } + return null; + } + + @Nullable + @Override + public PyType getReturnType(@NotNull Callable callable, @NotNull TypeEvalContext context) { + if (callable instanceof PyFunction) { + final PyFunction function = (PyFunction)callable; + final PyAnnotation annotation = function.getAnnotation(); + if (annotation != null) { + // XXX: Requires switching from stub to AST + final PyExpression value = annotation.getValue(); + if (value != null) { + return getTypingType(value, context); + } + } + final PyType constructorType = getGenericConstructorType(function, context); + if (constructorType != null) { + return constructorType; + } + } + return null; + } + + @Nullable + private static PyType getGenericConstructorType(@NotNull PyFunction function, @NotNull TypeEvalContext context) { + if (PyUtil.isInit(function)) { + final PyClass cls = function.getContainingClass(); + if (cls != null) { + final List<PyGenericType> genericTypes = collectGenericTypes(cls, context); + + final PyType elementType; + if (genericTypes.size() == 1) { + elementType = genericTypes.get(0); + } + else if (genericTypes.size() > 1) { + elementType = PyTupleType.create(cls, genericTypes.toArray(new PyType[genericTypes.size()])); + } + else { + elementType = null; + } + + if (elementType != null) { + return new PyCollectionTypeImpl(cls, false, elementType); + } + } + } + return null; + } + + @NotNull + private static List<PyGenericType> collectGenericTypes(@NotNull PyClass cls, @NotNull TypeEvalContext context) { + boolean isGeneric = false; + for (PyClass ancestor : cls.getAncestorClasses(context)) { + if (GENERIC_CLASSES.contains(ancestor.getQualifiedName())) { + isGeneric = true; + break; + } + } + if (isGeneric) { + final ArrayList<PyGenericType> results = new ArrayList<PyGenericType>(); + // XXX: Requires switching from stub to AST + for (PyExpression expr : cls.getSuperClassExpressions()) { + if (expr instanceof PySubscriptionExpression) { + final PyExpression indexExpr = ((PySubscriptionExpression)expr).getIndexExpression(); + if (indexExpr != null) { + final PyGenericType genericType = getGenericType(indexExpr, context); + if (genericType != null) { + results.add(genericType); + } + } + } + } + return results; + } + return Collections.emptyList(); + } + + @Nullable + private static PyType getTypingType(@NotNull PyExpression expression, @NotNull TypeEvalContext context) { + final PyType unionType = getUnionType(expression, context); + if (unionType != null) { + return unionType; + } + final PyType parameterizedType = getParameterizedType(expression, context); + if (parameterizedType != null) { + return parameterizedType; + } + final PyType builtinCollection = getBuiltinCollection(expression, context); + if (builtinCollection != null) { + return builtinCollection; + } + final PyType genericType = getGenericType(expression, context); + if (genericType != null) { + return genericType; + } + final PyType functionType = getFunctionType(expression, context); + if (functionType != null) { + return functionType; + } + final PyType stringBasedType = getStringBasedType(expression, context); + if (stringBasedType != null) { + return stringBasedType; + } + return null; + } + + @Nullable + private static PyType getStringBasedType(@NotNull PyExpression expression, @NotNull TypeEvalContext context) { + if (expression instanceof PyStringLiteralExpression) { + // XXX: Requires switching from stub to AST + final String contents = ((PyStringLiteralExpression)expression).getStringValue(); + final Project project = expression.getProject(); + final PyExpressionCodeFragmentImpl codeFragment = new PyExpressionCodeFragmentImpl(project, "dummy.py", contents, false); + codeFragment.setContext(expression.getContainingFile()); + final PsiElement element = codeFragment.getFirstChild(); + if (element instanceof PyExpressionStatement) { + final PyExpression dummyExpr = ((PyExpressionStatement)element).getExpression(); + return getType(dummyExpr, context); + } + } + return null; + } + + @Nullable + private static PyType getFunctionType(@NotNull PyExpression expression, @NotNull TypeEvalContext context) { + if (expression instanceof PySubscriptionExpression) { + final PySubscriptionExpression subscriptionExpr = (PySubscriptionExpression)expression; + final PyExpression operand = subscriptionExpr.getOperand(); + final String operandName = resolveToQualifiedName(operand, context); + if ("typing.Function".equals(operandName)) { + final PyExpression indexExpr = subscriptionExpr.getIndexExpression(); + if (indexExpr instanceof PyTupleExpression) { + final PyTupleExpression tupleExpr = (PyTupleExpression)indexExpr; + final PyExpression[] elements = tupleExpr.getElements(); + if (elements.length == 2) { + final PyExpression parametersExpr = elements[0]; + if (parametersExpr instanceof PyListLiteralExpression) { + final List<PyCallableParameter> parameters = new ArrayList<PyCallableParameter>(); + final PyListLiteralExpression listExpr = (PyListLiteralExpression)parametersExpr; + for (PyExpression argExpr : listExpr.getElements()) { + parameters.add(new PyCallableParameterImpl(null, getType(argExpr, context))); + } + final PyExpression returnTypeExpr = elements[1]; + final PyType returnType = getType(returnTypeExpr, context); + return new PyCallableTypeImpl(parameters, returnType); + } + } + } + } + } + return null; + } + + @Nullable + private static PyType getUnionType(@NotNull PyExpression expression, @NotNull TypeEvalContext context) { + if (expression instanceof PySubscriptionExpression) { + final PySubscriptionExpression subscriptionExpr = (PySubscriptionExpression)expression; + final PyExpression operand = subscriptionExpr.getOperand(); + final String operandName = resolveToQualifiedName(operand, context); + if ("typing.Union".equals(operandName)) { + return PyUnionType.union(getIndexTypes(subscriptionExpr, context)); + } + } + return null; + } + + @Nullable + private static PyGenericType getGenericType(@NotNull PyExpression expression, @NotNull TypeEvalContext context) { + final PsiElement resolved = resolve(expression, context); + if (resolved instanceof PyTargetExpression) { + final PyTargetExpression targetExpr = (PyTargetExpression)resolved; + final QualifiedName calleeName = targetExpr.getCalleeName(); + if (calleeName != null && "typevar".equals(calleeName.toString())) { + // XXX: Requires switching from stub to AST + final PyExpression assigned = targetExpr.findAssignedValue(); + if (assigned instanceof PyCallExpression) { + final PyCallExpression assignedCall = (PyCallExpression)assigned; + final PyExpression callee = assignedCall.getCallee(); + if (callee != null) { + final String calleeQName = resolveToQualifiedName(callee, context); + if ("typing.typevar".equals(calleeQName)) { + final PyExpression[] arguments = assignedCall.getArguments(); + if (arguments.length > 0) { + final PyExpression firstArgument = arguments[0]; + if (firstArgument instanceof PyStringLiteralExpression) { + final String name = ((PyStringLiteralExpression)firstArgument).getStringValue(); + if (name != null) { + return new PyGenericType(name, getGenericTypeBound(arguments, context)); + } + } + } + } + + } + } + } + } + return null; + } + + @Nullable + private static PyType getGenericTypeBound(@NotNull PyExpression[] typeVarArguments, @NotNull TypeEvalContext context) { + final List<PyType> types = new ArrayList<PyType>(); + if (typeVarArguments.length > 1) { + final PyExpression secondArgument = typeVarArguments[1]; + if (secondArgument instanceof PyKeywordArgument) { + final PyKeywordArgument valuesArgument = (PyKeywordArgument)secondArgument; + final PyExpression valueExpr = PyPsiUtils.flattenParens(valuesArgument.getValueExpression()); + if (valueExpr instanceof PyTupleExpression) { + final PyTupleExpression tupleExpr = (PyTupleExpression)valueExpr; + for (PyExpression expr : tupleExpr.getElements()) { + types.add(getType(expr, context)); + } + } + } + } + return PyUnionType.union(types); + } + + @NotNull + private static List<PyType> getIndexTypes(@NotNull PySubscriptionExpression expression, @NotNull TypeEvalContext context) { + final List<PyType> types = new ArrayList<PyType>(); + final PyExpression indexExpr = expression.getIndexExpression(); + if (indexExpr instanceof PyTupleExpression) { + final PyTupleExpression tupleExpr = (PyTupleExpression)indexExpr; + for (PyExpression expr : tupleExpr.getElements()) { + types.add(getType(expr, context)); + } + } + return types; + } + + @Nullable + private static PyType getParameterizedType(@NotNull PyExpression expression, @NotNull TypeEvalContext context) { + if (expression instanceof PySubscriptionExpression) { + final PySubscriptionExpression subscriptionExpr = (PySubscriptionExpression)expression; + final PyExpression operand = subscriptionExpr.getOperand(); + final PyExpression indexExpr = subscriptionExpr.getIndexExpression(); + final PyType operandType = getType(operand, context); + if (operandType instanceof PyClassType) { + final PyClass cls = ((PyClassType)operandType).getPyClass(); + if (PyNames.TUPLE.equals(cls.getQualifiedName())) { + final List<PyType> indexTypes = getIndexTypes(subscriptionExpr, context); + return PyTupleType.create(expression, indexTypes.toArray(new PyType[indexTypes.size()])); + } + else if (indexExpr != null) { + final PyType indexType = context.getType(indexExpr); + return new PyCollectionTypeImpl(cls, false, indexType); + } + } + } + return null; + } + + @Nullable + private static PyType getBuiltinCollection(@NotNull PyExpression expression, @NotNull TypeEvalContext context) { + final String collectionName = resolveToQualifiedName(expression, context); + final String builtinName = BUILTIN_COLLECTIONS.get(collectionName); + return builtinName != null ? PyTypeParser.getTypeByName(expression, builtinName) : null; + } + + @Nullable + private static PyType getType(@NotNull PyExpression expression, @NotNull TypeEvalContext context) { + // It is possible to replace PyAnnotation.getType() with this implementation + final PyType typingType = getTypingType(expression, context); + if (typingType != null) { + return typingType; + } + final PyType type = context.getType(expression); + if (type instanceof PyClassLikeType) { + final PyClassLikeType classType = (PyClassLikeType)type; + if (classType.isDefinition()) { + return classType.toInstance(); + } + } + else if (type instanceof PyNoneType) { + return type; + } + return null; + } + + @Nullable + private static PsiElement resolve(@NotNull PyExpression expression, @NotNull TypeEvalContext context) { + if (expression instanceof PyReferenceOwner) { + final PyReferenceOwner referenceOwner = (PyReferenceOwner)expression; + final PyResolveContext resolveContext = PyResolveContext.noImplicits().withTypeEvalContext(context); + final PsiPolyVariantReference reference = referenceOwner.getReference(resolveContext); + final PsiElement element = reference.resolve(); + if (element instanceof PyFunction) { + final PyFunction function = (PyFunction)element; + if (PyUtil.isInit(function)) { + final PyClass cls = function.getContainingClass(); + if (cls != null) { + return cls; + } + } + } + return element; + } + return null; + } + + @Nullable + private static String resolveToQualifiedName(@NotNull PyExpression expression, @NotNull TypeEvalContext context) { + final PsiElement element = resolve(expression, context); + if (element instanceof PyQualifiedNameOwner) { + final PyQualifiedNameOwner qualifiedNameOwner = (PyQualifiedNameOwner)element; + return qualifiedNameOwner.getQualifiedName(); + } + return null; + } +} |