summaryrefslogtreecommitdiff
path: root/python/src/com/jetbrains/python/codeInsight/completion/PyDictKeyNamesCompletionContributor.java
blob: d7b1598f39a6cf100fd08c75ef926b1c392684d9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
/*
 * Copyright 2000-2013 JetBrains s.r.o.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.jetbrains.python.codeInsight.completion;

import com.intellij.codeInsight.completion.*;
import com.intellij.codeInsight.lookup.LookupElement;
import com.intellij.codeInsight.lookup.LookupElementBuilder;
import com.intellij.lang.ASTNode;
import com.intellij.openapi.editor.Document;
import com.intellij.openapi.util.TextRange;
import com.intellij.openapi.util.text.StringUtil;
import com.intellij.psi.PsiElement;
import com.intellij.psi.PsiFile;
import com.intellij.psi.PsiReference;
import com.intellij.psi.util.PsiTreeUtil;
import com.intellij.util.PlatformIcons;
import com.intellij.util.ProcessingContext;
import com.jetbrains.python.PyTokenTypes;
import com.jetbrains.python.psi.*;
import com.jetbrains.python.psi.types.PyType;
import com.jetbrains.python.psi.types.TypeEvalContext;
import org.jetbrains.annotations.NotNull;

import static com.intellij.patterns.PlatformPatterns.psiElement;

/**
 * User: catherine
 *
 * Complete known keys for dictionaries
 */
public class PyDictKeyNamesCompletionContributor extends CompletionContributor {
  public PyDictKeyNamesCompletionContributor() {
    extend(
      CompletionType.BASIC,
      psiElement().inside(PySubscriptionExpression.class),
      new CompletionProvider<CompletionParameters>() {
        @Override
        protected void addCompletions(
          @NotNull final CompletionParameters parameters, final ProcessingContext context, @NotNull final CompletionResultSet result) {
          final PsiElement original = parameters.getOriginalPosition();
          final int offset = parameters.getOffset();
          if (original == null) return;
          final CompletionResultSet dictCompletion = createResult(original, result, offset);

          PySubscriptionExpression subscription = PsiTreeUtil.getParentOfType(original, PySubscriptionExpression.class);
          if (subscription == null) return;
          PsiElement operand = subscription.getOperand();
          if (operand != null) {
            PsiReference reference = operand.getReference();
            if (reference != null) {
              PsiElement resolvedElement = reference.resolve();
              if (resolvedElement instanceof PyTargetExpression) {
                PyDictLiteralExpression dict = PsiTreeUtil.getNextSiblingOfType(resolvedElement, PyDictLiteralExpression.class);
                if (dict != null) {
                  addDictLiteralKeys(dict, dictCompletion);
                  PsiFile file = parameters.getOriginalFile();
                  addAdditionalKeys(file, operand, dictCompletion);
                }
                PyCallExpression dictConstructor = PsiTreeUtil.getNextSiblingOfType(resolvedElement, PyCallExpression.class);
                if (dictConstructor != null) {
                  addDictConstructorKeys(dictConstructor, dictCompletion);
                  PsiFile file = parameters.getOriginalFile();
                  addAdditionalKeys(file, operand, dictCompletion);
                }
              }
            }
          }
        }
      }
    );
  }

  /**
   * create completion result with prefix matcher if needed
   *
   * @param original is original element
   * @param result is initial completion result
   * @param offset
   * @return
   */
  private static CompletionResultSet createResult(@NotNull final PsiElement original, @NotNull final CompletionResultSet result, final int offset) {
    PyStringLiteralExpression prevElement = PsiTreeUtil.getPrevSiblingOfType(original, PyStringLiteralExpression.class);
    if (prevElement != null) {
      ASTNode prevNode = prevElement.getNode();
      if (prevNode != null) {
        if (prevNode.getElementType() != PyTokenTypes.LBRACKET)
          return result.withPrefixMatcher(findPrefix(prevElement, offset));
      }
    }
    final PsiElement parentElement = original.getParent();
    if (parentElement != null) {
      if (parentElement instanceof PyStringLiteralExpression)
        return result.withPrefixMatcher(findPrefix((PyElement)parentElement, offset));
    }
    final PyNumericLiteralExpression number = PsiTreeUtil.findElementOfClassAtOffset(original.getContainingFile(),
                                                                               offset - 1, PyNumericLiteralExpression.class, false);
    if (number != null)
      return result.withPrefixMatcher(findPrefix(number, offset));
    return result;
  }

