diff options
author | Lei Zhang <antiagainst@google.com> | 2019-08-09 19:03:58 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2019-08-09 19:15:08 -0700 |
commit | 9e88516e6d43c3ee63ba266c45594bcfe4610df3 (patch) | |
tree | 0b916df70f36a17fd3920d40a98b0e920545ca0d | |
parent | bf62fcec003636338386f5246103b90a9580181c (diff) | |
download | tensorflow-9e88516e6d43c3ee63ba266c45594bcfe4610df3.tar.gz |
NFC: Refactoring PatternSymbolResolver into SymbolInfoMap
In declarative rewrite rules, a symbol can be bound to op arguments or
results in the source pattern, and it can be bound to op results in the
result pattern. This means given a symbol in the pattern, it can stands
for different things: op operand, op attribute, single op result,
op result pack. We need a better way to model this complexity so that
we can handle according to the specific kind a symbol corresponds to.
Created SymbolInfo class for maintaining the information regarding a
symbol. Also created a companion SymbolInfoMap class for a map of
such symbols, providing insertion and querying depending on use cases.
PiperOrigin-RevId: 262675515
-rw-r--r-- | third_party/mlir/include/mlir/TableGen/Pattern.h | 182 | ||||
-rw-r--r-- | third_party/mlir/lib/TableGen/Pattern.cpp | 219 | ||||
-rw-r--r-- | third_party/mlir/tools/mlir-tblgen/RewriterGen.cpp | 338 |
3 files changed, 429 insertions, 310 deletions
diff --git a/third_party/mlir/include/mlir/TableGen/Pattern.h b/third_party/mlir/include/mlir/TableGen/Pattern.h index 0e7fa44b6e7..efe6494d391 100644 --- a/third_party/mlir/include/mlir/TableGen/Pattern.h +++ b/third_party/mlir/include/mlir/TableGen/Pattern.h @@ -180,6 +180,154 @@ private: const llvm::DagInit *node; // nullptr means null DagNode }; +// A class for maintaining information for symbols bound in patterns and +// provides methods for resolving them according to specific use cases. +// +// Symbols can be bound to +// +// * Op arguments and op results in the source pattern and +// * Op results in result patterns. +// +// Symbols can be referenced in result patterns and additional constraints to +// the pattern. +// +// For example, in +// +// ``` +// def : Pattern< +// (SrcOp:$results1 $arg0, %arg1), +// [(ResOp1:$results2), (ResOp2 $results2 (ResOp3 $arg0, $arg1))]>; +// ``` +// +// `$argN` is bound to the `SrcOp`'s N-th argument. `$results1` is bound to +// `SrcOp`. `$results2` is bound to `ResOp1`. $result2 is referenced to build +// `ResOp2`. `$arg0` and `$arg1` are referenced to build `ResOp3`. +// +// If a symbol binds to a multi-result op and it does not have the `__N` +// suffix, the symbol is expanded to represent all results generated by the +// multi-result op. If the symbol has a `__N` suffix, then it will expand to +// only the N-th *static* result as declared in ODS, and that can still +// corresponds to multiple *dynamic* values if the N-th *static* result is +// variadic. +// +// This class keeps track of such symbols and resolves them into their bound +// values in a suitable way. +class SymbolInfoMap { +public: + explicit SymbolInfoMap(ArrayRef<llvm::SMLoc> loc) : loc(loc) {} + + // Class for information regarding a symbol. + class SymbolInfo { + public: + // Returns a string for defining a variable named as `name` to store the + // value bound by this symbol. + std::string getVarDecl(StringRef name) const; + + private: + // Allow SymbolInfoMap to access private methods. + friend class SymbolInfoMap; + + // What kind of entity this symbol represents: + // * Attr: op attribute + // * Operand: op operand + // * Result: op result + // * Value: a value not attached to an op (e.g., from NativeCodeCall) + enum class Kind : uint8_t { Attr, Operand, Result, Value }; + + // Creates a SymbolInfo instance. `index` is only used for `Attr` and + // `Operand` so should be negative for `Result` and `Value` kind. + SymbolInfo(const Operator *op, Kind kind, Optional<int> index); + + // Static methods for creating SymbolInfo. + static SymbolInfo getAttr(const Operator *op, int index) { + return SymbolInfo(op, Kind::Attr, index); + } + static SymbolInfo getOperand(const Operator *op, int index) { + return SymbolInfo(op, Kind::Operand, index); + } + static SymbolInfo getResult(const Operator *op) { + return SymbolInfo(op, Kind::Result, llvm::None); + } + static SymbolInfo getValue() { + return SymbolInfo(nullptr, Kind::Value, llvm::None); + } + + // Returns the number of static values this symbol corresponds to. + // A static value is an operand/result declared in ODS. Normally a symbol + // only represents one static value, but symbols bound to op results can + // represent more than one if the op is a multi-result op. + int getStaticValueCount() const; + + // Returns a string containing the C++ expression for referencing this + // symbol as a value (if this symbol represents one static value) or a value + // range (if this symbol represents multiple static values). `name` is the + // name of the C++ variable that this symbol bounds to. `index` should only + // be used for indexing results. + std::string getValueAndRangeUse(StringRef name, int index) const; + + const Operator *op; // The op where the bound entity belongs + Kind kind; // The kind of the bound entity + // The argument index (for `Attr` and `Operand` only) + Optional<int> argIndex; + }; + + using BaseT = llvm::StringMap<SymbolInfo>; + + // Iterators for accessing all symbols. + using iterator = BaseT::iterator; + iterator begin() { return symbolInfoMap.begin(); } + iterator end() { return symbolInfoMap.end(); } + + // Const iterators for accessing all symbols. + using const_iterator = BaseT::const_iterator; + const_iterator begin() const { return symbolInfoMap.begin(); } + const_iterator end() const { return symbolInfoMap.end(); } + + // Binds the given `symbol` to the `argIndex`-th argument to the given `op`. + // Returns false if `symbol` is already bound. + bool bindOpArgument(StringRef symbol, const Operator &op, int argIndex); + + // Binds the given `symbol` to the results the given `op`. Returns false if + // `symbol` is already bound. + bool bindOpResult(StringRef symbol, const Operator &op); + + // Registers the given `symbol` as bound to a value. Returns false if `symbol` + // is already bound. + bool bindValue(StringRef symbol); + + // Returns true if the given `symbol` is bound. + bool contains(StringRef symbol) const; + + // Returns an interator to the information of the given symbol named as `key`. + const_iterator find(StringRef key) const; + + // Returns the number of static values of the given `symbol` corresponds to. + // A static value is a operand/result declared in ODS. Normally a symbol only + // represents one static value, but symbols bound to op results can represent + // more than one if the op is a multi-result op. + int getStaticValueCount(StringRef symbol) const; + + // Returns a string containing the C++ expression for referencing this + // symbol as a value (if this symbol represents one static value) or a value + // range (if this symbol represents multiple static values). + std::string getValueAndRangeUse(StringRef symbol) const; + + // Splits the given `symbol` into a value pack name and an index. Returns the + // value pack name and writes the index to `index` on sucess. Returns `symbol` + // itself if it does not contain an index. + // + // We can use `name__N` to access the `N`-th value in the value pack bound to + // `name`. `name` is typically the results of an multi-result op. + static StringRef getValuePackName(StringRef symbol, int *index = nullptr); + +private: + llvm::StringMap<SymbolInfo> symbolInfoMap; + + // Pattern instantiation location. This is intended to be used as parameter + // to PrintFatalError() to report errors. + ArrayRef<llvm::SMLoc> loc; +}; + // Wrapper class providing helper methods for accessing MLIR Pattern defined // in TableGen. This class should closely reflect what is defined as class // `Pattern` in TableGen. This class contains maps so it is not intended to be @@ -198,24 +346,11 @@ public: // Returns the DAG tree root node of the `index`-th result pattern. DagNode getResultPattern(unsigned index) const; - // Checks whether an argument or op with the given `name` is bound in - // source pattern. Prints fatal error if not; does nothing otherwise. - void ensureBoundInSourcePattern(StringRef name) const; - - // Returns a reference to all the bound arguments in the source pattern. - llvm::StringMap<Argument> &getSourcePatternBoundArgs(); + // Collects all symbols bound in the source pattern into `infoMap`. + void collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap); - // The returned map contains pointers to the operators inside the - // `RecordOperatorMap` passed-in when constructing this pattern; callers - // should guarantee the lifetime of the returned map does not exceed that - // of the `RecordOperatorMap`. - using SymbolOperatorMap = llvm::StringMap<const Operator *>; - - // Returns a reference to all the bound ops in the source pattern. - SymbolOperatorMap &getSourcePatternBoundOps(); - - // Returns a reference to all the bound ops in the result patterns. - SymbolOperatorMap &getResultPatternBoundOps(); + // Collects all symbols bound in result patterns into `infoMap`. + void collectResultPatternBoundSymbols(SymbolInfoMap &infoMap); // Returns the op that the root node of the source pattern matches. const Operator &getSourceRootOp(); @@ -238,8 +373,8 @@ public: private: // Recursively collects all bound symbols inside the DAG tree rooted - // at `tree` and updates the given `symOpMap`. - void collectBoundSymbols(DagNode tree, SymbolOperatorMap &symOpMap, + // at `tree` and updates the given `infoMap`. + void collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap, bool isSrcPattern); // The TableGen definition of this pattern. @@ -249,15 +384,6 @@ private: // TODO(antiagainst): we need a proper context manager, like MLIRContext, // for managing the lifetime of shared entities. RecordOperatorMap *recordOpMap; - - // All source pattern bound op arguments. - llvm::StringMap<Argument> srcBoundArguments; - - // All source pattern bound ops. - SymbolOperatorMap srcBoundOps; - - // All result pattern bound ops. - SymbolOperatorMap resBoundOps; }; } // end namespace tblgen diff --git a/third_party/mlir/lib/TableGen/Pattern.cpp b/third_party/mlir/lib/TableGen/Pattern.cpp index fa37d22cc5e..51e4c3b376b 100644 --- a/third_party/mlir/lib/TableGen/Pattern.cpp +++ b/third_party/mlir/lib/TableGen/Pattern.cpp @@ -31,6 +31,10 @@ using namespace mlir; using llvm::formatv; using mlir::tblgen::Operator; +//===----------------------------------------------------------------------===// +// DagLeaf +//===----------------------------------------------------------------------===// + bool tblgen::DagLeaf::isUnspecified() const { return dyn_cast_or_null<llvm::UnsetInit>(def); } @@ -88,6 +92,10 @@ bool tblgen::DagLeaf::isSubClassOf(StringRef superclass) const { return false; } +//===----------------------------------------------------------------------===// +// DagNode +//===----------------------------------------------------------------------===// + bool tblgen::DagNode::isNativeCodeCall() const { if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(node->getOperator())) return defInit->getDef()->isSubClassOf("NativeCodeCall"); @@ -151,14 +159,158 @@ bool tblgen::DagNode::isReplaceWithValue() const { return dagOpDef->getName() == "replaceWithValue"; } -tblgen::Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper) - : def(*def), recordOpMap(mapper) { - collectBoundSymbols(getSourcePattern(), srcBoundOps, /*isSrcPattern=*/true); - for (int i = 0, e = getNumResultPatterns(); i < e; ++i) - collectBoundSymbols(getResultPattern(i), resBoundOps, - /*isSrcPattern=*/false); +//===----------------------------------------------------------------------===// +// SymbolInfoMap +//===----------------------------------------------------------------------===// + +StringRef tblgen::SymbolInfoMap::getValuePackName(StringRef symbol, + int *index) { + StringRef name, indexStr; + int idx = -1; + std::tie(name, indexStr) = symbol.rsplit("__"); + + if (indexStr.consumeInteger(10, idx)) { + // The second part is not an index; we return the whole symbol as-is. + return symbol; + } + if (index) { + *index = idx; + } + return name; +} + +tblgen::SymbolInfoMap::SymbolInfo::SymbolInfo(const Operator *op, + SymbolInfo::Kind kind, + Optional<int> index) + : op(op), kind(kind), argIndex(index) {} + +int tblgen::SymbolInfoMap::SymbolInfo::getStaticValueCount() const { + switch (kind) { + case Kind::Attr: + case Kind::Operand: + case Kind::Value: + return 1; + case Kind::Result: + return op->getNumResults(); + } +} + +std::string +tblgen::SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const { + switch (kind) { + case Kind::Attr: { + auto type = + op->getArg(*argIndex).get<NamedAttribute *>()->attr.getStorageType(); + return formatv("{0} {1};\n", type, name); + } + case Kind::Operand: + case Kind::Value: { + return formatv("Value *{0};\n", name); + } + case Kind::Result: { + // Use the op itself for the results. + return formatv("{0} {1};\n", op->getQualCppClassName(), name); + } + } +} + +std::string +tblgen::SymbolInfoMap::SymbolInfo::getValueAndRangeUse(StringRef name, + int index) const { + switch (kind) { + case Kind::Attr: + case Kind::Operand: { + assert(index < 0 && "only allowed for symbol bound to result"); + return name; + } + case Kind::Result: { + // TODO(b/133341698): The following is incorrect for variadic results. We + // should use getODSResults(). + if (index >= 0) { + return formatv("{0}.getOperation()->getResult({1})", name, index); + } + + // If referencing multiple results, compose a comma-separated list. + SmallVector<std::string, 4> values; + for (int i = 0, e = op->getNumResults(); i < e; ++i) { + values.push_back(formatv("{0}.getOperation()->getResult({1})", name, i)); + } + return llvm::join(values, ", "); + } + case Kind::Value: { + assert(index < 0 && "only allowed for symbol bound to result"); + assert(op == nullptr); + return name; + } + } +} + +bool tblgen::SymbolInfoMap::bindOpArgument(StringRef symbol, const Operator &op, + int argIndex) { + StringRef name = getValuePackName(symbol); + if (name != symbol) { + auto error = formatv( + "symbol '{0}' with trailing index cannot bind to op argument", symbol); + PrintFatalError(loc, error); + } + + auto symInfo = op.getArg(argIndex).is<NamedAttribute *>() + ? SymbolInfo::getAttr(&op, argIndex) + : SymbolInfo::getOperand(&op, argIndex); + + return symbolInfoMap.insert({symbol, symInfo}).second; +} + +bool tblgen::SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) { + StringRef name = getValuePackName(symbol); + return symbolInfoMap.insert({name, SymbolInfo::getResult(&op)}).second; +} + +bool tblgen::SymbolInfoMap::bindValue(StringRef symbol) { + return symbolInfoMap.insert({symbol, SymbolInfo::getValue()}).second; +} + +bool tblgen::SymbolInfoMap::contains(StringRef symbol) const { + return find(symbol) != symbolInfoMap.end(); +} + +tblgen::SymbolInfoMap::const_iterator +tblgen::SymbolInfoMap::find(StringRef key) const { + StringRef name = getValuePackName(key); + return symbolInfoMap.find(name); +} + +int tblgen::SymbolInfoMap::getStaticValueCount(StringRef symbol) const { + StringRef name = getValuePackName(symbol); + if (name != symbol) { + // If there is a trailing index inside symbol, it references just one + // static value. + return 1; + } + // Otherwise, find how many it represents by querying the symbol's info. + return find(name)->getValue().getStaticValueCount(); } +std::string tblgen::SymbolInfoMap::getValueAndRangeUse(StringRef symbol) const { + int index = -1; + StringRef name = getValuePackName(symbol, &index); + + auto it = symbolInfoMap.find(name); + if (it == symbolInfoMap.end()) { + auto error = formatv("referencing unbound symbol '{0}'", symbol); + PrintFatalError(loc, error); + } + + return it->getValue().getValueAndRangeUse(name, index); +} + +//===----------------------------------------------------------------------===// +// Pattern +//==----------------------------------------------------------------------===// + +tblgen::Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper) + : def(*def), recordOpMap(mapper) {} + tblgen::DagNode tblgen::Pattern::getSourcePattern() const { return tblgen::DagNode(def.getValueAsDag("sourcePattern")); } @@ -173,26 +325,17 @@ tblgen::DagNode tblgen::Pattern::getResultPattern(unsigned index) const { return tblgen::DagNode(cast<llvm::DagInit>(results->getElement(index))); } -void tblgen::Pattern::ensureBoundInSourcePattern(llvm::StringRef name) const { - if (srcBoundArguments.find(name) == srcBoundArguments.end() && - srcBoundOps.find(name) == srcBoundOps.end()) - PrintFatalError(def.getLoc(), - Twine("referencing unbound variable '") + name + "'"); +void tblgen::Pattern::collectSourcePatternBoundSymbols( + tblgen::SymbolInfoMap &infoMap) { + collectBoundSymbols(getSourcePattern(), infoMap, /*isSrcPattern=*/true); } -llvm::StringMap<tblgen::Argument> & -tblgen::Pattern::getSourcePatternBoundArgs() { - return srcBoundArguments; -} - -llvm::StringMap<const tblgen::Operator *> & -tblgen::Pattern::getSourcePatternBoundOps() { - return srcBoundOps; -} - -llvm::StringMap<const tblgen::Operator *> & -tblgen::Pattern::getResultPatternBoundOps() { - return resBoundOps; +void tblgen::Pattern::collectResultPatternBoundSymbols( + tblgen::SymbolInfoMap &infoMap) { + for (int i = 0, e = getNumResultPatterns(); i < e; ++i) { + auto pattern = getResultPattern(i); + collectBoundSymbols(pattern, infoMap, /*isSrcPattern=*/false); + } } const tblgen::Operator &tblgen::Pattern::getSourceRootOp() { @@ -251,8 +394,7 @@ tblgen::Pattern::getLocation() const { return result; } -void tblgen::Pattern::collectBoundSymbols(DagNode tree, - SymbolOperatorMap &symOpMap, +void tblgen::Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap, bool isSrcPattern) { auto treeName = tree.getSymbol(); if (!tree.isOperation()) { @@ -270,27 +412,34 @@ void tblgen::Pattern::collectBoundSymbols(DagNode tree, auto numTreeArgs = tree.getNumArgs(); if (numOpArgs != numTreeArgs) { - PrintFatalError(def.getLoc(), - formatv("op '{0}' argument number mismatch: " - "{1} in pattern vs. {2} in definition", - op.getOperationName(), numTreeArgs, numOpArgs)); + auto err = formatv("op '{0}' argument number mismatch: " + "{1} in pattern vs. {2} in definition", + op.getOperationName(), numTreeArgs, numOpArgs); + PrintFatalError(def.getLoc(), err); } // The name attached to the DAG node's operator is for representing the // results generated from this op. It should be remembered as bound results. - if (!treeName.empty()) - symOpMap.try_emplace(treeName, &op); + if (!treeName.empty()) { + if (!infoMap.bindOpResult(treeName, op)) + PrintFatalError(def.getLoc(), + formatv("symbol '{0}' bound more than once", treeName)); + } for (int i = 0; i != numTreeArgs; ++i) { if (auto treeArg = tree.getArgAsNestedDag(i)) { // This DAG node argument is a DAG node itself. Go inside recursively. - collectBoundSymbols(treeArg, symOpMap, isSrcPattern); + collectBoundSymbols(treeArg, infoMap, isSrcPattern); } else if (isSrcPattern) { // We can only bind symbols to op arguments in source pattern. Those // symbols are referenced in result patterns. auto treeArgName = tree.getArgName(i); - if (!treeArgName.empty()) - srcBoundArguments.try_emplace(treeArgName, op.getArg(i)); + if (!treeArgName.empty()) { + if (!infoMap.bindOpArgument(treeArgName, op, i)) { + auto err = formatv("symbol '{0}' bound more than once", treeArgName); + PrintFatalError(def.getLoc(), err); + } + } } } } diff --git a/third_party/mlir/tools/mlir-tblgen/RewriterGen.cpp b/third_party/mlir/tools/mlir-tblgen/RewriterGen.cpp index 7a170c701c1..3487eda545f 100644 --- a/third_party/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/third_party/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -51,166 +51,6 @@ template <> struct format_provider<mlir::tblgen::Pattern::IdentifierLine> { }; } // end namespace llvm -// Gets the dynamic value pack's name by removing the index suffix from -// `symbol`. Returns `symbol` itself if it does not contain an index. -// -// We can use `name__<index>` to access the `<index>`-th value in the dynamic -// value pack bound to `name`. `name` is typically the results of an -// multi-result op. -static StringRef getValuePackName(StringRef symbol, unsigned *index = nullptr) { - StringRef name, indexStr; - unsigned idx = 0; - std::tie(name, indexStr) = symbol.rsplit("__"); - if (indexStr.consumeInteger(10, idx)) { - // The second part is not an index. - return symbol; - } - if (index) - *index = idx; - return name; -} - -// Formats all values from a dynamic value pack `symbol` according to the given -// `fmt` string. The `fmt` string should use `{0}` as a placeholder for `symbol` -// and `{1}` as a placeholder for the value index, which will be offsetted by -// `offset`. The `symbol` value pack has a total of `count` values. -// -// This extracts one value from the pack if `symbol` contains an index, -// otherwise it extracts all values sequentially and returns them as a -// comma-separated list. -static std::string formatValuePack(const char *fmt, StringRef symbol, - unsigned count, unsigned offset) { - auto getNthValue = [fmt, offset](StringRef results, - unsigned index) -> std::string { - return formatv(fmt, results, index + offset); - }; - - unsigned index = 0; - StringRef name = getValuePackName(symbol, &index); - if (name != symbol) { - // The symbol contains an index. - return getNthValue(name, index); - } - - // The symbol does not contain an index. Treat the symbol as a whole. - SmallVector<std::string, 4> values; - values.reserve(count); - for (unsigned i = 0; i < count; ++i) - values.emplace_back(getNthValue(symbol, i)); - return llvm::join(values, ", "); -} - -//===----------------------------------------------------------------------===// -// PatternSymbolResolver -//===----------------------------------------------------------------------===// - -namespace { -// A class for resolving symbols bound in patterns. -// -// Symbols can be bound to op arguments and ops in the source pattern and ops -// in result patterns. For example, in -// -// ``` -// def : Pattern<(SrcOp:$op1 $arg0, %arg1), -// [(ResOp1:$op2), (ResOp2 $op2 (ResOp3))]>; -// ``` -// -// `$argN` is bound to the `SrcOp`'s N-th argument. `$op1` is bound to `SrcOp`. -// `$op2` is bound to `ResOp1`. -// -// If a symbol binds to a multi-result op and it does not have the `__N` -// suffix, the symbol is expanded to the whole value pack generated by the -// multi-result op. If the symbol has a `__N` suffix, then it will expand to -// only the N-th result. -// -// This class keeps track of such symbols and translates them into their bound -// values. -// -// Note that we also generate local variables for unnamed DAG nodes, like -// `(ResOp3)` in the above. Since we don't bind a symbol to the op, the -// generated local variable will be implicitly named. Those implicit names are -// not tracked in this class. -class PatternSymbolResolver { -public: - PatternSymbolResolver(const StringMap<Argument> &srcArgs, - const StringMap<const Operator *> &srcOperations); - - // Marks the given `symbol` as bound to a value pack with `numValues` and - // returns true on success. Returns false if the `symbol` is already bound. - bool add(StringRef symbol, int numValues); - - // Queries the substitution for the given `symbol`. Returns empty string if - // symbol not found. If the symbol represents a value pack, returns all the - // values separated via comma. - std::string query(StringRef symbol) const; - - // Returns how many static values the given `symbol` correspond to. Returns a - // negative value if the given symbol is not bound. - // - // Normally a symbol would correspond to just one value; for symbols bound to - // multi-result ops, it can be more than one. - int getValueCount(StringRef symbol) const; - -private: - // Symbols bound to arguments in source pattern. - const StringMap<Argument> &sourceArguments; - // Symbols bound to ops (for their results) in source pattern. - const StringMap<const Operator *> &sourceOps; - // Symbols bound to ops (for their results) in result patterns. - // Key: symbol; value: number of values inside the pack - StringMap<int> resultOps; -}; -} // end anonymous namespace - -PatternSymbolResolver::PatternSymbolResolver( - const StringMap<Argument> &srcArgs, - const StringMap<const Operator *> &srcOperations) - : sourceArguments(srcArgs), sourceOps(srcOperations) {} - -bool PatternSymbolResolver::add(StringRef symbol, int numValues) { - StringRef name = getValuePackName(symbol); - return resultOps.try_emplace(name, numValues).second; -} - -std::string PatternSymbolResolver::query(StringRef symbol) const { - StringRef name = getValuePackName(symbol); - // Handle symbols bound to generated ops - auto resOpIt = resultOps.find(name); - if (resOpIt != resultOps.end()) - return formatValuePack("{0}.getOperation()->getResult({1})", symbol, - resOpIt->second, /*offset=*/0); - - // Handle symbols bound to matched op arguments - auto srcArgIt = sourceArguments.find(symbol); - if (srcArgIt != sourceArguments.end()) - return symbol; - - // Handle symbols bound to matched op results - auto srcOpIt = sourceOps.find(name); - if (srcOpIt != sourceOps.end()) - return formatValuePack("{0}->getResult({1})", symbol, - srcOpIt->second->getNumResults(), /*offset=*/0); - return {}; -} - -int PatternSymbolResolver::getValueCount(StringRef symbol) const { - StringRef name = getValuePackName(symbol); - // Handle symbols bound to generated ops - auto resOpIt = resultOps.find(name); - if (resOpIt != resultOps.end()) - return name == symbol ? resOpIt->second : 1; - - // Handle symbols bound to matched op arguments - if (sourceArguments.count(symbol)) - return 1; - - // Handle symbols bound to matched op results - auto srcOpIt = sourceOps.find(name); - if (srcOpIt != sourceOps.end()) - return name == symbol ? srcOpIt->second->getNumResults() : 1; - return -1; -} - //===----------------------------------------------------------------------===// // PatternEmitter //===----------------------------------------------------------------------===// @@ -286,17 +126,13 @@ private: // Collects all of the operations within the given dag tree. void collectOps(DagNode tree, llvm::SmallPtrSetImpl<const Operator *> &ops); - // Returns a unique name for a value of the given `op`. - std::string getUniqueValueName(const Operator *op); + // Returns a unique symbol for a local variable of the given `op`. + std::string getUniqueSymbol(const Operator *op); //===--------------------------------------------------------------------===// // Symbol utilities //===--------------------------------------------------------------------===// - // Marks the symbol attached to DagNode `node` as bound. Aborts if the symbol - // is already bound. - void addSymbol(StringRef symbol, int numValues); - // Gets the substitution for `symbol`. Aborts if `symbol` is not bound. std::string resolveSymbol(StringRef symbol); @@ -308,13 +144,19 @@ private: // prototypes used. This is intended to be used as a whole to // PrintFatalError() on errors. ArrayRef<llvm::SMLoc> loc; - // Op's TableGen Record to wrapper object + + // Op's TableGen Record to wrapper object. RecordOperatorMap *opMap; - // Handy wrapper for pattern being emitted + + // Handy wrapper for pattern being emitted. Pattern pattern; - PatternSymbolResolver symbolResolver; - // The next unused ID for newly created values + + // Map for all bound symbols' info. + SymbolInfoMap symbolInfoMap; + + // The next unused ID for newly created values. unsigned nextValueId; + raw_ostream &os; // Format contexts containing placeholder substitutations. @@ -328,9 +170,7 @@ private: PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper, raw_ostream &os) : loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper), - symbolResolver(pattern.getSourcePatternBoundArgs(), - pattern.getSourcePatternBoundOps()), - nextValueId(0), os(os) { + symbolInfoMap(pat->getLoc()), nextValueId(0), os(os) { fmtCtx.withBuilder("rewriter"); } @@ -354,13 +194,14 @@ void PatternEmitter::emitOpMatch(DagNode tree, int depth) { } int indent = 4 + 2 * depth; + os.indent(indent) << formatv( + "auto castedOp{0} = dyn_cast_or_null<{1}>(op{0}); (void)castedOp{0};\n", + depth, op.getQualCppClassName()); // Skip the operand matching at depth 0 as the pattern rewriter already does. if (depth != 0) { // Skip if there is no defining operation (e.g., arguments to function). - os.indent(indent) << formatv("if (!op{0}) return matchFailure();\n", depth); - os.indent(indent) << formatv( - "if (!isa<{1}>(op{0})) return matchFailure();\n", depth, - op.getQualCppClassName()); + os.indent(indent) << formatv("if (!castedOp{0}) return matchFailure();\n", + depth); } if (tree.getNumArgs() != op.getNumArgs()) { PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in " @@ -372,7 +213,7 @@ void PatternEmitter::emitOpMatch(DagNode tree, int depth) { // If the operand's name is set, set to that variable. auto name = tree.getSymbol(); if (!name.empty()) - os.indent(indent) << formatv("{0} = op{1};\n", name, depth); + os.indent(indent) << formatv("{0} = castedOp{1};\n", name, depth); for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { auto opArg = op.getArg(i); @@ -381,7 +222,7 @@ void PatternEmitter::emitOpMatch(DagNode tree, int depth) { if (DagNode argTree = tree.getArgAsNestedDag(i)) { os.indent(indent) << "{\n"; os.indent(indent + 2) - << formatv("auto op{0} = op{1}->getOperand({2})->getDefiningOp();\n", + << formatv("auto *op{0} = op{1}->getOperand({2})->getDefiningOp();\n", depth + 1, depth, i); emitOpMatch(argTree, depth + 1); os.indent(indent + 2) @@ -569,21 +410,17 @@ void PatternEmitter::emit(StringRef rewriteName) { PatternRewriter &rewriter) const override { )"; + // Register all symbols bound in the source pattern. + pattern.collectSourcePatternBoundSymbols(symbolInfoMap); + os.indent(4) << "// Variables for capturing values and attributes used for " "creating ops\n"; - // Create local variables for storing the arguments bound to symbols. - for (const auto &arg : pattern.getSourcePatternBoundArgs()) { - auto fieldName = arg.first(); - if (auto namedAttr = arg.second.dyn_cast<NamedAttribute *>()) { - os.indent(4) << formatv("{0} {1};\n", namedAttr->attr.getStorageType(), - fieldName); - } else { - os.indent(4) << "Value *" << fieldName << ";\n"; - } - } - // Create local variables for storing the ops bound to symbols. - for (const auto &result : pattern.getSourcePatternBoundOps()) { - os.indent(4) << formatv("Operation *{0};\n", result.getKey()); + // Create local variables for storing the arguments and results bound + // to symbols. + for (const auto &symbolInfoPair : symbolInfoMap) { + StringRef symbol = symbolInfoPair.getKey(); + auto &info = symbolInfoPair.getValue(); + os.indent(4) << info.getVarDecl(symbol); } // TODO(jpienaar): capture ops with consistent numbering so that it can be // reused for fused loc. @@ -609,20 +446,22 @@ void PatternEmitter::emitRewriteLogic() { int numResultPatterns = pattern.getNumResultPatterns(); // First register all symbols bound to ops generated in result patterns. - for (const auto &boundOp : pattern.getResultPatternBoundOps()) { - addSymbol(boundOp.getKey(), boundOp.getValue()->getNumResults()); - } + pattern.collectResultPatternBoundSymbols(symbolInfoMap); // Only the last N static values generated are used to replace the matched // root N-result op. We need to calculate the starting index (of the results // of the matched op) each result pattern is to replace. SmallVector<int, 4> offsets(numResultPatterns + 1, numExpectedResults); - int replStartIndex = -1; + // If we don't need to replace any value at all, set the replacement starting + // index as the number of result patterns so we skip all of them when trying + // to replace the matched op's results. + int replStartIndex = numExpectedResults == 0 ? numResultPatterns : -1; for (int i = numResultPatterns - 1; i >= 0; --i) { auto numValues = getNodeValueCount(pattern.getResultPattern(i)); offsets[i] = offsets[i + 1] - numValues; if (offsets[i] == 0) { - replStartIndex = i; + if (replStartIndex == -1) + replStartIndex = i; } else if (offsets[i] < 0 && offsets[i + 1] > 0) { auto error = formatv( "cannot use the same multi-result op '{0}' to generate both " @@ -652,31 +491,36 @@ void PatternEmitter::emitRewriteLogic() { // Emit the final replaceOp() statement os.indent(4) << "rewriter.replaceOp(op0, {"; - interleave( - ArrayRef<std::string>(resultValues).drop_front(replStartIndex), - [&](const std::string &name) { os << name; }, [&]() { os << ", "; }); + interleaveComma( + ArrayRef<std::string>(resultValues).drop_front(replStartIndex), os, + [&](const std::string &symbol) { os << resolveSymbol(symbol); }); os << "});\n"; } -std::string PatternEmitter::getUniqueValueName(const Operator *op) { - return formatv("v{0}{1}", op->getCppClassName(), nextValueId++); +std::string PatternEmitter::getUniqueSymbol(const Operator *op) { + return formatv("tblgen_{0}_{1}", op->getCppClassName(), nextValueId++); } std::string PatternEmitter::handleResultPattern(DagNode resultTree, int resultIndex, int depth) { - if (resultTree.isNativeCodeCall()) - return handleReplaceWithNativeCodeCall(resultTree); + if (resultTree.isNativeCodeCall()) { + auto symbol = handleReplaceWithNativeCodeCall(resultTree); + symbolInfoMap.bindValue(symbol); + return symbol; + } - if (resultTree.isReplaceWithValue()) + if (resultTree.isReplaceWithValue()) { return handleReplaceWithValue(resultTree); + } - // Create the op and get the local variable for it. - auto results = handleOpCreation(resultTree, resultIndex, depth); - // We need to get all the values out of this local variable if we've created a - // multi-result op. - const auto &numResults = pattern.getDialectOp(resultTree).getNumResults(); - return formatValuePack("{0}.getOperation()->getResult({1})", results, - numResults, /*offset=*/0); + // Normal op creation. + auto symbol = handleOpCreation(resultTree, resultIndex, depth); + if (resultTree.getSymbol().empty()) { + // This is an op not explicitly bound to a symbol in the rewrite rule. + // Register the auto-generated symbol for it. + symbolInfoMap.bindOpResult(symbol, pattern.getDialectOp(resultTree)); + } + return symbol; } std::string PatternEmitter::handleReplaceWithValue(DagNode tree) { @@ -709,7 +553,6 @@ std::string PatternEmitter::handleOpArgument(DagLeaf leaf, StringRef argName) { std::string val = std::to_string(enumCase.getValue()); return handleConstantAttr(enumCase, val); } - pattern.ensureBoundInSourcePattern(argName); if (leaf.isUnspecified() || leaf.isOperandMatcher()) { return argName; } @@ -734,27 +577,23 @@ std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree) { attrs[5], attrs[6], attrs[7]); } -void PatternEmitter::addSymbol(StringRef symbol, int numValues) { - if (!symbolResolver.add(symbol, numValues)) - PrintFatalError(loc, formatv("symbol '{0}' bound more than once", symbol)); -} - std::string PatternEmitter::resolveSymbol(StringRef symbol) { - auto subst = symbolResolver.query(symbol); - if (subst.empty()) + auto subst = symbolInfoMap.getValueAndRangeUse(symbol); + if (subst.empty()) { PrintFatalError(loc, formatv("referencing unbound symbol '{0}'", symbol)); + } return subst; } int PatternEmitter::getNodeValueCount(DagNode node) { if (node.isOperation()) { - // First to see whether this op is bound and we just want a specific result - // of it with `__N` suffix in symbol. - int count = symbolResolver.getValueCount(node.getSymbol()); - if (count >= 0) - return count; - - // No symbol. Then we are using all the results. + // If the op is bound to a symbol in the rewrite rule, query its result + // count from the symbol info map. + auto symbol = node.getSymbol(); + if (!symbol.empty()) { + return symbolInfoMap.getStaticValueCount(symbol); + } + // Otherwise this is an unbound op; we will use all its results. return pattern.getDialectOp(node).getNumResults(); } // TODO(antiagainst): This considers all NativeCodeCall as returning one @@ -799,10 +638,10 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex, // Use the specified name for this op if available. Generate one otherwise. std::string resultValue = tree.getSymbol(); if (resultValue.empty()) - resultValue = getUniqueValueName(&resultOp); + resultValue = getUniqueSymbol(&resultOp); // Strip the index to get the name for the value pack. This will be used to // name the local variable for the op. - StringRef valuePackName = getValuePackName(resultValue); + StringRef valuePackName = SymbolInfoMap::getValuePackName(resultValue); // Then we build the new op corresponding to this DAG node. @@ -826,20 +665,25 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex, // here. // We need to specify the types for all results. - auto resultTypes = - formatValuePack("op0->getResult({1})->getType()", valuePackName, - resultOp.getNumResults(), resultIndex); + SmallVector<std::string, 4> resultTypes; + int numResults = resultOp.getNumResults(); + resultTypes.reserve(numResults); + for (int i = 0; i < numResults; ++i) { + resultTypes.push_back( + formatv("op0->getResult({0})->getType()", resultIndex + i)); + } os.indent(4) << formatv("auto {0} = rewriter.create<{1}>(loc", valuePackName, resultOp.getQualCppClassName()) - << (resultTypes.empty() ? "" : ", ") << resultTypes; + << (resultTypes.empty() ? "" : ", ") + << llvm::join(resultTypes, ", "); } // Create the builder call for the result. // Add operands. - int i = 0; - for (int e = resultOp.getNumOperands(); i < e; ++i) { - const auto &operand = resultOp.getOperand(i); + int argIndex = 0; + for (int e = resultOp.getNumOperands(); argIndex < e; ++argIndex) { + const auto &operand = resultOp.getOperand(argIndex); // Start each operand on its own line. (os << ",\n").indent(6); @@ -847,11 +691,11 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex, if (!operand.name.empty()) os << "/*" << operand.name << "=*/"; - if (tree.isNestedDagArg(i)) { - os << childNodeNames[i]; + if (tree.isNestedDagArg(argIndex)) { + os << childNodeNames[argIndex]; } else { - DagLeaf leaf = tree.getArgAsLeaf(i); - auto symbol = resolveSymbol(tree.getArgName(i)); + DagLeaf leaf = tree.getArgAsLeaf(argIndex); + auto symbol = resolveSymbol(tree.getArgName(argIndex)); if (leaf.isNativeCodeCall()) { os << tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol)); } else { @@ -862,26 +706,26 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex, } // Add attributes. - for (int e = tree.getNumArgs(); i != e; ++i) { + for (; argIndex != numOpArgs; ++argIndex) { // Start each attribute on its own line. (os << ",\n").indent(6); // The argument in the op definition. - auto opArgName = resultOp.getArgName(i); - if (auto subTree = tree.getArgAsNestedDag(i)) { + auto opArgName = resultOp.getArgName(argIndex); + if (auto subTree = tree.getArgAsNestedDag(argIndex)) { if (!subTree.isNativeCodeCall()) PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node " "for creating attribute"); os << formatv("/*{0}=*/{1}", opArgName, handleReplaceWithNativeCodeCall(subTree)); } else { - auto leaf = tree.getArgAsLeaf(i); + auto leaf = tree.getArgAsLeaf(argIndex); // The argument in the result DAG pattern. - auto patArgName = tree.getArgName(i); + auto patArgName = tree.getArgName(argIndex); if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) { // TODO(jpienaar): Refactor out into map to avoid recomputing these. - auto argument = resultOp.getArg(i); + auto argument = resultOp.getArg(argIndex); if (!argument.is<NamedAttribute *>()) - PrintFatalError(loc, Twine("expected attribute ") + Twine(i)); + PrintFatalError(loc, Twine("expected attribute ") + Twine(argIndex)); if (!patArgName.empty()) os << "/*" << patArgName << "=*/"; } else { |