aboutsummaryrefslogtreecommitdiff
path: root/go/callgraph/vta
diff options
context:
space:
mode:
Diffstat (limited to 'go/callgraph/vta')
-rw-r--r--go/callgraph/vta/graph.go126
-rw-r--r--go/callgraph/vta/graph_test.go7
-rw-r--r--go/callgraph/vta/helpers_test.go8
-rw-r--r--go/callgraph/vta/internal/trie/bits.go25
-rw-r--r--go/callgraph/vta/internal/trie/builder.go24
-rw-r--r--go/callgraph/vta/internal/trie/trie.go12
-rw-r--r--go/callgraph/vta/propagation.go57
-rw-r--r--go/callgraph/vta/propagation_test.go61
-rw-r--r--go/callgraph/vta/testdata/src/callgraph_generics.go71
-rw-r--r--go/callgraph/vta/testdata/src/callgraph_issue_57756.go67
-rw-r--r--go/callgraph/vta/testdata/src/callgraph_recursive_types.go56
-rw-r--r--go/callgraph/vta/testdata/src/function_alias.go44
-rw-r--r--go/callgraph/vta/testdata/src/panic.go3
-rw-r--r--go/callgraph/vta/utils.go110
-rw-r--r--go/callgraph/vta/vta.go33
-rw-r--r--go/callgraph/vta/vta_go117_test.go3
-rw-r--r--go/callgraph/vta/vta_test.go28
17 files changed, 543 insertions, 192 deletions
diff --git a/go/callgraph/vta/graph.go b/go/callgraph/vta/graph.go
index ad7ef0e88..2537123f4 100644
--- a/go/callgraph/vta/graph.go
+++ b/go/callgraph/vta/graph.go
@@ -12,6 +12,7 @@ import (
"golang.org/x/tools/go/callgraph"
"golang.org/x/tools/go/ssa"
"golang.org/x/tools/go/types/typeutil"
+ "golang.org/x/tools/internal/typeparams"
)
// node interface for VTA nodes.
@@ -175,9 +176,10 @@ func (f function) String() string {
// We merge such constructs into a single node for simplicity and without
// much precision sacrifice as such variables are rare in practice. Both
// a and b would be represented as the same PtrInterface(I) node in:
-// type I interface
-// var a ***I
-// var b **I
+//
+// type I interface
+// var a ***I
+// var b **I
type nestedPtrInterface struct {
typ types.Type
}
@@ -195,8 +197,9 @@ func (l nestedPtrInterface) String() string {
// constructs into a single node for simplicity and without much precision
// sacrifice as such variables are rare in practice. Both a and b would be
// represented as the same PtrFunction(func()) node in:
-// var a *func()
-// var b **func()
+//
+// var a *func()
+// var b **func()
type nestedPtrFunction struct {
typ types.Type
}
@@ -325,14 +328,16 @@ func (b *builder) instr(instr ssa.Instruction) {
// change type command a := A(b) results in a and b being the
// same value. For concrete type A, there is no interesting flow.
//
- // Note: When A is an interface, most interface casts are handled
+ // When A is an interface, most interface casts are handled
// by the ChangeInterface instruction. The relevant case here is
// when converting a pointer to an interface type. This can happen
// when the underlying interfaces have the same method set.
- // type I interface{ foo() }
- // type J interface{ foo() }
- // var b *I
- // a := (*J)(b)
+ //
+ // type I interface{ foo() }
+ // type J interface{ foo() }
+ // var b *I
+ // a := (*J)(b)
+ //
// When this happens we add flows between a <--> b.
b.addInFlowAliasEdges(b.nodeFromVal(i), b.nodeFromVal(i.X))
case *ssa.TypeAssert:
@@ -371,6 +376,8 @@ func (b *builder) instr(instr ssa.Instruction) {
// SliceToArrayPointer: t1 = slice to array pointer *[4]T <- []T (t0)
// No interesting flow as sliceArrayElem(t1) == sliceArrayElem(t0).
return
+ case *ssa.MultiConvert:
+ b.multiconvert(i)
default:
panic(fmt.Sprintf("unsupported instruction %v\n", instr))
}
@@ -441,7 +448,9 @@ func (b *builder) send(s *ssa.Send) {
}
// selekt generates flows for select statement
-// a = select blocking/nonblocking [c_1 <- t_1, c_2 <- t_2, ..., <- o_1, <- o_2, ...]
+//
+// a = select blocking/nonblocking [c_1 <- t_1, c_2 <- t_2, ..., <- o_1, <- o_2, ...]
+//
// between receiving channel registers c_i and corresponding input register t_i. Further,
// flows are generated between o_i and a[2 + i]. Note that a is a tuple register of type
// <int, bool, r_1, r_2, ...> where the type of r_i is the element type of channel o_i.
@@ -544,8 +553,9 @@ func (b *builder) closure(c *ssa.MakeClosure) {
// panic creates a flow from arguments to panic instructions to return
// registers of all recover statements in the program. Introduces a
// global panic node Panic and
-// 1) for every panic statement p: add p -> Panic
-// 2) for every recover statement r: add Panic -> r (handled in call)
+// 1. for every panic statement p: add p -> Panic
+// 2. for every recover statement r: add Panic -> r (handled in call)
+//
// TODO(zpavlinovic): improve precision by explicitly modeling how panic
// values flow from callees to callers and into deferred recover instructions.
func (b *builder) panic(p *ssa.Panic) {
@@ -563,7 +573,9 @@ func (b *builder) panic(p *ssa.Panic) {
func (b *builder) call(c ssa.CallInstruction) {
// When c is r := recover() call register instruction, we add Recover -> r.
if bf, ok := c.Common().Value.(*ssa.Builtin); ok && bf.Name() == "recover" {
- b.addInFlowEdge(recoverReturn{}, b.nodeFromVal(c.(*ssa.Call)))
+ if v, ok := c.(ssa.Value); ok {
+ b.addInFlowEdge(recoverReturn{}, b.nodeFromVal(v))
+ }
return
}
@@ -581,10 +593,18 @@ func addArgumentFlows(b *builder, c ssa.CallInstruction, f *ssa.Function) {
return
}
cc := c.Common()
- // When c is an unresolved method call (cc.Method != nil), cc.Value contains
- // the receiver object rather than cc.Args[0].
if cc.Method != nil {
- b.addInFlowAliasEdges(b.nodeFromVal(f.Params[0]), b.nodeFromVal(cc.Value))
+ // In principle we don't add interprocedural flows for receiver
+ // objects. At a call site, the receiver object is interface
+ // while the callee object is concrete. The flow from interface
+ // to concrete type in general does not make sense. The exception
+ // is when the concrete type is a named function type (see #57756).
+ //
+ // The flow other way around would bake in information from the
+ // initial call graph.
+ if isFunction(f.Params[0].Type()) {
+ b.addInFlowEdge(b.nodeFromVal(cc.Value), b.nodeFromVal(f.Params[0]))
+ }
}
offset := 0
@@ -638,6 +658,71 @@ func addReturnFlows(b *builder, r *ssa.Return, site ssa.Value) {
}
}
+func (b *builder) multiconvert(c *ssa.MultiConvert) {
+ // TODO(zpavlinovic): decide what to do on MultiConvert long term.
+ // TODO(zpavlinovic): add unit tests.
+ typeSetOf := func(typ types.Type) []*typeparams.Term {
+ // This is a adaptation of x/exp/typeparams.NormalTerms which x/tools cannot depend on.
+ var terms []*typeparams.Term
+ var err error
+ switch typ := typ.(type) {
+ case *typeparams.TypeParam:
+ terms, err = typeparams.StructuralTerms(typ)
+ case *typeparams.Union:
+ terms, err = typeparams.UnionTermSet(typ)
+ case *types.Interface:
+ terms, err = typeparams.InterfaceTermSet(typ)
+ default:
+ // Common case.
+ // Specializing the len=1 case to avoid a slice
+ // had no measurable space/time benefit.
+ terms = []*typeparams.Term{typeparams.NewTerm(false, typ)}
+ }
+
+ if err != nil {
+ return nil
+ }
+ return terms
+ }
+ // isValuePreserving returns true if a conversion from ut_src to
+ // ut_dst is value-preserving, i.e. just a change of type.
+ // Precondition: neither argument is a named type.
+ isValuePreserving := func(ut_src, ut_dst types.Type) bool {
+ // Identical underlying types?
+ if types.IdenticalIgnoreTags(ut_dst, ut_src) {
+ return true
+ }
+
+ switch ut_dst.(type) {
+ case *types.Chan:
+ // Conversion between channel types?
+ _, ok := ut_src.(*types.Chan)
+ return ok
+
+ case *types.Pointer:
+ // Conversion between pointers with identical base types?
+ _, ok := ut_src.(*types.Pointer)
+ return ok
+ }
+ return false
+ }
+ dst_terms := typeSetOf(c.Type())
+ src_terms := typeSetOf(c.X.Type())
+ for _, s := range src_terms {
+ us := s.Type().Underlying()
+ for _, d := range dst_terms {
+ ud := d.Type().Underlying()
+ if isValuePreserving(us, ud) {
+ // This is equivalent to a ChangeType.
+ b.addInFlowAliasEdges(b.nodeFromVal(c), b.nodeFromVal(c.X))
+ return
+ }
+ // This is equivalent to either: SliceToArrayPointer,,
+ // SliceToArrayPointer+Deref, Size 0 Array constant, or a Convert.
+ }
+ }
+}
+
// addInFlowEdge adds s -> d to g if d is node that can have an inflow, i.e., a node
// that represents an interface or an unresolved function value. Otherwise, there
// is no interesting type flow so the edge is omitted.
@@ -649,7 +734,7 @@ func (b *builder) addInFlowEdge(s, d node) {
// Creates const, pointer, global, func, and local nodes based on register instructions.
func (b *builder) nodeFromVal(val ssa.Value) node {
- if p, ok := val.Type().(*types.Pointer); ok && !isInterface(p.Elem()) && !isFunction(p.Elem()) {
+ if p, ok := val.Type().(*types.Pointer); ok && !types.IsInterface(p.Elem()) && !isFunction(p.Elem()) {
// Nested pointer to interfaces are modeled as a special
// nestedPtrInterface node.
if i := interfaceUnderPtr(p.Elem()); i != nil {
@@ -676,14 +761,15 @@ func (b *builder) nodeFromVal(val ssa.Value) node {
default:
panic(fmt.Errorf("unsupported value %v in node creation", val))
}
- return nil
}
// representative returns a unique representative for node `n`. Since
// semantically equivalent types can have different implementations,
// this method guarantees the same implementation is always used.
func (b *builder) representative(n node) node {
- if !hasInitialTypes(n) {
+ if n.Type() == nil {
+ // panicArg and recoverReturn do not have
+ // types and are unique by definition.
return n
}
t := canonicalize(n.Type(), &b.canon)
diff --git a/go/callgraph/vta/graph_test.go b/go/callgraph/vta/graph_test.go
index 8608844dd..8b8c6976f 100644
--- a/go/callgraph/vta/graph_test.go
+++ b/go/callgraph/vta/graph_test.go
@@ -13,6 +13,7 @@ import (
"testing"
"golang.org/x/tools/go/callgraph/cha"
+ "golang.org/x/tools/go/ssa"
"golang.org/x/tools/go/ssa/ssautil"
)
@@ -24,7 +25,7 @@ func TestNodeInterface(t *testing.T) {
// - global variable "gl"
// - "main" function and its
// - first register instruction t0 := *gl
- prog, _, err := testProg("testdata/src/simple.go")
+ prog, _, err := testProg("testdata/src/simple.go", ssa.BuilderMode(0))
if err != nil {
t.Fatalf("couldn't load testdata/src/simple.go program: %v", err)
}
@@ -78,7 +79,7 @@ func TestNodeInterface(t *testing.T) {
func TestVtaGraph(t *testing.T) {
// Get the basic type int from a real program.
- prog, _, err := testProg("testdata/src/simple.go")
+ prog, _, err := testProg("testdata/src/simple.go", ssa.BuilderMode(0))
if err != nil {
t.Fatalf("couldn't load testdata/src/simple.go program: %v", err)
}
@@ -191,7 +192,7 @@ func TestVTAGraphConstruction(t *testing.T) {
"testdata/src/panic.go",
} {
t.Run(file, func(t *testing.T) {
- prog, want, err := testProg(file)
+ prog, want, err := testProg(file, ssa.BuilderMode(0))
if err != nil {
t.Fatalf("couldn't load test file '%s': %s", file, err)
}
diff --git a/go/callgraph/vta/helpers_test.go b/go/callgraph/vta/helpers_test.go
index 0e00aeb28..768365f5b 100644
--- a/go/callgraph/vta/helpers_test.go
+++ b/go/callgraph/vta/helpers_test.go
@@ -35,7 +35,7 @@ func want(f *ast.File) []string {
// testProg returns an ssa representation of a program at
// `path`, assumed to define package "testdata," and the
// test want result as list of strings.
-func testProg(path string) (*ssa.Program, []string, error) {
+func testProg(path string, mode ssa.BuilderMode) (*ssa.Program, []string, error) {
content, err := ioutil.ReadFile(path)
if err != nil {
return nil, nil, err
@@ -56,7 +56,7 @@ func testProg(path string) (*ssa.Program, []string, error) {
return nil, nil, err
}
- prog := ssautil.CreateProgram(iprog, 0)
+ prog := ssautil.CreateProgram(iprog, mode)
// Set debug mode to exercise DebugRef instructions.
prog.Package(iprog.Created[0].Pkg).SetDebugMode(true)
prog.Build()
@@ -87,7 +87,9 @@ func funcName(f *ssa.Function) string {
// callGraphStr stringifes `g` into a list of strings where
// each entry is of the form
-// f: cs1 -> f1, f2, ...; ...; csw -> fx, fy, ...
+//
+// f: cs1 -> f1, f2, ...; ...; csw -> fx, fy, ...
+//
// f is a function, cs1, ..., csw are call sites in f, and
// f1, f2, ..., fx, fy, ... are the resolved callees.
func callGraphStr(g *callgraph.Graph) []string {
diff --git a/go/callgraph/vta/internal/trie/bits.go b/go/callgraph/vta/internal/trie/bits.go
index f2fd0ba83..c3aa15985 100644
--- a/go/callgraph/vta/internal/trie/bits.go
+++ b/go/callgraph/vta/internal/trie/bits.go
@@ -19,11 +19,11 @@ type key uint64
// bitpos is the position of a bit. A position is represented by having a 1
// bit in that position.
// Examples:
-// * 0b0010 is the position of the `1` bit in 2.
-// It is the 3rd most specific bit position in big endian encoding
-// (0b0 and 0b1 are more specific).
-// * 0b0100 is the position of the bit that 1 and 5 disagree on.
-// * 0b0 is a special value indicating that all bit agree.
+// - 0b0010 is the position of the `1` bit in 2.
+// It is the 3rd most specific bit position in big endian encoding
+// (0b0 and 0b1 are more specific).
+// - 0b0100 is the position of the bit that 1 and 5 disagree on.
+// - 0b0 is a special value indicating that all bit agree.
type bitpos uint64
// prefixes represent a set of keys that all agree with the
@@ -35,7 +35,8 @@ type bitpos uint64
// A prefix always mask(p, m) == p.
//
// A key is its own prefix for the bit position 64,
-// e.g. seeing a `prefix(key)` is not a problem.
+// e.g. seeing a `prefix(key)` is not a problem.
+//
// Prefixes should never be turned into keys.
type prefix uint64
@@ -64,8 +65,9 @@ func matchPrefix(k prefix, p prefix, b bitpos) bool {
// In big endian encoding, this value is the [64-(m-1)] most significant bits of k
// followed by a `0` bit at bitpos m, followed m-1 `1` bits.
// Examples:
-// prefix(0b1011) for a bitpos 0b0100 represents the keys:
-// 0b1000, 0b1001, 0b1010, 0b1011, 0b1100, 0b1101, 0b1110, 0b1111
+//
+// prefix(0b1011) for a bitpos 0b0100 represents the keys:
+// 0b1000, 0b1001, 0b1010, 0b1011, 0b1100, 0b1101, 0b1110, 0b1111
//
// This mask function has the property that if matchPrefix(k, p, b), then
// k <= p if and only if zeroBit(k, m). This induces binary search tree tries.
@@ -85,9 +87,10 @@ func ord(m, n bitpos) bool {
// can hold that can also be held by a prefix `q` for some bitpos `n`.
//
// This is equivalent to:
-// m ==n && p == q,
-// higher(m, n) && matchPrefix(q, p, m), or
-// higher(n, m) && matchPrefix(p, q, n)
+//
+// m ==n && p == q,
+// higher(m, n) && matchPrefix(q, p, m), or
+// higher(n, m) && matchPrefix(p, q, n)
func prefixesOverlap(p prefix, m bitpos, q prefix, n bitpos) bool {
fbb := n
if ord(m, n) {
diff --git a/go/callgraph/vta/internal/trie/builder.go b/go/callgraph/vta/internal/trie/builder.go
index 25d3805bc..11ff59b1b 100644
--- a/go/callgraph/vta/internal/trie/builder.go
+++ b/go/callgraph/vta/internal/trie/builder.go
@@ -9,7 +9,9 @@ package trie
// will be stored for the key.
//
// Collision functions must be idempotent:
-// collision(x, x) == x for all x.
+//
+// collision(x, x) == x for all x.
+//
// Collisions functions may be applied whenever a value is inserted
// or two maps are merged, or intersected.
type Collision func(lhs interface{}, rhs interface{}) interface{}
@@ -72,7 +74,8 @@ func (b *Builder) Empty() Map { return Map{b.Scope(), b.empty} }
// in the current scope and handle collisions using the collision function c.
//
// This is roughly corresponds to updating a map[uint64]interface{} by:
-// if _, ok := m[k]; ok { m[k] = c(m[k], v} else { m[k] = v}
+//
+// if _, ok := m[k]; ok { m[k] = c(m[k], v} else { m[k] = v}
//
// An insertion or update happened whenever Insert(m, ...) != m .
func (b *Builder) InsertWith(c Collision, m Map, k uint64, v interface{}) Map {
@@ -85,7 +88,8 @@ func (b *Builder) InsertWith(c Collision, m Map, k uint64, v interface{}) Map {
//
// If there was a previous value mapped by key, keep the previously mapped value.
// This is roughly corresponds to updating a map[uint64]interface{} by:
-// if _, ok := m[k]; ok { m[k] = val }
+//
+// if _, ok := m[k]; ok { m[k] = val }
//
// This is equivalent to b.Merge(m, b.Create({k: v})).
func (b *Builder) Insert(m Map, k uint64, v interface{}) Map {
@@ -94,7 +98,8 @@ func (b *Builder) Insert(m Map, k uint64, v interface{}) Map {
// Updates a (key, value) in the map. This is roughly corresponds to
// updating a map[uint64]interface{} by:
-// m[key] = val
+//
+// m[key] = val
func (b *Builder) Update(m Map, key uint64, val interface{}) Map {
return b.InsertWith(TakeRhs, m, key, val)
}
@@ -148,14 +153,17 @@ func (b *Builder) Remove(m Map, k uint64) Map {
// Intersect Maps lhs and rhs and returns a map with all of the keys in
// both lhs and rhs and the value comes from lhs, i.e.
-// {(k, lhs[k]) | k in lhs, k in rhs}.
+//
+// {(k, lhs[k]) | k in lhs, k in rhs}.
func (b *Builder) Intersect(lhs, rhs Map) Map {
return b.IntersectWith(TakeLhs, lhs, rhs)
}
// IntersectWith take lhs and rhs and returns the intersection
// with the value coming from the collision function, i.e.
-// {(k, c(lhs[k], rhs[k]) ) | k in lhs, k in rhs}.
+//
+// {(k, c(lhs[k], rhs[k]) ) | k in lhs, k in rhs}.
+//
// The elements of the resulting map are always { <k, c(lhs[k], rhs[k]) > }
// for each key k that a key in both lhs and rhs.
func (b *Builder) IntersectWith(c Collision, lhs, rhs Map) Map {
@@ -261,7 +269,9 @@ func (b *Builder) mkLeaf(k key, v interface{}) *leaf {
}
// mkBranch returns the hash-consed representative of the tuple
-// (prefix, branch, left, right)
+//
+// (prefix, branch, left, right)
+//
// in the current scope.
func (b *Builder) mkBranch(p prefix, bp bitpos, left node, right node) *branch {
br := &branch{
diff --git a/go/callgraph/vta/internal/trie/trie.go b/go/callgraph/vta/internal/trie/trie.go
index 160eb21be..511fde515 100644
--- a/go/callgraph/vta/internal/trie/trie.go
+++ b/go/callgraph/vta/internal/trie/trie.go
@@ -10,8 +10,10 @@
// environment abstract domains in program analysis).
//
// This implementation closely follows the paper:
-// C. Okasaki and A. Gill, “Fast mergeable integer maps,” in ACM SIGPLAN
-// Workshop on ML, September 1998, pp. 77–86.
+//
+// C. Okasaki and A. Gill, “Fast mergeable integer maps,” in ACM SIGPLAN
+// Workshop on ML, September 1998, pp. 77–86.
+//
// Each Map is immutable and can be read from concurrently. The map does not
// guarantee that the value pointed to by the interface{} value is not updated
// concurrently.
@@ -36,9 +38,9 @@ import (
// Maps are immutable and can be read from concurrently.
//
// Notes on concurrency:
-// - A Map value itself is an interface and assignments to a Map value can race.
-// - Map does not guarantee that the value pointed to by the interface{} value
-// is not updated concurrently.
+// - A Map value itself is an interface and assignments to a Map value can race.
+// - Map does not guarantee that the value pointed to by the interface{} value
+// is not updated concurrently.
type Map struct {
s Scope
n node
diff --git a/go/callgraph/vta/propagation.go b/go/callgraph/vta/propagation.go
index 5934ebc21..5817e8938 100644
--- a/go/callgraph/vta/propagation.go
+++ b/go/callgraph/vta/propagation.go
@@ -20,53 +20,52 @@ import (
// with ids X and Y s.t. X < Y, Y comes before X in the topological order.
func scc(g vtaGraph) (map[node]int, int) {
// standard data structures used by Tarjan's algorithm.
- var index uint64
+ type state struct {
+ index int
+ lowLink int
+ onStack bool
+ }
+ states := make(map[node]*state, len(g))
var stack []node
- indexMap := make(map[node]uint64)
- lowLink := make(map[node]uint64)
- onStack := make(map[node]bool)
- nodeToSccID := make(map[node]int)
+ nodeToSccID := make(map[node]int, len(g))
sccID := 0
var doSCC func(node)
doSCC = func(n node) {
- indexMap[n] = index
- lowLink[n] = index
- index = index + 1
- onStack[n] = true
+ index := len(states)
+ ns := &state{index: index, lowLink: index, onStack: true}
+ states[n] = ns
stack = append(stack, n)
for s := range g[n] {
- if _, ok := indexMap[s]; !ok {
+ if ss, visited := states[s]; !visited {
// Analyze successor s that has not been visited yet.
doSCC(s)
- lowLink[n] = min(lowLink[n], lowLink[s])
- } else if onStack[s] {
+ ss = states[s]
+ ns.lowLink = min(ns.lowLink, ss.lowLink)
+ } else if ss.onStack {
// The successor is on the stack, meaning it has to be
// in the current SCC.
- lowLink[n] = min(lowLink[n], indexMap[s])
+ ns.lowLink = min(ns.lowLink, ss.index)
}
}
// if n is a root node, pop the stack and generate a new SCC.
- if lowLink[n] == indexMap[n] {
- for {
- w := stack[len(stack)-1]
+ if ns.lowLink == index {
+ var w node
+ for w != n {
+ w = stack[len(stack)-1]
stack = stack[:len(stack)-1]
- onStack[w] = false
+ states[w].onStack = false
nodeToSccID[w] = sccID
- if w == n {
- break
- }
}
sccID++
}
}
- index = 0
for n := range g {
- if _, ok := indexMap[n]; !ok {
+ if _, visited := states[n]; !visited {
doSCC(n)
}
}
@@ -74,7 +73,7 @@ func scc(g vtaGraph) (map[node]int, int) {
return nodeToSccID, sccID
}
-func min(x, y uint64) uint64 {
+func min(x, y int) int {
if x < y {
return x
}
@@ -175,6 +174,18 @@ func nodeTypes(nodes []node, builder *trie.Builder, propTypeId func(p propType)
return &typeSet
}
+// hasInitialTypes check if a node can have initial types.
+// Returns true iff `n` is not a panic, recover, nestedPtr*
+// node, nor a node whose type is an interface.
+func hasInitialTypes(n node) bool {
+ switch n.(type) {
+ case panicArg, recoverReturn, nestedPtrFunction, nestedPtrInterface:
+ return false
+ default:
+ return !types.IsInterface(n.Type())
+ }
+}
+
// getPropType creates a propType for `node` based on its type.
// propType.typ is always node.Type(). If node is function, then
// propType.val is the underlying function; nil otherwise.
diff --git a/go/callgraph/vta/propagation_test.go b/go/callgraph/vta/propagation_test.go
index 96707417f..f4a754f96 100644
--- a/go/callgraph/vta/propagation_test.go
+++ b/go/callgraph/vta/propagation_test.go
@@ -58,7 +58,7 @@ func newLocal(name string, t types.Type) local {
// newNamedType creates a bogus type named `name`.
func newNamedType(name string) *types.Named {
- return types.NewNamed(types.NewTypeName(token.NoPos, nil, name, nil), nil, nil)
+ return types.NewNamed(types.NewTypeName(token.NoPos, nil, name, nil), types.Universe.Lookup("int").Type(), nil)
}
// sccString is a utility for stringifying `nodeToScc`. Every
@@ -123,7 +123,8 @@ func sccEqual(sccs1 []string, sccs2 []string) bool {
// isRevTopSorted checks if sccs of `g` are sorted in reverse
// topological order:
-// for every edge x -> y in g, nodeToScc[x] > nodeToScc[y]
+//
+// for every edge x -> y in g, nodeToScc[x] > nodeToScc[y]
func isRevTopSorted(g vtaGraph, nodeToScc map[node]int) bool {
for n, succs := range g {
for s := range succs {
@@ -148,39 +149,39 @@ func setName(f *ssa.Function, name string) {
// parentheses contain node types and F nodes stand for function
// nodes whose content is function named F:
//
-// no-cycles:
-// t0 (A) -> t1 (B) -> t2 (C)
+// no-cycles:
+// t0 (A) -> t1 (B) -> t2 (C)
//
-// trivial-cycle:
-// <-------- <--------
-// | | | |
-// t0 (A) -> t1 (B) ->
+// trivial-cycle:
+// <-------- <--------
+// | | | |
+// t0 (A) -> t1 (B) ->
//
-// circle-cycle:
-// t0 (A) -> t1 (A) -> t2 (B)
-// | |
-// <--------------------
+// circle-cycle:
+// t0 (A) -> t1 (A) -> t2 (B)
+// | |
+// <--------------------
//
-// fully-connected:
-// t0 (A) <-> t1 (B)
-// \ /
-// t2(C)
+// fully-connected:
+// t0 (A) <-> t1 (B)
+// \ /
+// t2(C)
//
-// subsumed-scc:
-// t0 (A) -> t1 (B) -> t2(B) -> t3 (A)
-// | | | |
-// | <--------- |
-// <-----------------------------
+// subsumed-scc:
+// t0 (A) -> t1 (B) -> t2(B) -> t3 (A)
+// | | | |
+// | <--------- |
+// <-----------------------------
//
-// more-realistic:
-// <--------
-// | |
-// t0 (A) -->
-// ---------->
-// | |
-// t1 (A) -> t2 (B) -> F1 -> F2 -> F3 -> F4
-// | | | |
-// <------- <------------
+// more-realistic:
+// <--------
+// | |
+// t0 (A) -->
+// ---------->
+// | |
+// t1 (A) -> t2 (B) -> F1 -> F2 -> F3 -> F4
+// | | | |
+// <------- <------------
func testSuite() map[string]vtaGraph {
a := newNamedType("A")
b := newNamedType("B")
diff --git a/go/callgraph/vta/testdata/src/callgraph_generics.go b/go/callgraph/vta/testdata/src/callgraph_generics.go
new file mode 100644
index 000000000..da3dca52a
--- /dev/null
+++ b/go/callgraph/vta/testdata/src/callgraph_generics.go
@@ -0,0 +1,71 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// go:build ignore
+
+package testdata
+
+func instantiated[X any](x *X) int {
+ print(x)
+ return 0
+}
+
+type I interface {
+ Bar()
+}
+
+func interfaceInstantiated[X I](x X) {
+ x.Bar()
+}
+
+type A struct{}
+
+func (a A) Bar() {}
+
+type B struct{}
+
+func (b B) Bar() {}
+
+func Foo(a A, b B) {
+ x := true
+ instantiated[bool](&x)
+ y := 1
+ instantiated[int](&y)
+
+ interfaceInstantiated[A](a)
+ interfaceInstantiated[B](b)
+}
+
+// Relevant SSA:
+//func Foo(a A, b B):
+// t0 = local A (a)
+// *t0 = a
+// t1 = local B (b)
+// *t1 = b
+// t2 = new bool (x)
+// *t2 = true:bool
+// t3 = instantiated[bool](t2)
+// t4 = new int (y)
+// *t4 = 1:int
+// t5 = instantiated[int](t4)
+// t6 = *t0
+// t7 = interfaceInstantiated[testdata.A](t6)
+// t8 = *t1
+// t9 = interfaceInstantiated[testdata.B](t8)
+// return
+//
+//func interfaceInstantiated[testdata.B](x B):
+// t0 = local B (x)
+// *t0 = x
+// t1 = *t0
+// t2 = (B).Bar(t1)
+// return
+//
+//func interfaceInstantiated[X I](x X):
+// (external)
+
+// WANT:
+// Foo: instantiated[bool](t2) -> instantiated[bool]; instantiated[int](t4) -> instantiated[int]; interfaceInstantiated[testdata.A](t6) -> interfaceInstantiated[testdata.A]; interfaceInstantiated[testdata.B](t8) -> interfaceInstantiated[testdata.B]
+// interfaceInstantiated[testdata.B]: (B).Bar(t1) -> B.Bar
+// interfaceInstantiated[testdata.A]: (A).Bar(t1) -> A.Bar
diff --git a/go/callgraph/vta/testdata/src/callgraph_issue_57756.go b/go/callgraph/vta/testdata/src/callgraph_issue_57756.go
new file mode 100644
index 000000000..e18f16eba
--- /dev/null
+++ b/go/callgraph/vta/testdata/src/callgraph_issue_57756.go
@@ -0,0 +1,67 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// go:build ignore
+
+package testdata
+
+// Test that the values of a named function type are correctly
+// flowing from interface objects i in i.Foo() to the receiver
+// parameters of callees of i.Foo().
+
+type H func()
+
+func (h H) Do() {
+ h()
+}
+
+type I interface {
+ Do()
+}
+
+func Bar() I {
+ return H(func() {})
+}
+
+func For(g G) {
+ b := Bar()
+ b.Do()
+
+ g[0] = b
+ g.Goo()
+}
+
+type G []I
+
+func (g G) Goo() {
+ g[0].Do()
+}
+
+// Relevant SSA:
+// func Bar$1():
+// return
+//
+// func Bar() I:
+// t0 = changetype H <- func() (Bar$1)
+// t1 = make I <- H (t0)
+//
+// func For():
+// t0 = Bar()
+// t1 = invoke t0.Do()
+// t2 = &g[0:int]
+// *t2 = t0
+// t3 = (G).Goo(g)
+//
+// func (h H) Do():
+// t0 = h()
+//
+// func (g G) Goo():
+// t0 = &g[0:int]
+// t1 = *t0
+// t2 = invoke t1.Do()
+
+// WANT:
+// For: (G).Goo(g) -> G.Goo; Bar() -> Bar; invoke t0.Do() -> H.Do
+// H.Do: h() -> Bar$1
+// G.Goo: invoke t1.Do() -> H.Do
diff --git a/go/callgraph/vta/testdata/src/callgraph_recursive_types.go b/go/callgraph/vta/testdata/src/callgraph_recursive_types.go
new file mode 100644
index 000000000..6c3fef6f7
--- /dev/null
+++ b/go/callgraph/vta/testdata/src/callgraph_recursive_types.go
@@ -0,0 +1,56 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// go:build ignore
+
+package testdata
+
+type I interface {
+ Foo() I
+}
+
+type A struct {
+ i int
+ a *A
+}
+
+func (a *A) Foo() I {
+ return a
+}
+
+type B **B
+
+type C *D
+type D *C
+
+func Bar(a *A, b *B, c *C, d *D) {
+ Baz(a)
+ Baz(a.a)
+
+ sink(*b)
+ sink(*c)
+ sink(*d)
+}
+
+func Baz(i I) {
+ i.Foo()
+}
+
+func sink(i interface{}) {
+ print(i)
+}
+
+// Relevant SSA:
+// func Baz(i I):
+// t0 = invoke i.Foo()
+// return
+//
+// func Bar(a *A, b *B):
+// t0 = make I <- *A (a)
+// t1 = Baz(t0)
+// ...
+
+// WANT:
+// Bar: Baz(t0) -> Baz; Baz(t4) -> Baz; sink(t10) -> sink; sink(t13) -> sink; sink(t7) -> sink
+// Baz: invoke i.Foo() -> A.Foo
diff --git a/go/callgraph/vta/testdata/src/function_alias.go b/go/callgraph/vta/testdata/src/function_alias.go
index b38e0e00d..0a8dffe79 100644
--- a/go/callgraph/vta/testdata/src/function_alias.go
+++ b/go/callgraph/vta/testdata/src/function_alias.go
@@ -33,42 +33,42 @@ func Baz(f func()) {
// t2 = *t1
// *t2 = Baz$1
// t3 = local A (a)
-// t4 = &t3.foo [#0]
-// t5 = *t1
-// t6 = *t5
-// *t4 = t6
+// t4 = *t1
+// t5 = *t4
+// t6 = &t3.foo [#0]
+// *t6 = t5
// t7 = &t3.foo [#0]
// t8 = *t7
// t9 = t8()
-// t10 = &t3.do [#1] *Doer
-// t11 = &t3.foo [#0] *func()
-// t12 = *t11 func()
-// t13 = changetype Doer <- func() (t12) Doer
-// *t10 = t13
+// t10 = &t3.foo [#0] *func()
+// t11 = *t10 func()
+// t12 = &t3.do [#1] *Doer
+// t13 = changetype Doer <- func() (t11) Doer
+// *t12 = t13
// t14 = &t3.do [#1] *Doer
// t15 = *t14 Doer
// t16 = t15() ()
// Flow chain showing that Baz$1 reaches t8():
-// Baz$1 -> t2 <-> PtrFunction(func()) <-> t5 -> t6 -> t4 <-> Field(testdata.A:foo) <-> t7 -> t8
+// Baz$1 -> t2 <-> PtrFunction(func()) <-> t4 -> t5 -> t6 <-> Field(testdata.A:foo) <-> t7 -> t8
// Flow chain showing that Baz$1 reaches t15():
-// Field(testdata.A:foo) <-> t11 -> t12 -> t13 -> t10 <-> Field(testdata.A:do) <-> t14 -> t15
+// Field(testdata.A:foo) <-> t10 -> t11 -> t13 -> t12 <-> Field(testdata.A:do) <-> t14 -> t15
// WANT:
// Local(f) -> Local(t0)
// Local(t0) -> PtrFunction(func())
// Function(Baz$1) -> Local(t2)
-// PtrFunction(func()) -> Local(t0), Local(t2), Local(t5)
+// PtrFunction(func()) -> Local(t0), Local(t2), Local(t4)
// Local(t2) -> PtrFunction(func())
-// Local(t4) -> Field(testdata.A:foo)
-// Local(t5) -> Local(t6), PtrFunction(func())
-// Local(t6) -> Local(t4)
+// Local(t6) -> Field(testdata.A:foo)
+// Local(t4) -> Local(t5), PtrFunction(func())
+// Local(t5) -> Local(t6)
// Local(t7) -> Field(testdata.A:foo), Local(t8)
-// Field(testdata.A:foo) -> Local(t11), Local(t4), Local(t7)
-// Local(t4) -> Field(testdata.A:foo)
-// Field(testdata.A:do) -> Local(t10), Local(t14)
-// Local(t10) -> Field(testdata.A:do)
-// Local(t11) -> Field(testdata.A:foo), Local(t12)
-// Local(t12) -> Local(t13)
-// Local(t13) -> Local(t10)
+// Field(testdata.A:foo) -> Local(t10), Local(t6), Local(t7)
+// Local(t6) -> Field(testdata.A:foo)
+// Field(testdata.A:do) -> Local(t12), Local(t14)
+// Local(t12) -> Field(testdata.A:do)
+// Local(t10) -> Field(testdata.A:foo), Local(t11)
+// Local(t11) -> Local(t13)
+// Local(t13) -> Local(t12)
// Local(t14) -> Field(testdata.A:do), Local(t15)
diff --git a/go/callgraph/vta/testdata/src/panic.go b/go/callgraph/vta/testdata/src/panic.go
index 2d39c70ea..5ef354857 100644
--- a/go/callgraph/vta/testdata/src/panic.go
+++ b/go/callgraph/vta/testdata/src/panic.go
@@ -27,12 +27,12 @@ func recover2() {
func Baz(a A) {
defer recover1()
+ defer recover()
panic(a)
}
// Relevant SSA:
// func recover1():
-// 0:
// t0 = print("only this recover...":string)
// t1 = recover()
// t2 = typeassert,ok t1.(I)
@@ -53,6 +53,7 @@ func Baz(a A) {
// t0 = local A (a)
// *t0 = a
// defer recover1()
+// defer recover()
// t1 = *t0
// t2 = make interface{} <- A (t1)
// panic t2
diff --git a/go/callgraph/vta/utils.go b/go/callgraph/vta/utils.go
index e7a97e2d8..d1831983a 100644
--- a/go/callgraph/vta/utils.go
+++ b/go/callgraph/vta/utils.go
@@ -9,6 +9,7 @@ import (
"golang.org/x/tools/go/callgraph"
"golang.org/x/tools/go/ssa"
+ "golang.org/x/tools/internal/typeparams"
)
func canAlias(n1, n2 node) bool {
@@ -32,13 +33,13 @@ func isReferenceNode(n node) bool {
// hasInFlow checks if a concrete type can flow to node `n`.
// Returns yes iff the type of `n` satisfies one the following:
-// 1) is an interface
-// 2) is a (nested) pointer to interface (needed for, say,
+// 1. is an interface
+// 2. is a (nested) pointer to interface (needed for, say,
// slice elements of nested pointers to interface type)
-// 3) is a function type (needed for higher-order type flow)
-// 4) is a (nested) pointer to function (needed for, say,
+// 3. is a function type (needed for higher-order type flow)
+// 4. is a (nested) pointer to function (needed for, say,
// slice elements of nested pointers to function type)
-// 5) is a global Recover or Panic node
+// 5. is a global Recover or Panic node
func hasInFlow(n node) bool {
if _, ok := n.(panicArg); ok {
return true
@@ -56,24 +57,7 @@ func hasInFlow(n node) bool {
return true
}
- return isInterface(t) || isFunction(t)
-}
-
-// hasInitialTypes check if a node can have initial types.
-// Returns true iff `n` is not a panic or recover node as
-// those are artificial.
-func hasInitialTypes(n node) bool {
- switch n.(type) {
- case panicArg, recoverReturn:
- return false
- default:
- return true
- }
-}
-
-func isInterface(t types.Type) bool {
- _, ok := t.Underlying().(*types.Interface)
- return ok
+ return types.IsInterface(t) || isFunction(t)
}
func isFunction(t types.Type) bool {
@@ -85,48 +69,76 @@ func isFunction(t types.Type) bool {
// pointer to interface and if yes, returns the interface type.
// Otherwise, returns nil.
func interfaceUnderPtr(t types.Type) types.Type {
- p, ok := t.Underlying().(*types.Pointer)
- if !ok {
- return nil
- }
+ seen := make(map[types.Type]bool)
+ var visit func(types.Type) types.Type
+ visit = func(t types.Type) types.Type {
+ if seen[t] {
+ return nil
+ }
+ seen[t] = true
- if isInterface(p.Elem()) {
- return p.Elem()
- }
+ p, ok := t.Underlying().(*types.Pointer)
+ if !ok {
+ return nil
+ }
+
+ if types.IsInterface(p.Elem()) {
+ return p.Elem()
+ }
- return interfaceUnderPtr(p.Elem())
+ return visit(p.Elem())
+ }
+ return visit(t)
}
// functionUnderPtr checks if type `t` is a potentially nested
// pointer to function type and if yes, returns the function type.
// Otherwise, returns nil.
func functionUnderPtr(t types.Type) types.Type {
- p, ok := t.Underlying().(*types.Pointer)
- if !ok {
- return nil
- }
+ seen := make(map[types.Type]bool)
+ var visit func(types.Type) types.Type
+ visit = func(t types.Type) types.Type {
+ if seen[t] {
+ return nil
+ }
+ seen[t] = true
- if isFunction(p.Elem()) {
- return p.Elem()
- }
+ p, ok := t.Underlying().(*types.Pointer)
+ if !ok {
+ return nil
+ }
+
+ if isFunction(p.Elem()) {
+ return p.Elem()
+ }
- return functionUnderPtr(p.Elem())
+ return visit(p.Elem())
+ }
+ return visit(t)
}
// sliceArrayElem returns the element type of type `t` that is
-// expected to be a (pointer to) array or slice, consistent with
+// expected to be a (pointer to) array, slice or string, consistent with
// the ssa.Index and ssa.IndexAddr instructions. Panics otherwise.
func sliceArrayElem(t types.Type) types.Type {
- u := t.Underlying()
-
- if p, ok := u.(*types.Pointer); ok {
- u = p.Elem().Underlying()
- }
-
- if a, ok := u.(*types.Array); ok {
- return a.Elem()
+ switch u := t.Underlying().(type) {
+ case *types.Pointer:
+ return u.Elem().Underlying().(*types.Array).Elem()
+ case *types.Array:
+ return u.Elem()
+ case *types.Slice:
+ return u.Elem()
+ case *types.Basic:
+ return types.Typ[types.Byte]
+ case *types.Interface: // type param.
+ terms, err := typeparams.InterfaceTermSet(u)
+ if err != nil || len(terms) == 0 {
+ panic(t)
+ }
+ return sliceArrayElem(terms[0].Type()) // Element types must match.
+ default:
+ panic(t)
}
- return u.(*types.Slice).Elem()
}
// siteCallees computes a set of callees for call site `c` given program `callgraph`.
diff --git a/go/callgraph/vta/vta.go b/go/callgraph/vta/vta.go
index 98fabe58c..583936003 100644
--- a/go/callgraph/vta/vta.go
+++ b/go/callgraph/vta/vta.go
@@ -3,7 +3,7 @@
// license that can be found in the LICENSE file.
// Package vta computes the call graph of a Go program using the Variable
-// Type Analysis (VTA) algorithm originally described in ``Practical Virtual
+// Type Analysis (VTA) algorithm originally described in “Practical Virtual
// Method Call Resolution for Java," Vijay Sundaresan, Laurie Hendren,
// Chrislain Razafimahefa, Raja Vallée-Rai, Patrick Lam, Etienne Gagnon, and
// Charles Godin.
@@ -18,22 +18,23 @@
//
// A type propagation is a directed, labeled graph. A node can represent
// one of the following:
-// - A field of a struct type.
-// - A local (SSA) variable of a method/function.
-// - All pointers to a non-interface type.
-// - The return value of a method.
-// - All elements in an array.
-// - All elements in a slice.
-// - All elements in a map.
-// - All elements in a channel.
-// - A global variable.
+// - A field of a struct type.
+// - A local (SSA) variable of a method/function.
+// - All pointers to a non-interface type.
+// - The return value of a method.
+// - All elements in an array.
+// - All elements in a slice.
+// - All elements in a map.
+// - All elements in a channel.
+// - A global variable.
+//
// In addition, the implementation used in this package introduces
// a few Go specific kinds of nodes:
-// - (De)references of nested pointers to interfaces are modeled
-// as a unique nestedPtrInterface node in the type propagation graph.
-// - Each function literal is represented as a function node whose
-// internal value is the (SSA) representation of the function. This
-// is done to precisely infer flow of higher-order functions.
+// - (De)references of nested pointers to interfaces are modeled
+// as a unique nestedPtrInterface node in the type propagation graph.
+// - Each function literal is represented as a function node whose
+// internal value is the (SSA) representation of the function. This
+// is done to precisely infer flow of higher-order functions.
//
// Edges in the graph represent flow of types (and function literals) through
// the program. That is, the model 1) typing constraints that are induced by
@@ -53,6 +54,8 @@
// reaching the node representing the call site to create a set of callees.
package vta
+// TODO(zpavlinovic): update VTA for how it handles generic function bodies and instantiation wrappers.
+
import (
"go/types"
diff --git a/go/callgraph/vta/vta_go117_test.go b/go/callgraph/vta/vta_go117_test.go
index 9ce6a8864..04f6980e5 100644
--- a/go/callgraph/vta/vta_go117_test.go
+++ b/go/callgraph/vta/vta_go117_test.go
@@ -11,12 +11,13 @@ import (
"testing"
"golang.org/x/tools/go/callgraph/cha"
+ "golang.org/x/tools/go/ssa"
"golang.org/x/tools/go/ssa/ssautil"
)
func TestVTACallGraphGo117(t *testing.T) {
file := "testdata/src/go117.go"
- prog, want, err := testProg(file)
+ prog, want, err := testProg(file, ssa.BuilderMode(0))
if err != nil {
t.Fatalf("couldn't load test file '%s': %s", file, err)
}
diff --git a/go/callgraph/vta/vta_test.go b/go/callgraph/vta/vta_test.go
index 33ceaf909..549c4af45 100644
--- a/go/callgraph/vta/vta_test.go
+++ b/go/callgraph/vta/vta_test.go
@@ -13,6 +13,7 @@ import (
"golang.org/x/tools/go/callgraph/cha"
"golang.org/x/tools/go/ssa"
"golang.org/x/tools/go/ssa/ssautil"
+ "golang.org/x/tools/internal/typeparams"
)
func TestVTACallGraph(t *testing.T) {
@@ -24,9 +25,11 @@ func TestVTACallGraph(t *testing.T) {
"testdata/src/callgraph_collections.go",
"testdata/src/callgraph_fields.go",
"testdata/src/callgraph_field_funcs.go",
+ "testdata/src/callgraph_recursive_types.go",
+ "testdata/src/callgraph_issue_57756.go",
} {
t.Run(file, func(t *testing.T) {
- prog, want, err := testProg(file)
+ prog, want, err := testProg(file, ssa.BuilderMode(0))
if err != nil {
t.Fatalf("couldn't load test file '%s': %s", file, err)
}
@@ -46,7 +49,7 @@ func TestVTACallGraph(t *testing.T) {
// enabled by having an arbitrary function set as input to CallGraph
// instead of the whole program (i.e., ssautil.AllFunctions(prog)).
func TestVTAProgVsFuncSet(t *testing.T) {
- prog, want, err := testProg("testdata/src/callgraph_nested_ptr.go")
+ prog, want, err := testProg("testdata/src/callgraph_nested_ptr.go", ssa.BuilderMode(0))
if err != nil {
t.Fatalf("couldn't load test `testdata/src/callgraph_nested_ptr.go`: %s", err)
}
@@ -111,3 +114,24 @@ func TestVTAPanicMissingDefinitions(t *testing.T) {
}
}
}
+
+func TestVTACallGraphGenerics(t *testing.T) {
+ if !typeparams.Enabled {
+ t.Skip("TestVTACallGraphGenerics requires type parameters")
+ }
+
+ // TODO(zpavlinovic): add more tests
+ file := "testdata/src/callgraph_generics.go"
+ prog, want, err := testProg(file, ssa.InstantiateGenerics)
+ if err != nil {
+ t.Fatalf("couldn't load test file '%s': %s", file, err)
+ }
+ if len(want) == 0 {
+ t.Fatalf("couldn't find want in `%s`", file)
+ }
+
+ g := CallGraph(ssautil.AllFunctions(prog), cha.CallGraph(prog))
+ if got := callGraphStr(g); !subGraph(want, got) {
+ t.Errorf("computed callgraph %v should contain %v", got, want)
+ }
+}