// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -*- mode: C++ -*- // // Copyright 2022 Google LLC // // Licensed under the Apache License v2.0 with LLVM Exceptions (the // "License"); you may not use this file except in compliance with the // License. You may obtain a copy of the License at // // https://llvm.org/LICENSE.txt // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // Author: Giuliano Procida #ifndef STG_EQUALITY_CACHE_H_ #define STG_EQUALITY_CACHE_H_ #include #include #include #include #include #include #include "graph.h" #include "hashing.h" #include "runtime.h" namespace stg { // Equality cache - for use with the Equals function object // // This supports many features, some of probably limited long-term utility. // // It caches equalities (symmetrically) using union-find with path halving and // union by rank. // // It caches inequalities (symmetrically); the inequalities are updated as part // of the union operation. // // Node hashes such as those generated by the Fingerprint function object may be // supplied to avoid equality testing when hashes differ. struct EqualityCache { EqualityCache(Runtime& runtime, const std::unordered_map& hashes) : hashes(hashes), query_count(runtime, "cache.query_count"), query_equal_ids(runtime, "cache.query_equal_ids"), query_unequal_hashes(runtime, "cache.query_unequal_hashes"), query_equal_representatives(runtime, "cache.query_equal_representatives"), query_inequality_found(runtime, "cache.query_inequality_found"), query_not_found(runtime, "cache.query_not_found"), find_halved(runtime, "cache.find_halved"), union_known(runtime, "cache.union_known"), union_rank_swap(runtime, "cache.union_rank_swap"), union_rank_increase(runtime, "cache.union_rank_increase"), union_rank_zero(runtime, "cache.union_rank_zero"), union_unknown(runtime, "cache.union_unknown"), disunion_known_hash(runtime, "cache.disunion_known_hash"), disunion_known_inequality(runtime, "cache.disunion_known_inequality"), disunion_unknown(runtime, "cache.disunion_unknown") {} std::optional Query(const Pair& comparison) { ++query_count; const auto& [id1, id2] = comparison; if (id1 == id2) { ++query_equal_ids; return std::make_optional(true); } if (DistinctHashes(id1, id2)) { ++query_unequal_hashes; return std::make_optional(false); } const Id fid1 = Find(id1); const Id fid2 = Find(id2); if (fid1 == fid2) { ++query_equal_representatives; return std::make_optional(true); } auto not_it = inequalities.find(fid1); if (not_it != inequalities.end()) { auto not_it2 = not_it->second.find(fid2); if (not_it2 != not_it->second.end()) { ++query_inequality_found; return std::make_optional(false); } } ++query_not_found; return std::nullopt; } void AllSame(const std::vector& comparisons) { for (const auto& [id1, id2] : comparisons) { Union(id1, id2); } } void AllDifferent(const std::vector& comparisons) { for (const auto& [id1, id2] : comparisons) { Disunion(id1, id2); } } bool DistinctHashes(Id id1, Id id2) { const auto it1 = hashes.find(id1); const auto it2 = hashes.find(id2); return it1 != hashes.end() && it2 != hashes.end() && it1->second != it2->second; } Id Find(Id id) { // path halving while (true) { auto it = mapping.find(id); if (it == mapping.end()) { return id; } auto& parent = it->second; auto parent_it = mapping.find(parent); if (parent_it == mapping.end()) { return parent; } auto parent_parent = parent_it->second; id = parent = parent_parent; ++find_halved; } } size_t GetRank(Id id) { auto it = rank.find(id); return it == rank.end() ? 0 : it->second; } void SetRank(Id id, size_t r) { if (r) { rank[id] = r; } else { rank.erase(id); } } void Union(Id id1, Id id2) { Check(!DistinctHashes(id1, id2)) << "union with distinct hashes"; Id fid1 = Find(id1); Id fid2 = Find(id2); if (fid1 == fid2) { ++union_known; return; } size_t rank1 = GetRank(fid1); size_t rank2 = GetRank(fid2); if (rank1 > rank2) { std::swap(fid1, fid2); std::swap(rank1, rank2); ++union_rank_swap; } // rank1 <= rank2 if (rank1 == rank2) { SetRank(fid2, rank2 + 1); ++union_rank_increase; } if (rank1) { SetRank(fid1, 0); ++union_rank_zero; } mapping.insert({fid1, fid2}); ++union_unknown; // move inequalities from fid1 to fid2 auto not_it = inequalities.find(fid1); if (not_it != inequalities.end()) { auto& source = not_it->second; auto& target = inequalities[fid2]; for (auto fid : source) { Check(fid != fid2) << "union of unequal"; target.insert(fid); auto& target2 = inequalities[fid]; target2.erase(fid1); target2.insert(fid2); } } } void Disunion(Id id1, Id id2) { if (DistinctHashes(id1, id2)) { ++disunion_known_hash; return; } const Id fid1 = Find(id1); const Id fid2 = Find(id2); Check(fid1 != fid2) << "disunion of equal"; if (inequalities[fid1].insert(fid2).second) { inequalities[fid2].insert(fid1); ++disunion_unknown; } else { ++disunion_known_inequality; } } const std::unordered_map& hashes; std::unordered_map mapping; std::unordered_map rank; std::unordered_map> inequalities; Counter query_count; Counter query_equal_ids; Counter query_unequal_hashes; Counter query_equal_representatives; Counter query_inequality_found; Counter query_not_found; Counter find_halved; Counter union_known; Counter union_rank_swap; Counter union_rank_increase; Counter union_rank_zero; Counter union_unknown; Counter disunion_known_hash; Counter disunion_known_inequality; Counter disunion_unknown; }; struct SimpleEqualityCache { explicit SimpleEqualityCache(Runtime& runtime) : query_count(runtime, "simple_cache.query_count"), query_equal_ids(runtime, "simple_cache.query_equal_ids"), query_known_equality(runtime, "simple_cache.query_known_equality"), known_equality_inserts(runtime, "simple_cache.known_equality_inserts") { } std::optional Query(const Pair& comparison) { ++query_count; const auto& [id1, id2] = comparison; if (id1 == id2) { ++query_equal_ids; return {true}; } if (known_equalities.count(comparison)) { ++query_known_equality; return {true}; } return std::nullopt; } void AllSame(const std::vector& comparisons) { for (const auto& comparison : comparisons) { ++known_equality_inserts; known_equalities.insert(comparison); } } void AllDifferent(const std::vector&) {} std::unordered_set known_equalities; Counter query_count; Counter query_equal_ids; Counter query_known_equality; Counter known_equality_inserts; }; } // namespace stg #endif // STG_EQUALITY_CACHE_H_