diff options
Diffstat (limited to 'source/opt/function.h')
-rw-r--r-- | source/opt/function.h | 41 |
1 files changed, 41 insertions, 0 deletions
diff --git a/source/opt/function.h b/source/opt/function.h index 917bf584..146cbe34 100644 --- a/source/opt/function.h +++ b/source/opt/function.h @@ -19,6 +19,7 @@ #include <functional> #include <memory> #include <string> +#include <unordered_set> #include <utility> #include <vector> @@ -180,7 +181,19 @@ class Function { // Returns true is a function declaration and not a function definition. bool IsDeclaration() { return begin() == end(); } + // Reorders the basic blocks in the function to match the structured order. + void ReorderBasicBlocksInStructuredOrder(); + private: + // Reorders the basic blocks in the function to match the order given by the + // range |{begin,end}|. The range must contain every basic block in the + // function, and no extras. + template <class It> + void ReorderBasicBlocks(It begin, It end); + + template <class It> + bool ContainsAllBlocksInTheFunction(It begin, It end); + // The OpFunction instruction that begins the definition of this function. std::unique_ptr<Instruction> def_inst_; // All parameters to this function. @@ -262,6 +275,34 @@ inline void Function::AddNonSemanticInstruction( non_semantic_.emplace_back(std::move(non_semantic)); } +template <class It> +void Function::ReorderBasicBlocks(It begin, It end) { + // Asserts to make sure every node in the function is in new_order. + assert(ContainsAllBlocksInTheFunction(begin, end)); + + // We have a pointer to all the elements in order, so we can release all + // pointers in |block_|, and then create the new unique pointers from |{begin, + // end}|. + std::for_each(blocks_.begin(), blocks_.end(), + [](std::unique_ptr<BasicBlock>& bb) { bb.release(); }); + std::transform(begin, end, blocks_.begin(), [](BasicBlock* bb) { + return std::unique_ptr<BasicBlock>(bb); + }); +} + +template <class It> +bool Function::ContainsAllBlocksInTheFunction(It begin, It end) { + std::unordered_multiset<BasicBlock*> range(begin, end); + if (range.size() != blocks_.size()) { + return false; + } + + for (auto& bb : blocks_) { + if (range.count(bb.get()) == 0) return false; + } + return true; +} + } // namespace opt } // namespace spvtools |