aboutsummaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
authorRahul Joshi <jurahul@google.com>2020-12-03 16:18:11 -0800
committerRahul Joshi <jurahul@google.com>2020-12-04 09:05:53 -0800
commit245233423e466979e11b39cbed676903892d07f8 (patch)
tree5f534c4512a3ad23661111f99c968d0b2479f882 /mlir
parent7f6f9f4cf966c78a315d15d6e913c43cfa45c47c (diff)
downloadllvm-project-245233423e466979e11b39cbed676903892d07f8.tar.gz
[MLIR] Generate inferReturnTypes declaration using InferTypeOpInterface trait.
- Instead of hardcoding the parameters and return types of 'inferReturnTypes', use the InferTypeOpInterface trait to generate the method declaration. - Fix InferTypeOfInterface to use fully qualified type for inferReturnTypes results. Differential Revision: https://reviews.llvm.org/D92585
Diffstat (limited to 'mlir')
-rw-r--r--mlir/include/mlir/Interfaces/InferTypeOpInterface.td2
-rw-r--r--mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp65
2 files changed, 40 insertions, 27 deletions
diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
index ed62e9015bde..9de087b1b4ca 100644
--- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
+++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
@@ -36,7 +36,7 @@ def InferTypeOpInterface : OpInterface<"InferTypeOpInterface"> {
which an Operation would be created (e.g., as used in Operation::create)
and the regions of the op.
}],
- /*retTy=*/"LogicalResult",
+ /*retTy=*/"::mlir::LogicalResult",
/*methodName=*/"inferReturnTypes",
/*args=*/(ins "::mlir::MLIRContext *":$context,
"::llvm::Optional<::mlir::Location>":$location,
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index c96fde648eb2..ccfb13fa3436 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -290,11 +290,16 @@ private:
// Generates the traits used by the object.
void genTraits();
- // Generate the OpInterface methods.
+ // Generate the OpInterface methods for all interfaces.
void genOpInterfaceMethods();
- // Generate op interface method.
- void genOpInterfaceMethod(const tblgen::InterfaceOpTrait *trait);
+ // Generate op interface methods for the given interface.
+ void genOpInterfaceMethods(const tblgen::InterfaceOpTrait *trait);
+
+ // Generate op interface method for the given interface method. If
+ // 'declaration' is true, generates a declaration, else a definition.
+ OpMethod *genOpInterfaceMethod(const tblgen::InterfaceMethod &method,
+ bool declaration = true);
// Generate the side effect interface methods.
void genSideEffectInterfaceMethods();
@@ -1588,7 +1593,7 @@ void OpEmitter::genFolderDecls() {
}
}
-void OpEmitter::genOpInterfaceMethod(const tblgen::InterfaceOpTrait *opTrait) {
+void OpEmitter::genOpInterfaceMethods(const tblgen::InterfaceOpTrait *opTrait) {
auto interface = opTrait->getOpInterface();
// Get the set of methods that should always be declared.
@@ -1606,23 +1611,29 @@ void OpEmitter::genOpInterfaceMethod(const tblgen::InterfaceOpTrait *opTrait) {
if (method.getDefaultImplementation() &&
!alwaysDeclaredMethods.count(method.getName()))
continue;
-
- SmallVector<OpMethodParameter, 4> paramList;
- for (const InterfaceMethod::Argument &arg : method.getArguments())
- paramList.emplace_back(arg.type, arg.name);
-
- auto properties = method.isStatic() ? OpMethod::MP_StaticDeclaration
- : OpMethod::MP_Declaration;
- opClass.addMethodAndPrune(method.getReturnType(), method.getName(),
- properties, std::move(paramList));
+ genOpInterfaceMethod(method);
}
}
+OpMethod *OpEmitter::genOpInterfaceMethod(const InterfaceMethod &method,
+ bool declaration) {
+ SmallVector<OpMethodParameter, 4> paramList;
+ for (const InterfaceMethod::Argument &arg : method.getArguments())
+ paramList.emplace_back(arg.type, arg.name);
+
+ auto properties = method.isStatic() ? OpMethod::MP_Static : OpMethod::MP_None;
+ if (declaration)
+ properties =
+ static_cast<OpMethod::Property>(properties | OpMethod::MP_Declaration);
+ return opClass.addMethodAndPrune(method.getReturnType(), method.getName(),
+ properties, std::move(paramList));
+}
+
void OpEmitter::genOpInterfaceMethods() {
for (const auto &trait : op.getTraits()) {
if (const auto *opTrait = dyn_cast<tblgen::InterfaceOpTrait>(&trait))
if (opTrait->shouldDeclareMethods())
- genOpInterfaceMethod(opTrait);
+ genOpInterfaceMethods(opTrait);
}
}
@@ -1727,18 +1738,20 @@ void OpEmitter::genSideEffectInterfaceMethods() {
void OpEmitter::genTypeInterfaceMethods() {
if (!op.allResultTypesKnown())
return;
-
- SmallVector<OpMethodParameter, 4> paramList;
- paramList.emplace_back("::mlir::MLIRContext *", "context");
- paramList.emplace_back("::llvm::Optional<::mlir::Location>", "location");
- paramList.emplace_back("::mlir::ValueRange", "operands");
- paramList.emplace_back("::mlir::DictionaryAttr", "attributes");
- paramList.emplace_back("::mlir::RegionRange", "regions");
- paramList.emplace_back("::llvm::SmallVectorImpl<::mlir::Type>&",
- "inferredReturnTypes");
- auto *method =
- opClass.addMethodAndPrune("::mlir::LogicalResult", "inferReturnTypes",
- OpMethod::MP_Static, std::move(paramList));
+ // Generate 'inferReturnTypes' method declaration using the interface method
+ // declared in 'InferTypeOpInterface' op interface.
+ const auto *trait = dyn_cast<InterfaceOpTrait>(
+ op.getTrait("::mlir::InferTypeOpInterface::Trait"));
+ auto interface = trait->getOpInterface();
+ OpMethod *method = [&]() -> OpMethod * {
+ for (const InterfaceMethod &interfaceMethod : interface.getMethods()) {
+ if (interfaceMethod.getName() == "inferReturnTypes") {
+ return genOpInterfaceMethod(interfaceMethod, /*declaration=*/false);
+ }
+ }
+ assert(0 && "unable to find inferReturnTypes interface method");
+ return nullptr;
+ }();
auto &body = method->body();
body << " inferredReturnTypes.resize(" << op.getNumResults() << ");\n";