/* * Copyright 2000-2014 JetBrains s.r.o. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.jetbrains.python.codeInsight; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.intellij.openapi.project.Project; import com.intellij.psi.PsiElement; import com.intellij.psi.PsiPolyVariantReference; import com.intellij.psi.util.QualifiedName; import com.jetbrains.python.PyNames; import com.jetbrains.python.psi.*; import com.jetbrains.python.psi.impl.PyExpressionCodeFragmentImpl; import com.jetbrains.python.psi.impl.PyPsiUtils; import com.jetbrains.python.psi.resolve.PyResolveContext; import com.jetbrains.python.psi.types.*; import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.Nullable; import java.util.ArrayList; import java.util.Collections; import java.util.List; /** * @author vlan */ public class PyTypingTypeProvider extends PyTypeProviderBase { private static ImmutableMap BUILTIN_COLLECTIONS = ImmutableMap.builder() .put("typing.List", "list") .put("typing.Dict", "dict") .put("typing.Set", PyNames.SET) .put("typing.Tuple", PyNames.TUPLE) .build(); private static ImmutableSet GENERIC_CLASSES = ImmutableSet.builder() .add("typing.Generic") .add("typing.AbstractGeneric") .add("typing.Protocol") .build(); public PyType getParameterType(@NotNull PyNamedParameter param, @NotNull PyFunction func, @NotNull TypeEvalContext context) { final PyAnnotation annotation = param.getAnnotation(); if (annotation != null) { // XXX: Requires switching from stub to AST final PyExpression value = annotation.getValue(); if (value != null) { return getTypingType(value, context); } } return null; } @Nullable @Override public PyType getReturnType(@NotNull Callable callable, @NotNull TypeEvalContext context) { if (callable instanceof PyFunction) { final PyFunction function = (PyFunction)callable; final PyAnnotation annotation = function.getAnnotation(); if (annotation != null) { // XXX: Requires switching from stub to AST final PyExpression value = annotation.getValue(); if (value != null) { return getTypingType(value, context); } } final PyType constructorType = getGenericConstructorType(function, context); if (constructorType != null) { return constructorType; } } return null; } @Nullable private static PyType getGenericConstructorType(@NotNull PyFunction function, @NotNull TypeEvalContext context) { if (PyUtil.isInit(function)) { final PyClass cls = function.getContainingClass(); if (cls != null) { final List genericTypes = collectGenericTypes(cls, context); final PyType elementType; if (genericTypes.size() == 1) { elementType = genericTypes.get(0); } else if (genericTypes.size() > 1) { elementType = PyTupleType.create(cls, genericTypes.toArray(new PyType[genericTypes.size()])); } else { elementType = null; } if (elementType != null) { return new PyCollectionTypeImpl(cls, false, elementType); } } } return null; } @NotNull private static List collectGenericTypes(@NotNull PyClass cls, @NotNull TypeEvalContext context) { boolean isGeneric = false; for (PyClass ancestor : cls.getAncestorClasses(context)) { if (GENERIC_CLASSES.contains(ancestor.getQualifiedName())) { isGeneric = true; break; } } if (isGeneric) { final ArrayList results = new ArrayList(); // XXX: Requires switching from stub to AST for (PyExpression expr : cls.getSuperClassExpressions()) { if (expr instanceof PySubscriptionExpression) { final PyExpression indexExpr = ((PySubscriptionExpression)expr).getIndexExpression(); if (indexExpr != null) { final PyGenericType genericType = getGenericType(indexExpr, context); if (genericType != null) { results.add(genericType); } } } } return results; } return Collections.emptyList(); } @Nullable private static PyType getTypingType(@NotNull PyExpression expression, @NotNull TypeEvalContext context) { final PyType unionType = getUnionType(expression, context); if (unionType != null) { return unionType; } final PyType parameterizedType = getParameterizedType(expression, context); if (parameterizedType != null) { return parameterizedType; } final PyType builtinCollection = getBuiltinCollection(expression, context); if (builtinCollection != null) { return builtinCollection; } final PyType genericType = getGenericType(expression, context); if (genericType != null) { return genericType; } final PyType functionType = getFunctionType(expression, context); if (functionType != null) { return functionType; } final PyType stringBasedType = getStringBasedType(expression, context); if (stringBasedType != null) { return stringBasedType; } return null; } @Nullable private static PyType getStringBasedType(@NotNull PyExpression expression, @NotNull TypeEvalContext context) { if (expression instanceof PyStringLiteralExpression) { // XXX: Requires switching from stub to AST final String contents = ((PyStringLiteralExpression)expression).getStringValue(); final Project project = expression.getProject(); final PyExpressionCodeFragmentImpl codeFragment = new PyExpressionCodeFragmentImpl(project, "dummy.py", contents, false); codeFragment.setContext(expression.getContainingFile()); final PsiElement element = codeFragment.getFirstChild(); if (element instanceof PyExpressionStatement) { final PyExpression dummyExpr = ((PyExpressionStatement)element).getExpression(); return getType(dummyExpr, context); } } return null; } @Nullable private static PyType getFunctionType(@NotNull PyExpression expression, @NotNull TypeEvalContext context) { if (expression instanceof PySubscriptionExpression) { final PySubscriptionExpression subscriptionExpr = (PySubscriptionExpression)expression; final PyExpression operand = subscriptionExpr.getOperand(); final String operandName = resolveToQualifiedName(operand, context); if ("typing.Function".equals(operandName)) { final PyExpression indexExpr = subscriptionExpr.getIndexExpression(); if (indexExpr instanceof PyTupleExpression) { final PyTupleExpression tupleExpr = (PyTupleExpression)indexExpr; final PyExpression[] elements = tupleExpr.getElements(); if (elements.length == 2) { final PyExpression parametersExpr = elements[0]; if (parametersExpr instanceof PyListLiteralExpression) { final List parameters = new ArrayList(); final PyListLiteralExpression listExpr = (PyListLiteralExpression)parametersExpr; for (PyExpression argExpr : listExpr.getElements()) { parameters.add(new PyCallableParameterImpl(null, getType(argExpr, context))); } final PyExpression returnTypeExpr = elements[1]; final PyType returnType = getType(returnTypeExpr, context); return new PyCallableTypeImpl(parameters, returnType); } } } } } return null; } @Nullable private static PyType getUnionType(@NotNull PyExpression expression, @NotNull TypeEvalContext context) { if (expression instanceof PySubscriptionExpression) { final PySubscriptionExpression subscriptionExpr = (PySubscriptionExpression)expression; final PyExpression operand = subscriptionExpr.getOperand(); final String operandName = resolveToQualifiedName(operand, context); if ("typing.Union".equals(operandName)) { return PyUnionType.union(getIndexTypes(subscriptionExpr, context)); } } return null; } @Nullable private static PyGenericType getGenericType(@NotNull PyExpression expression, @NotNull TypeEvalContext context) { final PsiElement resolved = resolve(expression, context); if (resolved instanceof PyTargetExpression) { final PyTargetExpression targetExpr = (PyTargetExpression)resolved; final QualifiedName calleeName = targetExpr.getCalleeName(); if (calleeName != null && "typevar".equals(calleeName.toString())) { // XXX: Requires switching from stub to AST final PyExpression assigned = targetExpr.findAssignedValue(); if (assigned instanceof PyCallExpression) { final PyCallExpression assignedCall = (PyCallExpression)assigned; final PyExpression callee = assignedCall.getCallee(); if (callee != null) { final String calleeQName = resolveToQualifiedName(callee, context); if ("typing.typevar".equals(calleeQName)) { final PyExpression[] arguments = assignedCall.getArguments(); if (arguments.length > 0) { final PyExpression firstArgument = arguments[0]; if (firstArgument instanceof PyStringLiteralExpression) { final String name = ((PyStringLiteralExpression)firstArgument).getStringValue(); if (name != null) { return new PyGenericType(name, getGenericTypeBound(arguments, context)); } } } } } } } } return null; } @Nullable private static PyType getGenericTypeBound(@NotNull PyExpression[] typeVarArguments, @NotNull TypeEvalContext context) { final List types = new ArrayList(); if (typeVarArguments.length > 1) { final PyExpression secondArgument = typeVarArguments[1]; if (secondArgument instanceof PyKeywordArgument) { final PyKeywordArgument valuesArgument = (PyKeywordArgument)secondArgument; final PyExpression valueExpr = PyPsiUtils.flattenParens(valuesArgument.getValueExpression()); if (valueExpr instanceof PyTupleExpression) { final PyTupleExpression tupleExpr = (PyTupleExpression)valueExpr; for (PyExpression expr : tupleExpr.getElements()) { types.add(getType(expr, context)); } } } } return PyUnionType.union(types); } @NotNull private static List getIndexTypes(@NotNull PySubscriptionExpression expression, @NotNull TypeEvalContext context) { final List types = new ArrayList(); final PyExpression indexExpr = expression.getIndexExpression(); if (indexExpr instanceof PyTupleExpression) { final PyTupleExpression tupleExpr = (PyTupleExpression)indexExpr; for (PyExpression expr : tupleExpr.getElements()) { types.add(getType(expr, context)); } } return types; } @Nullable private static PyType getParameterizedType(@NotNull PyExpression expression, @NotNull TypeEvalContext context) { if (expression instanceof PySubscriptionExpression) { final PySubscriptionExpression subscriptionExpr = (PySubscriptionExpression)expression; final PyExpression operand = subscriptionExpr.getOperand(); final PyExpression indexExpr = subscriptionExpr.getIndexExpression(); final PyType operandType = getType(operand, context); if (operandType instanceof PyClassType) { final PyClass cls = ((PyClassType)operandType).getPyClass(); if (PyNames.TUPLE.equals(cls.getQualifiedName())) { final List indexTypes = getIndexTypes(subscriptionExpr, context); return PyTupleType.create(expression, indexTypes.toArray(new PyType[indexTypes.size()])); } else if (indexExpr != null) { final PyType indexType = context.getType(indexExpr); return new PyCollectionTypeImpl(cls, false, indexType); } } } return null; } @Nullable private static PyType getBuiltinCollection(@NotNull PyExpression expression, @NotNull TypeEvalContext context) { final String collectionName = resolveToQualifiedName(expression, context); final String builtinName = BUILTIN_COLLECTIONS.get(collectionName); return builtinName != null ? PyTypeParser.getTypeByName(expression, builtinName) : null; } @Nullable private static PyType getType(@NotNull PyExpression expression, @NotNull TypeEvalContext context) { // It is possible to replace PyAnnotation.getType() with this implementation final PyType typingType = getTypingType(expression, context); if (typingType != null) { return typingType; } final PyType type = context.getType(expression); if (type instanceof PyClassLikeType) { final PyClassLikeType classType = (PyClassLikeType)type; if (classType.isDefinition()) { return classType.toInstance(); } } else if (type instanceof PyNoneType) { return type; } return null; } @Nullable private static PsiElement resolve(@NotNull PyExpression expression, @NotNull TypeEvalContext context) { if (expression instanceof PyReferenceOwner) { final PyReferenceOwner referenceOwner = (PyReferenceOwner)expression; final PyResolveContext resolveContext = PyResolveContext.noImplicits().withTypeEvalContext(context); final PsiPolyVariantReference reference = referenceOwner.getReference(resolveContext); final PsiElement element = reference.resolve(); if (element instanceof PyFunction) { final PyFunction function = (PyFunction)element; if (PyUtil.isInit(function)) { final PyClass cls = function.getContainingClass(); if (cls != null) { return cls; } } } return element; } return null; } @Nullable private static String resolveToQualifiedName(@NotNull PyExpression expression, @NotNull TypeEvalContext context) { final PsiElement element = resolve(expression, context); if (element instanceof PyQualifiedNameOwner) { final PyQualifiedNameOwner qualifiedNameOwner = (PyQualifiedNameOwner)element; return qualifiedNameOwner.getQualifiedName(); } return null; } }