summaryrefslogtreecommitdiff
path: root/python/src/com/jetbrains/python/inspections/quickfix/PyDefaultArgumentQuickFix.java
diff options
context:
space:
mode:
Diffstat (limited to 'python/src/com/jetbrains/python/inspections/quickfix/PyDefaultArgumentQuickFix.java')
-rw-r--r--python/src/com/jetbrains/python/inspections/quickfix/PyDefaultArgumentQuickFix.java39
1 files changed, 23 insertions, 16 deletions
diff --git a/python/src/com/jetbrains/python/inspections/quickfix/PyDefaultArgumentQuickFix.java b/python/src/com/jetbrains/python/inspections/quickfix/PyDefaultArgumentQuickFix.java
index 31723d3aab4e..db47f17f8ea8 100644
--- a/python/src/com/jetbrains/python/inspections/quickfix/PyDefaultArgumentQuickFix.java
+++ b/python/src/com/jetbrains/python/inspections/quickfix/PyDefaultArgumentQuickFix.java
@@ -61,29 +61,36 @@ public class PyDefaultArgumentQuickFix implements LocalQuickFix {
PyStatementList list = function.getStatementList();
PyParameterList paramList = function.getParameterList();
- StringBuilder str = new StringBuilder("def foo(");
+ final StringBuilder functionText = new StringBuilder("def foo(");
int size = paramList.getParameters().length;
for (int i = 0; i != size; ++i) {
PyParameter p = paramList.getParameters()[i];
if (p == param)
- str.append(defName).append("=None");
+ functionText.append(defName).append("=None");
else
- str.append(p.getText());
+ functionText.append(p.getText());
if (i != size-1)
- str.append(", ");
+ functionText.append(", ");
+ }
+ functionText.append("):\n\tif not ").append(defName).append(":\n\t\t").append(defName).append(" = ").append(defaultValue.getText());
+ final PyStatement[] statements = list.getStatements();
+ PyStatement firstStatement = statements.length > 0 ? statements[0] : null;
+ PyFunction newFunction = elementGenerator.createFromText(LanguageLevel.forElement(function), PyFunction.class,
+ functionText.toString());
+ if (firstStatement == null) {
+ function.replace(newFunction);
+ }
+ else {
+ final PyStatement ifStatement = newFunction.getStatementList().getStatements()[0];
+ PyStringLiteralExpression docString = function.getDocStringExpression();
+ if (docString != null)
+ list.addAfter(ifStatement, firstStatement);
+ else {
+ list.addBefore(ifStatement, firstStatement);
+ }
+ paramList.replace(elementGenerator.createFromText(LanguageLevel.forElement(defaultValue),
+ PyFunction.class, functionText.toString()).getParameterList());
}
- str.append("):\n\tpass");
- PyIfStatement ifStatement = elementGenerator.createFromText(LanguageLevel.forElement(function), PyIfStatement.class,
- "if not " + defName + ": " + defName + " = " + defaultValue.getText());
-
- PyStatement firstStatement = list.getStatements()[0];
- PyStringLiteralExpression docString = function.getDocStringExpression();
- if (docString != null)
- list.addAfter(ifStatement, firstStatement);
- else
- list.addBefore(ifStatement, firstStatement);
- paramList.replace(elementGenerator.createFromText(LanguageLevel.forElement(defaultValue),
- PyFunction.class, str.toString()).getParameterList());
}
}
}