aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLei Zhang <antiagainst@google.com>2019-08-09 19:03:58 -0700
committerTensorFlower Gardener <gardener@tensorflow.org>2019-08-09 19:15:08 -0700
commit9e88516e6d43c3ee63ba266c45594bcfe4610df3 (patch)
tree0b916df70f36a17fd3920d40a98b0e920545ca0d
parentbf62fcec003636338386f5246103b90a9580181c (diff)
downloadtensorflow-9e88516e6d43c3ee63ba266c45594bcfe4610df3.tar.gz
NFC: Refactoring PatternSymbolResolver into SymbolInfoMap
In declarative rewrite rules, a symbol can be bound to op arguments or results in the source pattern, and it can be bound to op results in the result pattern. This means given a symbol in the pattern, it can stands for different things: op operand, op attribute, single op result, op result pack. We need a better way to model this complexity so that we can handle according to the specific kind a symbol corresponds to. Created SymbolInfo class for maintaining the information regarding a symbol. Also created a companion SymbolInfoMap class for a map of such symbols, providing insertion and querying depending on use cases. PiperOrigin-RevId: 262675515
-rw-r--r--third_party/mlir/include/mlir/TableGen/Pattern.h182
-rw-r--r--third_party/mlir/lib/TableGen/Pattern.cpp219
-rw-r--r--third_party/mlir/tools/mlir-tblgen/RewriterGen.cpp338
3 files changed, 429 insertions, 310 deletions
diff --git a/third_party/mlir/include/mlir/TableGen/Pattern.h b/third_party/mlir/include/mlir/TableGen/Pattern.h
index 0e7fa44b6e7..efe6494d391 100644
--- a/third_party/mlir/include/mlir/TableGen/Pattern.h
+++ b/third_party/mlir/include/mlir/TableGen/Pattern.h
@@ -180,6 +180,154 @@ private:
const llvm::DagInit *node; // nullptr means null DagNode
};
+// A class for maintaining information for symbols bound in patterns and
+// provides methods for resolving them according to specific use cases.
+//
+// Symbols can be bound to
+//
+// * Op arguments and op results in the source pattern and
+// * Op results in result patterns.
+//
+// Symbols can be referenced in result patterns and additional constraints to
+// the pattern.
+//
+// For example, in
+//
+// ```
+// def : Pattern<
+// (SrcOp:$results1 $arg0, %arg1),
+// [(ResOp1:$results2), (ResOp2 $results2 (ResOp3 $arg0, $arg1))]>;
+// ```
+//
+// `$argN` is bound to the `SrcOp`'s N-th argument. `$results1` is bound to
+// `SrcOp`. `$results2` is bound to `ResOp1`. $result2 is referenced to build
+// `ResOp2`. `$arg0` and `$arg1` are referenced to build `ResOp3`.
+//
+// If a symbol binds to a multi-result op and it does not have the `__N`
+// suffix, the symbol is expanded to represent all results generated by the
+// multi-result op. If the symbol has a `__N` suffix, then it will expand to
+// only the N-th *static* result as declared in ODS, and that can still
+// corresponds to multiple *dynamic* values if the N-th *static* result is
+// variadic.
+//
+// This class keeps track of such symbols and resolves them into their bound
+// values in a suitable way.
+class SymbolInfoMap {
+public:
+ explicit SymbolInfoMap(ArrayRef<llvm::SMLoc> loc) : loc(loc) {}
+
+ // Class for information regarding a symbol.
+ class SymbolInfo {
+ public:
+ // Returns a string for defining a variable named as `name` to store the
+ // value bound by this symbol.
+ std::string getVarDecl(StringRef name) const;
+
+ private:
+ // Allow SymbolInfoMap to access private methods.
+ friend class SymbolInfoMap;
+
+ // What kind of entity this symbol represents:
+ // * Attr: op attribute
+ // * Operand: op operand
+ // * Result: op result
+ // * Value: a value not attached to an op (e.g., from NativeCodeCall)
+ enum class Kind : uint8_t { Attr, Operand, Result, Value };
+
+ // Creates a SymbolInfo instance. `index` is only used for `Attr` and
+ // `Operand` so should be negative for `Result` and `Value` kind.
+ SymbolInfo(const Operator *op, Kind kind, Optional<int> index);
+
+ // Static methods for creating SymbolInfo.
+ static SymbolInfo getAttr(const Operator *op, int index) {
+ return SymbolInfo(op, Kind::Attr, index);
+ }
+ static SymbolInfo getOperand(const Operator *op, int index) {
+ return SymbolInfo(op, Kind::Operand, index);
+ }
+ static SymbolInfo getResult(const Operator *op) {
+ return SymbolInfo(op, Kind::Result, llvm::None);
+ }
+ static SymbolInfo getValue() {
+ return SymbolInfo(nullptr, Kind::Value, llvm::None);
+ }
+
+ // Returns the number of static values this symbol corresponds to.
+ // A static value is an operand/result declared in ODS. Normally a symbol
+ // only represents one static value, but symbols bound to op results can
+ // represent more than one if the op is a multi-result op.
+ int getStaticValueCount() const;
+
+ // Returns a string containing the C++ expression for referencing this
+ // symbol as a value (if this symbol represents one static value) or a value
+ // range (if this symbol represents multiple static values). `name` is the
+ // name of the C++ variable that this symbol bounds to. `index` should only
+ // be used for indexing results.
+ std::string getValueAndRangeUse(StringRef name, int index) const;
+
+ const Operator *op; // The op where the bound entity belongs
+ Kind kind; // The kind of the bound entity
+ // The argument index (for `Attr` and `Operand` only)
+ Optional<int> argIndex;
+ };
+
+ using BaseT = llvm::StringMap<SymbolInfo>;
+
+ // Iterators for accessing all symbols.
+ using iterator = BaseT::iterator;
+ iterator begin() { return symbolInfoMap.begin(); }
+ iterator end() { return symbolInfoMap.end(); }
+
+ // Const iterators for accessing all symbols.
+ using const_iterator = BaseT::const_iterator;
+ const_iterator begin() const { return symbolInfoMap.begin(); }
+ const_iterator end() const { return symbolInfoMap.end(); }
+
+ // Binds the given `symbol` to the `argIndex`-th argument to the given `op`.
+ // Returns false if `symbol` is already bound.
+ bool bindOpArgument(StringRef symbol, const Operator &op, int argIndex);
+
+ // Binds the given `symbol` to the results the given `op`. Returns false if
+ // `symbol` is already bound.
+ bool bindOpResult(StringRef symbol, const Operator &op);
+
+ // Registers the given `symbol` as bound to a value. Returns false if `symbol`
+ // is already bound.
+ bool bindValue(StringRef symbol);
+
+ // Returns true if the given `symbol` is bound.
+ bool contains(StringRef symbol) const;
+
+ // Returns an interator to the information of the given symbol named as `key`.
+ const_iterator find(StringRef key) const;
+
+ // Returns the number of static values of the given `symbol` corresponds to.
+ // A static value is a operand/result declared in ODS. Normally a symbol only
+ // represents one static value, but symbols bound to op results can represent
+ // more than one if the op is a multi-result op.
+ int getStaticValueCount(StringRef symbol) const;
+
+ // Returns a string containing the C++ expression for referencing this
+ // symbol as a value (if this symbol represents one static value) or a value
+ // range (if this symbol represents multiple static values).
+ std::string getValueAndRangeUse(StringRef symbol) const;
+
+ // Splits the given `symbol` into a value pack name and an index. Returns the
+ // value pack name and writes the index to `index` on sucess. Returns `symbol`
+ // itself if it does not contain an index.
+ //
+ // We can use `name__N` to access the `N`-th value in the value pack bound to
+ // `name`. `name` is typically the results of an multi-result op.
+ static StringRef getValuePackName(StringRef symbol, int *index = nullptr);
+
+private:
+ llvm::StringMap<SymbolInfo> symbolInfoMap;
+
+ // Pattern instantiation location. This is intended to be used as parameter
+ // to PrintFatalError() to report errors.
+ ArrayRef<llvm::SMLoc> loc;
+};
+
// Wrapper class providing helper methods for accessing MLIR Pattern defined
// in TableGen. This class should closely reflect what is defined as class
// `Pattern` in TableGen. This class contains maps so it is not intended to be
@@ -198,24 +346,11 @@ public:
// Returns the DAG tree root node of the `index`-th result pattern.
DagNode getResultPattern(unsigned index) const;
- // Checks whether an argument or op with the given `name` is bound in
- // source pattern. Prints fatal error if not; does nothing otherwise.
- void ensureBoundInSourcePattern(StringRef name) const;
-
- // Returns a reference to all the bound arguments in the source pattern.
- llvm::StringMap<Argument> &getSourcePatternBoundArgs();
+ // Collects all symbols bound in the source pattern into `infoMap`.
+ void collectSourcePatternBoundSymbols(SymbolInfoMap &infoMap);
- // The returned map contains pointers to the operators inside the
- // `RecordOperatorMap` passed-in when constructing this pattern; callers
- // should guarantee the lifetime of the returned map does not exceed that
- // of the `RecordOperatorMap`.
- using SymbolOperatorMap = llvm::StringMap<const Operator *>;
-
- // Returns a reference to all the bound ops in the source pattern.
- SymbolOperatorMap &getSourcePatternBoundOps();
-
- // Returns a reference to all the bound ops in the result patterns.
- SymbolOperatorMap &getResultPatternBoundOps();
+ // Collects all symbols bound in result patterns into `infoMap`.
+ void collectResultPatternBoundSymbols(SymbolInfoMap &infoMap);
// Returns the op that the root node of the source pattern matches.
const Operator &getSourceRootOp();
@@ -238,8 +373,8 @@ public:
private:
// Recursively collects all bound symbols inside the DAG tree rooted
- // at `tree` and updates the given `symOpMap`.
- void collectBoundSymbols(DagNode tree, SymbolOperatorMap &symOpMap,
+ // at `tree` and updates the given `infoMap`.
+ void collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
bool isSrcPattern);
// The TableGen definition of this pattern.
@@ -249,15 +384,6 @@ private:
// TODO(antiagainst): we need a proper context manager, like MLIRContext,
// for managing the lifetime of shared entities.
RecordOperatorMap *recordOpMap;
-
- // All source pattern bound op arguments.
- llvm::StringMap<Argument> srcBoundArguments;
-
- // All source pattern bound ops.
- SymbolOperatorMap srcBoundOps;
-
- // All result pattern bound ops.
- SymbolOperatorMap resBoundOps;
};
} // end namespace tblgen
diff --git a/third_party/mlir/lib/TableGen/Pattern.cpp b/third_party/mlir/lib/TableGen/Pattern.cpp
index fa37d22cc5e..51e4c3b376b 100644
--- a/third_party/mlir/lib/TableGen/Pattern.cpp
+++ b/third_party/mlir/lib/TableGen/Pattern.cpp
@@ -31,6 +31,10 @@ using namespace mlir;
using llvm::formatv;
using mlir::tblgen::Operator;
+//===----------------------------------------------------------------------===//
+// DagLeaf
+//===----------------------------------------------------------------------===//
+
bool tblgen::DagLeaf::isUnspecified() const {
return dyn_cast_or_null<llvm::UnsetInit>(def);
}
@@ -88,6 +92,10 @@ bool tblgen::DagLeaf::isSubClassOf(StringRef superclass) const {
return false;
}
+//===----------------------------------------------------------------------===//
+// DagNode
+//===----------------------------------------------------------------------===//
+
bool tblgen::DagNode::isNativeCodeCall() const {
if (auto *defInit = dyn_cast_or_null<llvm::DefInit>(node->getOperator()))
return defInit->getDef()->isSubClassOf("NativeCodeCall");
@@ -151,14 +159,158 @@ bool tblgen::DagNode::isReplaceWithValue() const {
return dagOpDef->getName() == "replaceWithValue";
}
-tblgen::Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper)
- : def(*def), recordOpMap(mapper) {
- collectBoundSymbols(getSourcePattern(), srcBoundOps, /*isSrcPattern=*/true);
- for (int i = 0, e = getNumResultPatterns(); i < e; ++i)
- collectBoundSymbols(getResultPattern(i), resBoundOps,
- /*isSrcPattern=*/false);
+//===----------------------------------------------------------------------===//
+// SymbolInfoMap
+//===----------------------------------------------------------------------===//
+
+StringRef tblgen::SymbolInfoMap::getValuePackName(StringRef symbol,
+ int *index) {
+ StringRef name, indexStr;
+ int idx = -1;
+ std::tie(name, indexStr) = symbol.rsplit("__");
+
+ if (indexStr.consumeInteger(10, idx)) {
+ // The second part is not an index; we return the whole symbol as-is.
+ return symbol;
+ }
+ if (index) {
+ *index = idx;
+ }
+ return name;
+}
+
+tblgen::SymbolInfoMap::SymbolInfo::SymbolInfo(const Operator *op,
+ SymbolInfo::Kind kind,
+ Optional<int> index)
+ : op(op), kind(kind), argIndex(index) {}
+
+int tblgen::SymbolInfoMap::SymbolInfo::getStaticValueCount() const {
+ switch (kind) {
+ case Kind::Attr:
+ case Kind::Operand:
+ case Kind::Value:
+ return 1;
+ case Kind::Result:
+ return op->getNumResults();
+ }
+}
+
+std::string
+tblgen::SymbolInfoMap::SymbolInfo::getVarDecl(StringRef name) const {
+ switch (kind) {
+ case Kind::Attr: {
+ auto type =
+ op->getArg(*argIndex).get<NamedAttribute *>()->attr.getStorageType();
+ return formatv("{0} {1};\n", type, name);
+ }
+ case Kind::Operand:
+ case Kind::Value: {
+ return formatv("Value *{0};\n", name);
+ }
+ case Kind::Result: {
+ // Use the op itself for the results.
+ return formatv("{0} {1};\n", op->getQualCppClassName(), name);
+ }
+ }
+}
+
+std::string
+tblgen::SymbolInfoMap::SymbolInfo::getValueAndRangeUse(StringRef name,
+ int index) const {
+ switch (kind) {
+ case Kind::Attr:
+ case Kind::Operand: {
+ assert(index < 0 && "only allowed for symbol bound to result");
+ return name;
+ }
+ case Kind::Result: {
+ // TODO(b/133341698): The following is incorrect for variadic results. We
+ // should use getODSResults().
+ if (index >= 0) {
+ return formatv("{0}.getOperation()->getResult({1})", name, index);
+ }
+
+ // If referencing multiple results, compose a comma-separated list.
+ SmallVector<std::string, 4> values;
+ for (int i = 0, e = op->getNumResults(); i < e; ++i) {
+ values.push_back(formatv("{0}.getOperation()->getResult({1})", name, i));
+ }
+ return llvm::join(values, ", ");
+ }
+ case Kind::Value: {
+ assert(index < 0 && "only allowed for symbol bound to result");
+ assert(op == nullptr);
+ return name;
+ }
+ }
+}
+
+bool tblgen::SymbolInfoMap::bindOpArgument(StringRef symbol, const Operator &op,
+ int argIndex) {
+ StringRef name = getValuePackName(symbol);
+ if (name != symbol) {
+ auto error = formatv(
+ "symbol '{0}' with trailing index cannot bind to op argument", symbol);
+ PrintFatalError(loc, error);
+ }
+
+ auto symInfo = op.getArg(argIndex).is<NamedAttribute *>()
+ ? SymbolInfo::getAttr(&op, argIndex)
+ : SymbolInfo::getOperand(&op, argIndex);
+
+ return symbolInfoMap.insert({symbol, symInfo}).second;
+}
+
+bool tblgen::SymbolInfoMap::bindOpResult(StringRef symbol, const Operator &op) {
+ StringRef name = getValuePackName(symbol);
+ return symbolInfoMap.insert({name, SymbolInfo::getResult(&op)}).second;
+}
+
+bool tblgen::SymbolInfoMap::bindValue(StringRef symbol) {
+ return symbolInfoMap.insert({symbol, SymbolInfo::getValue()}).second;
+}
+
+bool tblgen::SymbolInfoMap::contains(StringRef symbol) const {
+ return find(symbol) != symbolInfoMap.end();
+}
+
+tblgen::SymbolInfoMap::const_iterator
+tblgen::SymbolInfoMap::find(StringRef key) const {
+ StringRef name = getValuePackName(key);
+ return symbolInfoMap.find(name);
+}
+
+int tblgen::SymbolInfoMap::getStaticValueCount(StringRef symbol) const {
+ StringRef name = getValuePackName(symbol);
+ if (name != symbol) {
+ // If there is a trailing index inside symbol, it references just one
+ // static value.
+ return 1;
+ }
+ // Otherwise, find how many it represents by querying the symbol's info.
+ return find(name)->getValue().getStaticValueCount();
}
+std::string tblgen::SymbolInfoMap::getValueAndRangeUse(StringRef symbol) const {
+ int index = -1;
+ StringRef name = getValuePackName(symbol, &index);
+
+ auto it = symbolInfoMap.find(name);
+ if (it == symbolInfoMap.end()) {
+ auto error = formatv("referencing unbound symbol '{0}'", symbol);
+ PrintFatalError(loc, error);
+ }
+
+ return it->getValue().getValueAndRangeUse(name, index);
+}
+
+//===----------------------------------------------------------------------===//
+// Pattern
+//==----------------------------------------------------------------------===//
+
+tblgen::Pattern::Pattern(const llvm::Record *def, RecordOperatorMap *mapper)
+ : def(*def), recordOpMap(mapper) {}
+
tblgen::DagNode tblgen::Pattern::getSourcePattern() const {
return tblgen::DagNode(def.getValueAsDag("sourcePattern"));
}
@@ -173,26 +325,17 @@ tblgen::DagNode tblgen::Pattern::getResultPattern(unsigned index) const {
return tblgen::DagNode(cast<llvm::DagInit>(results->getElement(index)));
}
-void tblgen::Pattern::ensureBoundInSourcePattern(llvm::StringRef name) const {
- if (srcBoundArguments.find(name) == srcBoundArguments.end() &&
- srcBoundOps.find(name) == srcBoundOps.end())
- PrintFatalError(def.getLoc(),
- Twine("referencing unbound variable '") + name + "'");
+void tblgen::Pattern::collectSourcePatternBoundSymbols(
+ tblgen::SymbolInfoMap &infoMap) {
+ collectBoundSymbols(getSourcePattern(), infoMap, /*isSrcPattern=*/true);
}
-llvm::StringMap<tblgen::Argument> &
-tblgen::Pattern::getSourcePatternBoundArgs() {
- return srcBoundArguments;
-}
-
-llvm::StringMap<const tblgen::Operator *> &
-tblgen::Pattern::getSourcePatternBoundOps() {
- return srcBoundOps;
-}
-
-llvm::StringMap<const tblgen::Operator *> &
-tblgen::Pattern::getResultPatternBoundOps() {
- return resBoundOps;
+void tblgen::Pattern::collectResultPatternBoundSymbols(
+ tblgen::SymbolInfoMap &infoMap) {
+ for (int i = 0, e = getNumResultPatterns(); i < e; ++i) {
+ auto pattern = getResultPattern(i);
+ collectBoundSymbols(pattern, infoMap, /*isSrcPattern=*/false);
+ }
}
const tblgen::Operator &tblgen::Pattern::getSourceRootOp() {
@@ -251,8 +394,7 @@ tblgen::Pattern::getLocation() const {
return result;
}
-void tblgen::Pattern::collectBoundSymbols(DagNode tree,
- SymbolOperatorMap &symOpMap,
+void tblgen::Pattern::collectBoundSymbols(DagNode tree, SymbolInfoMap &infoMap,
bool isSrcPattern) {
auto treeName = tree.getSymbol();
if (!tree.isOperation()) {
@@ -270,27 +412,34 @@ void tblgen::Pattern::collectBoundSymbols(DagNode tree,
auto numTreeArgs = tree.getNumArgs();
if (numOpArgs != numTreeArgs) {
- PrintFatalError(def.getLoc(),
- formatv("op '{0}' argument number mismatch: "
- "{1} in pattern vs. {2} in definition",
- op.getOperationName(), numTreeArgs, numOpArgs));
+ auto err = formatv("op '{0}' argument number mismatch: "
+ "{1} in pattern vs. {2} in definition",
+ op.getOperationName(), numTreeArgs, numOpArgs);
+ PrintFatalError(def.getLoc(), err);
}
// The name attached to the DAG node's operator is for representing the
// results generated from this op. It should be remembered as bound results.
- if (!treeName.empty())
- symOpMap.try_emplace(treeName, &op);
+ if (!treeName.empty()) {
+ if (!infoMap.bindOpResult(treeName, op))
+ PrintFatalError(def.getLoc(),
+ formatv("symbol '{0}' bound more than once", treeName));
+ }
for (int i = 0; i != numTreeArgs; ++i) {
if (auto treeArg = tree.getArgAsNestedDag(i)) {
// This DAG node argument is a DAG node itself. Go inside recursively.
- collectBoundSymbols(treeArg, symOpMap, isSrcPattern);
+ collectBoundSymbols(treeArg, infoMap, isSrcPattern);
} else if (isSrcPattern) {
// We can only bind symbols to op arguments in source pattern. Those
// symbols are referenced in result patterns.
auto treeArgName = tree.getArgName(i);
- if (!treeArgName.empty())
- srcBoundArguments.try_emplace(treeArgName, op.getArg(i));
+ if (!treeArgName.empty()) {
+ if (!infoMap.bindOpArgument(treeArgName, op, i)) {
+ auto err = formatv("symbol '{0}' bound more than once", treeArgName);
+ PrintFatalError(def.getLoc(), err);
+ }
+ }
}
}
}
diff --git a/third_party/mlir/tools/mlir-tblgen/RewriterGen.cpp b/third_party/mlir/tools/mlir-tblgen/RewriterGen.cpp
index 7a170c701c1..3487eda545f 100644
--- a/third_party/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/third_party/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -51,166 +51,6 @@ template <> struct format_provider<mlir::tblgen::Pattern::IdentifierLine> {
};
} // end namespace llvm
-// Gets the dynamic value pack's name by removing the index suffix from
-// `symbol`. Returns `symbol` itself if it does not contain an index.
-//
-// We can use `name__<index>` to access the `<index>`-th value in the dynamic
-// value pack bound to `name`. `name` is typically the results of an
-// multi-result op.
-static StringRef getValuePackName(StringRef symbol, unsigned *index = nullptr) {
- StringRef name, indexStr;
- unsigned idx = 0;
- std::tie(name, indexStr) = symbol.rsplit("__");
- if (indexStr.consumeInteger(10, idx)) {
- // The second part is not an index.
- return symbol;
- }
- if (index)
- *index = idx;
- return name;
-}
-
-// Formats all values from a dynamic value pack `symbol` according to the given
-// `fmt` string. The `fmt` string should use `{0}` as a placeholder for `symbol`
-// and `{1}` as a placeholder for the value index, which will be offsetted by
-// `offset`. The `symbol` value pack has a total of `count` values.
-//
-// This extracts one value from the pack if `symbol` contains an index,
-// otherwise it extracts all values sequentially and returns them as a
-// comma-separated list.
-static std::string formatValuePack(const char *fmt, StringRef symbol,
- unsigned count, unsigned offset) {
- auto getNthValue = [fmt, offset](StringRef results,
- unsigned index) -> std::string {
- return formatv(fmt, results, index + offset);
- };
-
- unsigned index = 0;
- StringRef name = getValuePackName(symbol, &index);
- if (name != symbol) {
- // The symbol contains an index.
- return getNthValue(name, index);
- }
-
- // The symbol does not contain an index. Treat the symbol as a whole.
- SmallVector<std::string, 4> values;
- values.reserve(count);
- for (unsigned i = 0; i < count; ++i)
- values.emplace_back(getNthValue(symbol, i));
- return llvm::join(values, ", ");
-}
-
-//===----------------------------------------------------------------------===//
-// PatternSymbolResolver
-//===----------------------------------------------------------------------===//
-
-namespace {
-// A class for resolving symbols bound in patterns.
-//
-// Symbols can be bound to op arguments and ops in the source pattern and ops
-// in result patterns. For example, in
-//
-// ```
-// def : Pattern<(SrcOp:$op1 $arg0, %arg1),
-// [(ResOp1:$op2), (ResOp2 $op2 (ResOp3))]>;
-// ```
-//
-// `$argN` is bound to the `SrcOp`'s N-th argument. `$op1` is bound to `SrcOp`.
-// `$op2` is bound to `ResOp1`.
-//
-// If a symbol binds to a multi-result op and it does not have the `__N`
-// suffix, the symbol is expanded to the whole value pack generated by the
-// multi-result op. If the symbol has a `__N` suffix, then it will expand to
-// only the N-th result.
-//
-// This class keeps track of such symbols and translates them into their bound
-// values.
-//
-// Note that we also generate local variables for unnamed DAG nodes, like
-// `(ResOp3)` in the above. Since we don't bind a symbol to the op, the
-// generated local variable will be implicitly named. Those implicit names are
-// not tracked in this class.
-class PatternSymbolResolver {
-public:
- PatternSymbolResolver(const StringMap<Argument> &srcArgs,
- const StringMap<const Operator *> &srcOperations);
-
- // Marks the given `symbol` as bound to a value pack with `numValues` and
- // returns true on success. Returns false if the `symbol` is already bound.
- bool add(StringRef symbol, int numValues);
-
- // Queries the substitution for the given `symbol`. Returns empty string if
- // symbol not found. If the symbol represents a value pack, returns all the
- // values separated via comma.
- std::string query(StringRef symbol) const;
-
- // Returns how many static values the given `symbol` correspond to. Returns a
- // negative value if the given symbol is not bound.
- //
- // Normally a symbol would correspond to just one value; for symbols bound to
- // multi-result ops, it can be more than one.
- int getValueCount(StringRef symbol) const;
-
-private:
- // Symbols bound to arguments in source pattern.
- const StringMap<Argument> &sourceArguments;
- // Symbols bound to ops (for their results) in source pattern.
- const StringMap<const Operator *> &sourceOps;
- // Symbols bound to ops (for their results) in result patterns.
- // Key: symbol; value: number of values inside the pack
- StringMap<int> resultOps;
-};
-} // end anonymous namespace
-
-PatternSymbolResolver::PatternSymbolResolver(
- const StringMap<Argument> &srcArgs,
- const StringMap<const Operator *> &srcOperations)
- : sourceArguments(srcArgs), sourceOps(srcOperations) {}
-
-bool PatternSymbolResolver::add(StringRef symbol, int numValues) {
- StringRef name = getValuePackName(symbol);
- return resultOps.try_emplace(name, numValues).second;
-}
-
-std::string PatternSymbolResolver::query(StringRef symbol) const {
- StringRef name = getValuePackName(symbol);
- // Handle symbols bound to generated ops
- auto resOpIt = resultOps.find(name);
- if (resOpIt != resultOps.end())
- return formatValuePack("{0}.getOperation()->getResult({1})", symbol,
- resOpIt->second, /*offset=*/0);
-
- // Handle symbols bound to matched op arguments
- auto srcArgIt = sourceArguments.find(symbol);
- if (srcArgIt != sourceArguments.end())
- return symbol;
-
- // Handle symbols bound to matched op results
- auto srcOpIt = sourceOps.find(name);
- if (srcOpIt != sourceOps.end())
- return formatValuePack("{0}->getResult({1})", symbol,
- srcOpIt->second->getNumResults(), /*offset=*/0);
- return {};
-}
-
-int PatternSymbolResolver::getValueCount(StringRef symbol) const {
- StringRef name = getValuePackName(symbol);
- // Handle symbols bound to generated ops
- auto resOpIt = resultOps.find(name);
- if (resOpIt != resultOps.end())
- return name == symbol ? resOpIt->second : 1;
-
- // Handle symbols bound to matched op arguments
- if (sourceArguments.count(symbol))
- return 1;
-
- // Handle symbols bound to matched op results
- auto srcOpIt = sourceOps.find(name);
- if (srcOpIt != sourceOps.end())
- return name == symbol ? srcOpIt->second->getNumResults() : 1;
- return -1;
-}
-
//===----------------------------------------------------------------------===//
// PatternEmitter
//===----------------------------------------------------------------------===//
@@ -286,17 +126,13 @@ private:
// Collects all of the operations within the given dag tree.
void collectOps(DagNode tree, llvm::SmallPtrSetImpl<const Operator *> &ops);
- // Returns a unique name for a value of the given `op`.
- std::string getUniqueValueName(const Operator *op);
+ // Returns a unique symbol for a local variable of the given `op`.
+ std::string getUniqueSymbol(const Operator *op);
//===--------------------------------------------------------------------===//
// Symbol utilities
//===--------------------------------------------------------------------===//
- // Marks the symbol attached to DagNode `node` as bound. Aborts if the symbol
- // is already bound.
- void addSymbol(StringRef symbol, int numValues);
-
// Gets the substitution for `symbol`. Aborts if `symbol` is not bound.
std::string resolveSymbol(StringRef symbol);
@@ -308,13 +144,19 @@ private:
// prototypes used. This is intended to be used as a whole to
// PrintFatalError() on errors.
ArrayRef<llvm::SMLoc> loc;
- // Op's TableGen Record to wrapper object
+
+ // Op's TableGen Record to wrapper object.
RecordOperatorMap *opMap;
- // Handy wrapper for pattern being emitted
+
+ // Handy wrapper for pattern being emitted.
Pattern pattern;
- PatternSymbolResolver symbolResolver;
- // The next unused ID for newly created values
+
+ // Map for all bound symbols' info.
+ SymbolInfoMap symbolInfoMap;
+
+ // The next unused ID for newly created values.
unsigned nextValueId;
+
raw_ostream &os;
// Format contexts containing placeholder substitutations.
@@ -328,9 +170,7 @@ private:
PatternEmitter::PatternEmitter(Record *pat, RecordOperatorMap *mapper,
raw_ostream &os)
: loc(pat->getLoc()), opMap(mapper), pattern(pat, mapper),
- symbolResolver(pattern.getSourcePatternBoundArgs(),
- pattern.getSourcePatternBoundOps()),
- nextValueId(0), os(os) {
+ symbolInfoMap(pat->getLoc()), nextValueId(0), os(os) {
fmtCtx.withBuilder("rewriter");
}
@@ -354,13 +194,14 @@ void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
}
int indent = 4 + 2 * depth;
+ os.indent(indent) << formatv(
+ "auto castedOp{0} = dyn_cast_or_null<{1}>(op{0}); (void)castedOp{0};\n",
+ depth, op.getQualCppClassName());
// Skip the operand matching at depth 0 as the pattern rewriter already does.
if (depth != 0) {
// Skip if there is no defining operation (e.g., arguments to function).
- os.indent(indent) << formatv("if (!op{0}) return matchFailure();\n", depth);
- os.indent(indent) << formatv(
- "if (!isa<{1}>(op{0})) return matchFailure();\n", depth,
- op.getQualCppClassName());
+ os.indent(indent) << formatv("if (!castedOp{0}) return matchFailure();\n",
+ depth);
}
if (tree.getNumArgs() != op.getNumArgs()) {
PrintFatalError(loc, formatv("op '{0}' argument number mismatch: {1} in "
@@ -372,7 +213,7 @@ void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
// If the operand's name is set, set to that variable.
auto name = tree.getSymbol();
if (!name.empty())
- os.indent(indent) << formatv("{0} = op{1};\n", name, depth);
+ os.indent(indent) << formatv("{0} = castedOp{1};\n", name, depth);
for (int i = 0, e = tree.getNumArgs(); i != e; ++i) {
auto opArg = op.getArg(i);
@@ -381,7 +222,7 @@ void PatternEmitter::emitOpMatch(DagNode tree, int depth) {
if (DagNode argTree = tree.getArgAsNestedDag(i)) {
os.indent(indent) << "{\n";
os.indent(indent + 2)
- << formatv("auto op{0} = op{1}->getOperand({2})->getDefiningOp();\n",
+ << formatv("auto *op{0} = op{1}->getOperand({2})->getDefiningOp();\n",
depth + 1, depth, i);
emitOpMatch(argTree, depth + 1);
os.indent(indent + 2)
@@ -569,21 +410,17 @@ void PatternEmitter::emit(StringRef rewriteName) {
PatternRewriter &rewriter) const override {
)";
+ // Register all symbols bound in the source pattern.
+ pattern.collectSourcePatternBoundSymbols(symbolInfoMap);
+
os.indent(4) << "// Variables for capturing values and attributes used for "
"creating ops\n";
- // Create local variables for storing the arguments bound to symbols.
- for (const auto &arg : pattern.getSourcePatternBoundArgs()) {
- auto fieldName = arg.first();
- if (auto namedAttr = arg.second.dyn_cast<NamedAttribute *>()) {
- os.indent(4) << formatv("{0} {1};\n", namedAttr->attr.getStorageType(),
- fieldName);
- } else {
- os.indent(4) << "Value *" << fieldName << ";\n";
- }
- }
- // Create local variables for storing the ops bound to symbols.
- for (const auto &result : pattern.getSourcePatternBoundOps()) {
- os.indent(4) << formatv("Operation *{0};\n", result.getKey());
+ // Create local variables for storing the arguments and results bound
+ // to symbols.
+ for (const auto &symbolInfoPair : symbolInfoMap) {
+ StringRef symbol = symbolInfoPair.getKey();
+ auto &info = symbolInfoPair.getValue();
+ os.indent(4) << info.getVarDecl(symbol);
}
// TODO(jpienaar): capture ops with consistent numbering so that it can be
// reused for fused loc.
@@ -609,20 +446,22 @@ void PatternEmitter::emitRewriteLogic() {
int numResultPatterns = pattern.getNumResultPatterns();
// First register all symbols bound to ops generated in result patterns.
- for (const auto &boundOp : pattern.getResultPatternBoundOps()) {
- addSymbol(boundOp.getKey(), boundOp.getValue()->getNumResults());
- }
+ pattern.collectResultPatternBoundSymbols(symbolInfoMap);
// Only the last N static values generated are used to replace the matched
// root N-result op. We need to calculate the starting index (of the results
// of the matched op) each result pattern is to replace.
SmallVector<int, 4> offsets(numResultPatterns + 1, numExpectedResults);
- int replStartIndex = -1;
+ // If we don't need to replace any value at all, set the replacement starting
+ // index as the number of result patterns so we skip all of them when trying
+ // to replace the matched op's results.
+ int replStartIndex = numExpectedResults == 0 ? numResultPatterns : -1;
for (int i = numResultPatterns - 1; i >= 0; --i) {
auto numValues = getNodeValueCount(pattern.getResultPattern(i));
offsets[i] = offsets[i + 1] - numValues;
if (offsets[i] == 0) {
- replStartIndex = i;
+ if (replStartIndex == -1)
+ replStartIndex = i;
} else if (offsets[i] < 0 && offsets[i + 1] > 0) {
auto error = formatv(
"cannot use the same multi-result op '{0}' to generate both "
@@ -652,31 +491,36 @@ void PatternEmitter::emitRewriteLogic() {
// Emit the final replaceOp() statement
os.indent(4) << "rewriter.replaceOp(op0, {";
- interleave(
- ArrayRef<std::string>(resultValues).drop_front(replStartIndex),
- [&](const std::string &name) { os << name; }, [&]() { os << ", "; });
+ interleaveComma(
+ ArrayRef<std::string>(resultValues).drop_front(replStartIndex), os,
+ [&](const std::string &symbol) { os << resolveSymbol(symbol); });
os << "});\n";
}
-std::string PatternEmitter::getUniqueValueName(const Operator *op) {
- return formatv("v{0}{1}", op->getCppClassName(), nextValueId++);
+std::string PatternEmitter::getUniqueSymbol(const Operator *op) {
+ return formatv("tblgen_{0}_{1}", op->getCppClassName(), nextValueId++);
}
std::string PatternEmitter::handleResultPattern(DagNode resultTree,
int resultIndex, int depth) {
- if (resultTree.isNativeCodeCall())
- return handleReplaceWithNativeCodeCall(resultTree);
+ if (resultTree.isNativeCodeCall()) {
+ auto symbol = handleReplaceWithNativeCodeCall(resultTree);
+ symbolInfoMap.bindValue(symbol);
+ return symbol;
+ }
- if (resultTree.isReplaceWithValue())
+ if (resultTree.isReplaceWithValue()) {
return handleReplaceWithValue(resultTree);
+ }
- // Create the op and get the local variable for it.
- auto results = handleOpCreation(resultTree, resultIndex, depth);
- // We need to get all the values out of this local variable if we've created a
- // multi-result op.
- const auto &numResults = pattern.getDialectOp(resultTree).getNumResults();
- return formatValuePack("{0}.getOperation()->getResult({1})", results,
- numResults, /*offset=*/0);
+ // Normal op creation.
+ auto symbol = handleOpCreation(resultTree, resultIndex, depth);
+ if (resultTree.getSymbol().empty()) {
+ // This is an op not explicitly bound to a symbol in the rewrite rule.
+ // Register the auto-generated symbol for it.
+ symbolInfoMap.bindOpResult(symbol, pattern.getDialectOp(resultTree));
+ }
+ return symbol;
}
std::string PatternEmitter::handleReplaceWithValue(DagNode tree) {
@@ -709,7 +553,6 @@ std::string PatternEmitter::handleOpArgument(DagLeaf leaf, StringRef argName) {
std::string val = std::to_string(enumCase.getValue());
return handleConstantAttr(enumCase, val);
}
- pattern.ensureBoundInSourcePattern(argName);
if (leaf.isUnspecified() || leaf.isOperandMatcher()) {
return argName;
}
@@ -734,27 +577,23 @@ std::string PatternEmitter::handleReplaceWithNativeCodeCall(DagNode tree) {
attrs[5], attrs[6], attrs[7]);
}
-void PatternEmitter::addSymbol(StringRef symbol, int numValues) {
- if (!symbolResolver.add(symbol, numValues))
- PrintFatalError(loc, formatv("symbol '{0}' bound more than once", symbol));
-}
-
std::string PatternEmitter::resolveSymbol(StringRef symbol) {
- auto subst = symbolResolver.query(symbol);
- if (subst.empty())
+ auto subst = symbolInfoMap.getValueAndRangeUse(symbol);
+ if (subst.empty()) {
PrintFatalError(loc, formatv("referencing unbound symbol '{0}'", symbol));
+ }
return subst;
}
int PatternEmitter::getNodeValueCount(DagNode node) {
if (node.isOperation()) {
- // First to see whether this op is bound and we just want a specific result
- // of it with `__N` suffix in symbol.
- int count = symbolResolver.getValueCount(node.getSymbol());
- if (count >= 0)
- return count;
-
- // No symbol. Then we are using all the results.
+ // If the op is bound to a symbol in the rewrite rule, query its result
+ // count from the symbol info map.
+ auto symbol = node.getSymbol();
+ if (!symbol.empty()) {
+ return symbolInfoMap.getStaticValueCount(symbol);
+ }
+ // Otherwise this is an unbound op; we will use all its results.
return pattern.getDialectOp(node).getNumResults();
}
// TODO(antiagainst): This considers all NativeCodeCall as returning one
@@ -799,10 +638,10 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
// Use the specified name for this op if available. Generate one otherwise.
std::string resultValue = tree.getSymbol();
if (resultValue.empty())
- resultValue = getUniqueValueName(&resultOp);
+ resultValue = getUniqueSymbol(&resultOp);
// Strip the index to get the name for the value pack. This will be used to
// name the local variable for the op.
- StringRef valuePackName = getValuePackName(resultValue);
+ StringRef valuePackName = SymbolInfoMap::getValuePackName(resultValue);
// Then we build the new op corresponding to this DAG node.
@@ -826,20 +665,25 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
// here.
// We need to specify the types for all results.
- auto resultTypes =
- formatValuePack("op0->getResult({1})->getType()", valuePackName,
- resultOp.getNumResults(), resultIndex);
+ SmallVector<std::string, 4> resultTypes;
+ int numResults = resultOp.getNumResults();
+ resultTypes.reserve(numResults);
+ for (int i = 0; i < numResults; ++i) {
+ resultTypes.push_back(
+ formatv("op0->getResult({0})->getType()", resultIndex + i));
+ }
os.indent(4) << formatv("auto {0} = rewriter.create<{1}>(loc",
valuePackName, resultOp.getQualCppClassName())
- << (resultTypes.empty() ? "" : ", ") << resultTypes;
+ << (resultTypes.empty() ? "" : ", ")
+ << llvm::join(resultTypes, ", ");
}
// Create the builder call for the result.
// Add operands.
- int i = 0;
- for (int e = resultOp.getNumOperands(); i < e; ++i) {
- const auto &operand = resultOp.getOperand(i);
+ int argIndex = 0;
+ for (int e = resultOp.getNumOperands(); argIndex < e; ++argIndex) {
+ const auto &operand = resultOp.getOperand(argIndex);
// Start each operand on its own line.
(os << ",\n").indent(6);
@@ -847,11 +691,11 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
if (!operand.name.empty())
os << "/*" << operand.name << "=*/";
- if (tree.isNestedDagArg(i)) {
- os << childNodeNames[i];
+ if (tree.isNestedDagArg(argIndex)) {
+ os << childNodeNames[argIndex];
} else {
- DagLeaf leaf = tree.getArgAsLeaf(i);
- auto symbol = resolveSymbol(tree.getArgName(i));
+ DagLeaf leaf = tree.getArgAsLeaf(argIndex);
+ auto symbol = resolveSymbol(tree.getArgName(argIndex));
if (leaf.isNativeCodeCall()) {
os << tgfmt(leaf.getNativeCodeTemplate(), &fmtCtx.withSelf(symbol));
} else {
@@ -862,26 +706,26 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
}
// Add attributes.
- for (int e = tree.getNumArgs(); i != e; ++i) {
+ for (; argIndex != numOpArgs; ++argIndex) {
// Start each attribute on its own line.
(os << ",\n").indent(6);
// The argument in the op definition.
- auto opArgName = resultOp.getArgName(i);
- if (auto subTree = tree.getArgAsNestedDag(i)) {
+ auto opArgName = resultOp.getArgName(argIndex);
+ if (auto subTree = tree.getArgAsNestedDag(argIndex)) {
if (!subTree.isNativeCodeCall())
PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node "
"for creating attribute");
os << formatv("/*{0}=*/{1}", opArgName,
handleReplaceWithNativeCodeCall(subTree));
} else {
- auto leaf = tree.getArgAsLeaf(i);
+ auto leaf = tree.getArgAsLeaf(argIndex);
// The argument in the result DAG pattern.
- auto patArgName = tree.getArgName(i);
+ auto patArgName = tree.getArgName(argIndex);
if (leaf.isConstantAttr() || leaf.isEnumAttrCase()) {
// TODO(jpienaar): Refactor out into map to avoid recomputing these.
- auto argument = resultOp.getArg(i);
+ auto argument = resultOp.getArg(argIndex);
if (!argument.is<NamedAttribute *>())
- PrintFatalError(loc, Twine("expected attribute ") + Twine(i));
+ PrintFatalError(loc, Twine("expected attribute ") + Twine(argIndex));
if (!patArgName.empty())
os << "/*" << patArgName << "=*/";
} else {