aboutsummaryrefslogtreecommitdiff
path: root/source/opt/function.h
diff options
context:
space:
mode:
Diffstat (limited to 'source/opt/function.h')
-rw-r--r--source/opt/function.h41
1 files changed, 41 insertions, 0 deletions
diff --git a/source/opt/function.h b/source/opt/function.h
index 917bf584..146cbe34 100644
--- a/source/opt/function.h
+++ b/source/opt/function.h
@@ -19,6 +19,7 @@
#include <functional>
#include <memory>
#include <string>
+#include <unordered_set>
#include <utility>
#include <vector>
@@ -180,7 +181,19 @@ class Function {
// Returns true is a function declaration and not a function definition.
bool IsDeclaration() { return begin() == end(); }
+ // Reorders the basic blocks in the function to match the structured order.
+ void ReorderBasicBlocksInStructuredOrder();
+
private:
+ // Reorders the basic blocks in the function to match the order given by the
+ // range |{begin,end}|. The range must contain every basic block in the
+ // function, and no extras.
+ template <class It>
+ void ReorderBasicBlocks(It begin, It end);
+
+ template <class It>
+ bool ContainsAllBlocksInTheFunction(It begin, It end);
+
// The OpFunction instruction that begins the definition of this function.
std::unique_ptr<Instruction> def_inst_;
// All parameters to this function.
@@ -262,6 +275,34 @@ inline void Function::AddNonSemanticInstruction(
non_semantic_.emplace_back(std::move(non_semantic));
}
+template <class It>
+void Function::ReorderBasicBlocks(It begin, It end) {
+ // Asserts to make sure every node in the function is in new_order.
+ assert(ContainsAllBlocksInTheFunction(begin, end));
+
+ // We have a pointer to all the elements in order, so we can release all
+ // pointers in |block_|, and then create the new unique pointers from |{begin,
+ // end}|.
+ std::for_each(blocks_.begin(), blocks_.end(),
+ [](std::unique_ptr<BasicBlock>& bb) { bb.release(); });
+ std::transform(begin, end, blocks_.begin(), [](BasicBlock* bb) {
+ return std::unique_ptr<BasicBlock>(bb);
+ });
+}
+
+template <class It>
+bool Function::ContainsAllBlocksInTheFunction(It begin, It end) {
+ std::unordered_multiset<BasicBlock*> range(begin, end);
+ if (range.size() != blocks_.size()) {
+ return false;
+ }
+
+ for (auto& bb : blocks_) {
+ if (range.count(bb.get()) == 0) return false;
+ }
+ return true;
+}
+
} // namespace opt
} // namespace spvtools