aboutsummaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
Diffstat (limited to 'mlir')
-rw-r--r--mlir/include/mlir/Pass/AnalysisManager.h16
-rw-r--r--mlir/include/mlir/Pass/PassManager.h19
-rw-r--r--mlir/lib/Pass/IRPrinting.cpp19
-rw-r--r--mlir/lib/Pass/Pass.cpp73
-rw-r--r--mlir/lib/Pass/PassManagerOptions.cpp2
5 files changed, 65 insertions, 64 deletions
diff --git a/mlir/include/mlir/Pass/AnalysisManager.h b/mlir/include/mlir/Pass/AnalysisManager.h
index de428da5abd4..ec6b7696ce60 100644
--- a/mlir/include/mlir/Pass/AnalysisManager.h
+++ b/mlir/include/mlir/Pass/AnalysisManager.h
@@ -9,7 +9,7 @@
#ifndef MLIR_PASS_ANALYSISMANAGER_H
#define MLIR_PASS_ANALYSISMANAGER_H
-#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Operation.h"
#include "mlir/Pass/PassInstrumentation.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/DenseMap.h"
@@ -177,8 +177,8 @@ private:
bool wasInserted;
std::tie(it, wasInserted) = analyses.try_emplace(id);
- // If we don't have a cached analysis for this function, compute it directly
- // and add it to the cache.
+ // If we don't have a cached analysis for this operation, compute it
+ // directly and add it to the cache.
if (wasInserted) {
if (pi)
pi->runBeforeAnalysis(getAnalysisName<AnalysisT>(), id, ir);
@@ -321,14 +321,14 @@ private:
friend class ModuleAnalysisManager;
};
-/// An analysis manager class specifically for the top-level module operation.
-/// This class contains the memory allocations for all nested analysis managers,
-/// and provides an anchor point. This is necessary because AnalysisManager is
+/// An analysis manager class specifically for the top-level operation. This
+/// class contains the memory allocations for all nested analysis managers, and
+/// provides an anchor point. This is necessary because AnalysisManager is
/// designed to be a thin wrapper around an existing analysis map instance.
class ModuleAnalysisManager {
public:
- ModuleAnalysisManager(ModuleOp module, PassInstrumentor *passInstrumentor)
- : analyses(module), passInstrumentor(passInstrumentor) {}
+ ModuleAnalysisManager(Operation *op, PassInstrumentor *passInstrumentor)
+ : analyses(op), passInstrumentor(passInstrumentor) {}
ModuleAnalysisManager(const ModuleAnalysisManager &) = delete;
ModuleAnalysisManager &operator=(const ModuleAnalysisManager &) = delete;
diff --git a/mlir/include/mlir/Pass/PassManager.h b/mlir/include/mlir/Pass/PassManager.h
index eb21359d6211..5e9c9a790d29 100644
--- a/mlir/include/mlir/Pass/PassManager.h
+++ b/mlir/include/mlir/Pass/PassManager.h
@@ -28,7 +28,6 @@ namespace mlir {
class AnalysisManager;
class Identifier;
class MLIRContext;
-class ModuleOp;
class Operation;
class Pass;
class PassInstrumentation;
@@ -158,12 +157,20 @@ enum class PassDisplayMode {
/// The main pass manager and pipeline builder.
class PassManager : public OpPassManager {
public:
- PassManager(MLIRContext *ctx, Nesting nesting = Nesting::Explicit);
+ /// Create a new pass manager under the given context with a specific nesting
+ /// style. The created pass manager can schedule operations that match
+ /// `operationName`.
+ PassManager(MLIRContext *ctx, Nesting nesting = Nesting::Explicit,
+ StringRef operationName = "module");
+ PassManager(MLIRContext *ctx, StringRef operationName)
+ : PassManager(ctx, Nesting::Explicit, operationName) {}
~PassManager();
- /// Run the passes within this manager on the provided module.
+ /// Run the passes within this manager on the provided operation. The
+ /// specified operation must have the same name as the one provided the pass
+ /// manager on construction.
LLVM_NODISCARD
- LogicalResult run(ModuleOp module);
+ LogicalResult run(Operation *op);
/// Return an instance of the context.
MLIRContext *getContext() const { return context; }
@@ -318,11 +325,11 @@ private:
void dumpStatistics();
/// Run the pass manager with crash recover enabled.
- LogicalResult runWithCrashRecovery(ModuleOp module, AnalysisManager am);
+ LogicalResult runWithCrashRecovery(Operation *op, AnalysisManager am);
/// Run the given passes with crash recover enabled.
LogicalResult
runWithCrashRecovery(MutableArrayRef<std::unique_ptr<Pass>> passes,
- ModuleOp module, AnalysisManager am);
+ Operation *op, AnalysisManager am);
/// Context this PassManager was initialized with.
MLIRContext *context;
diff --git a/mlir/lib/Pass/IRPrinting.cpp b/mlir/lib/Pass/IRPrinting.cpp
index 2f6c3a2a5af4..b27b39dd322d 100644
--- a/mlir/lib/Pass/IRPrinting.cpp
+++ b/mlir/lib/Pass/IRPrinting.cpp
@@ -7,7 +7,6 @@
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
-#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/PassManager.h"
#include "llvm/Support/Format.h"
#include "llvm/Support/FormatVariadic.h"
@@ -97,14 +96,10 @@ private:
static void printIR(Operation *op, bool printModuleScope, raw_ostream &out,
OpPrintingFlags flags) {
- // Check to see if we are printing the top-level module.
- auto module = dyn_cast<ModuleOp>(op);
- if (module && !op->getBlock())
- return module.print(out << "\n", flags);
-
// Otherwise, check to see if we are not printing at module scope.
if (!printModuleScope)
- return op->print(out << "\n", flags.useLocalScope());
+ return op->print(out << "\n",
+ op->getBlock() ? flags.useLocalScope() : flags);
// Otherwise, we are printing at module scope.
out << " ('" << op->getName() << "' operation";
@@ -113,17 +108,11 @@ static void printIR(Operation *op, bool printModuleScope, raw_ostream &out,
out << ": @" << symbolName.getValue();
out << ")\n";
- // Find the top-level module operation.
+ // Find the top-level operation.
auto *topLevelOp = op;
while (auto *parentOp = topLevelOp->getParentOp())
topLevelOp = parentOp;
-
- // Check to see if the top-level operation is actually a module in the case of
- // invalid-ir.
- if (auto module = dyn_cast<ModuleOp>(topLevelOp))
- module.print(out, flags);
- else
- topLevelOp->print(out, flags);
+ topLevelOp->print(out, flags);
}
/// Instrumentation hooks.
diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp
index 813b7a8db509..056da035a5b5 100644
--- a/mlir/lib/Pass/Pass.cpp
+++ b/mlir/lib/Pass/Pass.cpp
@@ -12,7 +12,6 @@
#include "mlir/Pass/Pass.h"
#include "PassDetail.h"
-#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Verifier.h"
@@ -528,9 +527,9 @@ void OpToOpPassAdaptor::runOnOperationAsyncImpl(bool verifyPasses) {
asyncExecutors.assign(llvm::hardware_concurrency().compute_thread_count(),
mgrs);
- // Run a prepass over the module to collect the operations to execute over.
- // This ensures that an analysis manager exists for each operation, as well as
- // providing a queue of operations to execute over.
+ // Run a prepass over the operation to collect the nested operations to
+ // execute over. This ensures that an analysis manager exists for each
+ // operation, as well as providing a queue of operations to execute over.
std::vector<std::pair<Operation *, AnalysisManager>> opAMPairs;
for (auto &region : getOperation()->getRegions()) {
for (auto &block : region) {
@@ -614,7 +613,7 @@ namespace {
/// reproducers when a signal is raised, such as a segfault.
struct RecoveryReproducerContext {
RecoveryReproducerContext(MutableArrayRef<std::unique_ptr<Pass>> passes,
- ModuleOp module, StringRef filename,
+ Operation *op, StringRef filename,
bool disableThreads, bool verifyPasses);
~RecoveryReproducerContext();
@@ -631,8 +630,8 @@ private:
/// The textual description of the currently executing pipeline.
std::string pipeline;
- /// The MLIR module representing the IR before the crash.
- OwningModuleRef module;
+ /// The MLIR operation representing the IR before the crash.
+ Operation *preCrashOperation;
/// The filename to use when generating the reproducer.
StringRef filename;
@@ -658,9 +657,9 @@ llvm::ManagedStatic<llvm::SmallSetVector<RecoveryReproducerContext *, 1>>
RecoveryReproducerContext::reproducerSet;
RecoveryReproducerContext::RecoveryReproducerContext(
- MutableArrayRef<std::unique_ptr<Pass>> passes, ModuleOp module,
+ MutableArrayRef<std::unique_ptr<Pass>> passes, Operation *op,
StringRef filename, bool disableThreads, bool verifyPasses)
- : module(module.clone()), filename(filename),
+ : preCrashOperation(op->clone()), filename(filename),
disableThreads(disableThreads), verifyPasses(verifyPasses) {
// Grab the textual pipeline being executed..
{
@@ -677,6 +676,9 @@ RecoveryReproducerContext::RecoveryReproducerContext(
}
RecoveryReproducerContext::~RecoveryReproducerContext() {
+ // Erase the cloned preCrash IR that we cached.
+ preCrashOperation->erase();
+
llvm::sys::SmartScopedLock<true> producerLock(*reproducerMutex);
reproducerSet->remove(this);
if (reproducerSet->empty())
@@ -700,7 +702,7 @@ LogicalResult RecoveryReproducerContext::generate(std::string &error) {
<< "\n";
// Output the .mlir module.
- module->print(outputOS);
+ preCrashOperation->print(outputOS);
outputFile->keep();
return success();
}
@@ -722,11 +724,11 @@ void RecoveryReproducerContext::registerSignalHandler() {
}
/// Run the pass manager with crash recover enabled.
-LogicalResult PassManager::runWithCrashRecovery(ModuleOp module,
+LogicalResult PassManager::runWithCrashRecovery(Operation *op,
AnalysisManager am) {
// If this isn't a local producer, run all of the passes in recovery mode.
if (!localReproducer)
- return runWithCrashRecovery(impl->passes, module, am);
+ return runWithCrashRecovery(impl->passes, op, am);
// Split the passes within adaptors to ensure that each pass can be run in
// isolation.
@@ -735,7 +737,7 @@ LogicalResult PassManager::runWithCrashRecovery(ModuleOp module,
// If this is a local producer, run each of the passes individually.
MutableArrayRef<std::unique_ptr<Pass>> passes = impl->passes;
for (std::unique_ptr<Pass> &pass : passes)
- if (failed(runWithCrashRecovery(pass, module, am)))
+ if (failed(runWithCrashRecovery(pass, op, am)))
return failure();
return success();
}
@@ -743,8 +745,8 @@ LogicalResult PassManager::runWithCrashRecovery(ModuleOp module,
/// Run the given passes with crash recover enabled.
LogicalResult
PassManager::runWithCrashRecovery(MutableArrayRef<std::unique_ptr<Pass>> passes,
- ModuleOp module, AnalysisManager am) {
- RecoveryReproducerContext context(passes, module, *crashReproducerFileName,
+ Operation *op, AnalysisManager am) {
+ RecoveryReproducerContext context(passes, op, *crashReproducerFileName,
!getContext()->isMultithreadingEnabled(),
verifyPasses);
@@ -753,7 +755,7 @@ PassManager::runWithCrashRecovery(MutableArrayRef<std::unique_ptr<Pass>> passes,
llvm::CrashRecoveryContext recoveryContext;
recoveryContext.RunSafelyOnThread([&] {
for (std::unique_ptr<Pass> &pass : passes)
- if (failed(OpToOpPassAdaptor::run(pass.get(), module, am, verifyPasses)))
+ if (failed(OpToOpPassAdaptor::run(pass.get(), op, am, verifyPasses)))
return;
passManagerResult = success();
});
@@ -762,8 +764,8 @@ PassManager::runWithCrashRecovery(MutableArrayRef<std::unique_ptr<Pass>> passes,
std::string error;
if (failed(context.generate(error)))
- return module.emitError("<MLIR-PassManager-Crash-Reproducer>: ") << error;
- return module.emitError()
+ return op->emitError("<MLIR-PassManager-Crash-Reproducer>: ") << error;
+ return op->emitError()
<< "A failure has been detected while processing the MLIR module, a "
"reproducer has been generated in '"
<< *crashReproducerFileName << "'";
@@ -773,18 +775,21 @@ PassManager::runWithCrashRecovery(MutableArrayRef<std::unique_ptr<Pass>> passes,
// PassManager
//===----------------------------------------------------------------------===//
-PassManager::PassManager(MLIRContext *ctx, Nesting nesting)
- : OpPassManager(Identifier::get(ModuleOp::getOperationName(), ctx),
- nesting),
- context(ctx), passTiming(false), localReproducer(false),
- verifyPasses(true) {}
+PassManager::PassManager(MLIRContext *ctx, Nesting nesting,
+ StringRef operationName)
+ : OpPassManager(Identifier::get(operationName, ctx), nesting), context(ctx),
+ passTiming(false), localReproducer(false), verifyPasses(true) {}
PassManager::~PassManager() {}
void PassManager::enableVerifier(bool enabled) { verifyPasses = enabled; }
-/// Run the passes within this manager on the provided module.
-LogicalResult PassManager::run(ModuleOp module) {
+/// Run the passes within this manager on the provided operation.
+LogicalResult PassManager::run(Operation *op) {
+ MLIRContext *context = getContext();
+ assert(op->getName().getIdentifier() == getOpName(*context) &&
+ "operation has a different name than the PassManager");
+
// Before running, make sure to coalesce any adjacent pass adaptors in the
// pipeline.
getImpl().coalesceAdjacentAdaptorPasses();
@@ -792,23 +797,23 @@ LogicalResult PassManager::run(ModuleOp module) {
// Register all dialects for the current pipeline.
DialectRegistry dependentDialects;
getDependentDialects(dependentDialects);
- dependentDialects.loadAll(module.getContext());
+ dependentDialects.loadAll(context);
- // Construct an analysis manager for the pipeline.
- ModuleAnalysisManager am(module, instrumentor.get());
+ // Construct a top level analysis manager for the pipeline.
+ ModuleAnalysisManager am(op, instrumentor.get());
// Notify the context that we start running a pipeline for book keeping.
- module.getContext()->enterMultiThreadedExecution();
+ context->enterMultiThreadedExecution();
// If reproducer generation is enabled, run the pass manager with crash
// handling enabled.
- LogicalResult result = crashReproducerFileName
- ? runWithCrashRecovery(module, am)
- : OpToOpPassAdaptor::runPipeline(
- getPasses(), module, am, verifyPasses);
+ LogicalResult result =
+ crashReproducerFileName
+ ? runWithCrashRecovery(op, am)
+ : OpToOpPassAdaptor::runPipeline(getPasses(), op, am, verifyPasses);
// Notify the context that the run is done.
- module.getContext()->exitMultiThreadedExecution();
+ context->exitMultiThreadedExecution();
// Dump all of the pass statistics if necessary.
if (passStatisticsMode)
diff --git a/mlir/lib/Pass/PassManagerOptions.cpp b/mlir/lib/Pass/PassManagerOptions.cpp
index b00f992eceb9..a581ce070fc4 100644
--- a/mlir/lib/Pass/PassManagerOptions.cpp
+++ b/mlir/lib/Pass/PassManagerOptions.cpp
@@ -50,7 +50,7 @@ struct PassManagerOptions {
llvm::cl::opt<bool> printModuleScope{
"print-ir-module-scope",
llvm::cl::desc("When printing IR for print-ir-[before|after]{-all} "
- "always print the top-level module operation"),
+ "always print the top-level operation"),
llvm::cl::init(false)};
/// Add an IR printing instrumentation if enabled by any 'print-ir' flags.