aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Pass/Pass.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Pass/Pass.cpp')
-rw-r--r--mlir/lib/Pass/Pass.cpp73
1 files changed, 39 insertions, 34 deletions
diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp
index 813b7a8db509..056da035a5b5 100644
--- a/mlir/lib/Pass/Pass.cpp
+++ b/mlir/lib/Pass/Pass.cpp
@@ -12,7 +12,6 @@
#include "mlir/Pass/Pass.h"
#include "PassDetail.h"
-#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Verifier.h"
@@ -528,9 +527,9 @@ void OpToOpPassAdaptor::runOnOperationAsyncImpl(bool verifyPasses) {
asyncExecutors.assign(llvm::hardware_concurrency().compute_thread_count(),
mgrs);
- // Run a prepass over the module to collect the operations to execute over.
- // This ensures that an analysis manager exists for each operation, as well as
- // providing a queue of operations to execute over.
+ // Run a prepass over the operation to collect the nested operations to
+ // execute over. This ensures that an analysis manager exists for each
+ // operation, as well as providing a queue of operations to execute over.
std::vector<std::pair<Operation *, AnalysisManager>> opAMPairs;
for (auto &region : getOperation()->getRegions()) {
for (auto &block : region) {
@@ -614,7 +613,7 @@ namespace {
/// reproducers when a signal is raised, such as a segfault.
struct RecoveryReproducerContext {
RecoveryReproducerContext(MutableArrayRef<std::unique_ptr<Pass>> passes,
- ModuleOp module, StringRef filename,
+ Operation *op, StringRef filename,
bool disableThreads, bool verifyPasses);
~RecoveryReproducerContext();
@@ -631,8 +630,8 @@ private:
/// The textual description of the currently executing pipeline.
std::string pipeline;
- /// The MLIR module representing the IR before the crash.
- OwningModuleRef module;
+ /// The MLIR operation representing the IR before the crash.
+ Operation *preCrashOperation;
/// The filename to use when generating the reproducer.
StringRef filename;
@@ -658,9 +657,9 @@ llvm::ManagedStatic<llvm::SmallSetVector<RecoveryReproducerContext *, 1>>
RecoveryReproducerContext::reproducerSet;
RecoveryReproducerContext::RecoveryReproducerContext(
- MutableArrayRef<std::unique_ptr<Pass>> passes, ModuleOp module,
+ MutableArrayRef<std::unique_ptr<Pass>> passes, Operation *op,
StringRef filename, bool disableThreads, bool verifyPasses)
- : module(module.clone()), filename(filename),
+ : preCrashOperation(op->clone()), filename(filename),
disableThreads(disableThreads), verifyPasses(verifyPasses) {
// Grab the textual pipeline being executed..
{
@@ -677,6 +676,9 @@ RecoveryReproducerContext::RecoveryReproducerContext(
}
RecoveryReproducerContext::~RecoveryReproducerContext() {
+ // Erase the cloned preCrash IR that we cached.
+ preCrashOperation->erase();
+
llvm::sys::SmartScopedLock<true> producerLock(*reproducerMutex);
reproducerSet->remove(this);
if (reproducerSet->empty())
@@ -700,7 +702,7 @@ LogicalResult RecoveryReproducerContext::generate(std::string &error) {
<< "\n";
// Output the .mlir module.
- module->print(outputOS);
+ preCrashOperation->print(outputOS);
outputFile->keep();
return success();
}
@@ -722,11 +724,11 @@ void RecoveryReproducerContext::registerSignalHandler() {
}
/// Run the pass manager with crash recover enabled.
-LogicalResult PassManager::runWithCrashRecovery(ModuleOp module,
+LogicalResult PassManager::runWithCrashRecovery(Operation *op,
AnalysisManager am) {
// If this isn't a local producer, run all of the passes in recovery mode.
if (!localReproducer)
- return runWithCrashRecovery(impl->passes, module, am);
+ return runWithCrashRecovery(impl->passes, op, am);
// Split the passes within adaptors to ensure that each pass can be run in
// isolation.
@@ -735,7 +737,7 @@ LogicalResult PassManager::runWithCrashRecovery(ModuleOp module,
// If this is a local producer, run each of the passes individually.
MutableArrayRef<std::unique_ptr<Pass>> passes = impl->passes;
for (std::unique_ptr<Pass> &pass : passes)
- if (failed(runWithCrashRecovery(pass, module, am)))
+ if (failed(runWithCrashRecovery(pass, op, am)))
return failure();
return success();
}
@@ -743,8 +745,8 @@ LogicalResult PassManager::runWithCrashRecovery(ModuleOp module,
/// Run the given passes with crash recover enabled.
LogicalResult
PassManager::runWithCrashRecovery(MutableArrayRef<std::unique_ptr<Pass>> passes,
- ModuleOp module, AnalysisManager am) {
- RecoveryReproducerContext context(passes, module, *crashReproducerFileName,
+ Operation *op, AnalysisManager am) {
+ RecoveryReproducerContext context(passes, op, *crashReproducerFileName,
!getContext()->isMultithreadingEnabled(),
verifyPasses);
@@ -753,7 +755,7 @@ PassManager::runWithCrashRecovery(MutableArrayRef<std::unique_ptr<Pass>> passes,
llvm::CrashRecoveryContext recoveryContext;
recoveryContext.RunSafelyOnThread([&] {
for (std::unique_ptr<Pass> &pass : passes)
- if (failed(OpToOpPassAdaptor::run(pass.get(), module, am, verifyPasses)))
+ if (failed(OpToOpPassAdaptor::run(pass.get(), op, am, verifyPasses)))
return;
passManagerResult = success();
});
@@ -762,8 +764,8 @@ PassManager::runWithCrashRecovery(MutableArrayRef<std::unique_ptr<Pass>> passes,
std::string error;
if (failed(context.generate(error)))
- return module.emitError("<MLIR-PassManager-Crash-Reproducer>: ") << error;
- return module.emitError()
+ return op->emitError("<MLIR-PassManager-Crash-Reproducer>: ") << error;
+ return op->emitError()
<< "A failure has been detected while processing the MLIR module, a "
"reproducer has been generated in '"
<< *crashReproducerFileName << "'";
@@ -773,18 +775,21 @@ PassManager::runWithCrashRecovery(MutableArrayRef<std::unique_ptr<Pass>> passes,
// PassManager
//===----------------------------------------------------------------------===//
-PassManager::PassManager(MLIRContext *ctx, Nesting nesting)
- : OpPassManager(Identifier::get(ModuleOp::getOperationName(), ctx),
- nesting),
- context(ctx), passTiming(false), localReproducer(false),
- verifyPasses(true) {}
+PassManager::PassManager(MLIRContext *ctx, Nesting nesting,
+ StringRef operationName)
+ : OpPassManager(Identifier::get(operationName, ctx), nesting), context(ctx),
+ passTiming(false), localReproducer(false), verifyPasses(true) {}
PassManager::~PassManager() {}
void PassManager::enableVerifier(bool enabled) { verifyPasses = enabled; }
-/// Run the passes within this manager on the provided module.
-LogicalResult PassManager::run(ModuleOp module) {
+/// Run the passes within this manager on the provided operation.
+LogicalResult PassManager::run(Operation *op) {
+ MLIRContext *context = getContext();
+ assert(op->getName().getIdentifier() == getOpName(*context) &&
+ "operation has a different name than the PassManager");
+
// Before running, make sure to coalesce any adjacent pass adaptors in the
// pipeline.
getImpl().coalesceAdjacentAdaptorPasses();
@@ -792,23 +797,23 @@ LogicalResult PassManager::run(ModuleOp module) {
// Register all dialects for the current pipeline.
DialectRegistry dependentDialects;
getDependentDialects(dependentDialects);
- dependentDialects.loadAll(module.getContext());
+ dependentDialects.loadAll(context);
- // Construct an analysis manager for the pipeline.
- ModuleAnalysisManager am(module, instrumentor.get());
+ // Construct a top level analysis manager for the pipeline.
+ ModuleAnalysisManager am(op, instrumentor.get());
// Notify the context that we start running a pipeline for book keeping.
- module.getContext()->enterMultiThreadedExecution();
+ context->enterMultiThreadedExecution();
// If reproducer generation is enabled, run the pass manager with crash
// handling enabled.
- LogicalResult result = crashReproducerFileName
- ? runWithCrashRecovery(module, am)
- : OpToOpPassAdaptor::runPipeline(
- getPasses(), module, am, verifyPasses);
+ LogicalResult result =
+ crashReproducerFileName
+ ? runWithCrashRecovery(op, am)
+ : OpToOpPassAdaptor::runPipeline(getPasses(), op, am, verifyPasses);
// Notify the context that the run is done.
- module.getContext()->exitMultiThreadedExecution();
+ context->exitMultiThreadedExecution();
// Dump all of the pass statistics if necessary.
if (passStatisticsMode)