diff options
author | Sid Nayyar <sidnayyar@google.com> | 2024-04-16 12:33:15 +0100 |
---|---|---|
committer | Giuliano Procida <gprocida@google.com> | 2024-04-25 22:06:08 +0100 |
commit | cee934c475d281bee29eb13541311531dfb64103 (patch) | |
tree | 3c2ae2dfb5f93bcb2723a33f2082eeea02bf6c9d | |
parent | 5ef4e13addf4037dc4f83ebcd76e926fb59884d7 (diff) | |
download | stg-cee934c475d281bee29eb13541311531dfb64103.tar.gz |
rust: add `Variant` node
These nodes will be used to represent 'fieldful' or tagged Rust enums.
STG Rust ABI representation is unstable and is not yet subject to
format versioning.
PiperOrigin-RevId: 625282784
Change-Id: Ife43024161f7d47cc98584c1f9d7afd08f08a345
-rw-r--r-- | comparison.cc | 17 | ||||
-rw-r--r-- | comparison.h | 1 | ||||
-rw-r--r-- | equality.h | 7 | ||||
-rw-r--r-- | fidelity.cc | 6 | ||||
-rw-r--r-- | fingerprint.cc | 6 | ||||
-rw-r--r-- | graph.h | 31 | ||||
-rw-r--r-- | naming.cc | 6 | ||||
-rw-r--r-- | naming.h | 1 | ||||
-rw-r--r-- | proto_reader.cc | 7 | ||||
-rw-r--r-- | proto_writer.cc | 13 | ||||
-rw-r--r-- | stable_hash.cc | 7 | ||||
-rw-r--r-- | stable_hash.h | 1 | ||||
-rw-r--r-- | stg.proto | 17 | ||||
-rw-r--r-- | substitution.h | 5 | ||||
-rw-r--r-- | type_normalisation.cc | 5 | ||||
-rw-r--r-- | type_resolution.cc | 10 | ||||
-rw-r--r-- | unification.cc | 8 |
17 files changed, 143 insertions, 5 deletions
diff --git a/comparison.cc b/comparison.cc index d5067b6..8e0dacb 100644 --- a/comparison.cc +++ b/comparison.cc @@ -560,6 +560,23 @@ Result Compare::operator()(const Enumeration& x1, const Enumeration& x2) { return result; } +Result Compare::operator()(const Variant& x1, const Variant& x2) { + Result result; + // Compare two identically named variants recursively, holding diffs. + // Everything else treated as distinct. No recursion. + if (x1.name != x2.name) { + return result.MarkIncomparable(); + } + result.diff_.holds_changes = true; // Anonymous variants are not allowed. + + result.MaybeAddNodeDiff("bytesize", x1.bytesize, x2.bytesize); + const auto type_diff = + (*this)(x1.discriminant_type_id, x2.discriminant_type_id); + result.MaybeAddEdgeDiff("discriminant", type_diff); + CompareNodes(result, *this, x1.members, x2.members); + return result; +} + Result Compare::operator()(const Function& x1, const Function& x2) { Result result; const auto type_diff = (*this)(x1.return_type_id, x2.return_type_id); diff --git a/comparison.h b/comparison.h index 7d49803..034c16a 100644 --- a/comparison.h +++ b/comparison.h @@ -288,6 +288,7 @@ struct Compare { Result operator()(const VariantMember&, const VariantMember&); Result operator()(const StructUnion&, const StructUnion&); Result operator()(const Enumeration&, const Enumeration&); + Result operator()(const Variant&, const Variant&); Result operator()(const Function&, const Function&); Result operator()(const ElfSymbol&, const ElfSymbol&); Result operator()(const Interface&, const Interface&); @@ -194,6 +194,13 @@ struct Equals { return result; } + bool operator()(const Variant& x1, const Variant& x2) { + return x1.name == x2.name + && x1.bytesize == x2.bytesize + && (*this)(x1.discriminant_type_id, x2.discriminant_type_id) + && (*this)(x1.members, x2.members); + } + bool operator()(const Function& x1, const Function& x2) { return (*this)(x1.parameters, x2.parameters) && (*this)(x1.return_type_id, x2.return_type_id); diff --git a/fidelity.cc b/fidelity.cc index 250c9d9..66ab9d2 100644 --- a/fidelity.cc +++ b/fidelity.cc @@ -56,6 +56,7 @@ struct Fidelity { void operator()(const VariantMember&, Id); void operator()(const StructUnion&, Id); void operator()(const Enumeration&, Id); + void operator()(const Variant&, Id); void operator()(const Function&, Id); void operator()(const ElfSymbol&, Id); void operator()(const Interface&, Id); @@ -151,6 +152,11 @@ void Fidelity::operator()(const Enumeration& x, Id id) { } } +void Fidelity::operator()(const Variant& x, Id id) { + types.emplace(describe(id).ToString(), TypeFidelity::FULLY_DEFINED); + (*this)(x.members); +} + void Fidelity::operator()(const Function& x, Id) { (*this)(x.return_type_id); (*this)(x.parameters); diff --git a/fingerprint.cc b/fingerprint.cc index 6f1c265..05a4cdc 100644 --- a/fingerprint.cc +++ b/fingerprint.cc @@ -138,6 +138,12 @@ struct Hasher { return h; } + HashValue operator()(const Variant& x) { + auto h = hash('v', x.name, x.bytesize, (*this)(x.discriminant_type_id)); + ToDo(x.members); + return h; + } + HashValue operator()(const Function& x) { auto h = hash('F', (*this)(x.return_type_id)); for (const auto& parameter : x.parameters) { @@ -243,6 +243,20 @@ struct Enumeration { std::optional<Definition> definition; }; +struct Variant { + Variant(const std::string& name, uint64_t bytesize, Id discriminant_type_id, + const std::vector<Id>& members) + : name(name), + bytesize(bytesize), + discriminant_type_id(discriminant_type_id), + members(members) {} + + std::string name; + uint64_t bytesize; + Id discriminant_type_id; + std::vector<Id> members; +}; + struct Function { Function(Id return_type_id, const std::vector<Id>& parameters) : return_type_id(return_type_id), parameters(parameters) {} @@ -382,6 +396,9 @@ class Graph { } else if constexpr (std::is_same_v<Node, Enumeration>) { reference = {Which::ENUMERATION, enumeration_.size()}; enumeration_.emplace_back(std::forward<Args>(args)...); + } else if constexpr (std::is_same_v<Node, Variant>) { + reference = {Which::VARIANT, variant_.size()}; + variant_.emplace_back(std::forward<Args>(args)...); } else if constexpr (std::is_same_v<Node, Function>) { reference = {Which::FUNCTION, function_.size()}; function_.emplace_back(std::forward<Args>(args)...); @@ -456,6 +473,7 @@ class Graph { VARIANT_MEMBER, STRUCT_UNION, ENUMERATION, + VARIANT, FUNCTION, ELF_SYMBOL, INTERFACE, @@ -476,6 +494,7 @@ class Graph { std::vector<VariantMember> variant_member_; std::vector<StructUnion> struct_union_; std::vector<Enumeration> enumeration_; + std::vector<Variant> variant_; std::vector<Function> function_; std::vector<ElfSymbol> elf_symbol_; std::vector<Interface> interface_; @@ -513,6 +532,8 @@ Result Graph::Apply(FunctionObject& function, Id id, Args&&... args) const { return function(struct_union_[ix], std::forward<Args>(args)...); case Which::ENUMERATION: return function(enumeration_[ix], std::forward<Args>(args)...); + case Which::VARIANT: + return function(variant_[ix], std::forward<Args>(args)...); case Which::FUNCTION: return function(function_[ix], std::forward<Args>(args)...); case Which::ELF_SYMBOL: @@ -572,6 +593,9 @@ Result Graph::Apply2( case Which::ENUMERATION: return function(enumeration_[ix1], enumeration_[ix2], std::forward<Args>(args)...); + case Which::VARIANT: + return function(variant_[ix1], variant_[ix2], + std::forward<Args>(args)...); case Which::FUNCTION: return function(function_[ix1], function_[ix2], std::forward<Args>(args)...); @@ -628,6 +652,13 @@ struct InterfaceKey { return "enum " + x.name; } + std::string operator()(const stg::Variant& x) const { + if (x.name.empty()) { + Die() << "anonymous variant interface type"; + } + return "variant " + x.name; + } + std::string operator()(const stg::ElfSymbol& x) const { return VersionedSymbolName(x); } @@ -225,6 +225,12 @@ Name Describe::operator()(const Enumeration& x) { return Name{os.str()}; } +Name Describe::operator()(const Variant& x) { + std::ostringstream os; + os << "variant " << x.name; + return Name{os.str()}; +} + Name Describe::operator()(const Function& x) { std::ostringstream os; os << '('; @@ -71,6 +71,7 @@ struct Describe { Name operator()(const VariantMember&); Name operator()(const StructUnion&); Name operator()(const Enumeration&); + Name operator()(const Variant&); Name operator()(const Function&); Name operator()(const ElfSymbol&); Name operator()(const Interface&); diff --git a/proto_reader.cc b/proto_reader.cc index 88ec127..f2683ef 100644 --- a/proto_reader.cc +++ b/proto_reader.cc @@ -65,6 +65,7 @@ struct Transformer { void AddNode(const BaseClass&); void AddNode(const Method&); void AddNode(const Member&); + void AddNode(const Variant&); void AddNode(const StructUnion&); void AddNode(const Enumeration&); void AddNode(const VariantMember&); @@ -115,6 +116,7 @@ Id Transformer::Transform(const proto::STG& x) { AddNodes(x.variant_member()); AddNodes(x.struct_union()); AddNodes(x.enumeration()); + AddNodes(x.variant()); AddNodes(x.function()); AddNodes(x.elf_symbol()); AddNodes(x.symbols()); @@ -224,6 +226,11 @@ void Transformer::AddNode(const Enumeration& x) { } } +void Transformer::AddNode(const Variant& x) { + AddNode<stg::Variant>(GetId(x.id()), x.name(), x.bytesize(), + GetId(x.discriminant_type_id()), x.member_id()); +} + void Transformer::AddNode(const Function& x) { AddNode<stg::Function>(GetId(x.id()), GetId(x.return_type_id()), x.parameter_id()); diff --git a/proto_writer.cc b/proto_writer.cc index a4ee972..ccbd683 100644 --- a/proto_writer.cc +++ b/proto_writer.cc @@ -77,6 +77,7 @@ struct Transform { void operator()(const stg::VariantMember&, uint32_t); void operator()(const stg::StructUnion&, uint32_t); void operator()(const stg::Enumeration&, uint32_t); + void operator()(const stg::Variant&, uint32_t); void operator()(const stg::Function&, uint32_t); void operator()(const stg::ElfSymbol&, uint32_t); void operator()(const stg::Interface&, uint32_t); @@ -263,6 +264,18 @@ void Transform<MapId>::operator()(const stg::Enumeration& x, uint32_t id) { } template <typename MapId> +void Transform<MapId>::operator()(const stg::Variant& x, uint32_t id) { + auto& variant = *stg.add_variant(); + variant.set_id(id); + variant.set_name(x.name); + variant.set_bytesize(x.bytesize); + variant.set_discriminant_type_id((*this)(x.discriminant_type_id)); + for (const auto id : x.members) { + variant.add_member_id((*this)(id)); + } +} + +template <typename MapId> void Transform<MapId>::operator()(const stg::Function& x, uint32_t id) { auto& function = *stg.add_function(); function.set_id(id); diff --git a/stable_hash.cc b/stable_hash.cc index 725cf61..a8f9366 100644 --- a/stable_hash.cc +++ b/stable_hash.cc @@ -159,6 +159,13 @@ HashValue StableHash::operator()(const Enumeration& x) { hash, DecayHashCombineInReverse<8>(x.definition->enumerators, hash_enum)); } +HashValue StableHash::operator()(const Variant& x) { + HashValue hash = hash_('V', x.name, x.bytesize); + hash = DecayHashCombine<8>(hash, (*this)(x.discriminant_type_id)); + return DecayHashCombine<2>(hash, + DecayHashCombineInReverse<8>(x.members, *this)); +} + HashValue StableHash::operator()(const Function& x) { return DecayHashCombine<2>(hash_('f', (*this)(x.return_type_id)), DecayHashCombineInReverse<4>(x.parameters, *this)); diff --git a/stable_hash.h b/stable_hash.h index 75cbb92..b0b9265 100644 --- a/stable_hash.h +++ b/stable_hash.h @@ -48,6 +48,7 @@ class StableHash { HashValue operator()(const VariantMember&); HashValue operator()(const StructUnion&); HashValue operator()(const Enumeration&); + HashValue operator()(const Variant&); HashValue operator()(const Function&); HashValue operator()(const ElfSymbol&); HashValue operator()(const Interface&); @@ -201,6 +201,14 @@ message Enumeration { optional Definition definition = 3; } +message Variant { + fixed32 id = 1; + string name = 2; + uint64 bytesize = 3; + fixed32 discriminant_type_id = 4; + repeated fixed32 member_id = 5; +} + message Function { fixed32 id = 1; fixed32 return_type_id = 2; @@ -278,8 +286,9 @@ message STG { repeated VariantMember variant_member = 15; repeated StructUnion struct_union = 16; repeated Enumeration enumeration = 17; - repeated Function function = 18; - repeated ElfSymbol elf_symbol = 19; - repeated Symbols symbols = 20; - repeated Interface interface = 21; + repeated Variant variant = 18; + repeated Function function = 19; + repeated ElfSymbol elf_symbol = 20; + repeated Symbols symbols = 21; + repeated Interface interface = 22; } diff --git a/substitution.h b/substitution.h index de0dad0..863115f 100644 --- a/substitution.h +++ b/substitution.h @@ -119,6 +119,11 @@ struct Substitute { } } + void operator()(Variant& x) { + Update(x.discriminant_type_id); + Update(x.members); + } + void operator()(Function& x) { Update(x.parameters); Update(x.return_type_id); diff --git a/type_normalisation.cc b/type_normalisation.cc index 70b0493..aeac699 100644 --- a/type_normalisation.cc +++ b/type_normalisation.cc @@ -145,6 +145,11 @@ struct FindQualifiedTypesAndFunctions { } } + void operator()(const Variant& x, Id) { + (*this)(x.discriminant_type_id); + (*this)(x.members); + } + void operator()(const Function& x, Id node_id) { functions.emplace(node_id); for (auto& id : x.parameters) { diff --git a/type_resolution.cc b/type_resolution.cc index d28db66..c9fe51d 100644 --- a/type_resolution.cc +++ b/type_resolution.cc @@ -45,7 +45,7 @@ struct NamedTypes { seen.Reserve(graph.Limit()); } - enum class Tag { STRUCT, UNION, ENUM, TYPEDEF }; + enum class Tag { STRUCT, UNION, ENUM, TYPEDEF, VARIANT }; using Type = std::pair<Tag, std::string>; struct Info { std::vector<Id> definitions; @@ -160,6 +160,14 @@ struct NamedTypes { } } + void operator()(const Variant& x, Id id) { + const auto& name = x.name; + auto& info = GetInfo(Tag::VARIANT, name); + info.definitions.push_back(id); + ++definitions; + (*this)(x.members); + } + void operator()(const Function& x, Id) { (*this)(x.return_type_id); (*this)(x.parameters); diff --git a/unification.cc b/unification.cc index 77891ea..0d77012 100644 --- a/unification.cc +++ b/unification.cc @@ -205,6 +205,14 @@ struct Unifier { return result ? definition2.has_value() ? Right : Left : Neither; } + Winner operator()(const Variant& x1, const Variant& x2) { + return x1.name == x2.name + && x1.bytesize == x2.bytesize + && (*this)(x1.discriminant_type_id, x2.discriminant_type_id) + && (*this)(x1.members, x2.members) + ? Right : Neither; + } + Winner operator()(const Function& x1, const Function& x2) { return (*this)(x1.parameters, x2.parameters) && (*this)(x1.return_type_id, x2.return_type_id) |