diff options
author | Rahul Joshi <jurahul@google.com> | 2020-12-03 16:18:11 -0800 |
---|---|---|
committer | Rahul Joshi <jurahul@google.com> | 2020-12-04 09:05:53 -0800 |
commit | 245233423e466979e11b39cbed676903892d07f8 (patch) | |
tree | 5f534c4512a3ad23661111f99c968d0b2479f882 /mlir | |
parent | 7f6f9f4cf966c78a315d15d6e913c43cfa45c47c (diff) | |
download | llvm-project-245233423e466979e11b39cbed676903892d07f8.tar.gz |
[MLIR] Generate inferReturnTypes declaration using InferTypeOpInterface trait.
- Instead of hardcoding the parameters and return types of 'inferReturnTypes', use the
InferTypeOpInterface trait to generate the method declaration.
- Fix InferTypeOfInterface to use fully qualified type for inferReturnTypes results.
Differential Revision: https://reviews.llvm.org/D92585
Diffstat (limited to 'mlir')
-rw-r--r-- | mlir/include/mlir/Interfaces/InferTypeOpInterface.td | 2 | ||||
-rw-r--r-- | mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 65 |
2 files changed, 40 insertions, 27 deletions
diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td index ed62e9015bde..9de087b1b4ca 100644 --- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td +++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td @@ -36,7 +36,7 @@ def InferTypeOpInterface : OpInterface<"InferTypeOpInterface"> { which an Operation would be created (e.g., as used in Operation::create) and the regions of the op. }], - /*retTy=*/"LogicalResult", + /*retTy=*/"::mlir::LogicalResult", /*methodName=*/"inferReturnTypes", /*args=*/(ins "::mlir::MLIRContext *":$context, "::llvm::Optional<::mlir::Location>":$location, diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index c96fde648eb2..ccfb13fa3436 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -290,11 +290,16 @@ private: // Generates the traits used by the object. void genTraits(); - // Generate the OpInterface methods. + // Generate the OpInterface methods for all interfaces. void genOpInterfaceMethods(); - // Generate op interface method. - void genOpInterfaceMethod(const tblgen::InterfaceOpTrait *trait); + // Generate op interface methods for the given interface. + void genOpInterfaceMethods(const tblgen::InterfaceOpTrait *trait); + + // Generate op interface method for the given interface method. If + // 'declaration' is true, generates a declaration, else a definition. + OpMethod *genOpInterfaceMethod(const tblgen::InterfaceMethod &method, + bool declaration = true); // Generate the side effect interface methods. void genSideEffectInterfaceMethods(); @@ -1588,7 +1593,7 @@ void OpEmitter::genFolderDecls() { } } -void OpEmitter::genOpInterfaceMethod(const tblgen::InterfaceOpTrait *opTrait) { +void OpEmitter::genOpInterfaceMethods(const tblgen::InterfaceOpTrait *opTrait) { auto interface = opTrait->getOpInterface(); // Get the set of methods that should always be declared. @@ -1606,23 +1611,29 @@ void OpEmitter::genOpInterfaceMethod(const tblgen::InterfaceOpTrait *opTrait) { if (method.getDefaultImplementation() && !alwaysDeclaredMethods.count(method.getName())) continue; - - SmallVector<OpMethodParameter, 4> paramList; - for (const InterfaceMethod::Argument &arg : method.getArguments()) - paramList.emplace_back(arg.type, arg.name); - - auto properties = method.isStatic() ? OpMethod::MP_StaticDeclaration - : OpMethod::MP_Declaration; - opClass.addMethodAndPrune(method.getReturnType(), method.getName(), - properties, std::move(paramList)); + genOpInterfaceMethod(method); } } +OpMethod *OpEmitter::genOpInterfaceMethod(const InterfaceMethod &method, + bool declaration) { + SmallVector<OpMethodParameter, 4> paramList; + for (const InterfaceMethod::Argument &arg : method.getArguments()) + paramList.emplace_back(arg.type, arg.name); + + auto properties = method.isStatic() ? OpMethod::MP_Static : OpMethod::MP_None; + if (declaration) + properties = + static_cast<OpMethod::Property>(properties | OpMethod::MP_Declaration); + return opClass.addMethodAndPrune(method.getReturnType(), method.getName(), + properties, std::move(paramList)); +} + void OpEmitter::genOpInterfaceMethods() { for (const auto &trait : op.getTraits()) { if (const auto *opTrait = dyn_cast<tblgen::InterfaceOpTrait>(&trait)) if (opTrait->shouldDeclareMethods()) - genOpInterfaceMethod(opTrait); + genOpInterfaceMethods(opTrait); } } @@ -1727,18 +1738,20 @@ void OpEmitter::genSideEffectInterfaceMethods() { void OpEmitter::genTypeInterfaceMethods() { if (!op.allResultTypesKnown()) return; - - SmallVector<OpMethodParameter, 4> paramList; - paramList.emplace_back("::mlir::MLIRContext *", "context"); - paramList.emplace_back("::llvm::Optional<::mlir::Location>", "location"); - paramList.emplace_back("::mlir::ValueRange", "operands"); - paramList.emplace_back("::mlir::DictionaryAttr", "attributes"); - paramList.emplace_back("::mlir::RegionRange", "regions"); - paramList.emplace_back("::llvm::SmallVectorImpl<::mlir::Type>&", - "inferredReturnTypes"); - auto *method = - opClass.addMethodAndPrune("::mlir::LogicalResult", "inferReturnTypes", - OpMethod::MP_Static, std::move(paramList)); + // Generate 'inferReturnTypes' method declaration using the interface method + // declared in 'InferTypeOpInterface' op interface. + const auto *trait = dyn_cast<InterfaceOpTrait>( + op.getTrait("::mlir::InferTypeOpInterface::Trait")); + auto interface = trait->getOpInterface(); + OpMethod *method = [&]() -> OpMethod * { + for (const InterfaceMethod &interfaceMethod : interface.getMethods()) { + if (interfaceMethod.getName() == "inferReturnTypes") { + return genOpInterfaceMethod(interfaceMethod, /*declaration=*/false); + } + } + assert(0 && "unable to find inferReturnTypes interface method"); + return nullptr; + }(); auto &body = method->body(); body << " inferredReturnTypes.resize(" << op.getNumResults() << ");\n"; |