diff options
Diffstat (limited to 'mlir/lib/Pass/Pass.cpp')
-rw-r--r-- | mlir/lib/Pass/Pass.cpp | 73 |
1 files changed, 39 insertions, 34 deletions
diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp index 813b7a8db509..056da035a5b5 100644 --- a/mlir/lib/Pass/Pass.cpp +++ b/mlir/lib/Pass/Pass.cpp @@ -12,7 +12,6 @@ #include "mlir/Pass/Pass.h" #include "PassDetail.h" -#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Verifier.h" @@ -528,9 +527,9 @@ void OpToOpPassAdaptor::runOnOperationAsyncImpl(bool verifyPasses) { asyncExecutors.assign(llvm::hardware_concurrency().compute_thread_count(), mgrs); - // Run a prepass over the module to collect the operations to execute over. - // This ensures that an analysis manager exists for each operation, as well as - // providing a queue of operations to execute over. + // Run a prepass over the operation to collect the nested operations to + // execute over. This ensures that an analysis manager exists for each + // operation, as well as providing a queue of operations to execute over. std::vector<std::pair<Operation *, AnalysisManager>> opAMPairs; for (auto ®ion : getOperation()->getRegions()) { for (auto &block : region) { @@ -614,7 +613,7 @@ namespace { /// reproducers when a signal is raised, such as a segfault. struct RecoveryReproducerContext { RecoveryReproducerContext(MutableArrayRef<std::unique_ptr<Pass>> passes, - ModuleOp module, StringRef filename, + Operation *op, StringRef filename, bool disableThreads, bool verifyPasses); ~RecoveryReproducerContext(); @@ -631,8 +630,8 @@ private: /// The textual description of the currently executing pipeline. std::string pipeline; - /// The MLIR module representing the IR before the crash. - OwningModuleRef module; + /// The MLIR operation representing the IR before the crash. + Operation *preCrashOperation; /// The filename to use when generating the reproducer. StringRef filename; @@ -658,9 +657,9 @@ llvm::ManagedStatic<llvm::SmallSetVector<RecoveryReproducerContext *, 1>> RecoveryReproducerContext::reproducerSet; RecoveryReproducerContext::RecoveryReproducerContext( - MutableArrayRef<std::unique_ptr<Pass>> passes, ModuleOp module, + MutableArrayRef<std::unique_ptr<Pass>> passes, Operation *op, StringRef filename, bool disableThreads, bool verifyPasses) - : module(module.clone()), filename(filename), + : preCrashOperation(op->clone()), filename(filename), disableThreads(disableThreads), verifyPasses(verifyPasses) { // Grab the textual pipeline being executed.. { @@ -677,6 +676,9 @@ RecoveryReproducerContext::RecoveryReproducerContext( } RecoveryReproducerContext::~RecoveryReproducerContext() { + // Erase the cloned preCrash IR that we cached. + preCrashOperation->erase(); + llvm::sys::SmartScopedLock<true> producerLock(*reproducerMutex); reproducerSet->remove(this); if (reproducerSet->empty()) @@ -700,7 +702,7 @@ LogicalResult RecoveryReproducerContext::generate(std::string &error) { << "\n"; // Output the .mlir module. - module->print(outputOS); + preCrashOperation->print(outputOS); outputFile->keep(); return success(); } @@ -722,11 +724,11 @@ void RecoveryReproducerContext::registerSignalHandler() { } /// Run the pass manager with crash recover enabled. -LogicalResult PassManager::runWithCrashRecovery(ModuleOp module, +LogicalResult PassManager::runWithCrashRecovery(Operation *op, AnalysisManager am) { // If this isn't a local producer, run all of the passes in recovery mode. if (!localReproducer) - return runWithCrashRecovery(impl->passes, module, am); + return runWithCrashRecovery(impl->passes, op, am); // Split the passes within adaptors to ensure that each pass can be run in // isolation. @@ -735,7 +737,7 @@ LogicalResult PassManager::runWithCrashRecovery(ModuleOp module, // If this is a local producer, run each of the passes individually. MutableArrayRef<std::unique_ptr<Pass>> passes = impl->passes; for (std::unique_ptr<Pass> &pass : passes) - if (failed(runWithCrashRecovery(pass, module, am))) + if (failed(runWithCrashRecovery(pass, op, am))) return failure(); return success(); } @@ -743,8 +745,8 @@ LogicalResult PassManager::runWithCrashRecovery(ModuleOp module, /// Run the given passes with crash recover enabled. LogicalResult PassManager::runWithCrashRecovery(MutableArrayRef<std::unique_ptr<Pass>> passes, - ModuleOp module, AnalysisManager am) { - RecoveryReproducerContext context(passes, module, *crashReproducerFileName, + Operation *op, AnalysisManager am) { + RecoveryReproducerContext context(passes, op, *crashReproducerFileName, !getContext()->isMultithreadingEnabled(), verifyPasses); @@ -753,7 +755,7 @@ PassManager::runWithCrashRecovery(MutableArrayRef<std::unique_ptr<Pass>> passes, llvm::CrashRecoveryContext recoveryContext; recoveryContext.RunSafelyOnThread([&] { for (std::unique_ptr<Pass> &pass : passes) - if (failed(OpToOpPassAdaptor::run(pass.get(), module, am, verifyPasses))) + if (failed(OpToOpPassAdaptor::run(pass.get(), op, am, verifyPasses))) return; passManagerResult = success(); }); @@ -762,8 +764,8 @@ PassManager::runWithCrashRecovery(MutableArrayRef<std::unique_ptr<Pass>> passes, std::string error; if (failed(context.generate(error))) - return module.emitError("<MLIR-PassManager-Crash-Reproducer>: ") << error; - return module.emitError() + return op->emitError("<MLIR-PassManager-Crash-Reproducer>: ") << error; + return op->emitError() << "A failure has been detected while processing the MLIR module, a " "reproducer has been generated in '" << *crashReproducerFileName << "'"; @@ -773,18 +775,21 @@ PassManager::runWithCrashRecovery(MutableArrayRef<std::unique_ptr<Pass>> passes, // PassManager //===----------------------------------------------------------------------===// -PassManager::PassManager(MLIRContext *ctx, Nesting nesting) - : OpPassManager(Identifier::get(ModuleOp::getOperationName(), ctx), - nesting), - context(ctx), passTiming(false), localReproducer(false), - verifyPasses(true) {} +PassManager::PassManager(MLIRContext *ctx, Nesting nesting, + StringRef operationName) + : OpPassManager(Identifier::get(operationName, ctx), nesting), context(ctx), + passTiming(false), localReproducer(false), verifyPasses(true) {} PassManager::~PassManager() {} void PassManager::enableVerifier(bool enabled) { verifyPasses = enabled; } -/// Run the passes within this manager on the provided module. -LogicalResult PassManager::run(ModuleOp module) { +/// Run the passes within this manager on the provided operation. +LogicalResult PassManager::run(Operation *op) { + MLIRContext *context = getContext(); + assert(op->getName().getIdentifier() == getOpName(*context) && + "operation has a different name than the PassManager"); + // Before running, make sure to coalesce any adjacent pass adaptors in the // pipeline. getImpl().coalesceAdjacentAdaptorPasses(); @@ -792,23 +797,23 @@ LogicalResult PassManager::run(ModuleOp module) { // Register all dialects for the current pipeline. DialectRegistry dependentDialects; getDependentDialects(dependentDialects); - dependentDialects.loadAll(module.getContext()); + dependentDialects.loadAll(context); - // Construct an analysis manager for the pipeline. - ModuleAnalysisManager am(module, instrumentor.get()); + // Construct a top level analysis manager for the pipeline. + ModuleAnalysisManager am(op, instrumentor.get()); // Notify the context that we start running a pipeline for book keeping. - module.getContext()->enterMultiThreadedExecution(); + context->enterMultiThreadedExecution(); // If reproducer generation is enabled, run the pass manager with crash // handling enabled. - LogicalResult result = crashReproducerFileName - ? runWithCrashRecovery(module, am) - : OpToOpPassAdaptor::runPipeline( - getPasses(), module, am, verifyPasses); + LogicalResult result = + crashReproducerFileName + ? runWithCrashRecovery(op, am) + : OpToOpPassAdaptor::runPipeline(getPasses(), op, am, verifyPasses); // Notify the context that the run is done. - module.getContext()->exitMultiThreadedExecution(); + context->exitMultiThreadedExecution(); // Dump all of the pass statistics if necessary. if (passStatisticsMode) |