diff options
Diffstat (limited to 'refactor')
-rw-r--r-- | refactor/eg/eg.go | 32 | ||||
-rw-r--r-- | refactor/eg/eg_test.go | 6 | ||||
-rw-r--r-- | refactor/eg/rewrite.go | 152 | ||||
-rw-r--r-- | refactor/eg/testdata/I.template | 14 | ||||
-rw-r--r-- | refactor/eg/testdata/I1.go | 9 | ||||
-rw-r--r-- | refactor/eg/testdata/I1.golden | 14 | ||||
-rw-r--r-- | refactor/eg/testdata/J.template | 11 | ||||
-rw-r--r-- | refactor/eg/testdata/J1.go | 10 | ||||
-rw-r--r-- | refactor/eg/testdata/J1.golden | 11 | ||||
-rw-r--r-- | refactor/eg/testdata/no_after_return.template | 2 |
10 files changed, 211 insertions, 50 deletions
diff --git a/refactor/eg/eg.go b/refactor/eg/eg.go index bf0ee2909..02c6dc951 100644 --- a/refactor/eg/eg.go +++ b/refactor/eg/eg.go @@ -145,6 +145,7 @@ type Transformer struct { env map[string]ast.Expr // maps parameter name to wildcard binding importedObjs map[types.Object]*ast.SelectorExpr // objects imported by after(). before, after ast.Expr + afterStmts []ast.Stmt allowWildcards bool // Working state of Transform(): @@ -198,7 +199,7 @@ func NewTransformer(fset *token.FileSet, tmplPkg *types.Package, tmplFile *ast.F if err != nil { return nil, fmt.Errorf("before: %s", err) } - after, err := soleExpr(afterDecl) + afterStmts, after, err := stmtAndExpr(afterDecl) if err != nil { return nil, fmt.Errorf("after: %s", err) } @@ -242,6 +243,7 @@ func NewTransformer(fset *token.FileSet, tmplPkg *types.Package, tmplFile *ast.F importedObjs: make(map[types.Object]*ast.SelectorExpr), before: before, after: after, + afterStmts: afterStmts, } // Combine type info from the template and input packages, and @@ -279,6 +281,7 @@ func WriteAST(fset *token.FileSet, filename string, f *ast.File) (err error) { if err != nil { return err } + defer func() { if err2 := fh.Close(); err != nil { err = err2 // prefer earlier error @@ -319,6 +322,33 @@ func soleExpr(fn *ast.FuncDecl) (ast.Expr, error) { return nil, fmt.Errorf("must contain a single return or expression statement") } +// stmtAndExpr returns the expression in the last return statement as well as the preceeding lines. +func stmtAndExpr(fn *ast.FuncDecl) ([]ast.Stmt, ast.Expr, error) { + if fn.Body == nil { + return nil, nil, fmt.Errorf("no body") + } + + n := len(fn.Body.List) + if n == 0 { + return nil, nil, fmt.Errorf("must contain at least one statement") + } + + stmts, last := fn.Body.List[:n-1], fn.Body.List[n-1] + + switch last := last.(type) { + case *ast.ReturnStmt: + if len(last.Results) != 1 { + return nil, nil, fmt.Errorf("return statement must have a single operand") + } + return stmts, last.Results[0], nil + + case *ast.ExprStmt: + return stmts, last.X, nil + } + + return nil, nil, fmt.Errorf("must end with a single return or expression statement") +} + // mergeTypeInfo adds type info from src to dst. func mergeTypeInfo(dst, src *types.Info) { for k, v := range src.Types { diff --git a/refactor/eg/eg_test.go b/refactor/eg/eg_test.go index 896bc9b08..89b7d08fb 100644 --- a/refactor/eg/eg_test.go +++ b/refactor/eg/eg_test.go @@ -78,6 +78,12 @@ func Test(t *testing.T) { "testdata/H.template", "testdata/H1.go", + "testdata/I.template", + "testdata/I1.go", + + "testdata/J.template", + "testdata/J1.go", + "testdata/bad_type.template", "testdata/no_before.template", "testdata/no_after_return.template", diff --git a/refactor/eg/rewrite.go b/refactor/eg/rewrite.go index dd28aa578..a1281582b 100644 --- a/refactor/eg/rewrite.go +++ b/refactor/eg/rewrite.go @@ -22,6 +22,52 @@ import ( "golang.org/x/tools/go/ast/astutil" ) +// transformItem takes a reflect.Value representing a variable of type ast.Node +// transforms its child elements recursively with apply, and then transforms the +// actual element if it contains an expression. +func (tr *Transformer) transformItem(rv reflect.Value) (reflect.Value, bool, map[string]ast.Expr) { + // don't bother if val is invalid to start with + if !rv.IsValid() { + return reflect.Value{}, false, nil + } + + rv, changed, newEnv := tr.apply(tr.transformItem, rv) + + e := rvToExpr(rv) + if e == nil { + return rv, changed, newEnv + } + + savedEnv := tr.env + tr.env = make(map[string]ast.Expr) // inefficient! Use a slice of k/v pairs + + if tr.matchExpr(tr.before, e) { + if tr.verbose { + fmt.Fprintf(os.Stderr, "%s matches %s", + astString(tr.fset, tr.before), astString(tr.fset, e)) + if len(tr.env) > 0 { + fmt.Fprintf(os.Stderr, " with:") + for name, ast := range tr.env { + fmt.Fprintf(os.Stderr, " %s->%s", + name, astString(tr.fset, ast)) + } + } + fmt.Fprintf(os.Stderr, "\n") + } + tr.nsubsts++ + + // Clone the replacement tree, performing parameter substitution. + // We update all positions to n.Pos() to aid comment placement. + rv = tr.subst(tr.env, reflect.ValueOf(tr.after), + reflect.ValueOf(e.Pos())) + changed = true + newEnv = tr.env + } + tr.env = savedEnv + + return rv, changed, newEnv +} + // Transform applies the transformation to the specified parsed file, // whose type information is supplied in info, and returns the number // of replacements that were made. @@ -43,48 +89,14 @@ func (tr *Transformer) Transform(info *types.Info, pkg *types.Package, file *ast if tr.verbose { fmt.Fprintf(os.Stderr, "before: %s\n", astString(tr.fset, tr.before)) fmt.Fprintf(os.Stderr, "after: %s\n", astString(tr.fset, tr.after)) + fmt.Fprintf(os.Stderr, "afterStmts: %s\n", tr.afterStmts) } - var f func(rv reflect.Value) reflect.Value - f = func(rv reflect.Value) reflect.Value { - // don't bother if val is invalid to start with - if !rv.IsValid() { - return reflect.Value{} - } - - rv = apply(f, rv) - - e := rvToExpr(rv) - if e != nil { - savedEnv := tr.env - tr.env = make(map[string]ast.Expr) // inefficient! Use a slice of k/v pairs - - if tr.matchExpr(tr.before, e) { - if tr.verbose { - fmt.Fprintf(os.Stderr, "%s matches %s", - astString(tr.fset, tr.before), astString(tr.fset, e)) - if len(tr.env) > 0 { - fmt.Fprintf(os.Stderr, " with:") - for name, ast := range tr.env { - fmt.Fprintf(os.Stderr, " %s->%s", - name, astString(tr.fset, ast)) - } - } - fmt.Fprintf(os.Stderr, "\n") - } - tr.nsubsts++ - - // Clone the replacement tree, performing parameter substitution. - // We update all positions to n.Pos() to aid comment placement. - rv = tr.subst(tr.env, reflect.ValueOf(tr.after), - reflect.ValueOf(e.Pos())) - } - tr.env = savedEnv - } - - return rv + o, changed, _ := tr.apply(tr.transformItem, reflect.ValueOf(file)) + if changed { + panic("BUG") } - file2 := apply(f, reflect.ValueOf(file)).Interface().(*ast.File) + file2 := o.Interface().(*ast.File) // By construction, the root node is unchanged. if file != file2 { @@ -150,45 +162,91 @@ var ( identType = reflect.TypeOf((*ast.Ident)(nil)) selectorExprType = reflect.TypeOf((*ast.SelectorExpr)(nil)) objectPtrType = reflect.TypeOf((*ast.Object)(nil)) + statementType = reflect.TypeOf((*ast.Stmt)(nil)).Elem() positionType = reflect.TypeOf(token.NoPos) scopePtrType = reflect.TypeOf((*ast.Scope)(nil)) ) // apply replaces each AST field x in val with f(x), returning val. // To avoid extra conversions, f operates on the reflect.Value form. -func apply(f func(reflect.Value) reflect.Value, val reflect.Value) reflect.Value { +// f takes a reflect.Value representing the variable to modify of type ast.Node. +// It returns a reflect.Value containing the transformed value of type ast.Node, +// whether any change was made, and a map of identifiers to ast.Expr (so we can +// do contextually correct substitutions in the parent statements). +func (tr *Transformer) apply(f func(reflect.Value) (reflect.Value, bool, map[string]ast.Expr), val reflect.Value) (reflect.Value, bool, map[string]ast.Expr) { if !val.IsValid() { - return reflect.Value{} + return reflect.Value{}, false, nil } // *ast.Objects introduce cycles and are likely incorrect after // rewrite; don't follow them but replace with nil instead if val.Type() == objectPtrType { - return objectPtrNil + return objectPtrNil, false, nil } // similarly for scopes: they are likely incorrect after a rewrite; // replace them with nil if val.Type() == scopePtrType { - return scopePtrNil + return scopePtrNil, false, nil } switch v := reflect.Indirect(val); v.Kind() { case reflect.Slice: + // no possible rewriting of statements. + if v.Type().Elem() != statementType { + changed := false + var envp map[string]ast.Expr + for i := 0; i < v.Len(); i++ { + e := v.Index(i) + o, localchanged, env := f(e) + if localchanged { + changed = true + // we clobber envp here, + // which means if we have two sucessive + // replacements inside the same statement + // we will only generate the setup for one of them. + envp = env + } + setValue(e, o) + } + return val, changed, envp + } + + // statements are rewritten. + var out []ast.Stmt for i := 0; i < v.Len(); i++ { e := v.Index(i) - setValue(e, f(e)) + o, changed, env := f(e) + if changed { + for _, s := range tr.afterStmts { + t := tr.subst(env, reflect.ValueOf(s), reflect.Value{}).Interface() + out = append(out, t.(ast.Stmt)) + } + } + setValue(e, o) + out = append(out, e.Interface().(ast.Stmt)) } + return reflect.ValueOf(out), false, nil case reflect.Struct: + changed := false + var envp map[string]ast.Expr for i := 0; i < v.NumField(); i++ { e := v.Field(i) - setValue(e, f(e)) + o, localchanged, env := f(e) + if localchanged { + changed = true + envp = env + } + setValue(e, o) } + return val, changed, envp case reflect.Interface: e := v.Elem() - setValue(v, f(e)) + o, changed, env := f(e) + setValue(v, o) + return val, changed, env } - return val + return val, false, nil } // subst returns a copy of (replacement) pattern with values from env diff --git a/refactor/eg/testdata/I.template b/refactor/eg/testdata/I.template new file mode 100644 index 000000000..d03e77482 --- /dev/null +++ b/refactor/eg/testdata/I.template @@ -0,0 +1,14 @@ +// +build ignore + +package templates + +import ( + "errors" + "fmt" +) + +func before(s string) error { return fmt.Errorf("%s", s) } +func after(s string) error { + n := fmt.Sprintf("error - %s", s) + return errors.New(n) +} diff --git a/refactor/eg/testdata/I1.go b/refactor/eg/testdata/I1.go new file mode 100644 index 000000000..d1762ebb3 --- /dev/null +++ b/refactor/eg/testdata/I1.go @@ -0,0 +1,9 @@ +// +build ignore + +package I1 + +import "fmt" + +func example() { + _ = fmt.Errorf("%s", "foo") +} diff --git a/refactor/eg/testdata/I1.golden b/refactor/eg/testdata/I1.golden new file mode 100644 index 000000000..f33b3e106 --- /dev/null +++ b/refactor/eg/testdata/I1.golden @@ -0,0 +1,14 @@ +// +build ignore + +package I1 + +import ( + "errors" + "fmt" +) + +func example() { + + n := fmt.Sprintf("error - %s", "foo") + _ = errors.New(n) +} diff --git a/refactor/eg/testdata/J.template b/refactor/eg/testdata/J.template new file mode 100644 index 000000000..6f82cdfe8 --- /dev/null +++ b/refactor/eg/testdata/J.template @@ -0,0 +1,11 @@ +// +build ignore + +package templates + +import () + +func before(x int) int { return x + x + x } +func after(x int) int { + temp := x + x + return temp + x +} diff --git a/refactor/eg/testdata/J1.go b/refactor/eg/testdata/J1.go new file mode 100644 index 000000000..2fbeee801 --- /dev/null +++ b/refactor/eg/testdata/J1.go @@ -0,0 +1,10 @@ +// +build ignore + +package I1 + +import "fmt" + +func example() { + temp := 5 + fmt.Print(temp + temp + temp) +} diff --git a/refactor/eg/testdata/J1.golden b/refactor/eg/testdata/J1.golden new file mode 100644 index 000000000..bb2f11c60 --- /dev/null +++ b/refactor/eg/testdata/J1.golden @@ -0,0 +1,11 @@ +// +build ignore + +package I1 + +import "fmt" + +func example() { + temp := 5 + temp := temp + temp + fmt.Print(temp + temp) +} diff --git a/refactor/eg/testdata/no_after_return.template b/refactor/eg/testdata/no_after_return.template index 536b01e67..dd2cbf61e 100644 --- a/refactor/eg/testdata/no_after_return.template +++ b/refactor/eg/testdata/no_after_return.template @@ -1,6 +1,4 @@ package template -const shouldFail = "after: must contain a single statement" - func before() int { return 0 } func after() int { println(); return 0 } |