  /**
   * finds prefix. For *'str'* returns just *'str*.
   * @param element to find prefix of
   * @return prefix
   */
  private static String findPrefix(final PyElement element, final int offset) {
    return TextRange.create(element.getTextRange().getStartOffset(), offset).substring(element.getContainingFile().getText());
  }

  /**
   * add keys to completion result from dict constructor
   */
  private static void addDictConstructorKeys(final PyCallExpression dictConstructor, final CompletionResultSet result) {
    final PyExpression callee = dictConstructor.getCallee();
    if (callee == null) return;
    final String name = callee.getText();
    if ("dict".equals(name)) {
      final TypeEvalContext context = TypeEvalContext.userInitiated(callee.getContainingFile());
      final PyType type = context.getType(dictConstructor);
      if (type != null && type.isBuiltin()) {
        final PyArgumentList list = dictConstructor.getArgumentList();
        if (list == null) return;
        final PyExpression[] argumentList = list.getArguments();
        for (final PyExpression argument : argumentList) {
          if (argument instanceof PyKeywordArgument) {
            result.addElement(createElement("'" + ((PyKeywordArgument)argument).getKeyword() + "'"));
          }
        }
      }
    }
  }

  /**
   * add keys from assignment statements
   * For instance, dictionary['b']=b
   * @param file to get additional keys
   * @param operand is operand of origin element
   * @param result is completion result set
   */
  private static void addAdditionalKeys(final PsiFile file, final PsiElement operand, final CompletionResultSet result) {
    PySubscriptionExpression[] subscriptionExpressions = PyUtil.getAllChildrenOfType(file, PySubscriptionExpression.class);
    for (PySubscriptionExpression expr : subscriptionExpressions) {
      if (expr.getOperand().getText().equals(operand.getText())) {
        final PsiElement parent = expr.getParent();
        if (parent instanceof PyAssignmentStatement) {
          if (expr.equals(((PyAssignmentStatement)parent).getLeftHandSideExpression())) {
            PyExpression key = expr.getIndexExpression();
            if (key != null) {
              boolean addHandler = PsiTreeUtil.findElementOfClassAtRange(file, key.getTextRange().getStartOffset(),
                                                           key.getTextRange().getEndOffset(), PyStringLiteralExpression.class) != null;
              result.addElement(createElement(key.getText(), addHandler));
            }
          }
        }
      }
    }
  }

  /**
   * add keys from dict literal expression
   */
  private static void addDictLiteralKeys(final PyDictLiteralExpression dict, final CompletionResultSet result) {
    PyKeyValueExpression[] keyValues = dict.getElements();
    for (PyKeyValueExpression expression : keyValues) {
      boolean addHandler = PsiTreeUtil.findElementOfClassAtRange(dict.getContainingFile(), expression.getTextRange().getStartOffset(),
                                                           expression.getTextRange().getEndOffset(), PyStringLiteralExpression.class) != null;
      result.addElement(createElement(expression.getKey().getText(), addHandler));
    }
  }

  private static LookupElementBuilder createElement(final String key) {
    return createElement(key, true);
  }

  private static LookupElementBuilder createElement(final String key, final boolean addHandler) {
    LookupElementBuilder item;
    item = LookupElementBuilder
      .create(key)
      .withTypeText("dict key")
      .withIcon(PlatformIcons.PARAMETER_ICON);

    if (addHandler)
      item = item.withInsertHandler(new InsertHandler<LookupElement>() {
        @Override
        public void handleInsert(final InsertionContext context, final LookupElement item) {
          final PyStringLiteralExpression str = PsiTreeUtil.findElementOfClassAtOffset(context.getFile(), context.getStartOffset(),
                                                                                       PyStringLiteralExpression.class, false);
          if (str != null) {
            final boolean isDictKeys = PsiTreeUtil.getParentOfType(str, PySubscriptionExpression.class) != null;
            if (isDictKeys) {
              final int off = context.getStartOffset() + str.getTextLength();
              final PsiElement element = context.getFile().findElementAt(off);
              final boolean atRBrace = element == null || element.getNode().getElementType() == PyTokenTypes.RBRACKET;
              final boolean badQuoting =
                (!StringUtil.startsWithChar(str.getText(), '\'') || !StringUtil.endsWithChar(str.getText(), '\'')) &&
                (!StringUtil.startsWithChar(str.getText(), '"') || !StringUtil.endsWithChar(str.getText(), '"'));
              if (badQuoting || !atRBrace) {
                final Document document = context.getEditor().getDocument();
                final int offset = context.getTailOffset();
                document.deleteString(offset - 1, offset);
              }
            }
          }
        }
      });
    return item;
  }

}