aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCole Faust <colefaust@google.com>2024-01-25 14:45:22 -0800
committerCole Faust <colefaust@google.com>2024-01-30 15:18:24 -0800
commit7add62142d6d75bdef9eb788ecb59bae9bb04421 (patch)
tree663d34c4720a76711376162a805205fc5bd76dbe
parent1b3fe6bd2c1d534558ae169d6d9c84bd0abdd9b2 (diff)
downloadblueprint-7add62142d6d75bdef9eb788ecb59bae9bb04421.tar.gz
Enforce that providers are not changed
When setProvider() is called, hash the provider and store the hash in the module. Then after the build is done, hash all the providers again and compare the hashes. It's an error if they don't match. Also add a flag to control it in case this check gets slow as we convert more things to providers. However right now it's fast (unnoticable in terms of whole seconds) so just have the flag always enabled. Bug: 322069292 Test: m nothing Change-Id: Ie4e806a6a9f20542ffcc7439eef376d3fb6a98ca
-rw-r--r--Android.bp4
-rw-r--r--bootstrap/command.go22
-rw-r--r--context.go75
-rw-r--r--proptools/hash_provider.go205
-rw-r--r--proptools/hash_provider_test.go112
-rw-r--r--provider.go13
6 files changed, 415 insertions, 16 deletions
diff --git a/Android.bp b/Android.bp
index 246207a..d0a16ad 100644
--- a/Android.bp
+++ b/Android.bp
@@ -52,7 +52,7 @@ bootstrap_go_package {
"provider.go",
"scope.go",
"singleton_ctx.go",
- "source_file_provider.go"
+ "source_file_provider.go",
],
testSrcs: [
"context_test.go",
@@ -120,6 +120,7 @@ bootstrap_go_package {
"proptools/escape.go",
"proptools/extend.go",
"proptools/filter.go",
+ "proptools/hash_provider.go",
"proptools/proptools.go",
"proptools/tag.go",
"proptools/typeequal.go",
@@ -130,6 +131,7 @@ bootstrap_go_package {
"proptools/escape_test.go",
"proptools/extend_test.go",
"proptools/filter_test.go",
+ "proptools/hash_provider_test.go",
"proptools/tag_test.go",
"proptools/typeequal_test.go",
"proptools/unpack_test.go",
diff --git a/bootstrap/command.go b/bootstrap/command.go
index bc1d32d..d7dcc27 100644
--- a/bootstrap/command.go
+++ b/bootstrap/command.go
@@ -25,6 +25,7 @@ import (
"runtime/debug"
"runtime/pprof"
"runtime/trace"
+ "strings"
"github.com/google/blueprint"
)
@@ -134,6 +135,15 @@ func RunBlueprint(args Args, stopBefore StopBefore, ctx *blueprint.Context, conf
return ninjaDeps, nil
}
+ providersValidationChan := make(chan []error, 1)
+ if ctx.GetVerifyProvidersAreUnchanged() {
+ go func() {
+ providersValidationChan <- ctx.VerifyProvidersWereUnchanged()
+ }()
+ } else {
+ providersValidationChan <- nil
+ }
+
const outFilePermissions = 0666
var out blueprint.StringWriterWriter
var f *os.File
@@ -172,6 +182,18 @@ func RunBlueprint(args Args, stopBefore StopBefore, ctx *blueprint.Context, conf
}
}
+ providerValidationErrors := <-providersValidationChan
+ if providerValidationErrors != nil {
+ var sb strings.Builder
+ for i, err := range providerValidationErrors {
+ if i != 0 {
+ sb.WriteString("\n")
+ }
+ sb.WriteString(err.Error())
+ }
+ return nil, errors.New(sb.String())
+ }
+
if args.Memprofile != "" {
f, err := os.Create(joinPath(ctx.SrcDir(), args.Memprofile))
if err != nil {
diff --git a/context.go b/context.go
index 28f0cc5..4130700 100644
--- a/context.go
+++ b/context.go
@@ -101,6 +101,8 @@ type Context struct {
// set by SetAllowMissingDependencies
allowMissingDependencies bool
+ verifyProvidersAreUnchanged bool
+
// set during PrepareBuildActions
nameTracker *nameTracker
liveGlobals *liveTracker
@@ -351,7 +353,8 @@ type moduleInfo struct {
// set during PrepareBuildActions
actionDefs localBuildActions
- providers []interface{}
+ providers []interface{}
+ providerInitialValueHashes []uint64
startedMutator *mutatorInfo
finishedMutator *mutatorInfo
@@ -463,20 +466,21 @@ type mutatorInfo struct {
func newContext() *Context {
eventHandler := metrics.EventHandler{}
return &Context{
- Context: context.Background(),
- EventHandler: &eventHandler,
- moduleFactories: make(map[string]ModuleFactory),
- nameInterface: NewSimpleNameInterface(),
- moduleInfo: make(map[Module]*moduleInfo),
- globs: make(map[globKey]pathtools.GlobResult),
- fs: pathtools.OsFs,
- finishedMutators: make(map[*mutatorInfo]bool),
- includeTags: &IncludeTags{},
- sourceRootDirs: &SourceRootDirs{},
- outDir: nil,
- requiredNinjaMajor: 1,
- requiredNinjaMinor: 7,
- requiredNinjaMicro: 0,
+ Context: context.Background(),
+ EventHandler: &eventHandler,
+ moduleFactories: make(map[string]ModuleFactory),
+ nameInterface: NewSimpleNameInterface(),
+ moduleInfo: make(map[Module]*moduleInfo),
+ globs: make(map[globKey]pathtools.GlobResult),
+ fs: pathtools.OsFs,
+ finishedMutators: make(map[*mutatorInfo]bool),
+ includeTags: &IncludeTags{},
+ sourceRootDirs: &SourceRootDirs{},
+ outDir: nil,
+ requiredNinjaMajor: 1,
+ requiredNinjaMinor: 7,
+ requiredNinjaMicro: 0,
+ verifyProvidersAreUnchanged: true,
}
}
@@ -952,6 +956,18 @@ func (c *Context) SetAllowMissingDependencies(allowMissingDependencies bool) {
c.allowMissingDependencies = allowMissingDependencies
}
+// SetVerifyProvidersAreUnchanged makes blueprint hash all providers immediately
+// after SetProvider() is called, and then hash them again after the build finished.
+// If the hashes change, it's an error. Providers are supposed to be immutable, but
+// we don't have any more direct way to enforce that in go.
+func (c *Context) SetVerifyProvidersAreUnchanged(verifyProvidersAreUnchanged bool) {
+ c.verifyProvidersAreUnchanged = verifyProvidersAreUnchanged
+}
+
+func (c *Context) GetVerifyProvidersAreUnchanged() bool {
+ return c.verifyProvidersAreUnchanged
+}
+
func (c *Context) SetModuleListFile(listFile string) {
c.moduleListFile = listFile
}
@@ -1730,6 +1746,7 @@ func (c *Context) createVariations(origModule *moduleInfo, mutatorName string,
newModule.variant = newVariant(origModule, mutatorName, variationName, local)
newModule.properties = newProperties
newModule.providers = append([]interface{}(nil), origModule.providers...)
+ newModule.providerInitialValueHashes = append([]uint64(nil), origModule.providerInitialValueHashes...)
newModules = append(newModules, newModule)
@@ -4207,6 +4224,34 @@ func (c *Context) SingletonName(singleton Singleton) string {
return ""
}
+// Checks that the hashes of all the providers match the hashes from when they were first set.
+// Does nothing on success, returns a list of errors otherwise. It's recommended to run this
+// in a goroutine.
+func (c *Context) VerifyProvidersWereUnchanged() []error {
+ if !c.buildActionsReady {
+ return []error{ErrBuildActionsNotReady}
+ }
+ var errors []error
+ for _, m := range c.modulesSorted {
+ for i, provider := range m.providers {
+ if provider != nil {
+ hash, err := proptools.HashProvider(provider)
+ if err != nil {
+ errors = append(errors, fmt.Errorf("provider %q on module %q was modified after being set, and no longer hashable afterwards: %s", providerRegistry[i].typ, m.Name(), err.Error()))
+ continue
+ }
+ if provider != nil && m.providerInitialValueHashes[i] != hash {
+ errors = append(errors, fmt.Errorf("provider %q on module %q was modified after being set", providerRegistry[i].typ, m.Name()))
+ }
+ } else if m.providerInitialValueHashes[i] != 0 {
+ // This should be unreachable, because in setProvider we check if the provider has already been set.
+ errors = append(errors, fmt.Errorf("provider %q on module %q was unset somehow, this is an internal error", providerRegistry[i].typ, m.Name()))
+ }
+ }
+ }
+ return errors
+}
+
// WriteBuildFile writes the Ninja manifest text for the generated build
// actions to w. If this is called before PrepareBuildActions successfully
// completes then ErrBuildActionsNotReady is returned.
diff --git a/proptools/hash_provider.go b/proptools/hash_provider.go
new file mode 100644
index 0000000..b52a10e
--- /dev/null
+++ b/proptools/hash_provider.go
@@ -0,0 +1,205 @@
+// Copyright 2023 Google Inc. All rights reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package proptools
+
+import (
+ "cmp"
+ "encoding/binary"
+ "fmt"
+ "hash/maphash"
+ "io"
+ "math"
+ "reflect"
+ "sort"
+)
+
+var seed maphash.Seed = maphash.MakeSeed()
+
+// byte to insert between elements of lists, fields of structs/maps, etc in order
+// to try and make sure the hash is different when values are moved around between
+// elements. 36 is arbitrary, but it's the ascii code for a record separator
+var recordSeparator []byte = []byte{36}
+
+func HashProvider(provider interface{}) (uint64, error) {
+ hasher := maphash.Hash{}
+ hasher.SetSeed(seed)
+ ptrs := make(map[uintptr]bool)
+ v := reflect.ValueOf(provider)
+ var err error
+ if v.IsValid() {
+ err = hashProviderInternal(&hasher, v, ptrs)
+ }
+ return hasher.Sum64(), err
+}
+
+func hashProviderInternal(hasher io.Writer, v reflect.Value, ptrs map[uintptr]bool) error {
+ var int64Array [8]byte
+ int64Buf := int64Array[:]
+ binary.LittleEndian.PutUint64(int64Buf, uint64(v.Kind()))
+ hasher.Write(int64Buf)
+ v.IsValid()
+ switch v.Kind() {
+ case reflect.Struct:
+ binary.LittleEndian.PutUint64(int64Buf, uint64(v.NumField()))
+ hasher.Write(int64Buf)
+ for i := 0; i < v.NumField(); i++ {
+ hasher.Write(recordSeparator)
+ err := hashProviderInternal(hasher, v.Field(i), ptrs)
+ if err != nil {
+ return fmt.Errorf("in field %s: %s", v.Type().Field(i).Name, err.Error())
+ }
+ }
+ case reflect.Map:
+ binary.LittleEndian.PutUint64(int64Buf, uint64(v.Len()))
+ hasher.Write(int64Buf)
+ indexes := make([]int, v.Len())
+ keys := make([]reflect.Value, v.Len())
+ values := make([]reflect.Value, v.Len())
+ iter := v.MapRange()
+ for i := 0; iter.Next(); i++ {
+ indexes[i] = i
+ keys[i] = iter.Key()
+ values[i] = iter.Value()
+ }
+ sort.SliceStable(indexes, func(i, j int) bool {
+ return compare_values(keys[indexes[i]], keys[indexes[j]]) < 0
+ })
+ for i := 0; i < v.Len(); i++ {
+ hasher.Write(recordSeparator)
+ err := hashProviderInternal(hasher, keys[indexes[i]], ptrs)
+ if err != nil {
+ return fmt.Errorf("in map: %s", err.Error())
+ }
+ hasher.Write(recordSeparator)
+ err = hashProviderInternal(hasher, keys[indexes[i]], ptrs)
+ if err != nil {
+ return fmt.Errorf("in map: %s", err.Error())
+ }
+ }
+ case reflect.Slice, reflect.Array:
+ binary.LittleEndian.PutUint64(int64Buf, uint64(v.Len()))
+ hasher.Write(int64Buf)
+ for i := 0; i < v.Len(); i++ {
+ hasher.Write(recordSeparator)
+ err := hashProviderInternal(hasher, v.Index(i), ptrs)
+ if err != nil {
+ return fmt.Errorf("in %s at index %d: %s", v.Kind().String(), i, err.Error())
+ }
+ }
+ case reflect.Pointer:
+ if v.IsNil() {
+ int64Buf[0] = 0
+ hasher.Write(int64Buf[:1])
+ return nil
+ }
+ addr := v.Pointer()
+ binary.LittleEndian.PutUint64(int64Buf, uint64(addr))
+ hasher.Write(int64Buf)
+ if _, ok := ptrs[addr]; ok {
+ // We could make this an error if we want to disallow pointer cycles in the future
+ return nil
+ }
+ ptrs[addr] = true
+ err := hashProviderInternal(hasher, v.Elem(), ptrs)
+ if err != nil {
+ return fmt.Errorf("in pointer: %s", err.Error())
+ }
+ case reflect.Interface:
+ if v.IsNil() {
+ int64Buf[0] = 0
+ hasher.Write(int64Buf[:1])
+ } else {
+ // The only way get the pointer out of an interface to hash it or check for cycles
+ // would be InterfaceData(), but that's deprecated and seems like it has undefined behavior.
+ err := hashProviderInternal(hasher, v.Elem(), ptrs)
+ if err != nil {
+ return fmt.Errorf("in interface: %s", err.Error())
+ }
+ }
+ case reflect.String:
+ hasher.Write([]byte(v.String()))
+ case reflect.Bool:
+ if v.Bool() {
+ int64Buf[0] = 1
+ } else {
+ int64Buf[0] = 0
+ }
+ hasher.Write(int64Buf[:1])
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
+ binary.LittleEndian.PutUint64(int64Buf, v.Uint())
+ hasher.Write(int64Buf)
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ binary.LittleEndian.PutUint64(int64Buf, uint64(v.Int()))
+ hasher.Write(int64Buf)
+ case reflect.Float32, reflect.Float64:
+ binary.LittleEndian.PutUint64(int64Buf, math.Float64bits(v.Float()))
+ hasher.Write(int64Buf)
+ default:
+ return fmt.Errorf("providers may only contain primitives, strings, arrays, slices, structs, maps, and pointers, found: %s", v.Kind().String())
+ }
+ return nil
+}
+
+func compare_values(x, y reflect.Value) int {
+ if x.Type() != y.Type() {
+ panic("Expected equal types")
+ }
+
+ switch x.Kind() {
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
+ return cmp.Compare(x.Uint(), y.Uint())
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ return cmp.Compare(x.Int(), y.Int())
+ case reflect.Float32, reflect.Float64:
+ return cmp.Compare(x.Float(), y.Float())
+ case reflect.String:
+ return cmp.Compare(x.String(), y.String())
+ case reflect.Bool:
+ if x.Bool() == y.Bool() {
+ return 0
+ } else if x.Bool() {
+ return 1
+ } else {
+ return -1
+ }
+ case reflect.Pointer:
+ return cmp.Compare(x.Pointer(), y.Pointer())
+ case reflect.Array:
+ for i := 0; i < x.Len(); i++ {
+ if result := compare_values(x.Index(i), y.Index(i)); result != 0 {
+ return result
+ }
+ }
+ return 0
+ case reflect.Struct:
+ for i := 0; i < x.NumField(); i++ {
+ if result := compare_values(x.Field(i), y.Field(i)); result != 0 {
+ return result
+ }
+ }
+ return 0
+ case reflect.Interface:
+ if x.IsNil() && y.IsNil() {
+ return 0
+ } else if x.IsNil() {
+ return 1
+ } else if y.IsNil() {
+ return -1
+ }
+ return compare_values(x.Elem(), y.Elem())
+ default:
+ panic(fmt.Sprintf("Could not compare types %s and %s", x.Type().String(), y.Type().String()))
+ }
+}
diff --git a/proptools/hash_provider_test.go b/proptools/hash_provider_test.go
new file mode 100644
index 0000000..1c97aec
--- /dev/null
+++ b/proptools/hash_provider_test.go
@@ -0,0 +1,112 @@
+package proptools
+
+import (
+ "strings"
+ "testing"
+)
+
+func mustHash(t *testing.T, provider interface{}) uint64 {
+ t.Helper()
+ result, err := HashProvider(provider)
+ if err != nil {
+ t.Fatal(err)
+ }
+ return result
+}
+
+func TestHashingMapGetsSameResults(t *testing.T) {
+ provider := map[string]string{"foo": "bar", "baz": "qux"}
+ first := mustHash(t, provider)
+ second := mustHash(t, provider)
+ third := mustHash(t, provider)
+ fourth := mustHash(t, provider)
+ if first != second || second != third || third != fourth {
+ t.Fatal("Did not get the same result every time for a map")
+ }
+}
+
+func TestHashingNonSerializableTypesFails(t *testing.T) {
+ testCases := []struct {
+ name string
+ provider interface{}
+ }{
+ {
+ name: "function pointer",
+ provider: []func(){nil},
+ },
+ {
+ name: "channel",
+ provider: []chan int{make(chan int)},
+ },
+ {
+ name: "list with non-serializable type",
+ provider: []interface{}{"foo", make(chan int)},
+ },
+ }
+ for _, testCase := range testCases {
+ t.Run(testCase.name, func(t *testing.T) {
+ _, err := HashProvider(testCase)
+ if err == nil {
+ t.Fatal("Expected hashing error but didn't get one")
+ }
+ expected := "providers may only contain primitives, strings, arrays, slices, structs, maps, and pointers"
+ if !strings.Contains(err.Error(), expected) {
+ t.Fatalf("Expected %q, got %q", expected, err.Error())
+ }
+ })
+ }
+}
+
+func TestHashSuccessful(t *testing.T) {
+ testCases := []struct {
+ name string
+ provider interface{}
+ }{
+ {
+ name: "int",
+ provider: 5,
+ },
+ {
+ name: "string",
+ provider: "foo",
+ },
+ {
+ name: "*string",
+ provider: StringPtr("foo"),
+ },
+ {
+ name: "array",
+ provider: [3]string{"foo", "bar", "baz"},
+ },
+ {
+ name: "slice",
+ provider: []string{"foo", "bar", "baz"},
+ },
+ {
+ name: "struct",
+ provider: struct {
+ foo string
+ bar int
+ }{
+ foo: "foo",
+ bar: 3,
+ },
+ },
+ {
+ name: "map",
+ provider: map[string]int{
+ "foo": 3,
+ "bar": 4,
+ },
+ },
+ {
+ name: "list of interfaces with different types",
+ provider: []interface{}{"foo", 3, []string{"bar", "baz"}},
+ },
+ }
+ for _, testCase := range testCases {
+ t.Run(testCase.name, func(t *testing.T) {
+ mustHash(t, testCase.provider)
+ })
+ }
+}
diff --git a/provider.go b/provider.go
index 48527b1..297861e 100644
--- a/provider.go
+++ b/provider.go
@@ -16,6 +16,8 @@ package blueprint
import (
"fmt"
+
+ "github.com/google/blueprint/proptools"
)
// This file implements Providers, modelled after Bazel
@@ -151,6 +153,17 @@ func (c *Context) setProvider(m *moduleInfo, provider *providerKey, value any) {
}
m.providers[provider.id] = value
+
+ if c.verifyProvidersAreUnchanged {
+ if m.providerInitialValueHashes == nil {
+ m.providerInitialValueHashes = make([]uint64, len(providerRegistry))
+ }
+ hash, err := proptools.HashProvider(value)
+ if err != nil {
+ panic(fmt.Sprintf("Can't set value of provider %s: %s", provider.typ, err.Error()))
+ }
+ m.providerInitialValueHashes[provider.id] = hash
+ }
}
// provider returns the value, if any, for a given provider for a module. Verifies that it is