diff options
author | Treehugger Robot <treehugger-gerrit@google.com> | 2017-02-28 21:01:28 +0000 |
---|---|---|
committer | Gerrit Code Review <noreply-gerritcodereview@google.com> | 2017-02-28 21:01:30 +0000 |
commit | e3fe99e687d6c1315967251372704bd61663a9d4 (patch) | |
tree | a07cfac007c65aedbfe6bd36883b82071eb55ac4 /rsov | |
parent | 0fc0d1c3a7a5600f24e2979d360bfa1a7a80f5e3 (diff) | |
parent | 3f30b6202dd5ad6ff66959131d216405850ed152 (diff) | |
download | rs-e3fe99e687d6c1315967251372704bd61663a9d4.tar.gz |
Merge "Added Pass and PassQueue to Spirit"
Diffstat (limited to 'rsov')
28 files changed, 956 insertions, 266 deletions
diff --git a/rsov/compiler/Builtin.cpp b/rsov/compiler/Builtin.cpp index 106e6309..3b871448 100644 --- a/rsov/compiler/Builtin.cpp +++ b/rsov/compiler/Builtin.cpp @@ -16,13 +16,16 @@ #include "Builtin.h" -#include <map> -#include <string> - #include "cxxabi.h" #include "spirit.h" #include "transformer.h" +#include <stdint.h> + +#include <map> +#include <string> +#include <vector> + namespace android { namespace spirit { @@ -271,13 +274,16 @@ BuiltinLookupTable::sNameCode constexpr BuiltinLookupTable::mFPMathFuncOpCode[]; class BuiltinTransformer : public Transformer { public: - BuiltinTransformer(Builder *b, Module *m) : mBuilder(b), mModule(m) {} - // BEGIN: cleanup unrelated to builtin functions, but necessary for LLVM-SPIRV // converter generated code. // TODO: Move these in its own pass + std::vector<uint32_t> runAndSerialize(Module *module, int *error) override { + module->addExtInstImport("GLSL.std.450"); + return Transformer::runAndSerialize(module, error); + } + Instruction *transform(CapabilityInst *inst) override { if (inst->mOperand1 == Capability::Linkage || inst->mOperand1 == Capability::Kernel) { @@ -287,7 +293,7 @@ public: } Instruction *transform(ExtInstImportInst *inst) override { - if (strcmp(inst->mOperand1, "OpenCL.std") == 0) { + if (inst->mOperand1.compare("OpenCL.std") == 0) { return nullptr; } return inst; @@ -316,7 +322,7 @@ public: // TODO: attach name to the instruction to avoid linear search in the debug // section, i.e., // const char *name = func->getName(); - const char *name = mModule->lookupNameByInstruction(func); + const char *name = getModule()->lookupNameByInstruction(func); if (!name) { return call; } @@ -327,7 +333,7 @@ public: if (!fpTranslate) { return call; } - Instruction *inst = fpTranslate(name, call, this, mBuilder, mModule); + Instruction *inst = fpTranslate(name, call, this, &mBuilder, getModule()); if (inst) { inst->setId(call->getId()); @@ -337,8 +343,7 @@ public: } private: - Builder *mBuilder; - Module *mModule; + Builder mBuilder; }; } // namespace spirit @@ -346,12 +351,9 @@ private: namespace rs2spirv { -std::vector<uint32_t> TranslateBuiltins(android::spirit::Builder &b, - android::spirit::Module *m, - int *error) { - android::spirit::BuiltinTransformer trans(&b, m); - *error = 0; - return trans.transformSerialize(m); +android::spirit::Pass *CreateBuiltinPass() { + return new android::spirit::BuiltinTransformer(); } } // namespace rs2spirv + diff --git a/rsov/compiler/Builtin.h b/rsov/compiler/Builtin.h index d293c9da..2a280979 100644 --- a/rsov/compiler/Builtin.h +++ b/rsov/compiler/Builtin.h @@ -17,23 +17,17 @@ #ifndef BUILTIN_H #define BUILTIN_H -#include <stdint.h> - -#include <vector> - namespace android { namespace spirit { -class Builder; -class Module; +class Pass; } // namespace spirit } // namespace android namespace rs2spirv { -std::vector<uint32_t> TranslateBuiltins(android::spirit::Builder &b, - android::spirit::Module *m, int *error); +android::spirit::Pass *CreateBuiltinPass(); } // namespace rs2spirv diff --git a/rsov/compiler/Builtin_test.cpp b/rsov/compiler/Builtin_test.cpp index 3a5f0c44..c180add9 100644 --- a/rsov/compiler/Builtin_test.cpp +++ b/rsov/compiler/Builtin_test.cpp @@ -17,6 +17,7 @@ #include "Builtin.h" #include "file_utils.h" +#include "pass_queue.h" #include "spirit.h" #include "test_utils.h" #include "gtest/gtest.h" @@ -30,19 +31,10 @@ TEST(BuiltinTest, testBuiltinTranslation) { "frameworks/rs/rsov/compiler/spirit/test_data/"); const std::string &fullPath = getAbsolutePath(testDataPath + testFile); auto words = readFile<uint32_t>(fullPath); - std::unique_ptr<InputWordStream> IS( - InputWordStream::Create(std::move(words))); - std::unique_ptr<Module> m(Deserialize<Module>(*IS)); - ASSERT_NE(nullptr, m); - - Builder b; - m->setBuilder(&b); - - int error; - auto words1 = rs2spirv::TranslateBuiltins(b, m.get(), &error); - - ASSERT_EQ(0, error); + PassQueue passes; + passes.append(rs2spirv::CreateBuiltinPass()); + auto words1 = passes.run(words); std::unique_ptr<InputWordStream> IS1( InputWordStream::Create(std::move(words1))); diff --git a/rsov/compiler/GlobalAllocSPIRITPass.cpp b/rsov/compiler/GlobalAllocSPIRITPass.cpp index f9b1066e..61a75129 100644 --- a/rsov/compiler/GlobalAllocSPIRITPass.cpp +++ b/rsov/compiler/GlobalAllocSPIRITPass.cpp @@ -22,6 +22,53 @@ namespace android { namespace spirit { +namespace { + +// Metadata buffer for global allocations +// struct metadata { +// uint32_t element_size; +// uint32_t x_size; +// uint32_t y_size; +// uint32_t ?? +// }; +VariableInst *AddGAMetadata(Builder &b, Module *m) { + TypeIntInst *UInt32Ty = m->getUnsignedIntType(32); + std::vector<Instruction *> metadata{ + UInt32Ty, + UInt32Ty, + UInt32Ty, + UInt32Ty + }; + auto MetadataStructTy = m->getStructType(metadata.data(), metadata.size()); + // FIXME: workaround on a weird OpAccessChain member offset problem. Somehow + // when given constant indices, OpAccessChain returns pointers that are 4 bytes + // less than what are supposed to be (at runtime). + // For now workaround this with +4 the member offsets. + MetadataStructTy->memberDecorate(0, Decoration::Offset)->addExtraOperand(4); + MetadataStructTy->memberDecorate(1, Decoration::Offset)->addExtraOperand(8); + MetadataStructTy->memberDecorate(2, Decoration::Offset)->addExtraOperand(12); + MetadataStructTy->memberDecorate(3, Decoration::Offset)->addExtraOperand(16); + // TBD: Implement getArrayType. RuntimeArray requires buffers and hence we + // cannot use PushConstant underneath + auto MetadataBufSTy = m->getRuntimeArrayType(MetadataStructTy); + // Stride of metadata. + MetadataBufSTy->decorate(Decoration::ArrayStride)->addExtraOperand( + metadata.size()*sizeof(uint32_t)); + auto MetadataSSBO = m->getStructType(MetadataBufSTy); + MetadataSSBO->decorate(Decoration::BufferBlock); + auto MetadataPtrTy = m->getPointerType(StorageClass::Uniform, MetadataSSBO); + + + VariableInst *MetadataVar = b.MakeVariable(MetadataPtrTy, StorageClass::Uniform); + MetadataVar->decorate(Decoration::DescriptorSet)->addExtraOperand(0); + MetadataVar->decorate(Decoration::Binding)->addExtraOperand(0); + m->addVariable(MetadataVar); + + return MetadataVar; +} + +} // anonymous namespace + // Replacing calls to lowered accessors, e.g., __rsov_rsAllocationGetDimX // which was created from rsAllocationGetDimX by replacing the allocation // with an ID in an earlier LLVM pass (see GlobalAllocationPass.cpp), @@ -54,13 +101,15 @@ namespace spirit { class GAAccessorTransformer : public Transformer { public: - GAAccessorTransformer(Builder *b, Module *m, VariableInst *metadata) - : mBuilder(b), mModule(m), mMetadata(metadata) {} + std::vector<uint32_t> runAndSerialize(Module *module, int *error) override { + mMetadata = AddGAMetadata(mBuilder, module); + return Transformer::runAndSerialize(module, error); + } Instruction *transform(FunctionCallInst *call) { FunctionInst *func = static_cast<FunctionInst *>(call->mOperand1.mInstruction); - const char *name = mModule->lookupNameByInstruction(func); + const char *name = getModule()->lookupNameByInstruction(func); if (!name) { return call; } @@ -69,19 +118,19 @@ public: // Maps name into a SPIR-V instruction // TODO: generalize it to support more accessors if (!strcmp(name, "__rsov_rsAllocationGetDimX")) { - TypeIntInst *UInt32Ty = mModule->getUnsignedIntType(32); + TypeIntInst *UInt32Ty = getModule()->getUnsignedIntType(32); // TODO: hardcoded layout - auto ConstZero = mModule->getConstant(UInt32Ty, 0U); - auto ConstOne = mModule->getConstant(UInt32Ty, 1U); + auto ConstZero = getModule()->getConstant(UInt32Ty, 0U); + auto ConstOne = getModule()->getConstant(UInt32Ty, 1U); // TODO: Use constant memory later auto resultPtrType = - mModule->getPointerType(StorageClass::Uniform, UInt32Ty); - AccessChainInst *LoadPtr = mBuilder->MakeAccessChain( + getModule()->getPointerType(StorageClass::Uniform, UInt32Ty); + AccessChainInst *LoadPtr = mBuilder.MakeAccessChain( resultPtrType, mMetadata, {ConstZero, ConstZero, ConstOne}); insert(LoadPtr); - inst = mBuilder->MakeLoad(UInt32Ty, LoadPtr); + inst = mBuilder.MakeLoad(UInt32Ty, LoadPtr); inst->setId(call->getId()); } else { inst = call; @@ -90,8 +139,7 @@ public: } private: - Builder *mBuilder; - Module *mModule; + Builder mBuilder; VariableInst *mMetadata; }; @@ -100,13 +148,8 @@ private: namespace rs2spirv { -// android::spirit::Module * -std::vector<uint32_t> -TranslateGAAccessors(android::spirit::Builder &b, android::spirit::Module *m, - android::spirit::VariableInst *metadata, int *error) { - android::spirit::GAAccessorTransformer trans(&b, m, metadata); - *error = 0; - return trans.transformSerialize(m); +android::spirit::Pass *CreateGAPass() { + return new android::spirit::GAAccessorTransformer(); } } // namespace rs2spirv diff --git a/rsov/compiler/GlobalAllocSPIRITPass.h b/rsov/compiler/GlobalAllocSPIRITPass.h index e979bfa5..b3a6ecb0 100644 --- a/rsov/compiler/GlobalAllocSPIRITPass.h +++ b/rsov/compiler/GlobalAllocSPIRITPass.h @@ -17,26 +17,17 @@ #ifndef GLOBALALLOCSPIRITPASS_H #define GLOBALALLOCSPIRITPASS_H -#include <stdint.h> - -#include <vector> - namespace android { namespace spirit { -class Builder; -class Module; -class VariableInst; +class Pass; } // namespace spirit } // namespace android namespace rs2spirv { -// android::spirit::Module * -std::vector<uint32_t> -TranslateGAAccessors(android::spirit::Builder &b, android::spirit::Module *m, - android::spirit::VariableInst *metadata, int *error); +android::spirit::Pass *CreateGAPass(); } // namespace rs2spirv diff --git a/rsov/compiler/RSSPIRVWriter.cpp b/rsov/compiler/RSSPIRVWriter.cpp index 19b5703f..541f5bc7 100644 --- a/rsov/compiler/RSSPIRVWriter.cpp +++ b/rsov/compiler/RSSPIRVWriter.cpp @@ -30,11 +30,14 @@ #include "llvm/Transforms/IPO.h" #include "llvm/Transforms/Scalar.h" +#include "Builtin.h" #include "GlobalAllocPass.h" +#include "GlobalAllocSPIRITPass.h" #include "GlobalMergePass.h" #include "InlinePreparationPass.h" #include "RemoveNonkernelsPass.h" #include "Wrapper.h" +#include "pass_queue.h" #include <fstream> #include <sstream> @@ -137,8 +140,13 @@ bool WriteSPIRV(llvm::Module *M, llvm::raw_ostream &OS, std::string &ErrMsg) { memcpy(words.data(), str.data(), str.size()); + android::spirit::PassQueue spiritPasses; + spiritPasses.append(CreateWrapperPass(ME, *M)); + spiritPasses.append(CreateBuiltinPass()); + spiritPasses.append(CreateGAPass()); + int error; - auto wordsOut = AddGLComputeWrappers(words, ME, *M, &error); + auto wordsOut = spiritPasses.run(words, &error); if (error != 0) { OS << *BM; diff --git a/rsov/compiler/Wrapper.cpp b/rsov/compiler/Wrapper.cpp index 8f1a1d65..66a00fc1 100644 --- a/rsov/compiler/Wrapper.cpp +++ b/rsov/compiler/Wrapper.cpp @@ -16,6 +16,8 @@ #include "Wrapper.h" +#include "llvm/IR/Module.h" + #include "Builtin.h" #include "GlobalAllocSPIRITPass.h" #include "RSAllocationUtils.h" @@ -23,57 +25,16 @@ #include "builder.h" #include "instructions.h" #include "module.h" +#include "pass.h" #include "word_stream.h" -#include "llvm/IR/Module.h" + +#include <vector> using bcinfo::MetadataExtractor; + namespace android { namespace spirit { -// Metadata buffer for global allocations -// struct metadata { -// uint32_t element_size; -// uint32_t x_size; -// uint32_t y_size; -// uint32_t ?? -// }; -VariableInst *AddGAMetadata(/*Instruction *elementType, uint32_t binding, */ Builder &b, - Module *m) { - TypeIntInst *UInt32Ty = m->getUnsignedIntType(32); - std::vector<Instruction *> metadata{ - UInt32Ty, - UInt32Ty, - UInt32Ty, - UInt32Ty - }; - auto MetadataStructTy = m->getStructType(metadata.data(), metadata.size()); - // FIXME: workaround on a weird OpAccessChain member offset problem. Somehow - // when given constant indices, OpAccessChain returns pointers that are 4 bytes - // less than what are supposed to be (at runtime). - // For now workaround this with +4 the member offsets. - MetadataStructTy->memberDecorate(0, Decoration::Offset)->addExtraOperand(4); - MetadataStructTy->memberDecorate(1, Decoration::Offset)->addExtraOperand(8); - MetadataStructTy->memberDecorate(2, Decoration::Offset)->addExtraOperand(12); - MetadataStructTy->memberDecorate(3, Decoration::Offset)->addExtraOperand(16); - // TBD: Implement getArrayType. RuntimeArray requires buffers and hence we - // cannot use PushConstant underneath - auto MetadataBufSTy = m->getRuntimeArrayType(MetadataStructTy); - // Stride of metadata. - MetadataBufSTy->decorate(Decoration::ArrayStride)->addExtraOperand( - metadata.size()*sizeof(uint32_t)); - auto MetadataSSBO = m->getStructType(MetadataBufSTy); - MetadataSSBO->decorate(Decoration::BufferBlock); - auto MetadataPtrTy = m->getPointerType(StorageClass::Uniform, MetadataSSBO); - - - VariableInst *MetadataVar = b.MakeVariable(MetadataPtrTy, StorageClass::Uniform); - MetadataVar->decorate(Decoration::DescriptorSet)->addExtraOperand(0); - MetadataVar->decorate(Decoration::Binding)->addExtraOperand(0); - m->addVariable(MetadataVar); - - return MetadataVar; -} - VariableInst *AddBuffer(Instruction *elementType, uint32_t binding, Builder &b, Module *m) { auto ArrTy = m->getRuntimeArrayType(elementType); @@ -356,8 +317,6 @@ void AddHeader(Module *m) { // m->addCapability(Capability::Addresses); m->setMemoryModel(AddressingModel::Physical32, MemoryModel::GLSL450); - m->addExtInstImport("GLSL.std.450"); - m->addSource(SourceLanguage::GLSL, 450); m->addSourceExtension("GL_ARB_separate_shader_objects"); m->addSourceExtension("GL_ARB_shading_language_420pack"); @@ -394,45 +353,18 @@ void FixGlobalStorageClass(Module *m) { } // anonymous namespace -} // namespace spirit -} // namespace android - -using android::spirit::AddHeader; -using android::spirit::AddWrapper; -using android::spirit::DecorateGlobalBuffer; -using android::spirit::InputWordStream; -using android::spirit::FixGlobalStorageClass; - -namespace rs2spirv { - -std::vector<uint32_t> -AddGLComputeWrappers(const std::vector<uint32_t> &kernel_spirv, - const bcinfo::MetadataExtractor &metadata, - llvm::Module &LM, int *error) { - std::unique_ptr<InputWordStream> IS( - InputWordStream::Create(std::move(kernel_spirv))); - std::unique_ptr<android::spirit::Module> m( - android::spirit::Deserialize<android::spirit::Module>(*IS)); - - if (!m) { - *error = -1; - return std::vector<uint32_t>(); - } - - if (!m->resolveIds()) { - *error = -2; - return std::vector<uint32_t>(); - } - +bool AddWrappers(const bcinfo::MetadataExtractor &metadata, + llvm::Module &LM, + android::spirit::Module *m) { android::spirit::Builder b; m->setBuilder(&b); - FixGlobalStorageClass(m.get()); + FixGlobalStorageClass(m); - AddHeader(m.get()); + AddHeader(m); - DecorateGlobalBuffer(LM, b, m.get()); + DecorateGlobalBuffer(LM, b, m); const size_t numKernel = metadata.getExportForEachSignatureCount(); const char **kernelName = metadata.getExportForEachNameList(); @@ -441,53 +373,43 @@ AddGLComputeWrappers(const std::vector<uint32_t> &kernel_spirv, for (size_t i = 0; i < numKernel; i++) { bool success = - AddWrapper(kernelName[i], kernelSigature[i], inputCount[i], b, m.get()); + AddWrapper(kernelName[i], kernelSigature[i], inputCount[i], b, m); if (!success) { - *error = -3; - return std::vector<uint32_t>(); + return false; } } m->consolidateAnnotations(); - auto words = rs2spirv::TranslateBuiltins(b, m.get(), error); - - // Recreate a module in known state after TranslateBuiltins - std::unique_ptr<InputWordStream> IS1( - InputWordStream::Create(std::move(words))); - std::unique_ptr<android::spirit::Module> m1( - android::spirit::Deserialize<android::spirit::Module>(*IS1)); - - if (!m1) { - *error = -1; - return std::vector<uint32_t>(); - } + return true; +} - if (!m1->resolveIds()) { - *error = -2; - return std::vector<uint32_t>(); +class WrapperPass : public Pass { +public: + WrapperPass(const bcinfo::MetadataExtractor &Metadata, + const llvm::Module &LM) : mLLVMMetadata(Metadata), + mLLVMModule(const_cast<llvm::Module&>(LM)) {} + + Module *run(Module *m, int *error) override { + bool success = AddWrappers(mLLVMMetadata, mLLVMModule, m); + if (error) { + *error = success ? 0 : -1; + } + return m; } - // Builders can be reused - m1->setBuilder(&b); - - // Create types and variable declarations for global allocation metadata - android::spirit::VariableInst *GAmetadata = AddGAMetadata(b, m1.get()); +private: + const bcinfo::MetadataExtractor &mLLVMMetadata; + llvm::Module &mLLVMModule; +}; - // Adding types on-the-fly inside a transformer is not well suported now; - // creating them here before we enter transformer to avoid problems. - // TODO: Fix the transformer - android::spirit::TypeIntInst *UInt32Ty = m1->getUnsignedIntType(32); - m1->getConstant(UInt32Ty, 0U); - m1->getConstant(UInt32Ty, 1U); - // TODO: Use constant memory for metadata - m1->getPointerType(android::spirit::StorageClass::Uniform, - UInt32Ty); +} // namespace spirit +} // namespace android - // Transform calls to lowered allocation accessors to use metadata - // TODO: implement the lowering pass in LLVM - m1->consolidateAnnotations(); - return rs2spirv::TranslateGAAccessors(b, m1.get(), GAmetadata, error); +namespace rs2spirv { +android::spirit::Pass* CreateWrapperPass(const bcinfo::MetadataExtractor &metadata, + const llvm::Module &LLVMModule) { + return new android::spirit::WrapperPass(metadata, LLVMModule); } } // namespace rs2spirv diff --git a/rsov/compiler/Wrapper.h b/rsov/compiler/Wrapper.h index ead4e360..ca2a116e 100644 --- a/rsov/compiler/Wrapper.h +++ b/rsov/compiler/Wrapper.h @@ -17,7 +17,7 @@ #ifndef WRAPPER_H #define WRAPPER_H -#include <vector> +#include <stdint.h> namespace bcinfo { class MetadataExtractor; @@ -27,21 +27,13 @@ namespace llvm { class Module; } -namespace rs2spirv { - -std::vector<uint32_t> -AddGLComputeWrappers(const std::vector<uint32_t> &kernel_spirv, - const bcinfo::MetadataExtractor &metadata, llvm::Module &M, - int *error); - -} // namespace rs2spirv - namespace android { namespace spirit { class Builder; class Instruction; class Module; +class Pass; class VariableInst; // TODO: avoid exposing these methods while still unit testing them @@ -59,4 +51,11 @@ bool DecorateGlobalBuffer(llvm::Module &M, Builder &b, Module *m); } // namespace spirit } // namespace android +namespace rs2spirv { + +android::spirit::Pass* CreateWrapperPass(const bcinfo::MetadataExtractor &metadata, + const llvm::Module &LLVMModule); + +} // namespace rs2spirv + #endif diff --git a/rsov/compiler/spirit/Android.mk b/rsov/compiler/spirit/Android.mk index d0dba157..ad1c6479 100644 --- a/rsov/compiler/spirit/Android.mk +++ b/rsov/compiler/spirit/Android.mk @@ -21,6 +21,8 @@ SPIRIT_SRCS := \ entity.cpp\ instructions.cpp\ module.cpp\ + pass.cpp\ + pass_queue.cpp\ transformer.cpp\ visitor.cpp\ word_stream.cpp\ @@ -148,6 +150,29 @@ LOCAL_C_INCLUDES += $(PATH_TO_GENERATED) include $(BUILD_HOST_NATIVE_TEST) #===================================================================== +# Tests for host module pass queue +#===================================================================== + +include $(CLEAR_VARS) + +LOCAL_SRC_FILES := \ + pass.cpp \ + pass_queue.cpp \ + pass_queue_test.cpp \ + +LOCAL_STATIC_LIBRARIES := libgtest_host + +LOCAL_SHARED_LIBRARIES := $(LIBNAME) + +LOCAL_MODULE := pass_queue_test +LOCAL_MULTILIB := first +LOCAL_MODULE_TAGS := tests +LOCAL_MODULE_CLASS := NATIVE_TESTS +LOCAL_IS_HOST_MODULE := true + +include $(BUILD_HOST_NATIVE_TEST) + +#===================================================================== # Tests for host shared library #===================================================================== diff --git a/rsov/compiler/spirit/builder_test.cpp b/rsov/compiler/spirit/builder_test.cpp index f4d69b4b..111bddb1 100644 --- a/rsov/compiler/spirit/builder_test.cpp +++ b/rsov/compiler/spirit/builder_test.cpp @@ -15,6 +15,7 @@ */ #include "builder.h" + #include "file_utils.h" #include "instructions.h" #include "module.h" diff --git a/rsov/compiler/spirit/core_defs.h b/rsov/compiler/spirit/core_defs.h index d37f41d1..06b1d1b0 100644 --- a/rsov/compiler/spirit/core_defs.h +++ b/rsov/compiler/spirit/core_defs.h @@ -17,13 +17,15 @@ #ifndef CORE_DEFS_H #define CORE_DEFS_H +#include <string> + namespace android { namespace spirit { class Instruction; typedef int32_t LiteralInteger; -typedef const char *LiteralString; +typedef std::string LiteralString; typedef union { int32_t intValue; int64_t longValue; diff --git a/rsov/compiler/spirit/instructions.h b/rsov/compiler/spirit/instructions.h index 88d188c0..6e162a1d 100644 --- a/rsov/compiler/spirit/instructions.h +++ b/rsov/compiler/spirit/instructions.h @@ -20,6 +20,7 @@ #include <stdint.h> #include <iostream> +#include <string> #include <vector> #include "core_defs.h" @@ -38,8 +39,8 @@ template <typename T> uint16_t WordCount(T) { return 1; } inline uint16_t WordCount(PairLiteralIntegerIdRef) { return 2; } inline uint16_t WordCount(PairIdRefLiteralInteger) { return 2; } inline uint16_t WordCount(PairIdRefIdRef) { return 2; } -inline uint16_t WordCount(const char *operand) { - return strlen(operand) / 4 + 1; +inline uint16_t WordCount(const std::string &operand) { + return operand.length() / 4 + 1; } class Instruction : public Entity { diff --git a/rsov/compiler/spirit/instructions_test.cpp b/rsov/compiler/spirit/instructions_test.cpp index 1156591c..7b089caa 100644 --- a/rsov/compiler/spirit/instructions_test.cpp +++ b/rsov/compiler/spirit/instructions_test.cpp @@ -14,13 +14,14 @@ * limitations under the License. */ -#include <memory> -#include <vector> - #include "instructions.h" + #include "word_stream.h" #include "gtest/gtest.h" +#include <memory> +#include <vector> + namespace android { namespace spirit { @@ -39,7 +40,7 @@ TEST(InstructionTest, testOpExtension) { std::unique_ptr<InputWordStream> IS(InputWordStream::Create(words)); auto *i = Deserialize<ExtensionInst>(*IS); ASSERT_NE(nullptr, i); - EXPECT_STREQ("ABCDEFG", i->mOperand1); + EXPECT_STREQ("ABCDEFG", i->mOperand1.c_str()); } TEST(InstructionTest, testOpExtInstImport) { @@ -51,7 +52,7 @@ TEST(InstructionTest, testOpExtInstImport) { std::unique_ptr<InputWordStream> IS(InputWordStream::Create(words)); auto *i = Deserialize<ExtInstImportInst>(*IS); ASSERT_NE(nullptr, i); - EXPECT_STREQ("GLSL.std.450", i->mOperand1); + EXPECT_STREQ("GLSL.std.450", i->mOperand1.c_str()); } } // namespace spirit diff --git a/rsov/compiler/spirit/module.cpp b/rsov/compiler/spirit/module.cpp index ac413b3f..98001124 100644 --- a/rsov/compiler/spirit/module.cpp +++ b/rsov/compiler/spirit/module.cpp @@ -519,7 +519,7 @@ Instruction *DebugInfoSection::lookupByName(const char *name) const { for (auto inst : mNames) { if (inst->getOpCode() == OpName) { NameInst *nameInst = static_cast<NameInst *>(inst); - if (strcmp(nameInst->mOperand2, name) == 0) { + if (nameInst->mOperand2.compare(name) == 0) { return nameInst->mOperand1.mInstruction; } } @@ -534,7 +534,7 @@ DebugInfoSection::lookupNameByInstruction(const Instruction *target) const { if (inst->getOpCode() == OpName) { NameInst *nameInst = static_cast<NameInst *>(inst); if (nameInst->mOperand1.mInstruction == target) { - return nameInst->mOperand2; + return nameInst->mOperand2.c_str(); } } // Ignore member names @@ -684,12 +684,8 @@ GlobalSection::getConstantComposite(TypeVectorInst *type, TypeVoidInst *GlobalSection::getVoidType() { return findOrCreate<TypeVoidInst>( - [=](TypeVoidInst *) -> bool { - return true; - }, - [=]() -> TypeVoidInst * { - return mBuilder->MakeTypeVoid(); - }, + [=](TypeVoidInst *) -> bool { return true; }, + [=]() -> TypeVoidInst * { return mBuilder->MakeTypeVoid(); }, &mGlobalDefs); } @@ -698,14 +694,14 @@ TypeIntInst *GlobalSection::getIntType(int bits, bool isSigned) { switch (bits) { #define HANDLE_INT_SIZE(INT_TYPE, BITS, SIGNED) \ case BITS: { \ - return findOrCreate<TypeIntInst>( \ - [=](TypeIntInst *intTy) -> bool {\ - return intTy->mOperand1 == BITS && intTy->mOperand2 == SIGNED; \ - }, \ - [=]() -> TypeIntInst * { \ - return mBuilder->MakeTypeInt(BITS, SIGNED); \ - }, \ - &mGlobalDefs); \ + return findOrCreate<TypeIntInst>( \ + [=](TypeIntInst *intTy) -> bool { \ + return intTy->mOperand1 == BITS && intTy->mOperand2 == SIGNED; \ + }, \ + [=]() -> TypeIntInst * { \ + return mBuilder->MakeTypeInt(BITS, SIGNED); \ + }, \ + &mGlobalDefs); \ } HANDLE_INT_SIZE(Int, 8, 1); HANDLE_INT_SIZE(Int, 16, 1); @@ -732,14 +728,12 @@ TypeFloatInst *GlobalSection::getFloatType(int bits) { switch (bits) { #define HANDLE_FLOAT_SIZE(BITS) \ case BITS: { \ - return findOrCreate<TypeFloatInst>( \ - [=](TypeFloatInst *floatTy) -> bool {\ - return floatTy->mOperand1 == BITS; \ - }, \ - [=]() -> TypeFloatInst * { \ - return mBuilder->MakeTypeFloat(BITS); \ - }, \ - &mGlobalDefs); \ + return findOrCreate<TypeFloatInst>( \ + [=](TypeFloatInst *floatTy) -> bool { \ + return floatTy->mOperand1 == BITS; \ + }, \ + [=]() -> TypeFloatInst * { return mBuilder->MakeTypeFloat(BITS); }, \ + &mGlobalDefs); \ } HANDLE_FLOAT_SIZE(16); HANDLE_FLOAT_SIZE(32); diff --git a/rsov/compiler/spirit/module.h b/rsov/compiler/spirit/module.h index 7ec41ff2..a69d255f 100644 --- a/rsov/compiler/spirit/module.h +++ b/rsov/compiler/spirit/module.h @@ -23,6 +23,7 @@ #include "core_defs.h" #include "entity.h" +#include "instructions.h" #include "stl_util.h" #include "types_generated.h" #include "visitor.h" @@ -134,6 +135,10 @@ public: getFunctionDefinitionFromInstruction(FunctionInst *) const; FunctionDefinition *lookupFunctionDefinitionByName(const char *name) const; + // Find the name of the instruction, e.g., the name of a function (OpFunction + // instruction). + // The returned string is owned by the OpName instruction, whose first operand + // is the instruction being queried on. const char *lookupNameByInstruction(const Instruction *) const; VariableInst *getInvocationId(); diff --git a/rsov/compiler/spirit/module_test.cpp b/rsov/compiler/spirit/module_test.cpp index 37f52e04..f24268fa 100644 --- a/rsov/compiler/spirit/module_test.cpp +++ b/rsov/compiler/spirit/module_test.cpp @@ -14,16 +14,17 @@ * limitations under the License. */ -#include <fstream> -#include <memory> +#include "module.h" #include "file_utils.h" #include "instructions.h" -#include "module.h" #include "test_utils.h" #include "word_stream.h" #include "gtest/gtest.h" +#include <fstream> +#include <memory> + namespace android { namespace spirit { diff --git a/rsov/compiler/spirit/pass.cpp b/rsov/compiler/spirit/pass.cpp new file mode 100644 index 00000000..6653f96c --- /dev/null +++ b/rsov/compiler/spirit/pass.cpp @@ -0,0 +1,54 @@ +/* + * Copyright 2017, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pass.h" + +#include "module.h" +#include "word_stream.h" + +namespace android { +namespace spirit { + +Module *Pass::run(Module *module, int *error) { + int intermediateError; + auto words = runAndSerialize(module, &intermediateError); + if (intermediateError) { + if (error) { + *error = intermediateError; + } + return nullptr; + } + std::unique_ptr<InputWordStream> IS(InputWordStream::Create(words)); + return Deserialize<Module>(*IS); +} + +std::vector<uint32_t> Pass::runAndSerialize(Module *module, int *error) { + int intermediateError; + auto m1 = run(module, &intermediateError); + if (intermediateError) { + if (error) { + *error = intermediateError; + } + return std::vector<uint32_t>(); + } + std::unique_ptr<OutputWordStream> OS(OutputWordStream::Create()); + m1->Serialize(*OS); + return OS->getWords(); +} + +} // namespace spirit +} // namespace android + diff --git a/rsov/compiler/spirit/pass.h b/rsov/compiler/spirit/pass.h new file mode 100644 index 00000000..1d493a1d --- /dev/null +++ b/rsov/compiler/spirit/pass.h @@ -0,0 +1,51 @@ +/* + * Copyright 2017, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef RSOV_COMPILER_SPIRIT_PASS_H +#define RSOV_COMPILER_SPIRIT_PASS_H + +#include <stdint.h> + +#include <vector> + +namespace android { +namespace spirit { + +class Module; + +// The base class for a pass, either an analysis or a transformation of a +// Module. An instanace of a derived class can be added to a PassQueue and +// applied to a Module, and produce a result Module with other passes. +class Pass { +public: + virtual ~Pass() {} + + // Runs the pass on the input module and returns the result module. + // If argument error is not null, set the error code. On a successful run, + // error code is set to zero. + virtual Module *run(Module *module, int *error); + + // Runs the pass on the input module, serializes the result module, and + // returns the words as a vector. + // If argument error is not null, set the error code. On a successful run, + // error code is set to zero. + virtual std::vector<uint32_t> runAndSerialize(Module *module, int *error); +}; + +} // namespace spirit +} // namespace android + +#endif // RSOV_COMPILER_SPIRIT_PASS_H diff --git a/rsov/compiler/spirit/pass_queue.cpp b/rsov/compiler/spirit/pass_queue.cpp new file mode 100644 index 00000000..0a37e0a0 --- /dev/null +++ b/rsov/compiler/spirit/pass_queue.cpp @@ -0,0 +1,122 @@ +/* + * Copyright 2017, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pass_queue.h" + +#include "module.h" +#include "word_stream.h" + +namespace android { +namespace spirit { + +bool PassQueue::append(Pass *pass) { + mPasses.push_back(pass); + mPassSet.insert(pass); + return true; +} + +Module *PassQueue::run(Module *module, int *error) { + if (mPasses.empty()) { + return module; + } + + // A unique ptr to keep intermediate modules from leaking + std::unique_ptr<Module> tempModule; + + for (auto pass : mPasses) { + int intermediateError = 0; + Module* newModule = pass->run(module, &intermediateError); + // Some passes modify the input module in place, while others create a new + // module. Release memory only when it is necessary. + if (newModule != module) { + tempModule.reset(newModule); + } + module = newModule; + if (intermediateError) { + if (error) { + *error = intermediateError; + } + return nullptr; + } + if (!module || !module->resolveIds()) { + if (error) { + *error = -1; + } + return nullptr; + } + } + + if (tempModule == nullptr) { + return module; + } + + return tempModule.release(); +} + +std::vector<uint32_t> PassQueue::run(const std::vector<uint32_t> &spirvWords, + int *error) { + if (mPasses.empty()) { + return spirvWords; + } + + std::unique_ptr<InputWordStream> IS( + InputWordStream::Create(std::move(spirvWords))); + Module *module = Deserialize<Module>(*IS); + if (!module || !module->resolveIds()) { + return std::vector<uint32_t>(); + } + + return runAndSerialize(module, error); +} + +std::vector<uint32_t> PassQueue::runAndSerialize(Module *module, int *error) { + const int n = mPasses.size(); + if (n < 1) { + std::unique_ptr<OutputWordStream> OS(OutputWordStream::Create()); + module->Serialize(*OS); + return OS->getWords(); + } + + // A unique ptr to keep intermediate modules from leaking + std::unique_ptr<Module> tempModule; + + for (int i = 0; i < n - 1; i++) { + int intermediateError = 0; + Module *newModule = mPasses[i]->run(module, &intermediateError); + // Some passes modify the input module in place, while others create a new + // module. Release memory only when it is necessary. + if (newModule != module) { + tempModule.reset(newModule); + } + module = newModule; + if (intermediateError) { + if (error) { + *error = intermediateError; + } + return std::vector<uint32_t>(); + } + if (!module || !module->resolveIds()) { + if (error) { + *error = -1; + } + return std::vector<uint32_t>(); + } + } + return mPasses[n - 1]->runAndSerialize(module, error); +} + +} // namespace spirit +} // namespace android diff --git a/rsov/compiler/spirit/pass_queue.h b/rsov/compiler/spirit/pass_queue.h new file mode 100644 index 00000000..f00dcf7f --- /dev/null +++ b/rsov/compiler/spirit/pass_queue.h @@ -0,0 +1,77 @@ +/* + * Copyright 2017, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef RSOV_COMPILER_SPIRIT_PASS_QUEUE_H +#define RSOV_COMPILER_SPIRIT_PASS_QUEUE_H + +#include "module.h" +#include "pass.h" +#include "stl_util.h" + +#include <stdint.h> + +#include <memory> +#include <set> +#include <vector> + +namespace android { +namespace spirit { + +// A FIFO of passes. Passes are appended to the end of the FIFO and run in the +// first-in first-out order. Once appended to a pass queue, Passes are owned by +// the queue. +class PassQueue { +public: + PassQueue() : mPassesDeleter(mPassSet) {} + + // Appends a pass to the end of the queue + bool append(Pass *pass); + + // Runs all passes in the queue in the first-in first-out order on a Module. + // Returns the result Module after all passes have run. + // If argument error is not null, sets the error code. On a successful run, + // error code is set to zero. + Module *run(Module *module, int *error = nullptr); + + // Deserialize the input vector of words into a Module, and runs all passes in + // the queue in the first-in first-out order on the Module. + // for a serialized Module. + // After all the passes have run, returns the words from the serialized result + // Module. + // If argument error is not null, sets the error code. On a successful run, + // error code is set to zero. + std::vector<uint32_t> run(const std::vector<uint32_t> &spirvWords, + int *error = nullptr); + + // Runs all passes in the queue in the first-in first-out order on a Module. + // After all the passes have run, serializes the result Module, and returns + // the words as a vector. + // If argument error is not null, sets the error code. On a successful run, + // error code is set to zero. + std::vector<uint32_t> runAndSerialize(Module *module, int *error = nullptr); + +private: + std::vector<Pass *> mPasses; + // Keep all passes in a set so that we can delete them on destruction without + // worrying about duplicates + std::set<Pass *> mPassSet; + ContainerDeleter<std::set<Pass *>> mPassesDeleter; +}; + +} // spirit +} // android + +#endif // RSOV_COMPILER_SPIRIT_PASS_QUEUE_H diff --git a/rsov/compiler/spirit/pass_queue_test.cpp b/rsov/compiler/spirit/pass_queue_test.cpp new file mode 100644 index 00000000..0e7c183f --- /dev/null +++ b/rsov/compiler/spirit/pass_queue_test.cpp @@ -0,0 +1,308 @@ +/* + * Copyright 2017, The Android Open Source Project + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pass_queue.h" + +#include "file_utils.h" +#include "spirit.h" +#include "test_utils.h" +#include "transformer.h" +#include "word_stream.h" +#include "gtest/gtest.h" + +#include <stdint.h> + +namespace android { +namespace spirit { + +namespace { + +class MulToAddTransformer : public Transformer { +public: + Instruction *transform(IMulInst *mul) override { + auto ret = new IAddInst(mul->mResultType, mul->mOperand1, mul->mOperand2); + ret->setId(mul->getId()); + return ret; + } +}; + +class AddToDivTransformer : public Transformer { +public: + Instruction *transform(IAddInst *add) override { + auto ret = new SDivInst(add->mResultType, add->mOperand1, add->mOperand2); + ret->setId(add->getId()); + return ret; + } +}; + +class AddMulAfterAddTransformer : public Transformer { +public: + Instruction *transform(IAddInst *add) override { + insert(add); + auto ret = new IMulInst(add->mResultType, add, add); + ret->setId(add->getId()); + return ret; + } +}; + +class Deleter : public Transformer { +public: + Instruction *transform(IMulInst *) override { return nullptr; } +}; + +class InPlaceModifyingPass : public Pass { +public: + Module *run(Module *m, int *error) override { + m->getFloatType(64); + if (error) { + *error = 0; + } + return m; + } +}; + +} // annonymous namespace + +class PassQueueTest : public ::testing::Test { +protected: + virtual void SetUp() { mWordsGreyscale = readWords("greyscale.spv"); } + + std::vector<uint32_t> mWordsGreyscale; + +private: + std::vector<uint32_t> readWords(const char *testFile) { + static const std::string testDataPath( + "frameworks/rs/rsov/compiler/spirit/test_data/"); + const std::string &fullPath = getAbsolutePath(testDataPath + testFile); + return readFile<uint32_t>(fullPath); + } +}; + +TEST_F(PassQueueTest, testMulToAdd) { + std::unique_ptr<InputWordStream> IS(InputWordStream::Create(mWordsGreyscale)); + std::unique_ptr<Module> m(Deserialize<Module>(*IS)); + + ASSERT_NE(nullptr, m); + + EXPECT_EQ(1, countEntity<IAddInst>(m.get())); + EXPECT_EQ(1, countEntity<IMulInst>(m.get())); + + PassQueue passes; + passes.append(new MulToAddTransformer()); + auto m1 = passes.run(m.get()); + + ASSERT_NE(nullptr, m1); + + ASSERT_TRUE(m1->resolveIds()); + + EXPECT_EQ(2, countEntity<IAddInst>(m1)); + EXPECT_EQ(0, countEntity<IMulInst>(m1)); +} + +TEST_F(PassQueueTest, testInPlaceModifying) { + std::unique_ptr<InputWordStream> IS(InputWordStream::Create(mWordsGreyscale)); + std::unique_ptr<Module> m(Deserialize<Module>(*IS)); + + ASSERT_NE(nullptr, m); + + EXPECT_EQ(1, countEntity<IAddInst>(m.get())); + EXPECT_EQ(1, countEntity<IMulInst>(m.get())); + EXPECT_EQ(1, countEntity<TypeFloatInst>(m.get())); + + PassQueue passes; + passes.append(new InPlaceModifyingPass()); + auto m1 = passes.run(m.get()); + + ASSERT_NE(nullptr, m1); + + ASSERT_TRUE(m1->resolveIds()); + + EXPECT_EQ(1, countEntity<IAddInst>(m1)); + EXPECT_EQ(1, countEntity<IMulInst>(m1)); + EXPECT_EQ(2, countEntity<TypeFloatInst>(m1)); +} + +TEST_F(PassQueueTest, testDeletion) { + std::unique_ptr<InputWordStream> IS(InputWordStream::Create(mWordsGreyscale)); + std::unique_ptr<Module> m(Deserialize<Module>(*IS)); + + ASSERT_NE(nullptr, m.get()); + + EXPECT_EQ(1, countEntity<IMulInst>(m.get())); + + PassQueue passes; + passes.append(new Deleter()); + auto m1 = passes.run(m.get()); + + // One of the ids from the input module is missing now. + ASSERT_EQ(nullptr, m1); +} + +TEST_F(PassQueueTest, testMulToAddToDiv) { + std::unique_ptr<InputWordStream> IS(InputWordStream::Create(mWordsGreyscale)); + std::unique_ptr<Module> m(Deserialize<Module>(*IS)); + + ASSERT_NE(nullptr, m); + + EXPECT_EQ(1, countEntity<IAddInst>(m.get())); + EXPECT_EQ(1, countEntity<IMulInst>(m.get())); + + PassQueue passes; + passes.append(new MulToAddTransformer()); + passes.append(new AddToDivTransformer()); + auto m1 = passes.run(m.get()); + + ASSERT_NE(nullptr, m1); + + ASSERT_TRUE(m1->resolveIds()); + + EXPECT_EQ(0, countEntity<IAddInst>(m1)); + EXPECT_EQ(0, countEntity<IMulInst>(m1)); + EXPECT_EQ(2, countEntity<SDivInst>(m1)); +} + +TEST_F(PassQueueTest, testAMix) { + std::unique_ptr<InputWordStream> IS(InputWordStream::Create(mWordsGreyscale)); + std::unique_ptr<Module> m(Deserialize<Module>(*IS)); + + ASSERT_NE(nullptr, m); + + EXPECT_EQ(1, countEntity<IAddInst>(m.get())); + EXPECT_EQ(1, countEntity<IMulInst>(m.get())); + EXPECT_EQ(0, countEntity<SDivInst>(m.get())); + EXPECT_EQ(1, countEntity<TypeFloatInst>(m.get())); + + PassQueue passes; + passes.append(new MulToAddTransformer()); + passes.append(new AddToDivTransformer()); + passes.append(new InPlaceModifyingPass()); + + std::unique_ptr<Module> m1(passes.run(m.get())); + + ASSERT_NE(nullptr, m1); + + ASSERT_TRUE(m1->resolveIds()); + + EXPECT_EQ(0, countEntity<IAddInst>(m1.get())); + EXPECT_EQ(0, countEntity<IMulInst>(m1.get())); + EXPECT_EQ(2, countEntity<SDivInst>(m1.get())); + EXPECT_EQ(2, countEntity<TypeFloatInst>(m1.get())); +} + +TEST_F(PassQueueTest, testAnotherMix) { + std::unique_ptr<InputWordStream> IS(InputWordStream::Create(mWordsGreyscale)); + std::unique_ptr<Module> m(Deserialize<Module>(*IS)); + + ASSERT_NE(nullptr, m); + + EXPECT_EQ(1, countEntity<IAddInst>(m.get())); + EXPECT_EQ(1, countEntity<IMulInst>(m.get())); + EXPECT_EQ(0, countEntity<SDivInst>(m.get())); + EXPECT_EQ(1, countEntity<TypeFloatInst>(m.get())); + + PassQueue passes; + passes.append(new InPlaceModifyingPass()); + passes.append(new MulToAddTransformer()); + passes.append(new AddToDivTransformer()); + auto outputWords = passes.runAndSerialize(m.get()); + + std::unique_ptr<InputWordStream> IS1(InputWordStream::Create(outputWords)); + std::unique_ptr<Module> m1(Deserialize<Module>(*IS1)); + + ASSERT_NE(nullptr, m1); + + ASSERT_TRUE(m1->resolveIds()); + + EXPECT_EQ(0, countEntity<IAddInst>(m1.get())); + EXPECT_EQ(0, countEntity<IMulInst>(m1.get())); + EXPECT_EQ(2, countEntity<SDivInst>(m1.get())); + EXPECT_EQ(2, countEntity<TypeFloatInst>(m1.get())); +} + +TEST_F(PassQueueTest, testMulToAddToDivFromWords) { + PassQueue passes; + passes.append(new MulToAddTransformer()); + passes.append(new AddToDivTransformer()); + auto outputWords = passes.run(std::move(mWordsGreyscale)); + + std::unique_ptr<InputWordStream> IS(InputWordStream::Create(outputWords)); + std::unique_ptr<Module> m1(Deserialize<Module>(*IS)); + + ASSERT_NE(nullptr, m1); + + ASSERT_TRUE(m1->resolveIds()); + + EXPECT_EQ(0, countEntity<IAddInst>(m1.get())); + EXPECT_EQ(0, countEntity<IMulInst>(m1.get())); + EXPECT_EQ(2, countEntity<SDivInst>(m1.get())); +} + +TEST_F(PassQueueTest, testMulToAddToDivToWords) { + std::unique_ptr<InputWordStream> IS(InputWordStream::Create(mWordsGreyscale)); + std::unique_ptr<Module> m(Deserialize<Module>(*IS)); + + ASSERT_NE(nullptr, m); + + EXPECT_EQ(1, countEntity<IAddInst>(m.get())); + EXPECT_EQ(1, countEntity<IMulInst>(m.get())); + + PassQueue passes; + passes.append(new MulToAddTransformer()); + passes.append(new AddToDivTransformer()); + auto outputWords = passes.runAndSerialize(m.get()); + + std::unique_ptr<InputWordStream> IS1(InputWordStream::Create(outputWords)); + std::unique_ptr<Module> m1(Deserialize<Module>(*IS1)); + + ASSERT_NE(nullptr, m1); + + ASSERT_TRUE(m1->resolveIds()); + + EXPECT_EQ(0, countEntity<IAddInst>(m1.get())); + EXPECT_EQ(0, countEntity<IMulInst>(m1.get())); + EXPECT_EQ(2, countEntity<SDivInst>(m1.get())); +} + +TEST_F(PassQueueTest, testAddMulAfterAdd) { + std::unique_ptr<InputWordStream> IS(InputWordStream::Create(mWordsGreyscale)); + std::unique_ptr<Module> m(Deserialize<Module>(*IS)); + + ASSERT_NE(nullptr, m); + + EXPECT_EQ(1, countEntity<IAddInst>(m.get())); + EXPECT_EQ(1, countEntity<IMulInst>(m.get())); + + constexpr int kNumMulToAdd = 100; + + PassQueue passes; + for (int i = 0; i < kNumMulToAdd; i++) { + passes.append(new AddMulAfterAddTransformer()); + } + auto outputWords = passes.runAndSerialize(m.get()); + + std::unique_ptr<InputWordStream> IS1(InputWordStream::Create(outputWords)); + std::unique_ptr<Module> m1(Deserialize<Module>(*IS1)); + + ASSERT_NE(nullptr, m1); + + ASSERT_TRUE(m1->resolveIds()); + + EXPECT_EQ(1, countEntity<IAddInst>(m1.get())); + EXPECT_EQ(1 + kNumMulToAdd, countEntity<IMulInst>(m1.get())); +} + +} // namespace spirit +} // namespace android diff --git a/rsov/compiler/spirit/transformer.cpp b/rsov/compiler/spirit/transformer.cpp index 1a3620c4..4d2ab1f4 100644 --- a/rsov/compiler/spirit/transformer.cpp +++ b/rsov/compiler/spirit/transformer.cpp @@ -21,19 +21,46 @@ namespace android { namespace spirit { -Module *Transformer::applyTo(Module *m) { - // TODO fix Module::accept() to have the header serialization code there - m->SerializeHeader(*mStream); - m->accept(this); - std::unique_ptr<InputWordStream> IS( - InputWordStream::Create(mStream->getWords())); +Module *Transformer::run(Module *module, int *error) { + auto words = runAndSerialize(module, error); + std::unique_ptr<InputWordStream> IS(InputWordStream::Create(words)); return Deserialize<Module>(*IS); } -std::vector<uint32_t> Transformer::transformSerialize(Module *m) { +std::vector<uint32_t> Transformer::runAndSerialize(Module *m, int *error) { + mModule = m; + + // Since contents in the decoration or global section may change, transform + // and serialize the function definitions first. + mVisit = 0; + mShouldRecord = false; + mStream = mStreamFunctions.get(); + m->accept(this); + + // Record in the annotation section any new annotations added + m->consolidateAnnotations(); + + // After the functions are transformed, serialize the other sections to + // capture any changes made during the function transformation, and append + // the new words from function serialization. + + mVisit = 1; + mShouldRecord = true; + mStream = mStreamFinal.get(); + + // TODO fix Module::accept() to have the header serialization code there m->SerializeHeader(*mStream); m->accept(this); - return mStream->getWords(); + + auto output = mStream->getWords(); + auto functions = mStreamFunctions->getWords(); + output.insert(output.end(), functions.begin(), functions.end()); + + if (error) { + *error = 0; + } + + return output; } void Transformer::insert(Instruction *inst) { diff --git a/rsov/compiler/spirit/transformer.h b/rsov/compiler/spirit/transformer.h index be990c0e..e2293f75 100644 --- a/rsov/compiler/spirit/transformer.h +++ b/rsov/compiler/spirit/transformer.h @@ -20,29 +20,52 @@ #include <vector> #include "instructions.h" +#include "pass.h" #include "visitor.h" #include "word_stream.h" namespace android { namespace spirit { -class Transformer : public DoNothingVisitor { +// Transformer is the base class for a transformation that transforms a Module. +// An instance of a derived class can be added to a PassQueue and applied to a +// Module. +class Transformer : public Pass, public DoNothingVisitor { public: - Transformer() : mStream(WordStream::Create()) {} + Transformer() + : mStreamFunctions(WordStream::Create()), + mStreamFinal(WordStream::Create()) {} virtual ~Transformer() {} - Module *applyTo(Module *m); - std::vector<uint32_t> transformSerialize(Module *m); + Module *run(Module *m, int *error = nullptr) override; - // Inserts a new instruction before the current instruction + std::vector<uint32_t> runAndSerialize(Module *module, + int *error = nullptr) override; + + // Returns the module being transformed + Module *getModule() const { return mModule; } + + // Inserts a new instruction before the current instruction. + // Call this from a transform() method in a derived class. void insert(Instruction *); + void visit(FunctionDefinition *fdef) override { + mShouldRecord = (mVisit == 0); + DoNothingVisitor::visit(fdef); + } + + // Transforms the current instruction into a new instruction as specified by + // the return value. If returns nullptr, deletes the current instruction. + // Override this in a derived class for desired behavior. #define HANDLE_INSTRUCTION(OPCODE, INST_CLASS) \ virtual Instruction *transform(INST_CLASS *inst) { \ return static_cast<Instruction *>(inst); \ } \ virtual void visit(INST_CLASS *inst) { \ + if (!mShouldRecord) { \ + return; \ + } \ if (Instruction *transformed = transform(inst)) { \ transformed->Serialize(*mStream); \ } \ @@ -51,7 +74,12 @@ public: #undef HANDLE_INSTRUCTION private: - std::unique_ptr<WordStream> mStream; + Module *mModule; + int mVisit; + bool mShouldRecord; + std::unique_ptr<WordStream> mStreamFunctions; + std::unique_ptr<WordStream> mStreamFinal; + WordStream *mStream; }; } // namespace spirit diff --git a/rsov/compiler/spirit/transformer_test.cpp b/rsov/compiler/spirit/transformer_test.cpp index 8fb422c5..e5f39a60 100644 --- a/rsov/compiler/spirit/transformer_test.cpp +++ b/rsov/compiler/spirit/transformer_test.cpp @@ -16,14 +16,14 @@ #include "transformer.h" -#include <stdint.h> - #include "file_utils.h" #include "spirit.h" #include "test_utils.h" #include "word_stream.h" #include "gtest/gtest.h" +#include <stdint.h> + namespace android { namespace spirit { @@ -43,6 +43,22 @@ public: Instruction *transform(IMulInst *) override { return nullptr; } }; +class NewDataTypeTransformer : public Transformer { +public: + Instruction *transform(IMulInst *mul) override { + insert(mul); + auto *DoubleTy = getModule()->getFloatType(64); + ConstantInst *ConstDouble2 = getModule()->getConstant(DoubleTy, 2.0); + auto ret = new IAddInst(DoubleTy, mul, ConstDouble2); + + IdResult id = ret->getId(); + ret->setId(mul->getId()); + mul->setId(id); + + return ret; + } +}; + } // annonymous namespace class TransformerTest : public ::testing::Test { @@ -70,7 +86,7 @@ TEST_F(TransformerTest, testMulToAdd) { EXPECT_EQ(1, countEntity<IMulInst>(m.get())); MulToAddTransformer trans; - std::unique_ptr<Module> m1(trans.applyTo(m.get())); + std::unique_ptr<Module> m1(trans.run(m.get())); ASSERT_NE(nullptr, m1); @@ -89,7 +105,7 @@ TEST_F(TransformerTest, testDeletion) { EXPECT_EQ(1, countEntity<IMulInst>(m.get())); Deleter trans; - std::unique_ptr<Module> m1(trans.applyTo(m.get())); + std::unique_ptr<Module> m1(trans.run(m.get())); ASSERT_NE(nullptr, m1.get()); @@ -97,5 +113,26 @@ TEST_F(TransformerTest, testDeletion) { EXPECT_EQ(0, countEntity<IMulInst>(m1.get())); } +TEST_F(TransformerTest, testAddInstructionUsingNewDataType) { + std::unique_ptr<InputWordStream> IS(InputWordStream::Create(mWordsGreyscale)); + std::unique_ptr<Module> m(Deserialize<Module>(*IS)); + + ASSERT_NE(nullptr, m.get()); + + EXPECT_EQ(5, countEntity<ConstantInst>(m.get())); + EXPECT_EQ(1, countEntity<TypeFloatInst>(m.get())); + EXPECT_EQ(1, countEntity<IMulInst>(m.get())); + + NewDataTypeTransformer trans; + std::unique_ptr<Module> m1(trans.run(m.get())); + + ASSERT_NE(nullptr, m1.get()); + + EXPECT_EQ(6, countEntity<ConstantInst>(m.get())); + EXPECT_EQ(2, countEntity<TypeFloatInst>(m1.get())); + EXPECT_EQ(2, countEntity<IAddInst>(m1.get())); + EXPECT_EQ(1, countEntity<IMulInst>(m1.get())); +} + } // namespace spirit } // namespace android diff --git a/rsov/compiler/spirit/word_stream.h b/rsov/compiler/spirit/word_stream.h index bbbf4ef6..0c740def 100644 --- a/rsov/compiler/spirit/word_stream.h +++ b/rsov/compiler/spirit/word_stream.h @@ -17,13 +17,14 @@ #ifndef WORD_STREAM_H #define WORD_STREAM_H +#include "core_defs.h" +#include "types_generated.h" + #include <stdint.h> +#include <string> #include <vector> -#include "core_defs.h" -#include "types_generated.h" - namespace android { namespace spirit { @@ -48,7 +49,7 @@ public: virtual InputWordStream &operator>>(uint32_t *RHS) = 0; virtual InputWordStream &operator>>(LiteralContextDependentNumber *num) = 0; - virtual InputWordStream &operator>>(const char **str) = 0; + virtual InputWordStream &operator>>(std::string *str) = 0; InputWordStream &operator>>(int32_t *RHS) { return *this >> (uint32_t *)RHS; } @@ -98,7 +99,7 @@ public: virtual OutputWordStream &operator<<(const uint32_t RHS) = 0; virtual OutputWordStream & operator<<(const LiteralContextDependentNumber &RHS) = 0; - virtual OutputWordStream &operator<<(const char *str) = 0; + virtual OutputWordStream &operator<<(const std::string &str) = 0; OutputWordStream &operator<<(const int32_t RHS) { return *this << (uint32_t)RHS; diff --git a/rsov/compiler/spirit/word_stream_impl.cpp b/rsov/compiler/spirit/word_stream_impl.cpp index 002e6d33..33749e7d 100644 --- a/rsov/compiler/spirit/word_stream_impl.cpp +++ b/rsov/compiler/spirit/word_stream_impl.cpp @@ -27,10 +27,10 @@ WordStreamImpl::WordStreamImpl(const std::vector<uint32_t> &words) WordStreamImpl::WordStreamImpl(std::vector<uint32_t> &&words) : mWords(words), mIter(mWords.begin()) {} -WordStreamImpl &WordStreamImpl::operator<<(const char *str) { - const size_t len = strlen(str); - const uint32_t *begin = (uint32_t *)str; - const uint32_t *end = ((uint32_t *)str) + (len / 4); +WordStreamImpl &WordStreamImpl::operator<<(const std::string &str) { + const size_t len = str.length(); + const uint32_t *begin = (uint32_t *)str.c_str(); + const uint32_t *end = begin + (len / 4); mWords.insert(mWords.end(), begin, end); uint32_t lastWord = *end; @@ -47,10 +47,11 @@ WordStreamImpl &WordStreamImpl::operator<<(const char *str) { return *this; } -WordStreamImpl &WordStreamImpl::operator>>(const char **str) { - *str = (const char *)&*mIter; - while (*mIter++ & 0xFF000000) - ; +WordStreamImpl &WordStreamImpl::operator>>(std::string *str) { + const char *s = (const char *)&*mIter; + str->assign(s); + while (*mIter++ & 0xFF000000) { + } return *this; } diff --git a/rsov/compiler/spirit/word_stream_impl.h b/rsov/compiler/spirit/word_stream_impl.h index ad624caa..365e3d50 100644 --- a/rsov/compiler/spirit/word_stream_impl.h +++ b/rsov/compiler/spirit/word_stream_impl.h @@ -19,6 +19,8 @@ #include "word_stream.h" +#include <string> + namespace android { namespace spirit { @@ -44,7 +46,7 @@ public: return *this >> (uint32_t *)(&RHS->intValue); } - WordStreamImpl &operator>>(const char **str) override; + WordStreamImpl &operator>>(std::string *str) override; std::vector<uint32_t> getWords() override { return mWords; } @@ -61,7 +63,7 @@ public: // double, etc. return *this << (uint32_t)(RHS.intValue); } - WordStreamImpl &operator<<(const char *str) override; + WordStreamImpl &operator<<(const std::string &str) override; private: std::vector<uint32_t> mWords; diff --git a/rsov/compiler/spirit/word_stream_test.cpp b/rsov/compiler/spirit/word_stream_test.cpp index b5049071..2735e15e 100644 --- a/rsov/compiler/spirit/word_stream_test.cpp +++ b/rsov/compiler/spirit/word_stream_test.cpp @@ -14,11 +14,12 @@ * limitations under the License. */ -#include <vector> - #include "word_stream.h" + #include "gtest/gtest.h" +#include <vector> + namespace android { namespace spirit { @@ -50,9 +51,9 @@ TEST(WordStreamTest, testStringInput1) { std::vector<uint32_t> words((uint32_t *)bytes, (uint32_t *)(bytes + sizeof(bytes))); std::unique_ptr<InputWordStream> IS(InputWordStream::Create(words)); - const char *s; + std::string s; *IS >> &s; - EXPECT_STREQ("ABCDEFG", s); + EXPECT_STREQ("ABCDEFG", s.c_str()); } TEST(WordStreamTest, testStringInput2) { @@ -61,9 +62,9 @@ TEST(WordStreamTest, testStringInput2) { std::vector<uint32_t> words((uint32_t *)bytes, (uint32_t *)(bytes + sizeof(bytes))); std::unique_ptr<InputWordStream> IS(InputWordStream::Create(words)); - const char *s; + std::string s; *IS >> &s; - EXPECT_STREQ("GLSL.std.450", s); + EXPECT_STREQ("GLSL.std.450", s.c_str()); } } // namespace spirit |