diff options
Diffstat (limited to 'go/callgraph/vta')
-rw-r--r-- | go/callgraph/vta/graph.go | 126 | ||||
-rw-r--r-- | go/callgraph/vta/graph_test.go | 7 | ||||
-rw-r--r-- | go/callgraph/vta/helpers_test.go | 8 | ||||
-rw-r--r-- | go/callgraph/vta/internal/trie/bits.go | 25 | ||||
-rw-r--r-- | go/callgraph/vta/internal/trie/builder.go | 24 | ||||
-rw-r--r-- | go/callgraph/vta/internal/trie/trie.go | 12 | ||||
-rw-r--r-- | go/callgraph/vta/propagation.go | 57 | ||||
-rw-r--r-- | go/callgraph/vta/propagation_test.go | 61 | ||||
-rw-r--r-- | go/callgraph/vta/testdata/src/callgraph_generics.go | 71 | ||||
-rw-r--r-- | go/callgraph/vta/testdata/src/callgraph_issue_57756.go | 67 | ||||
-rw-r--r-- | go/callgraph/vta/testdata/src/callgraph_recursive_types.go | 56 | ||||
-rw-r--r-- | go/callgraph/vta/testdata/src/function_alias.go | 44 | ||||
-rw-r--r-- | go/callgraph/vta/testdata/src/panic.go | 3 | ||||
-rw-r--r-- | go/callgraph/vta/utils.go | 110 | ||||
-rw-r--r-- | go/callgraph/vta/vta.go | 33 | ||||
-rw-r--r-- | go/callgraph/vta/vta_go117_test.go | 3 | ||||
-rw-r--r-- | go/callgraph/vta/vta_test.go | 28 |
17 files changed, 543 insertions, 192 deletions
diff --git a/go/callgraph/vta/graph.go b/go/callgraph/vta/graph.go index ad7ef0e88..2537123f4 100644 --- a/go/callgraph/vta/graph.go +++ b/go/callgraph/vta/graph.go @@ -12,6 +12,7 @@ import ( "golang.org/x/tools/go/callgraph" "golang.org/x/tools/go/ssa" "golang.org/x/tools/go/types/typeutil" + "golang.org/x/tools/internal/typeparams" ) // node interface for VTA nodes. @@ -175,9 +176,10 @@ func (f function) String() string { // We merge such constructs into a single node for simplicity and without // much precision sacrifice as such variables are rare in practice. Both // a and b would be represented as the same PtrInterface(I) node in: -// type I interface -// var a ***I -// var b **I +// +// type I interface +// var a ***I +// var b **I type nestedPtrInterface struct { typ types.Type } @@ -195,8 +197,9 @@ func (l nestedPtrInterface) String() string { // constructs into a single node for simplicity and without much precision // sacrifice as such variables are rare in practice. Both a and b would be // represented as the same PtrFunction(func()) node in: -// var a *func() -// var b **func() +// +// var a *func() +// var b **func() type nestedPtrFunction struct { typ types.Type } @@ -325,14 +328,16 @@ func (b *builder) instr(instr ssa.Instruction) { // change type command a := A(b) results in a and b being the // same value. For concrete type A, there is no interesting flow. // - // Note: When A is an interface, most interface casts are handled + // When A is an interface, most interface casts are handled // by the ChangeInterface instruction. The relevant case here is // when converting a pointer to an interface type. This can happen // when the underlying interfaces have the same method set. - // type I interface{ foo() } - // type J interface{ foo() } - // var b *I - // a := (*J)(b) + // + // type I interface{ foo() } + // type J interface{ foo() } + // var b *I + // a := (*J)(b) + // // When this happens we add flows between a <--> b. b.addInFlowAliasEdges(b.nodeFromVal(i), b.nodeFromVal(i.X)) case *ssa.TypeAssert: @@ -371,6 +376,8 @@ func (b *builder) instr(instr ssa.Instruction) { // SliceToArrayPointer: t1 = slice to array pointer *[4]T <- []T (t0) // No interesting flow as sliceArrayElem(t1) == sliceArrayElem(t0). return + case *ssa.MultiConvert: + b.multiconvert(i) default: panic(fmt.Sprintf("unsupported instruction %v\n", instr)) } @@ -441,7 +448,9 @@ func (b *builder) send(s *ssa.Send) { } // selekt generates flows for select statement -// a = select blocking/nonblocking [c_1 <- t_1, c_2 <- t_2, ..., <- o_1, <- o_2, ...] +// +// a = select blocking/nonblocking [c_1 <- t_1, c_2 <- t_2, ..., <- o_1, <- o_2, ...] +// // between receiving channel registers c_i and corresponding input register t_i. Further, // flows are generated between o_i and a[2 + i]. Note that a is a tuple register of type // <int, bool, r_1, r_2, ...> where the type of r_i is the element type of channel o_i. @@ -544,8 +553,9 @@ func (b *builder) closure(c *ssa.MakeClosure) { // panic creates a flow from arguments to panic instructions to return // registers of all recover statements in the program. Introduces a // global panic node Panic and -// 1) for every panic statement p: add p -> Panic -// 2) for every recover statement r: add Panic -> r (handled in call) +// 1. for every panic statement p: add p -> Panic +// 2. for every recover statement r: add Panic -> r (handled in call) +// // TODO(zpavlinovic): improve precision by explicitly modeling how panic // values flow from callees to callers and into deferred recover instructions. func (b *builder) panic(p *ssa.Panic) { @@ -563,7 +573,9 @@ func (b *builder) panic(p *ssa.Panic) { func (b *builder) call(c ssa.CallInstruction) { // When c is r := recover() call register instruction, we add Recover -> r. if bf, ok := c.Common().Value.(*ssa.Builtin); ok && bf.Name() == "recover" { - b.addInFlowEdge(recoverReturn{}, b.nodeFromVal(c.(*ssa.Call))) + if v, ok := c.(ssa.Value); ok { + b.addInFlowEdge(recoverReturn{}, b.nodeFromVal(v)) + } return } @@ -581,10 +593,18 @@ func addArgumentFlows(b *builder, c ssa.CallInstruction, f *ssa.Function) { return } cc := c.Common() - // When c is an unresolved method call (cc.Method != nil), cc.Value contains - // the receiver object rather than cc.Args[0]. if cc.Method != nil { - b.addInFlowAliasEdges(b.nodeFromVal(f.Params[0]), b.nodeFromVal(cc.Value)) + // In principle we don't add interprocedural flows for receiver + // objects. At a call site, the receiver object is interface + // while the callee object is concrete. The flow from interface + // to concrete type in general does not make sense. The exception + // is when the concrete type is a named function type (see #57756). + // + // The flow other way around would bake in information from the + // initial call graph. + if isFunction(f.Params[0].Type()) { + b.addInFlowEdge(b.nodeFromVal(cc.Value), b.nodeFromVal(f.Params[0])) + } } offset := 0 @@ -638,6 +658,71 @@ func addReturnFlows(b *builder, r *ssa.Return, site ssa.Value) { } } +func (b *builder) multiconvert(c *ssa.MultiConvert) { + // TODO(zpavlinovic): decide what to do on MultiConvert long term. + // TODO(zpavlinovic): add unit tests. + typeSetOf := func(typ types.Type) []*typeparams.Term { + // This is a adaptation of x/exp/typeparams.NormalTerms which x/tools cannot depend on. + var terms []*typeparams.Term + var err error + switch typ := typ.(type) { + case *typeparams.TypeParam: + terms, err = typeparams.StructuralTerms(typ) + case *typeparams.Union: + terms, err = typeparams.UnionTermSet(typ) + case *types.Interface: + terms, err = typeparams.InterfaceTermSet(typ) + default: + // Common case. + // Specializing the len=1 case to avoid a slice + // had no measurable space/time benefit. + terms = []*typeparams.Term{typeparams.NewTerm(false, typ)} + } + + if err != nil { + return nil + } + return terms + } + // isValuePreserving returns true if a conversion from ut_src to + // ut_dst is value-preserving, i.e. just a change of type. + // Precondition: neither argument is a named type. + isValuePreserving := func(ut_src, ut_dst types.Type) bool { + // Identical underlying types? + if types.IdenticalIgnoreTags(ut_dst, ut_src) { + return true + } + + switch ut_dst.(type) { + case *types.Chan: + // Conversion between channel types? + _, ok := ut_src.(*types.Chan) + return ok + + case *types.Pointer: + // Conversion between pointers with identical base types? + _, ok := ut_src.(*types.Pointer) + return ok + } + return false + } + dst_terms := typeSetOf(c.Type()) + src_terms := typeSetOf(c.X.Type()) + for _, s := range src_terms { + us := s.Type().Underlying() + for _, d := range dst_terms { + ud := d.Type().Underlying() + if isValuePreserving(us, ud) { + // This is equivalent to a ChangeType. + b.addInFlowAliasEdges(b.nodeFromVal(c), b.nodeFromVal(c.X)) + return + } + // This is equivalent to either: SliceToArrayPointer,, + // SliceToArrayPointer+Deref, Size 0 Array constant, or a Convert. + } + } +} + // addInFlowEdge adds s -> d to g if d is node that can have an inflow, i.e., a node // that represents an interface or an unresolved function value. Otherwise, there // is no interesting type flow so the edge is omitted. @@ -649,7 +734,7 @@ func (b *builder) addInFlowEdge(s, d node) { // Creates const, pointer, global, func, and local nodes based on register instructions. func (b *builder) nodeFromVal(val ssa.Value) node { - if p, ok := val.Type().(*types.Pointer); ok && !isInterface(p.Elem()) && !isFunction(p.Elem()) { + if p, ok := val.Type().(*types.Pointer); ok && !types.IsInterface(p.Elem()) && !isFunction(p.Elem()) { // Nested pointer to interfaces are modeled as a special // nestedPtrInterface node. if i := interfaceUnderPtr(p.Elem()); i != nil { @@ -676,14 +761,15 @@ func (b *builder) nodeFromVal(val ssa.Value) node { default: panic(fmt.Errorf("unsupported value %v in node creation", val)) } - return nil } // representative returns a unique representative for node `n`. Since // semantically equivalent types can have different implementations, // this method guarantees the same implementation is always used. func (b *builder) representative(n node) node { - if !hasInitialTypes(n) { + if n.Type() == nil { + // panicArg and recoverReturn do not have + // types and are unique by definition. return n } t := canonicalize(n.Type(), &b.canon) diff --git a/go/callgraph/vta/graph_test.go b/go/callgraph/vta/graph_test.go index 8608844dd..8b8c6976f 100644 --- a/go/callgraph/vta/graph_test.go +++ b/go/callgraph/vta/graph_test.go @@ -13,6 +13,7 @@ import ( "testing" "golang.org/x/tools/go/callgraph/cha" + "golang.org/x/tools/go/ssa" "golang.org/x/tools/go/ssa/ssautil" ) @@ -24,7 +25,7 @@ func TestNodeInterface(t *testing.T) { // - global variable "gl" // - "main" function and its // - first register instruction t0 := *gl - prog, _, err := testProg("testdata/src/simple.go") + prog, _, err := testProg("testdata/src/simple.go", ssa.BuilderMode(0)) if err != nil { t.Fatalf("couldn't load testdata/src/simple.go program: %v", err) } @@ -78,7 +79,7 @@ func TestNodeInterface(t *testing.T) { func TestVtaGraph(t *testing.T) { // Get the basic type int from a real program. - prog, _, err := testProg("testdata/src/simple.go") + prog, _, err := testProg("testdata/src/simple.go", ssa.BuilderMode(0)) if err != nil { t.Fatalf("couldn't load testdata/src/simple.go program: %v", err) } @@ -191,7 +192,7 @@ func TestVTAGraphConstruction(t *testing.T) { "testdata/src/panic.go", } { t.Run(file, func(t *testing.T) { - prog, want, err := testProg(file) + prog, want, err := testProg(file, ssa.BuilderMode(0)) if err != nil { t.Fatalf("couldn't load test file '%s': %s", file, err) } diff --git a/go/callgraph/vta/helpers_test.go b/go/callgraph/vta/helpers_test.go index 0e00aeb28..768365f5b 100644 --- a/go/callgraph/vta/helpers_test.go +++ b/go/callgraph/vta/helpers_test.go @@ -35,7 +35,7 @@ func want(f *ast.File) []string { // testProg returns an ssa representation of a program at // `path`, assumed to define package "testdata," and the // test want result as list of strings. -func testProg(path string) (*ssa.Program, []string, error) { +func testProg(path string, mode ssa.BuilderMode) (*ssa.Program, []string, error) { content, err := ioutil.ReadFile(path) if err != nil { return nil, nil, err @@ -56,7 +56,7 @@ func testProg(path string) (*ssa.Program, []string, error) { return nil, nil, err } - prog := ssautil.CreateProgram(iprog, 0) + prog := ssautil.CreateProgram(iprog, mode) // Set debug mode to exercise DebugRef instructions. prog.Package(iprog.Created[0].Pkg).SetDebugMode(true) prog.Build() @@ -87,7 +87,9 @@ func funcName(f *ssa.Function) string { // callGraphStr stringifes `g` into a list of strings where // each entry is of the form -// f: cs1 -> f1, f2, ...; ...; csw -> fx, fy, ... +// +// f: cs1 -> f1, f2, ...; ...; csw -> fx, fy, ... +// // f is a function, cs1, ..., csw are call sites in f, and // f1, f2, ..., fx, fy, ... are the resolved callees. func callGraphStr(g *callgraph.Graph) []string { diff --git a/go/callgraph/vta/internal/trie/bits.go b/go/callgraph/vta/internal/trie/bits.go index f2fd0ba83..c3aa15985 100644 --- a/go/callgraph/vta/internal/trie/bits.go +++ b/go/callgraph/vta/internal/trie/bits.go @@ -19,11 +19,11 @@ type key uint64 // bitpos is the position of a bit. A position is represented by having a 1 // bit in that position. // Examples: -// * 0b0010 is the position of the `1` bit in 2. -// It is the 3rd most specific bit position in big endian encoding -// (0b0 and 0b1 are more specific). -// * 0b0100 is the position of the bit that 1 and 5 disagree on. -// * 0b0 is a special value indicating that all bit agree. +// - 0b0010 is the position of the `1` bit in 2. +// It is the 3rd most specific bit position in big endian encoding +// (0b0 and 0b1 are more specific). +// - 0b0100 is the position of the bit that 1 and 5 disagree on. +// - 0b0 is a special value indicating that all bit agree. type bitpos uint64 // prefixes represent a set of keys that all agree with the @@ -35,7 +35,8 @@ type bitpos uint64 // A prefix always mask(p, m) == p. // // A key is its own prefix for the bit position 64, -// e.g. seeing a `prefix(key)` is not a problem. +// e.g. seeing a `prefix(key)` is not a problem. +// // Prefixes should never be turned into keys. type prefix uint64 @@ -64,8 +65,9 @@ func matchPrefix(k prefix, p prefix, b bitpos) bool { // In big endian encoding, this value is the [64-(m-1)] most significant bits of k // followed by a `0` bit at bitpos m, followed m-1 `1` bits. // Examples: -// prefix(0b1011) for a bitpos 0b0100 represents the keys: -// 0b1000, 0b1001, 0b1010, 0b1011, 0b1100, 0b1101, 0b1110, 0b1111 +// +// prefix(0b1011) for a bitpos 0b0100 represents the keys: +// 0b1000, 0b1001, 0b1010, 0b1011, 0b1100, 0b1101, 0b1110, 0b1111 // // This mask function has the property that if matchPrefix(k, p, b), then // k <= p if and only if zeroBit(k, m). This induces binary search tree tries. @@ -85,9 +87,10 @@ func ord(m, n bitpos) bool { // can hold that can also be held by a prefix `q` for some bitpos `n`. // // This is equivalent to: -// m ==n && p == q, -// higher(m, n) && matchPrefix(q, p, m), or -// higher(n, m) && matchPrefix(p, q, n) +// +// m ==n && p == q, +// higher(m, n) && matchPrefix(q, p, m), or +// higher(n, m) && matchPrefix(p, q, n) func prefixesOverlap(p prefix, m bitpos, q prefix, n bitpos) bool { fbb := n if ord(m, n) { diff --git a/go/callgraph/vta/internal/trie/builder.go b/go/callgraph/vta/internal/trie/builder.go index 25d3805bc..11ff59b1b 100644 --- a/go/callgraph/vta/internal/trie/builder.go +++ b/go/callgraph/vta/internal/trie/builder.go @@ -9,7 +9,9 @@ package trie // will be stored for the key. // // Collision functions must be idempotent: -// collision(x, x) == x for all x. +// +// collision(x, x) == x for all x. +// // Collisions functions may be applied whenever a value is inserted // or two maps are merged, or intersected. type Collision func(lhs interface{}, rhs interface{}) interface{} @@ -72,7 +74,8 @@ func (b *Builder) Empty() Map { return Map{b.Scope(), b.empty} } // in the current scope and handle collisions using the collision function c. // // This is roughly corresponds to updating a map[uint64]interface{} by: -// if _, ok := m[k]; ok { m[k] = c(m[k], v} else { m[k] = v} +// +// if _, ok := m[k]; ok { m[k] = c(m[k], v} else { m[k] = v} // // An insertion or update happened whenever Insert(m, ...) != m . func (b *Builder) InsertWith(c Collision, m Map, k uint64, v interface{}) Map { @@ -85,7 +88,8 @@ func (b *Builder) InsertWith(c Collision, m Map, k uint64, v interface{}) Map { // // If there was a previous value mapped by key, keep the previously mapped value. // This is roughly corresponds to updating a map[uint64]interface{} by: -// if _, ok := m[k]; ok { m[k] = val } +// +// if _, ok := m[k]; ok { m[k] = val } // // This is equivalent to b.Merge(m, b.Create({k: v})). func (b *Builder) Insert(m Map, k uint64, v interface{}) Map { @@ -94,7 +98,8 @@ func (b *Builder) Insert(m Map, k uint64, v interface{}) Map { // Updates a (key, value) in the map. This is roughly corresponds to // updating a map[uint64]interface{} by: -// m[key] = val +// +// m[key] = val func (b *Builder) Update(m Map, key uint64, val interface{}) Map { return b.InsertWith(TakeRhs, m, key, val) } @@ -148,14 +153,17 @@ func (b *Builder) Remove(m Map, k uint64) Map { // Intersect Maps lhs and rhs and returns a map with all of the keys in // both lhs and rhs and the value comes from lhs, i.e. -// {(k, lhs[k]) | k in lhs, k in rhs}. +// +// {(k, lhs[k]) | k in lhs, k in rhs}. func (b *Builder) Intersect(lhs, rhs Map) Map { return b.IntersectWith(TakeLhs, lhs, rhs) } // IntersectWith take lhs and rhs and returns the intersection // with the value coming from the collision function, i.e. -// {(k, c(lhs[k], rhs[k]) ) | k in lhs, k in rhs}. +// +// {(k, c(lhs[k], rhs[k]) ) | k in lhs, k in rhs}. +// // The elements of the resulting map are always { <k, c(lhs[k], rhs[k]) > } // for each key k that a key in both lhs and rhs. func (b *Builder) IntersectWith(c Collision, lhs, rhs Map) Map { @@ -261,7 +269,9 @@ func (b *Builder) mkLeaf(k key, v interface{}) *leaf { } // mkBranch returns the hash-consed representative of the tuple -// (prefix, branch, left, right) +// +// (prefix, branch, left, right) +// // in the current scope. func (b *Builder) mkBranch(p prefix, bp bitpos, left node, right node) *branch { br := &branch{ diff --git a/go/callgraph/vta/internal/trie/trie.go b/go/callgraph/vta/internal/trie/trie.go index 160eb21be..511fde515 100644 --- a/go/callgraph/vta/internal/trie/trie.go +++ b/go/callgraph/vta/internal/trie/trie.go @@ -10,8 +10,10 @@ // environment abstract domains in program analysis). // // This implementation closely follows the paper: -// C. Okasaki and A. Gill, “Fast mergeable integer maps,” in ACM SIGPLAN -// Workshop on ML, September 1998, pp. 77–86. +// +// C. Okasaki and A. Gill, “Fast mergeable integer maps,” in ACM SIGPLAN +// Workshop on ML, September 1998, pp. 77–86. +// // Each Map is immutable and can be read from concurrently. The map does not // guarantee that the value pointed to by the interface{} value is not updated // concurrently. @@ -36,9 +38,9 @@ import ( // Maps are immutable and can be read from concurrently. // // Notes on concurrency: -// - A Map value itself is an interface and assignments to a Map value can race. -// - Map does not guarantee that the value pointed to by the interface{} value -// is not updated concurrently. +// - A Map value itself is an interface and assignments to a Map value can race. +// - Map does not guarantee that the value pointed to by the interface{} value +// is not updated concurrently. type Map struct { s Scope n node diff --git a/go/callgraph/vta/propagation.go b/go/callgraph/vta/propagation.go index 5934ebc21..5817e8938 100644 --- a/go/callgraph/vta/propagation.go +++ b/go/callgraph/vta/propagation.go @@ -20,53 +20,52 @@ import ( // with ids X and Y s.t. X < Y, Y comes before X in the topological order. func scc(g vtaGraph) (map[node]int, int) { // standard data structures used by Tarjan's algorithm. - var index uint64 + type state struct { + index int + lowLink int + onStack bool + } + states := make(map[node]*state, len(g)) var stack []node - indexMap := make(map[node]uint64) - lowLink := make(map[node]uint64) - onStack := make(map[node]bool) - nodeToSccID := make(map[node]int) + nodeToSccID := make(map[node]int, len(g)) sccID := 0 var doSCC func(node) doSCC = func(n node) { - indexMap[n] = index - lowLink[n] = index - index = index + 1 - onStack[n] = true + index := len(states) + ns := &state{index: index, lowLink: index, onStack: true} + states[n] = ns stack = append(stack, n) for s := range g[n] { - if _, ok := indexMap[s]; !ok { + if ss, visited := states[s]; !visited { // Analyze successor s that has not been visited yet. doSCC(s) - lowLink[n] = min(lowLink[n], lowLink[s]) - } else if onStack[s] { + ss = states[s] + ns.lowLink = min(ns.lowLink, ss.lowLink) + } else if ss.onStack { // The successor is on the stack, meaning it has to be // in the current SCC. - lowLink[n] = min(lowLink[n], indexMap[s]) + ns.lowLink = min(ns.lowLink, ss.index) } } // if n is a root node, pop the stack and generate a new SCC. - if lowLink[n] == indexMap[n] { - for { - w := stack[len(stack)-1] + if ns.lowLink == index { + var w node + for w != n { + w = stack[len(stack)-1] stack = stack[:len(stack)-1] - onStack[w] = false + states[w].onStack = false nodeToSccID[w] = sccID - if w == n { - break - } } sccID++ } } - index = 0 for n := range g { - if _, ok := indexMap[n]; !ok { + if _, visited := states[n]; !visited { doSCC(n) } } @@ -74,7 +73,7 @@ func scc(g vtaGraph) (map[node]int, int) { return nodeToSccID, sccID } -func min(x, y uint64) uint64 { +func min(x, y int) int { if x < y { return x } @@ -175,6 +174,18 @@ func nodeTypes(nodes []node, builder *trie.Builder, propTypeId func(p propType) return &typeSet } +// hasInitialTypes check if a node can have initial types. +// Returns true iff `n` is not a panic, recover, nestedPtr* +// node, nor a node whose type is an interface. +func hasInitialTypes(n node) bool { + switch n.(type) { + case panicArg, recoverReturn, nestedPtrFunction, nestedPtrInterface: + return false + default: + return !types.IsInterface(n.Type()) + } +} + // getPropType creates a propType for `node` based on its type. // propType.typ is always node.Type(). If node is function, then // propType.val is the underlying function; nil otherwise. diff --git a/go/callgraph/vta/propagation_test.go b/go/callgraph/vta/propagation_test.go index 96707417f..f4a754f96 100644 --- a/go/callgraph/vta/propagation_test.go +++ b/go/callgraph/vta/propagation_test.go @@ -58,7 +58,7 @@ func newLocal(name string, t types.Type) local { // newNamedType creates a bogus type named `name`. func newNamedType(name string) *types.Named { - return types.NewNamed(types.NewTypeName(token.NoPos, nil, name, nil), nil, nil) + return types.NewNamed(types.NewTypeName(token.NoPos, nil, name, nil), types.Universe.Lookup("int").Type(), nil) } // sccString is a utility for stringifying `nodeToScc`. Every @@ -123,7 +123,8 @@ func sccEqual(sccs1 []string, sccs2 []string) bool { // isRevTopSorted checks if sccs of `g` are sorted in reverse // topological order: -// for every edge x -> y in g, nodeToScc[x] > nodeToScc[y] +// +// for every edge x -> y in g, nodeToScc[x] > nodeToScc[y] func isRevTopSorted(g vtaGraph, nodeToScc map[node]int) bool { for n, succs := range g { for s := range succs { @@ -148,39 +149,39 @@ func setName(f *ssa.Function, name string) { // parentheses contain node types and F nodes stand for function // nodes whose content is function named F: // -// no-cycles: -// t0 (A) -> t1 (B) -> t2 (C) +// no-cycles: +// t0 (A) -> t1 (B) -> t2 (C) // -// trivial-cycle: -// <-------- <-------- -// | | | | -// t0 (A) -> t1 (B) -> +// trivial-cycle: +// <-------- <-------- +// | | | | +// t0 (A) -> t1 (B) -> // -// circle-cycle: -// t0 (A) -> t1 (A) -> t2 (B) -// | | -// <-------------------- +// circle-cycle: +// t0 (A) -> t1 (A) -> t2 (B) +// | | +// <-------------------- // -// fully-connected: -// t0 (A) <-> t1 (B) -// \ / -// t2(C) +// fully-connected: +// t0 (A) <-> t1 (B) +// \ / +// t2(C) // -// subsumed-scc: -// t0 (A) -> t1 (B) -> t2(B) -> t3 (A) -// | | | | -// | <--------- | -// <----------------------------- +// subsumed-scc: +// t0 (A) -> t1 (B) -> t2(B) -> t3 (A) +// | | | | +// | <--------- | +// <----------------------------- // -// more-realistic: -// <-------- -// | | -// t0 (A) --> -// ----------> -// | | -// t1 (A) -> t2 (B) -> F1 -> F2 -> F3 -> F4 -// | | | | -// <------- <------------ +// more-realistic: +// <-------- +// | | +// t0 (A) --> +// ----------> +// | | +// t1 (A) -> t2 (B) -> F1 -> F2 -> F3 -> F4 +// | | | | +// <------- <------------ func testSuite() map[string]vtaGraph { a := newNamedType("A") b := newNamedType("B") diff --git a/go/callgraph/vta/testdata/src/callgraph_generics.go b/go/callgraph/vta/testdata/src/callgraph_generics.go new file mode 100644 index 000000000..da3dca52a --- /dev/null +++ b/go/callgraph/vta/testdata/src/callgraph_generics.go @@ -0,0 +1,71 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// go:build ignore + +package testdata + +func instantiated[X any](x *X) int { + print(x) + return 0 +} + +type I interface { + Bar() +} + +func interfaceInstantiated[X I](x X) { + x.Bar() +} + +type A struct{} + +func (a A) Bar() {} + +type B struct{} + +func (b B) Bar() {} + +func Foo(a A, b B) { + x := true + instantiated[bool](&x) + y := 1 + instantiated[int](&y) + + interfaceInstantiated[A](a) + interfaceInstantiated[B](b) +} + +// Relevant SSA: +//func Foo(a A, b B): +// t0 = local A (a) +// *t0 = a +// t1 = local B (b) +// *t1 = b +// t2 = new bool (x) +// *t2 = true:bool +// t3 = instantiated[bool](t2) +// t4 = new int (y) +// *t4 = 1:int +// t5 = instantiated[int](t4) +// t6 = *t0 +// t7 = interfaceInstantiated[testdata.A](t6) +// t8 = *t1 +// t9 = interfaceInstantiated[testdata.B](t8) +// return +// +//func interfaceInstantiated[testdata.B](x B): +// t0 = local B (x) +// *t0 = x +// t1 = *t0 +// t2 = (B).Bar(t1) +// return +// +//func interfaceInstantiated[X I](x X): +// (external) + +// WANT: +// Foo: instantiated[bool](t2) -> instantiated[bool]; instantiated[int](t4) -> instantiated[int]; interfaceInstantiated[testdata.A](t6) -> interfaceInstantiated[testdata.A]; interfaceInstantiated[testdata.B](t8) -> interfaceInstantiated[testdata.B] +// interfaceInstantiated[testdata.B]: (B).Bar(t1) -> B.Bar +// interfaceInstantiated[testdata.A]: (A).Bar(t1) -> A.Bar diff --git a/go/callgraph/vta/testdata/src/callgraph_issue_57756.go b/go/callgraph/vta/testdata/src/callgraph_issue_57756.go new file mode 100644 index 000000000..e18f16eba --- /dev/null +++ b/go/callgraph/vta/testdata/src/callgraph_issue_57756.go @@ -0,0 +1,67 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// go:build ignore + +package testdata + +// Test that the values of a named function type are correctly +// flowing from interface objects i in i.Foo() to the receiver +// parameters of callees of i.Foo(). + +type H func() + +func (h H) Do() { + h() +} + +type I interface { + Do() +} + +func Bar() I { + return H(func() {}) +} + +func For(g G) { + b := Bar() + b.Do() + + g[0] = b + g.Goo() +} + +type G []I + +func (g G) Goo() { + g[0].Do() +} + +// Relevant SSA: +// func Bar$1(): +// return +// +// func Bar() I: +// t0 = changetype H <- func() (Bar$1) +// t1 = make I <- H (t0) +// +// func For(): +// t0 = Bar() +// t1 = invoke t0.Do() +// t2 = &g[0:int] +// *t2 = t0 +// t3 = (G).Goo(g) +// +// func (h H) Do(): +// t0 = h() +// +// func (g G) Goo(): +// t0 = &g[0:int] +// t1 = *t0 +// t2 = invoke t1.Do() + +// WANT: +// For: (G).Goo(g) -> G.Goo; Bar() -> Bar; invoke t0.Do() -> H.Do +// H.Do: h() -> Bar$1 +// G.Goo: invoke t1.Do() -> H.Do diff --git a/go/callgraph/vta/testdata/src/callgraph_recursive_types.go b/go/callgraph/vta/testdata/src/callgraph_recursive_types.go new file mode 100644 index 000000000..6c3fef6f7 --- /dev/null +++ b/go/callgraph/vta/testdata/src/callgraph_recursive_types.go @@ -0,0 +1,56 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// go:build ignore + +package testdata + +type I interface { + Foo() I +} + +type A struct { + i int + a *A +} + +func (a *A) Foo() I { + return a +} + +type B **B + +type C *D +type D *C + +func Bar(a *A, b *B, c *C, d *D) { + Baz(a) + Baz(a.a) + + sink(*b) + sink(*c) + sink(*d) +} + +func Baz(i I) { + i.Foo() +} + +func sink(i interface{}) { + print(i) +} + +// Relevant SSA: +// func Baz(i I): +// t0 = invoke i.Foo() +// return +// +// func Bar(a *A, b *B): +// t0 = make I <- *A (a) +// t1 = Baz(t0) +// ... + +// WANT: +// Bar: Baz(t0) -> Baz; Baz(t4) -> Baz; sink(t10) -> sink; sink(t13) -> sink; sink(t7) -> sink +// Baz: invoke i.Foo() -> A.Foo diff --git a/go/callgraph/vta/testdata/src/function_alias.go b/go/callgraph/vta/testdata/src/function_alias.go index b38e0e00d..0a8dffe79 100644 --- a/go/callgraph/vta/testdata/src/function_alias.go +++ b/go/callgraph/vta/testdata/src/function_alias.go @@ -33,42 +33,42 @@ func Baz(f func()) { // t2 = *t1 // *t2 = Baz$1 // t3 = local A (a) -// t4 = &t3.foo [#0] -// t5 = *t1 -// t6 = *t5 -// *t4 = t6 +// t4 = *t1 +// t5 = *t4 +// t6 = &t3.foo [#0] +// *t6 = t5 // t7 = &t3.foo [#0] // t8 = *t7 // t9 = t8() -// t10 = &t3.do [#1] *Doer -// t11 = &t3.foo [#0] *func() -// t12 = *t11 func() -// t13 = changetype Doer <- func() (t12) Doer -// *t10 = t13 +// t10 = &t3.foo [#0] *func() +// t11 = *t10 func() +// t12 = &t3.do [#1] *Doer +// t13 = changetype Doer <- func() (t11) Doer +// *t12 = t13 // t14 = &t3.do [#1] *Doer // t15 = *t14 Doer // t16 = t15() () // Flow chain showing that Baz$1 reaches t8(): -// Baz$1 -> t2 <-> PtrFunction(func()) <-> t5 -> t6 -> t4 <-> Field(testdata.A:foo) <-> t7 -> t8 +// Baz$1 -> t2 <-> PtrFunction(func()) <-> t4 -> t5 -> t6 <-> Field(testdata.A:foo) <-> t7 -> t8 // Flow chain showing that Baz$1 reaches t15(): -// Field(testdata.A:foo) <-> t11 -> t12 -> t13 -> t10 <-> Field(testdata.A:do) <-> t14 -> t15 +// Field(testdata.A:foo) <-> t10 -> t11 -> t13 -> t12 <-> Field(testdata.A:do) <-> t14 -> t15 // WANT: // Local(f) -> Local(t0) // Local(t0) -> PtrFunction(func()) // Function(Baz$1) -> Local(t2) -// PtrFunction(func()) -> Local(t0), Local(t2), Local(t5) +// PtrFunction(func()) -> Local(t0), Local(t2), Local(t4) // Local(t2) -> PtrFunction(func()) -// Local(t4) -> Field(testdata.A:foo) -// Local(t5) -> Local(t6), PtrFunction(func()) -// Local(t6) -> Local(t4) +// Local(t6) -> Field(testdata.A:foo) +// Local(t4) -> Local(t5), PtrFunction(func()) +// Local(t5) -> Local(t6) // Local(t7) -> Field(testdata.A:foo), Local(t8) -// Field(testdata.A:foo) -> Local(t11), Local(t4), Local(t7) -// Local(t4) -> Field(testdata.A:foo) -// Field(testdata.A:do) -> Local(t10), Local(t14) -// Local(t10) -> Field(testdata.A:do) -// Local(t11) -> Field(testdata.A:foo), Local(t12) -// Local(t12) -> Local(t13) -// Local(t13) -> Local(t10) +// Field(testdata.A:foo) -> Local(t10), Local(t6), Local(t7) +// Local(t6) -> Field(testdata.A:foo) +// Field(testdata.A:do) -> Local(t12), Local(t14) +// Local(t12) -> Field(testdata.A:do) +// Local(t10) -> Field(testdata.A:foo), Local(t11) +// Local(t11) -> Local(t13) +// Local(t13) -> Local(t12) // Local(t14) -> Field(testdata.A:do), Local(t15) diff --git a/go/callgraph/vta/testdata/src/panic.go b/go/callgraph/vta/testdata/src/panic.go index 2d39c70ea..5ef354857 100644 --- a/go/callgraph/vta/testdata/src/panic.go +++ b/go/callgraph/vta/testdata/src/panic.go @@ -27,12 +27,12 @@ func recover2() { func Baz(a A) { defer recover1() + defer recover() panic(a) } // Relevant SSA: // func recover1(): -// 0: // t0 = print("only this recover...":string) // t1 = recover() // t2 = typeassert,ok t1.(I) @@ -53,6 +53,7 @@ func Baz(a A) { // t0 = local A (a) // *t0 = a // defer recover1() +// defer recover() // t1 = *t0 // t2 = make interface{} <- A (t1) // panic t2 diff --git a/go/callgraph/vta/utils.go b/go/callgraph/vta/utils.go index e7a97e2d8..d1831983a 100644 --- a/go/callgraph/vta/utils.go +++ b/go/callgraph/vta/utils.go @@ -9,6 +9,7 @@ import ( "golang.org/x/tools/go/callgraph" "golang.org/x/tools/go/ssa" + "golang.org/x/tools/internal/typeparams" ) func canAlias(n1, n2 node) bool { @@ -32,13 +33,13 @@ func isReferenceNode(n node) bool { // hasInFlow checks if a concrete type can flow to node `n`. // Returns yes iff the type of `n` satisfies one the following: -// 1) is an interface -// 2) is a (nested) pointer to interface (needed for, say, +// 1. is an interface +// 2. is a (nested) pointer to interface (needed for, say, // slice elements of nested pointers to interface type) -// 3) is a function type (needed for higher-order type flow) -// 4) is a (nested) pointer to function (needed for, say, +// 3. is a function type (needed for higher-order type flow) +// 4. is a (nested) pointer to function (needed for, say, // slice elements of nested pointers to function type) -// 5) is a global Recover or Panic node +// 5. is a global Recover or Panic node func hasInFlow(n node) bool { if _, ok := n.(panicArg); ok { return true @@ -56,24 +57,7 @@ func hasInFlow(n node) bool { return true } - return isInterface(t) || isFunction(t) -} - -// hasInitialTypes check if a node can have initial types. -// Returns true iff `n` is not a panic or recover node as -// those are artificial. -func hasInitialTypes(n node) bool { - switch n.(type) { - case panicArg, recoverReturn: - return false - default: - return true - } -} - -func isInterface(t types.Type) bool { - _, ok := t.Underlying().(*types.Interface) - return ok + return types.IsInterface(t) || isFunction(t) } func isFunction(t types.Type) bool { @@ -85,48 +69,76 @@ func isFunction(t types.Type) bool { // pointer to interface and if yes, returns the interface type. // Otherwise, returns nil. func interfaceUnderPtr(t types.Type) types.Type { - p, ok := t.Underlying().(*types.Pointer) - if !ok { - return nil - } + seen := make(map[types.Type]bool) + var visit func(types.Type) types.Type + visit = func(t types.Type) types.Type { + if seen[t] { + return nil + } + seen[t] = true - if isInterface(p.Elem()) { - return p.Elem() - } + p, ok := t.Underlying().(*types.Pointer) + if !ok { + return nil + } + + if types.IsInterface(p.Elem()) { + return p.Elem() + } - return interfaceUnderPtr(p.Elem()) + return visit(p.Elem()) + } + return visit(t) } // functionUnderPtr checks if type `t` is a potentially nested // pointer to function type and if yes, returns the function type. // Otherwise, returns nil. func functionUnderPtr(t types.Type) types.Type { - p, ok := t.Underlying().(*types.Pointer) - if !ok { - return nil - } + seen := make(map[types.Type]bool) + var visit func(types.Type) types.Type + visit = func(t types.Type) types.Type { + if seen[t] { + return nil + } + seen[t] = true - if isFunction(p.Elem()) { - return p.Elem() - } + p, ok := t.Underlying().(*types.Pointer) + if !ok { + return nil + } + + if isFunction(p.Elem()) { + return p.Elem() + } - return functionUnderPtr(p.Elem()) + return visit(p.Elem()) + } + return visit(t) } // sliceArrayElem returns the element type of type `t` that is -// expected to be a (pointer to) array or slice, consistent with +// expected to be a (pointer to) array, slice or string, consistent with // the ssa.Index and ssa.IndexAddr instructions. Panics otherwise. func sliceArrayElem(t types.Type) types.Type { - u := t.Underlying() - - if p, ok := u.(*types.Pointer); ok { - u = p.Elem().Underlying() - } - - if a, ok := u.(*types.Array); ok { - return a.Elem() + switch u := t.Underlying().(type) { + case *types.Pointer: + return u.Elem().Underlying().(*types.Array).Elem() + case *types.Array: + return u.Elem() + case *types.Slice: + return u.Elem() + case *types.Basic: + return types.Typ[types.Byte] + case *types.Interface: // type param. + terms, err := typeparams.InterfaceTermSet(u) + if err != nil || len(terms) == 0 { + panic(t) + } + return sliceArrayElem(terms[0].Type()) // Element types must match. + default: + panic(t) } - return u.(*types.Slice).Elem() } // siteCallees computes a set of callees for call site `c` given program `callgraph`. diff --git a/go/callgraph/vta/vta.go b/go/callgraph/vta/vta.go index 98fabe58c..583936003 100644 --- a/go/callgraph/vta/vta.go +++ b/go/callgraph/vta/vta.go @@ -3,7 +3,7 @@ // license that can be found in the LICENSE file. // Package vta computes the call graph of a Go program using the Variable -// Type Analysis (VTA) algorithm originally described in ``Practical Virtual +// Type Analysis (VTA) algorithm originally described in “Practical Virtual // Method Call Resolution for Java," Vijay Sundaresan, Laurie Hendren, // Chrislain Razafimahefa, Raja Vallée-Rai, Patrick Lam, Etienne Gagnon, and // Charles Godin. @@ -18,22 +18,23 @@ // // A type propagation is a directed, labeled graph. A node can represent // one of the following: -// - A field of a struct type. -// - A local (SSA) variable of a method/function. -// - All pointers to a non-interface type. -// - The return value of a method. -// - All elements in an array. -// - All elements in a slice. -// - All elements in a map. -// - All elements in a channel. -// - A global variable. +// - A field of a struct type. +// - A local (SSA) variable of a method/function. +// - All pointers to a non-interface type. +// - The return value of a method. +// - All elements in an array. +// - All elements in a slice. +// - All elements in a map. +// - All elements in a channel. +// - A global variable. +// // In addition, the implementation used in this package introduces // a few Go specific kinds of nodes: -// - (De)references of nested pointers to interfaces are modeled -// as a unique nestedPtrInterface node in the type propagation graph. -// - Each function literal is represented as a function node whose -// internal value is the (SSA) representation of the function. This -// is done to precisely infer flow of higher-order functions. +// - (De)references of nested pointers to interfaces are modeled +// as a unique nestedPtrInterface node in the type propagation graph. +// - Each function literal is represented as a function node whose +// internal value is the (SSA) representation of the function. This +// is done to precisely infer flow of higher-order functions. // // Edges in the graph represent flow of types (and function literals) through // the program. That is, the model 1) typing constraints that are induced by @@ -53,6 +54,8 @@ // reaching the node representing the call site to create a set of callees. package vta +// TODO(zpavlinovic): update VTA for how it handles generic function bodies and instantiation wrappers. + import ( "go/types" diff --git a/go/callgraph/vta/vta_go117_test.go b/go/callgraph/vta/vta_go117_test.go index 9ce6a8864..04f6980e5 100644 --- a/go/callgraph/vta/vta_go117_test.go +++ b/go/callgraph/vta/vta_go117_test.go @@ -11,12 +11,13 @@ import ( "testing" "golang.org/x/tools/go/callgraph/cha" + "golang.org/x/tools/go/ssa" "golang.org/x/tools/go/ssa/ssautil" ) func TestVTACallGraphGo117(t *testing.T) { file := "testdata/src/go117.go" - prog, want, err := testProg(file) + prog, want, err := testProg(file, ssa.BuilderMode(0)) if err != nil { t.Fatalf("couldn't load test file '%s': %s", file, err) } diff --git a/go/callgraph/vta/vta_test.go b/go/callgraph/vta/vta_test.go index 33ceaf909..549c4af45 100644 --- a/go/callgraph/vta/vta_test.go +++ b/go/callgraph/vta/vta_test.go @@ -13,6 +13,7 @@ import ( "golang.org/x/tools/go/callgraph/cha" "golang.org/x/tools/go/ssa" "golang.org/x/tools/go/ssa/ssautil" + "golang.org/x/tools/internal/typeparams" ) func TestVTACallGraph(t *testing.T) { @@ -24,9 +25,11 @@ func TestVTACallGraph(t *testing.T) { "testdata/src/callgraph_collections.go", "testdata/src/callgraph_fields.go", "testdata/src/callgraph_field_funcs.go", + "testdata/src/callgraph_recursive_types.go", + "testdata/src/callgraph_issue_57756.go", } { t.Run(file, func(t *testing.T) { - prog, want, err := testProg(file) + prog, want, err := testProg(file, ssa.BuilderMode(0)) if err != nil { t.Fatalf("couldn't load test file '%s': %s", file, err) } @@ -46,7 +49,7 @@ func TestVTACallGraph(t *testing.T) { // enabled by having an arbitrary function set as input to CallGraph // instead of the whole program (i.e., ssautil.AllFunctions(prog)). func TestVTAProgVsFuncSet(t *testing.T) { - prog, want, err := testProg("testdata/src/callgraph_nested_ptr.go") + prog, want, err := testProg("testdata/src/callgraph_nested_ptr.go", ssa.BuilderMode(0)) if err != nil { t.Fatalf("couldn't load test `testdata/src/callgraph_nested_ptr.go`: %s", err) } @@ -111,3 +114,24 @@ func TestVTAPanicMissingDefinitions(t *testing.T) { } } } + +func TestVTACallGraphGenerics(t *testing.T) { + if !typeparams.Enabled { + t.Skip("TestVTACallGraphGenerics requires type parameters") + } + + // TODO(zpavlinovic): add more tests + file := "testdata/src/callgraph_generics.go" + prog, want, err := testProg(file, ssa.InstantiateGenerics) + if err != nil { + t.Fatalf("couldn't load test file '%s': %s", file, err) + } + if len(want) == 0 { + t.Fatalf("couldn't find want in `%s`", file) + } + + g := CallGraph(ssautil.AllFunctions(prog), cha.CallGraph(prog)) + if got := callGraphStr(g); !subGraph(want, got) { + t.Errorf("computed callgraph %v should contain %v", got, want) + } +} |