aboutsummaryrefslogtreecommitdiff
path: root/refactor
diff options
context:
space:
mode:
Diffstat (limited to 'refactor')
-rw-r--r--refactor/eg/eg.go32
-rw-r--r--refactor/eg/eg_test.go6
-rw-r--r--refactor/eg/rewrite.go152
-rw-r--r--refactor/eg/testdata/I.template14
-rw-r--r--refactor/eg/testdata/I1.go9
-rw-r--r--refactor/eg/testdata/I1.golden14
-rw-r--r--refactor/eg/testdata/J.template11
-rw-r--r--refactor/eg/testdata/J1.go10
-rw-r--r--refactor/eg/testdata/J1.golden11
-rw-r--r--refactor/eg/testdata/no_after_return.template2
10 files changed, 211 insertions, 50 deletions
diff --git a/refactor/eg/eg.go b/refactor/eg/eg.go
index bf0ee2909..02c6dc951 100644
--- a/refactor/eg/eg.go
+++ b/refactor/eg/eg.go
@@ -145,6 +145,7 @@ type Transformer struct {
env map[string]ast.Expr // maps parameter name to wildcard binding
importedObjs map[types.Object]*ast.SelectorExpr // objects imported by after().
before, after ast.Expr
+ afterStmts []ast.Stmt
allowWildcards bool
// Working state of Transform():
@@ -198,7 +199,7 @@ func NewTransformer(fset *token.FileSet, tmplPkg *types.Package, tmplFile *ast.F
if err != nil {
return nil, fmt.Errorf("before: %s", err)
}
- after, err := soleExpr(afterDecl)
+ afterStmts, after, err := stmtAndExpr(afterDecl)
if err != nil {
return nil, fmt.Errorf("after: %s", err)
}
@@ -242,6 +243,7 @@ func NewTransformer(fset *token.FileSet, tmplPkg *types.Package, tmplFile *ast.F
importedObjs: make(map[types.Object]*ast.SelectorExpr),
before: before,
after: after,
+ afterStmts: afterStmts,
}
// Combine type info from the template and input packages, and
@@ -279,6 +281,7 @@ func WriteAST(fset *token.FileSet, filename string, f *ast.File) (err error) {
if err != nil {
return err
}
+
defer func() {
if err2 := fh.Close(); err != nil {
err = err2 // prefer earlier error
@@ -319,6 +322,33 @@ func soleExpr(fn *ast.FuncDecl) (ast.Expr, error) {
return nil, fmt.Errorf("must contain a single return or expression statement")
}
+// stmtAndExpr returns the expression in the last return statement as well as the preceeding lines.
+func stmtAndExpr(fn *ast.FuncDecl) ([]ast.Stmt, ast.Expr, error) {
+ if fn.Body == nil {
+ return nil, nil, fmt.Errorf("no body")
+ }
+
+ n := len(fn.Body.List)
+ if n == 0 {
+ return nil, nil, fmt.Errorf("must contain at least one statement")
+ }
+
+ stmts, last := fn.Body.List[:n-1], fn.Body.List[n-1]
+
+ switch last := last.(type) {
+ case *ast.ReturnStmt:
+ if len(last.Results) != 1 {
+ return nil, nil, fmt.Errorf("return statement must have a single operand")
+ }
+ return stmts, last.Results[0], nil
+
+ case *ast.ExprStmt:
+ return stmts, last.X, nil
+ }
+
+ return nil, nil, fmt.Errorf("must end with a single return or expression statement")
+}
+
// mergeTypeInfo adds type info from src to dst.
func mergeTypeInfo(dst, src *types.Info) {
for k, v := range src.Types {
diff --git a/refactor/eg/eg_test.go b/refactor/eg/eg_test.go
index 896bc9b08..89b7d08fb 100644
--- a/refactor/eg/eg_test.go
+++ b/refactor/eg/eg_test.go
@@ -78,6 +78,12 @@ func Test(t *testing.T) {
"testdata/H.template",
"testdata/H1.go",
+ "testdata/I.template",
+ "testdata/I1.go",
+
+ "testdata/J.template",
+ "testdata/J1.go",
+
"testdata/bad_type.template",
"testdata/no_before.template",
"testdata/no_after_return.template",
diff --git a/refactor/eg/rewrite.go b/refactor/eg/rewrite.go
index dd28aa578..a1281582b 100644
--- a/refactor/eg/rewrite.go
+++ b/refactor/eg/rewrite.go
@@ -22,6 +22,52 @@ import (
"golang.org/x/tools/go/ast/astutil"
)
+// transformItem takes a reflect.Value representing a variable of type ast.Node
+// transforms its child elements recursively with apply, and then transforms the
+// actual element if it contains an expression.
+func (tr *Transformer) transformItem(rv reflect.Value) (reflect.Value, bool, map[string]ast.Expr) {
+ // don't bother if val is invalid to start with
+ if !rv.IsValid() {
+ return reflect.Value{}, false, nil
+ }
+
+ rv, changed, newEnv := tr.apply(tr.transformItem, rv)
+
+ e := rvToExpr(rv)
+ if e == nil {
+ return rv, changed, newEnv
+ }
+
+ savedEnv := tr.env
+ tr.env = make(map[string]ast.Expr) // inefficient! Use a slice of k/v pairs
+
+ if tr.matchExpr(tr.before, e) {
+ if tr.verbose {
+ fmt.Fprintf(os.Stderr, "%s matches %s",
+ astString(tr.fset, tr.before), astString(tr.fset, e))
+ if len(tr.env) > 0 {
+ fmt.Fprintf(os.Stderr, " with:")
+ for name, ast := range tr.env {
+ fmt.Fprintf(os.Stderr, " %s->%s",
+ name, astString(tr.fset, ast))
+ }
+ }
+ fmt.Fprintf(os.Stderr, "\n")
+ }
+ tr.nsubsts++
+
+ // Clone the replacement tree, performing parameter substitution.
+ // We update all positions to n.Pos() to aid comment placement.
+ rv = tr.subst(tr.env, reflect.ValueOf(tr.after),
+ reflect.ValueOf(e.Pos()))
+ changed = true
+ newEnv = tr.env
+ }
+ tr.env = savedEnv
+
+ return rv, changed, newEnv
+}
+
// Transform applies the transformation to the specified parsed file,
// whose type information is supplied in info, and returns the number
// of replacements that were made.
@@ -43,48 +89,14 @@ func (tr *Transformer) Transform(info *types.Info, pkg *types.Package, file *ast
if tr.verbose {
fmt.Fprintf(os.Stderr, "before: %s\n", astString(tr.fset, tr.before))
fmt.Fprintf(os.Stderr, "after: %s\n", astString(tr.fset, tr.after))
+ fmt.Fprintf(os.Stderr, "afterStmts: %s\n", tr.afterStmts)
}
- var f func(rv reflect.Value) reflect.Value
- f = func(rv reflect.Value) reflect.Value {
- // don't bother if val is invalid to start with
- if !rv.IsValid() {
- return reflect.Value{}
- }
-
- rv = apply(f, rv)
-
- e := rvToExpr(rv)
- if e != nil {
- savedEnv := tr.env
- tr.env = make(map[string]ast.Expr) // inefficient! Use a slice of k/v pairs
-
- if tr.matchExpr(tr.before, e) {
- if tr.verbose {
- fmt.Fprintf(os.Stderr, "%s matches %s",
- astString(tr.fset, tr.before), astString(tr.fset, e))
- if len(tr.env) > 0 {
- fmt.Fprintf(os.Stderr, " with:")
- for name, ast := range tr.env {
- fmt.Fprintf(os.Stderr, " %s->%s",
- name, astString(tr.fset, ast))
- }
- }
- fmt.Fprintf(os.Stderr, "\n")
- }
- tr.nsubsts++
-
- // Clone the replacement tree, performing parameter substitution.
- // We update all positions to n.Pos() to aid comment placement.
- rv = tr.subst(tr.env, reflect.ValueOf(tr.after),
- reflect.ValueOf(e.Pos()))
- }
- tr.env = savedEnv
- }
-
- return rv
+ o, changed, _ := tr.apply(tr.transformItem, reflect.ValueOf(file))
+ if changed {
+ panic("BUG")
}
- file2 := apply(f, reflect.ValueOf(file)).Interface().(*ast.File)
+ file2 := o.Interface().(*ast.File)
// By construction, the root node is unchanged.
if file != file2 {
@@ -150,45 +162,91 @@ var (
identType = reflect.TypeOf((*ast.Ident)(nil))
selectorExprType = reflect.TypeOf((*ast.SelectorExpr)(nil))
objectPtrType = reflect.TypeOf((*ast.Object)(nil))
+ statementType = reflect.TypeOf((*ast.Stmt)(nil)).Elem()
positionType = reflect.TypeOf(token.NoPos)
scopePtrType = reflect.TypeOf((*ast.Scope)(nil))
)
// apply replaces each AST field x in val with f(x), returning val.
// To avoid extra conversions, f operates on the reflect.Value form.
-func apply(f func(reflect.Value) reflect.Value, val reflect.Value) reflect.Value {
+// f takes a reflect.Value representing the variable to modify of type ast.Node.
+// It returns a reflect.Value containing the transformed value of type ast.Node,
+// whether any change was made, and a map of identifiers to ast.Expr (so we can
+// do contextually correct substitutions in the parent statements).
+func (tr *Transformer) apply(f func(reflect.Value) (reflect.Value, bool, map[string]ast.Expr), val reflect.Value) (reflect.Value, bool, map[string]ast.Expr) {
if !val.IsValid() {
- return reflect.Value{}
+ return reflect.Value{}, false, nil
}
// *ast.Objects introduce cycles and are likely incorrect after
// rewrite; don't follow them but replace with nil instead
if val.Type() == objectPtrType {
- return objectPtrNil
+ return objectPtrNil, false, nil
}
// similarly for scopes: they are likely incorrect after a rewrite;
// replace them with nil
if val.Type() == scopePtrType {
- return scopePtrNil
+ return scopePtrNil, false, nil
}
switch v := reflect.Indirect(val); v.Kind() {
case reflect.Slice:
+ // no possible rewriting of statements.
+ if v.Type().Elem() != statementType {
+ changed := false
+ var envp map[string]ast.Expr
+ for i := 0; i < v.Len(); i++ {
+ e := v.Index(i)
+ o, localchanged, env := f(e)
+ if localchanged {
+ changed = true
+ // we clobber envp here,
+ // which means if we have two sucessive
+ // replacements inside the same statement
+ // we will only generate the setup for one of them.
+ envp = env
+ }
+ setValue(e, o)
+ }
+ return val, changed, envp
+ }
+
+ // statements are rewritten.
+ var out []ast.Stmt
for i := 0; i < v.Len(); i++ {
e := v.Index(i)
- setValue(e, f(e))
+ o, changed, env := f(e)
+ if changed {
+ for _, s := range tr.afterStmts {
+ t := tr.subst(env, reflect.ValueOf(s), reflect.Value{}).Interface()
+ out = append(out, t.(ast.Stmt))
+ }
+ }
+ setValue(e, o)
+ out = append(out, e.Interface().(ast.Stmt))
}
+ return reflect.ValueOf(out), false, nil
case reflect.Struct:
+ changed := false
+ var envp map[string]ast.Expr
for i := 0; i < v.NumField(); i++ {
e := v.Field(i)
- setValue(e, f(e))
+ o, localchanged, env := f(e)
+ if localchanged {
+ changed = true
+ envp = env
+ }
+ setValue(e, o)
}
+ return val, changed, envp
case reflect.Interface:
e := v.Elem()
- setValue(v, f(e))
+ o, changed, env := f(e)
+ setValue(v, o)
+ return val, changed, env
}
- return val
+ return val, false, nil
}
// subst returns a copy of (replacement) pattern with values from env
diff --git a/refactor/eg/testdata/I.template b/refactor/eg/testdata/I.template
new file mode 100644
index 000000000..d03e77482
--- /dev/null
+++ b/refactor/eg/testdata/I.template
@@ -0,0 +1,14 @@
+// +build ignore
+
+package templates
+
+import (
+ "errors"
+ "fmt"
+)
+
+func before(s string) error { return fmt.Errorf("%s", s) }
+func after(s string) error {
+ n := fmt.Sprintf("error - %s", s)
+ return errors.New(n)
+}
diff --git a/refactor/eg/testdata/I1.go b/refactor/eg/testdata/I1.go
new file mode 100644
index 000000000..d1762ebb3
--- /dev/null
+++ b/refactor/eg/testdata/I1.go
@@ -0,0 +1,9 @@
+// +build ignore
+
+package I1
+
+import "fmt"
+
+func example() {
+ _ = fmt.Errorf("%s", "foo")
+}
diff --git a/refactor/eg/testdata/I1.golden b/refactor/eg/testdata/I1.golden
new file mode 100644
index 000000000..f33b3e106
--- /dev/null
+++ b/refactor/eg/testdata/I1.golden
@@ -0,0 +1,14 @@
+// +build ignore
+
+package I1
+
+import (
+ "errors"
+ "fmt"
+)
+
+func example() {
+
+ n := fmt.Sprintf("error - %s", "foo")
+ _ = errors.New(n)
+}
diff --git a/refactor/eg/testdata/J.template b/refactor/eg/testdata/J.template
new file mode 100644
index 000000000..6f82cdfe8
--- /dev/null
+++ b/refactor/eg/testdata/J.template
@@ -0,0 +1,11 @@
+// +build ignore
+
+package templates
+
+import ()
+
+func before(x int) int { return x + x + x }
+func after(x int) int {
+ temp := x + x
+ return temp + x
+}
diff --git a/refactor/eg/testdata/J1.go b/refactor/eg/testdata/J1.go
new file mode 100644
index 000000000..2fbeee801
--- /dev/null
+++ b/refactor/eg/testdata/J1.go
@@ -0,0 +1,10 @@
+// +build ignore
+
+package I1
+
+import "fmt"
+
+func example() {
+ temp := 5
+ fmt.Print(temp + temp + temp)
+}
diff --git a/refactor/eg/testdata/J1.golden b/refactor/eg/testdata/J1.golden
new file mode 100644
index 000000000..bb2f11c60
--- /dev/null
+++ b/refactor/eg/testdata/J1.golden
@@ -0,0 +1,11 @@
+// +build ignore
+
+package I1
+
+import "fmt"
+
+func example() {
+ temp := 5
+ temp := temp + temp
+ fmt.Print(temp + temp)
+}
diff --git a/refactor/eg/testdata/no_after_return.template b/refactor/eg/testdata/no_after_return.template
index 536b01e67..dd2cbf61e 100644
--- a/refactor/eg/testdata/no_after_return.template
+++ b/refactor/eg/testdata/no_after_return.template
@@ -1,6 +1,4 @@
package template
-const shouldFail = "after: must contain a single statement"
-
func before() int { return 0 }
func after() int { println(); return 0 }